-
Notifications
You must be signed in to change notification settings - Fork 82
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add TrainKNN Runner/Operation for Benchmarking Approximate KNN Algorithms #556
Changes from 2 commits
3d138cb
13da7b8
9b206c0
9fac862
16e493b
b43cbc6
a78351e
95d445e
904ae3f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -47,7 +47,7 @@ | |
from osbenchmark.utils import convert | ||
from osbenchmark.client import RequestContextHolder | ||
# Mapping from operation type to specific runner | ||
from osbenchmark.utils.parse import parse_int_parameter, parse_string_parameter | ||
from osbenchmark.utils.parse import parse_int_parameter, parse_string_parameter, parse_float_parameter | ||
|
||
__RUNNERS = {} | ||
|
||
|
@@ -105,7 +105,7 @@ def register_default_runners(): | |
register_runner(workload.OperationType.DeleteMlModel, Retry(DeleteMlModel()), async_runner=True) | ||
register_runner(workload.OperationType.RegisterMlModel, Retry(RegisterMlModel()), async_runner=True) | ||
register_runner(workload.OperationType.DeployMlModel, Retry(DeployMlModel()), async_runner=True) | ||
|
||
register_runner(workload.OperationType.TrainKnnModel, Retry(TrainKnnModel()), async_runner=True) | ||
|
||
def runner_for(operation_type): | ||
try: | ||
|
@@ -652,6 +652,77 @@ def __repr__(self, *args, **kwargs): | |
return "bulk-index" | ||
|
||
|
||
class TrainKnnModel(Runner): | ||
""" | ||
Trains model named model_id until training is complete or retries are exhausted. | ||
""" | ||
|
||
NAME = "train-knn-model" | ||
|
||
async def __call__(self, opensearch, params): | ||
""" | ||
Create and train one model named model_id. | ||
|
||
:param opensearch: The OpenSearch client. | ||
:param params: A hash with all parameters. See below for details. | ||
:return: A hash with meta data for this bulk operation. See below for details. | ||
:raises: Exception if training fails, times out, or a different error occurs. | ||
It expects a parameter dict with the following mandatory keys: | ||
|
||
* ``body``: containing parameters to pass on to the train engine. | ||
See https://opensearch.org/docs/latest/search-plugins/knn/api/#train-a-model for information. | ||
* ``retries``: Maximum number of retries allowed for the training to complete (seconds). | ||
* ``polling-interval``: Polling interval to see if the model has been trained yet (seconds). | ||
* ``model_id``: ID of the model to train. | ||
""" | ||
body = params["body"] | ||
model_id = parse_string_parameter("model_id", params) | ||
max_retries = parse_int_parameter("retries", params) | ||
poll_period = parse_float_parameter("poll_period", params) | ||
|
||
method = "POST" | ||
model_uri = f"/_plugins/_knn/models/{model_id}" | ||
request_context_holder.on_client_request_start() | ||
await opensearch.transport.perform_request(method, f"{model_uri}/_train", body=body) | ||
|
||
current_number_retries = 0 | ||
while True: | ||
model_response = await opensearch.transport.perform_request("GET", model_uri) | ||
|
||
if 'state' not in model_response.keys(): | ||
request_context_holder.on_client_request_end() | ||
self.logger.error( | ||
"Failed to create model [%s] with error response: [%s]", model_id, model_response) | ||
raise Exception( | ||
f"Failed to create model {model_id} with error response: {model_response}") | ||
|
||
if current_number_retries > max_retries: | ||
request_context_holder.on_client_request_end() | ||
self.logger.error( | ||
"Failed to create model [%s] within [%i] retries.", model_id, max_retries) | ||
raise TimeoutError( | ||
f'Failed to create model: {model_id} within {max_retries} retries') | ||
|
||
if model_response['state'] == 'training': | ||
current_number_retries += 1 | ||
await asyncio.sleep(poll_period) | ||
continue | ||
|
||
# at this point, training either failed or finished. | ||
request_context_holder.on_client_request_end() | ||
if model_response['state'] == 'created': | ||
self.logger.info( | ||
"Training model [%s] was completed successfully.", model_id) | ||
break | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: return instead of break |
||
|
||
# training failed. | ||
self.logger.error( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IMO, we should add model_response['state'] == 'failed' to validate failure condition, and, raise exception if state contains unexpected value. This will help us in case knn decided to add new status |
||
"Training for model [%s] failed. Response: [%s]", model_id, model_response) | ||
raise Exception(f"Failed to create model: {model_response}") | ||
|
||
def __repr__(self, *args, **kwargs): | ||
return self.NAME | ||
|
||
# TODO: Add retry logic to BulkIndex, so that we can remove BulkVectorDataSet and use BulkIndex. | ||
class BulkVectorDataSet(Runner): | ||
""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
request_context_holder.on_client_request_start() is missing