diff --git a/CHANGELOG.md b/CHANGELOG.md index ba733fea..504de060 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) - Support for security default admin credential changes in 2.12.0 in ([#365](https://github.com/opensearch-project/opensearch-py-ml/pull/365)) - adding cross encoder models in the pre-trained traced list ([#378](https://github.com/opensearch-project/opensearch-py-ml/pull/378)) - Add workflows and scripts for sparse encoding model tracing and uploading process by @conggguan in ([#394](https://github.com/opensearch-project/opensearch-py-ml/pull/394)) +- Implemented `predict` method and added unit tests by @yerzhaisang([425](https://github.com/opensearch-project/opensearch-py-ml/pull/425)) ### Changed - Add a parameter for customize the upload folder prefix ([#398](https://github.com/opensearch-project/opensearch-py-ml/pull/398)) diff --git a/opensearch_py_ml/ml_commons/ml_commons_client.py b/opensearch_py_ml/ml_commons/ml_commons_client.py index 5509eaa7..0287d5d1 100644 --- a/opensearch_py_ml/ml_commons/ml_commons_client.py +++ b/opensearch_py_ml/ml_commons/ml_commons_client.py @@ -568,6 +568,26 @@ def unload_model(self, model_id: str, node_ids: List[str] = []) -> object: body=API_BODY, ) + def predict( + self, model_id: str, predict_object: dict, algorithm_name: str = None + ) -> dict: + """ + Generalized predict method to make predictions using different ML algorithms. + + :param algorithm_name: The name of the algorithm, e.g., 'kmeans', 'text_embedding' + :type algorithm_name: str + :param model_id: Unique identifier of the deployed model + :type model_id: str + :param predict_object: JSON object containing the input data and parameters for prediction + :type predict_object: dict + :return: Prediction response from the ML model + :rtype: dict + """ + # Make the POST request to the prediction API + response = self.generate_model_inference(model_id, predict_object) + + return response + def undeploy_model(self, model_id: str, node_ids: List[str] = []) -> object: """ This method undeploys a model from all the nodes or from the given list of nodes (using ml commons _undeploy api) diff --git a/tests/ml_commons/test_ml_commons_client.py b/tests/ml_commons/test_ml_commons_client.py index 418fa3c8..5fcd2459 100644 --- a/tests/ml_commons/test_ml_commons_client.py +++ b/tests/ml_commons/test_ml_commons_client.py @@ -356,6 +356,44 @@ def test_DEPRECATED_integration_model_train_upload_full_cycle(): else: assert len(embedding_result.get("inference_results")) == 2 + try: + text_docs = ["First test sentence", "Second test sentence"] + return_number = True + target_response = ["sentence_embedding"] + algorithm_name = "text_embedding" + prediction_result = ml_client.predict( + algorithm_name=algorithm_name, + model_id=model_id, + predict_object={ + "text_docs": text_docs, + "return_number": return_number, + "target_response": target_response, + }, + ) + prediction_result_2 = ml_client.generate_model_inference( + model_id=model_id, + predict_object={ + "text_docs": text_docs, + "return_number": return_number, + "target_response": target_response, + }, + ) + inference_results = prediction_result.get("inference_results") + output_1 = inference_results[0].get("output") + output_2 = inference_results[1].get("output") + except Exception as ex: # noqa: E722 + pytest.fail(f"Exception occurred when predicting: {ex}") + else: + assert prediction_result == prediction_result_2 + assert output_1[0].get("name") == "sentence_embedding" + assert output_1[0].get("data_type") == "FLOAT32" + assert output_1[0].get("shape")[0] == 384 + assert isinstance(output_1[0].get("data"), list) + assert output_2[0].get("name") == "sentence_embedding" + assert output_2[0].get("data_type") == "FLOAT32" + assert output_2[0].get("shape")[0] == 384 + assert isinstance(output_2[0].get("data"), list) + try: delete_task_obj = ml_client.delete_task(task_id) except Exception as ex: # noqa: E722 @@ -471,6 +509,44 @@ def test_integration_model_train_register_full_cycle(): else: assert len(embedding_result.get("inference_results")) == 2 + try: + text_docs = ["First test sentence", "Second test sentence"] + return_number = True + target_response = ["sentence_embedding"] + algorithm_name = "text_embedding" + prediction_result = ml_client.predict( + algorithm_name=algorithm_name, + model_id=model_id, + predict_object={ + "text_docs": text_docs, + "return_number": return_number, + "target_response": target_response, + }, + ) + prediction_result_2 = ml_client.generate_model_inference( + model_id=model_id, + predict_object={ + "text_docs": text_docs, + "return_number": return_number, + "target_response": target_response, + }, + ) + inference_results = prediction_result.get("inference_results") + output_1 = inference_results[0].get("output") + output_2 = inference_results[1].get("output") + except Exception as ex: # noqa: E722 + pytest.fail(f"Exception occurred when predicting: {ex}") + else: + assert prediction_result == prediction_result_2 + assert output_1[0].get("name") == "sentence_embedding" + assert output_1[0].get("data_type") == "FLOAT32" + assert output_1[0].get("shape")[0] == 384 + assert isinstance(output_1[0].get("data"), list) + assert output_2[0].get("name") == "sentence_embedding" + assert output_2[0].get("data_type") == "FLOAT32" + assert output_2[0].get("shape")[0] == 384 + assert isinstance(output_2[0].get("data"), list) + try: delete_task_obj = ml_client.delete_task(task_id) except Exception as ex: