From 7a58dc04671905900dfb33e817af7b94cf82de5a Mon Sep 17 00:00:00 2001 From: kalyan Date: Tue, 31 Oct 2023 07:54:24 +0530 Subject: [PATCH] updated test cases Signed-off-by: kalyan --- opensearch_py_ml/ml_commons/model_train.py | 2 +- tests/ml_commons/test_ml_commons_client.py | 1 - tests/ml_commons/test_model_train.py | 68 +++++++++++++++++++--- 3 files changed, 60 insertions(+), 11 deletions(-) diff --git a/opensearch_py_ml/ml_commons/model_train.py b/opensearch_py_ml/ml_commons/model_train.py index e10d24ffe..3c5ef576c 100644 --- a/opensearch_py_ml/ml_commons/model_train.py +++ b/opensearch_py_ml/ml_commons/model_train.py @@ -38,7 +38,7 @@ def _train( return self._client.transport.perform_request( method="POST", - url=f"{ML_BASE_URI}/{API_ENDPOINT}/{algorithm_name}", + url=f"{ML_BASE_URI}/{ModelTrain.API_ENDPOINT}/{algorithm_name}", body=input_json, params=params, ) diff --git a/tests/ml_commons/test_ml_commons_client.py b/tests/ml_commons/test_ml_commons_client.py index 10be2c164..cbc923542 100644 --- a/tests/ml_commons/test_ml_commons_client.py +++ b/tests/ml_commons/test_ml_commons_client.py @@ -515,4 +515,3 @@ def test_integration_model_train_register_full_cycle(): assert raised == False, "Raised Exception in deleting model" -test_integration_model_train_register_full_cycle() diff --git a/tests/ml_commons/test_model_train.py b/tests/ml_commons/test_model_train.py index 5ea2cd405..90488761d 100644 --- a/tests/ml_commons/test_model_train.py +++ b/tests/ml_commons/test_model_train.py @@ -5,21 +5,71 @@ # Any modifications Copyright OpenSearch Contributors. See # GitHub history for details. +import pytest +from opensearchpy import OpenSearch, helpers +from sklearn.datasets import load_iris +import time +from opensearch_py_ml.ml_commons import MLCommonClient, ModelTrain +from tests import OPENSEARCH_TEST_CLIENT -from opensearchpy import OpenSearch +ml_client = MLCommonClient(OPENSEARCH_TEST_CLIENT) -from opensearch_py_ml.ml_commons import MLCommonClient, ModelTrain + +@pytest.fixture +def iris_index(): + index_name = "test__index__iris_data" + index_mapping = { + "mappings": { + "properties": { + "sepal_length": {"type": "float"}, + "sepal_width": {"type": "float"}, + "petal_length": {"type": "float"}, + "petal_width": {"type": "float"}, + "species": {"type": "keyword"}, + } + } + } + + if ml_client._client.indices.exists(index=index_name): + ml_client._client.indices.delete(index=index_name) + ml_client._client.indices.create(index=index_name, body=index_mapping) + + iris = load_iris() + iris_data = iris.data + iris_target = iris.target + iris_species = [iris.target_names[i] for i in iris_target] + + actions = [ + { + "_index": index_name, + "_source": { + "sepal_length": sepal_length, + "sepal_width": sepal_width, + "petal_length": petal_length, + "petal_width": petal_width, + "species": species, + }, + } + for (sepal_length, sepal_width, petal_length, petal_width), species in zip( + iris_data, iris_species + ) + ] + + helpers.bulk(ml_client._client, actions) + # without the sleep, test is failing. + time.sleep(2) + + yield index_name + + ml_client._client.indices.delete(index=index_name) -def test_init(opensearch_client): - ml_client = MLCommonClient(opensearch_client) +def test_init(): assert isinstance(ml_client._client, OpenSearch) assert isinstance(ml_client._model_train, ModelTrain) -def test_train(iris_index_client): - client, test_index_name = iris_index_client - ml_client = MLCommonClient(client) +def test_train(iris_index): algorithm_name = "kmeans" input_json_sync = { "parameters": {"centroids": 3, "iterations": 10, "distance_type": "COSINE"}, @@ -27,7 +77,7 @@ def test_train(iris_index_client): "_source": ["petal_length", "petal_width"], "size": 10000, }, - "input_index": [test_index_name], + "input_index": [iris_index], } response = ml_client.train_model(algorithm_name, input_json_sync) assert isinstance(response, dict) @@ -41,7 +91,7 @@ def test_train(iris_index_client): "_source": ["petal_length", "petal_width"], "size": 10000, }, - "input_index": [test_index_name], + "input_index": [iris_index], } response = ml_client.train_model(algorithm_name, input_json_async, is_async=True)