Skip to content

Commit

Permalink
updated test cases
Browse files Browse the repository at this point in the history
Signed-off-by: kalyan <[email protected]>
  • Loading branch information
rawwar committed Oct 31, 2023
1 parent 55f35c5 commit 7a58dc0
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 11 deletions.
2 changes: 1 addition & 1 deletion opensearch_py_ml/ml_commons/model_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def _train(

return self._client.transport.perform_request(
method="POST",
url=f"{ML_BASE_URI}/{API_ENDPOINT}/{algorithm_name}",
url=f"{ML_BASE_URI}/{ModelTrain.API_ENDPOINT}/{algorithm_name}",
body=input_json,
params=params,
)
1 change: 0 additions & 1 deletion tests/ml_commons/test_ml_commons_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,4 +515,3 @@ def test_integration_model_train_register_full_cycle():
assert raised == False, "Raised Exception in deleting model"


test_integration_model_train_register_full_cycle()
68 changes: 59 additions & 9 deletions tests/ml_commons/test_model_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,79 @@
# Any modifications Copyright OpenSearch Contributors. See
# GitHub history for details.

import pytest
from opensearchpy import OpenSearch, helpers
from sklearn.datasets import load_iris
import time
from opensearch_py_ml.ml_commons import MLCommonClient, ModelTrain
from tests import OPENSEARCH_TEST_CLIENT

from opensearchpy import OpenSearch
ml_client = MLCommonClient(OPENSEARCH_TEST_CLIENT)

from opensearch_py_ml.ml_commons import MLCommonClient, ModelTrain

@pytest.fixture
def iris_index():
index_name = "test__index__iris_data"
index_mapping = {
"mappings": {
"properties": {
"sepal_length": {"type": "float"},
"sepal_width": {"type": "float"},
"petal_length": {"type": "float"},
"petal_width": {"type": "float"},
"species": {"type": "keyword"},
}
}
}

if ml_client._client.indices.exists(index=index_name):
ml_client._client.indices.delete(index=index_name)
ml_client._client.indices.create(index=index_name, body=index_mapping)

iris = load_iris()
iris_data = iris.data
iris_target = iris.target
iris_species = [iris.target_names[i] for i in iris_target]

actions = [
{
"_index": index_name,
"_source": {
"sepal_length": sepal_length,
"sepal_width": sepal_width,
"petal_length": petal_length,
"petal_width": petal_width,
"species": species,
},
}
for (sepal_length, sepal_width, petal_length, petal_width), species in zip(
iris_data, iris_species
)
]

helpers.bulk(ml_client._client, actions)
# without the sleep, test is failing.
time.sleep(2)

yield index_name

ml_client._client.indices.delete(index=index_name)


def test_init(opensearch_client):
ml_client = MLCommonClient(opensearch_client)
def test_init():
assert isinstance(ml_client._client, OpenSearch)
assert isinstance(ml_client._model_train, ModelTrain)


def test_train(iris_index_client):
client, test_index_name = iris_index_client
ml_client = MLCommonClient(client)
def test_train(iris_index):
algorithm_name = "kmeans"
input_json_sync = {
"parameters": {"centroids": 3, "iterations": 10, "distance_type": "COSINE"},
"input_query": {
"_source": ["petal_length", "petal_width"],
"size": 10000,
},
"input_index": [test_index_name],
"input_index": [iris_index],
}
response = ml_client.train_model(algorithm_name, input_json_sync)
assert isinstance(response, dict)
Expand All @@ -41,7 +91,7 @@ def test_train(iris_index_client):
"_source": ["petal_length", "petal_width"],
"size": 10000,
},
"input_index": [test_index_name],
"input_index": [iris_index],
}
response = ml_client.train_model(algorithm_name, input_json_async, is_async=True)

Expand Down

0 comments on commit 7a58dc0

Please sign in to comment.