Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ML-Commons train api functionality #310

Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
c182e5c
add ml-commons train support
rawwar Oct 7, 2023
edafa9b
update __all__
rawwar Oct 7, 2023
443b9b0
fix test cases
rawwar Oct 8, 2023
e3c4e64
sleep after bulk insert
rawwar Oct 8, 2023
f0ce236
fix formatting
rawwar Oct 8, 2023
7e18884
remove unused imports
rawwar Oct 8, 2023
e7ead98
remove duplicate conftest
rawwar Oct 11, 2023
ac14c98
delete duplicate conftest
rawwar Oct 11, 2023
2a52fac
include pytest plugins
rawwar Oct 11, 2023
e9365a5
revert pandas version
rawwar Oct 13, 2023
cf5074a
include license
rawwar Oct 13, 2023
4112a2f
fix formatting
rawwar Oct 13, 2023
2b14a14
fix imports order
rawwar Oct 13, 2023
e09c697
fix imports order
rawwar Oct 13, 2023
7c57d2f
lint fix
rawwar Oct 13, 2023
2610dda
update changelog
rawwar Oct 13, 2023
31fac86
Merge branch 'opensearch-project:main' into kalyan/286-ml-commons-add…
rawwar Oct 18, 2023
eab5a29
Merge branch 'opensearch-project:main' into kalyan/286-ml-commons-add…
rawwar Oct 27, 2023
4e159c2
revert testcases
rawwar Oct 31, 2023
55f35c5
remove fixtures
rawwar Oct 31, 2023
7a58dc0
updated test cases
rawwar Oct 31, 2023
43e5a7e
lint fixes
rawwar Oct 31, 2023
70f0a75
update fixture
rawwar Oct 31, 2023
404c0b3
revert
rawwar Oct 31, 2023
f19e36e
Merge branch 'main' of https://github.com/opensearch-project/opensear…
rawwar Nov 3, 2023
3cb7fb6
include train in MLCommons class as a func
rawwar Nov 3, 2023
a8306fe
remove model train
rawwar Nov 3, 2023
45d7aeb
fix tests
rawwar Nov 3, 2023
9431d8f
revert nox
rawwar Nov 3, 2023
6faf462
add tests to model_train
rawwar Nov 3, 2023
eaf4bdc
fix lint
rawwar Nov 3, 2023
a6f0969
fix lint
rawwar Nov 3, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import pytest
rawwar marked this conversation as resolved.
Show resolved Hide resolved
import time
import os
from opensearchpy import OpenSearch, helpers
from sklearn.datasets import load_iris

@pytest.fixture
def opensearch_client():
opensearch_host = os.environ.get("OPENSEARCH_HOST", "https://localhost:9200")
opensearch_admin_user = os.environ.get("OPENSEARCH_ADMIN_USER", "admin")
opensearch_admin_password = os.environ.get("OPENSEARCH_ADMIN_PASSWORD", "admin")
client = OpenSearch(
hosts=[opensearch_host],
http_auth=(opensearch_admin_user, opensearch_admin_password),
verify_certs=False,
)
yield client

# tear down
client.transport.close()


@pytest.fixture
def iris_index_client(opensearch_client: OpenSearch):
index_name = "test__index__iris_data"
index_mapping = {
"mappings": {
"properties": {
"sepal_length": {"type": "float"},
"sepal_width": {"type": "float"},
"petal_length": {"type": "float"},
"petal_width": {"type": "float"},
"species": {"type": "keyword"}
}
}
}

if opensearch_client.indices.exists(index=index_name):
opensearch_client.indices.delete(index=index_name)
opensearch_client.indices.create(index=index_name, body=index_mapping)

iris = load_iris()
iris_data = iris.data
iris_target = iris.target
iris_species = [iris.target_names[i] for i in iris_target]

actions = [
{ '_index': index_name,
"_source":{
"sepal_length": sepal_length,
"sepal_width": sepal_width,
"petal_length": petal_length,
"petal_width": petal_width,
"species": species
}
}
for (sepal_length, sepal_width, petal_length, petal_width), species in zip(iris_data, iris_species)
]

helpers.bulk(opensearch_client, actions)
# without the sleep, test is failing.
time.sleep(2)

yield opensearch_client, index_name

opensearch_client.indices.delete(index=index_name)
3 changes: 2 additions & 1 deletion opensearch_py_ml/ml_commons/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from opensearch_py_ml.ml_commons.ml_commons_client import MLCommonClient
from opensearch_py_ml.ml_commons.model_execute import ModelExecute
from opensearch_py_ml.ml_commons.model_train import ModelTrain
from opensearch_py_ml.ml_commons.model_uploader import ModelUploader

