diff --git a/opensearch_py_ml/ml_commons/ml_commons_client.py b/opensearch_py_ml/ml_commons/ml_commons_client.py index 72e2e158b..83d36d8c5 100644 --- a/opensearch_py_ml/ml_commons/ml_commons_client.py +++ b/opensearch_py_ml/ml_commons/ml_commons_client.py @@ -518,6 +518,34 @@ def generate_embedding(self, model_id: str, sentences: List[str]) -> object: body=API_BODY, ) + def predict(self, model_id: str, algo_name: str, input_json): + + API_URL = f"{ML_BASE_URI}/_predict/{algo_name}/{model_id}" + + if isinstance(input_json, str): + try: + json_obj = json.loads(input_json) + if not isinstance(json_obj, dict): + return "Invalid JSON object passed as argument." + API_BODY = json.dumps(json_obj) + except json.JSONDecodeError: + return "Invalid JSON string passed as argument." + elif isinstance(input_json, dict): + API_BODY = json.dumps(input_json) + else: + return "Invalid JSON object passed as argument." + + return self._client.transport.perform_request( + method="POST", + url=API_URL, + body=API_BODY, + ) + + + + + + @deprecated( reason="Since OpenSearch 2.7.0, you can use undeploy_model instead", version="2.7.0", diff --git a/tests/ml_commons/test_ml_commons_client.py b/tests/ml_commons/test_ml_commons_client.py index 27cd79dc9..e3ce4ecdf 100644 --- a/tests/ml_commons/test_ml_commons_client.py +++ b/tests/ml_commons/test_ml_commons_client.py @@ -237,6 +237,64 @@ def test_DEPRECATED_integration_pretrained_model_upload_unload_delete(): raised = True assert raised == False, "Raised Exception in deleting pretrained model" +def test_predict(): + input_json = { + { + "input_query": { + "_source": ["petal_length_in_cm", "petal_width_in_cm"], + "size": 10000 + }, + "input_index": [ + "iris_data" + ] + } + } + + raised = False + model_id = ml_client.register_pretrained_model( + model_name=PRETRAINED_MODEL_NAME, + model_version=PRETRAINED_MODEL_VERSION, + model_format=PRETRAINED_MODEL_FORMAT, + deploy_model=True, + wait_until_deployed=True, + ) + + try: + predict_obj = ml_client.predict( + model_id=model_id, algo_name="kmeans",input_json=input_json + ) + assert predict_obj["status"] == "COMPLETED" + except: # noqa: E722 + raised = True + assert raised == False, "Raised Exception in training and predicting task" + + raised = False + try: + predict_obj = ml_client.predict( + model_id=model_id, algo_name="something else",input_json=input_json + ) + assert predict_obj == "Invalid algorithm name passed as argument." + except: # noqa: E722 + raised = True + assert raised == False, "Raised Exception in training and predicting task" + + try: + predict_obj = ml_client.predict( + model_id=model_id, algo_name="something else",input_json="15" + ) + assert predict_obj == "Invalid JSON object passed as argument." + except: # noqa: E722 + raised = True + assert raised == False, "Raised Exception in training and predicting task" + + try: + predict_obj = ml_client.predict( + model_id=model_id, algo_name="something else",input_json=15 + ) + assert predict_obj == "Invalid JSON object passed as argument." + except: # noqa: E722 + raised = True + assert raised == False, "Raised Exception in training and predicting task" def test_integration_pretrained_model_register_undeploy_delete(): raised = False