Skip to content

Commit

Permalink
Implement predict method and add unit tests (#425)
Browse files Browse the repository at this point in the history
* added dynamic predict method for various ML models

Signed-off-by: Yerzhaisang Taskali <[email protected]>

* fixed the CI

Signed-off-by: Yerzhaisang Taskali <[email protected]>

* added unit test for predict api functionality

Signed-off-by: Yerzhaisang Taskali <[email protected]>

* updated CHANGELOG.md

Signed-off-by: Yerzhaisang Taskali <[email protected]>

* made use of existing method

Signed-off-by: Yerzhaisang Taskali <[email protected]>

* added unit tests

Signed-off-by: Yerzhaisang Taskali <[email protected]>

* made the parameter optional

Signed-off-by: Yerzhaisang Taskali <[email protected]>

---------

Signed-off-by: Yerzhaisang Taskali <[email protected]>
  • Loading branch information
Yerzhaisang authored Dec 12, 2024
1 parent 8c925c4 commit 155bdc3
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
- Support for security default admin credential changes in 2.12.0 in ([#365](https://github.com/opensearch-project/opensearch-py-ml/pull/365))
- adding cross encoder models in the pre-trained traced list ([#378](https://github.com/opensearch-project/opensearch-py-ml/pull/378))
- Add workflows and scripts for sparse encoding model tracing and uploading process by @conggguan in ([#394](https://github.com/opensearch-project/opensearch-py-ml/pull/394))
- Implemented `predict` method and added unit tests by @yerzhaisang([425](https://github.com/opensearch-project/opensearch-py-ml/pull/425))

### Changed
- Add a parameter for customize the upload folder prefix ([#398](https://github.com/opensearch-project/opensearch-py-ml/pull/398))
Expand Down
20 changes: 20 additions & 0 deletions opensearch_py_ml/ml_commons/ml_commons_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,26 @@ def unload_model(self, model_id: str, node_ids: List[str] = []) -> object:
body=API_BODY,
)

def predict(
self, model_id: str, predict_object: dict, algorithm_name: str = None
) -> dict:
"""
Generalized predict method to make predictions using different ML algorithms.
:param algorithm_name: The name of the algorithm, e.g., 'kmeans', 'text_embedding'
:type algorithm_name: str
:param model_id: Unique identifier of the deployed model
:type model_id: str
:param predict_object: JSON object containing the input data and parameters for prediction
:type predict_object: dict
:return: Prediction response from the ML model
:rtype: dict
"""
# Make the POST request to the prediction API
response = self.generate_model_inference(model_id, predict_object)

return response

def undeploy_model(self, model_id: str, node_ids: List[str] = []) -> object:
"""
This method undeploys a model from all the nodes or from the given list of nodes (using ml commons _undeploy api)
Expand Down
76 changes: 76 additions & 0 deletions tests/ml_commons/test_ml_commons_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,44 @@ def test_DEPRECATED_integration_model_train_upload_full_cycle():
else:
assert len(embedding_result.get("inference_results")) == 2

try:
text_docs = ["First test sentence", "Second test sentence"]
return_number = True
target_response = ["sentence_embedding"]
algorithm_name = "text_embedding"
prediction_result = ml_client.predict(
algorithm_name=algorithm_name,
model_id=model_id,
predict_object={
"text_docs": text_docs,
"return_number": return_number,
"target_response": target_response,
},
)
prediction_result_2 = ml_client.generate_model_inference(
model_id=model_id,
predict_object={
"text_docs": text_docs,
"return_number": return_number,
"target_response": target_response,
},
)
inference_results = prediction_result.get("inference_results")
output_1 = inference_results[0].get("output")
output_2 = inference_results[1].get("output")
except Exception as ex: # noqa: E722
pytest.fail(f"Exception occurred when predicting: {ex}")
else:
assert prediction_result == prediction_result_2
assert output_1[0].get("name") == "sentence_embedding"
assert output_1[0].get("data_type") == "FLOAT32"
assert output_1[0].get("shape")[0] == 384
assert isinstance(output_1[0].get("data"), list)
assert output_2[0].get("name") == "sentence_embedding"
assert output_2[0].get("data_type") == "FLOAT32"
assert output_2[0].get("shape")[0] == 384
assert isinstance(output_2[0].get("data"), list)

try:
delete_task_obj = ml_client.delete_task(task_id)
except Exception as ex: # noqa: E722
Expand Down Expand Up @@ -471,6 +509,44 @@ def test_integration_model_train_register_full_cycle():
else:
assert len(embedding_result.get("inference_results")) == 2

try:
text_docs = ["First test sentence", "Second test sentence"]
return_number = True
target_response = ["sentence_embedding"]
algorithm_name = "text_embedding"
prediction_result = ml_client.predict(
algorithm_name=algorithm_name,
model_id=model_id,
predict_object={
"text_docs": text_docs,
"return_number": return_number,
"target_response": target_response,
},
)
prediction_result_2 = ml_client.generate_model_inference(
model_id=model_id,
predict_object={
"text_docs": text_docs,
"return_number": return_number,
"target_response": target_response,
},
)
inference_results = prediction_result.get("inference_results")
output_1 = inference_results[0].get("output")
output_2 = inference_results[1].get("output")
except Exception as ex: # noqa: E722
pytest.fail(f"Exception occurred when predicting: {ex}")
else:
assert prediction_result == prediction_result_2
assert output_1[0].get("name") == "sentence_embedding"
assert output_1[0].get("data_type") == "FLOAT32"
assert output_1[0].get("shape")[0] == 384
assert isinstance(output_1[0].get("data"), list)
assert output_2[0].get("name") == "sentence_embedding"
assert output_2[0].get("data_type") == "FLOAT32"
assert output_2[0].get("shape")[0] == 384
assert isinstance(output_2[0].get("data"), list)

try:
delete_task_obj = ml_client.delete_task(task_id)
except Exception as ex:
Expand Down

0 comments on commit 155bdc3

Please sign in to comment.