__all__ = ["MLCommonClient", "ModelExecute", "ModelUploader"]
__all__ = ["MLCommonClient", "ModelExecute", "ModelUploader", "ModelTrain"]
13 changes: 12 additions & 1 deletion opensearch_py_ml/ml_commons/ml_commons_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import json
import time
from typing import Any, List, Union
from typing import Any, List, Optional, Union

from deprecated.sphinx import deprecated
from opensearchpy import OpenSearch
Expand All @@ -22,6 +22,7 @@
TIMEOUT,
)
from opensearch_py_ml.ml_commons.model_execute import ModelExecute
from opensearch_py_ml.ml_commons.model_train import ModelTrain
from opensearch_py_ml.ml_commons.model_uploader import ModelUploader


Expand All @@ -35,6 +36,7 @@ def __init__(self, os_client: OpenSearch):
self._client = os_client
self._model_uploader = ModelUploader(os_client)
self._model_execute = ModelExecute(os_client)
self._model_train = ModelTrain(os_client)

def execute(self, algorithm_name: str, input_json: dict) -> dict:
"""
Expand Down Expand Up @@ -580,3 +582,12 @@ def delete_task(self, task_id: str) -> object:
method="DELETE",
url=API_URL,
)

def train_model(
self, algorithm_name: str, input_json: dict, is_async: Optional[bool] = False
) -> dict:
"""
This method trains an ML model
"""

return self._model_train._train(algorithm_name, input_json, is_async)
44 changes: 44 additions & 0 deletions opensearch_py_ml/ml_commons/model_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# SPDX-License-Identifier: Apache-2.0
# The OpenSearch Contributors require contributions made to
# this file be licensed under the Apache-2.0 license or a
# compatible open source license.
# Any modifications Copyright OpenSearch Contributors. See
# GitHub history for details.

import json
from typing import Optional

from opensearchpy import OpenSearch

from opensearch_py_ml.ml_commons.ml_common_utils import ML_BASE_URI


class ModelTrain:
rawwar marked this conversation as resolved.
Show resolved Hide resolved
"""
Class for training models using ML Commons train API.
"""

API_ENDPOINT = "models/_train"
rawwar marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, os_client: OpenSearch):
self._client = os_client

def _train(
self, algorithm_name: str, input_json: dict, is_async: Optional[bool] = True
) -> dict:
"""
This method trains an ML model
"""

params = {}
if not isinstance(input_json, dict):
input_json = json.loads(input_json)
if is_async:
params["async"] = "true"

return self._client.transport.perform_request(
method="POST",
url=f"{ML_BASE_URI}/_train/{algorithm_name}",
body=input_json,
params=params,
)
51 changes: 51 additions & 0 deletions tests/ml_commons/test_model_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# SPDX-License-Identifier: Apache-2.0
# The OpenSearch Contributors require contributions made to
# this file be licensed under the Apache-2.0 license or a
# compatible open source license.
# Any modifications Copyright OpenSearch Contributors. See
# GitHub history for details.


from opensearchpy import OpenSearch

from opensearch_py_ml.ml_commons import MLCommonClient, ModelTrain


def test_init(opensearch_client):
ml_client = MLCommonClient(opensearch_client)
assert isinstance(ml_client._client, OpenSearch)
assert isinstance(ml_client._model_train, ModelTrain)


def test_train(iris_index_client):
client, test_index_name = iris_index_client
ml_client = MLCommonClient(client)
algorithm_name = "kmeans"
input_json_sync = {
"parameters": {"centroids": 3, "iterations": 10, "distance_type": "COSINE"},
"input_query": {
"_source": ["petal_length", "petal_width"],
"size": 10000,
},
"input_index": [test_index_name],
}
response = ml_client.train_model(algorithm_name, input_json_sync)
assert isinstance(response, dict)
assert "model_id" in response
assert "status" in response
assert response["status"] == "COMPLETED"

input_json_async = {
"parameters": {"centroids": 3, "iterations": 10, "distance_type": "COSINE"},
"input_query": {
"_source": ["petal_length", "petal_width"],
"size": 10000,
},
"input_index": [test_index_name],
}
response = ml_client.train_model(algorithm_name, input_json_async, is_async=True)

assert isinstance(response, dict)
assert "task_id" in response
assert "status" in response
assert response["status"] == "CREATED"