-
Notifications
You must be signed in to change notification settings - Fork 82
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add TrainKNN Runner/Operation for Benchmarking Approximate KNN Algorithms #556
Changes from all commits
3d138cb
13da7b8
9b206c0
9fac862
16e493b
b43cbc6
a78351e
95d445e
904ae3f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -47,7 +47,7 @@ | |
from osbenchmark.utils import convert | ||
from osbenchmark.client import RequestContextHolder | ||
# Mapping from operation type to specific runner | ||
from osbenchmark.utils.parse import parse_int_parameter, parse_string_parameter | ||
from osbenchmark.utils.parse import parse_int_parameter, parse_string_parameter, parse_float_parameter | ||
|
||
__RUNNERS = {} | ||
|
||
|
@@ -105,7 +105,8 @@ def register_default_runners(): | |
register_runner(workload.OperationType.DeleteMlModel, Retry(DeleteMlModel()), async_runner=True) | ||
register_runner(workload.OperationType.RegisterMlModel, Retry(RegisterMlModel()), async_runner=True) | ||
register_runner(workload.OperationType.DeployMlModel, Retry(DeployMlModel()), async_runner=True) | ||
|
||
register_runner(workload.OperationType.TrainKnnModel, Retry(TrainKnnModel()), async_runner=True) | ||
register_runner(workload.OperationType.DeleteKnnModel, Retry(DeleteKnnModel()), async_runner=True) | ||
|
||
def runner_for(operation_type): | ||
try: | ||
|
@@ -652,6 +653,184 @@ def __repr__(self, *args, **kwargs): | |
return "bulk-index" | ||
|
||
|
||
class DeleteKnnModel(Runner): | ||
""" | ||
Deletes the K-NN model named model_id. | ||
""" | ||
|
||
NAME = "delete-knn-model" | ||
MODEL_DOES_NOT_EXIST_STATUS_CODE = 404 | ||
|
||
async def __call__(self, opensearch, params): | ||
model_id = parse_string_parameter("model_id", params) | ||
ignore_if_model_does_not_exist = params.get( | ||
"ignore-if-model-does-not-exist", False | ||
) | ||
|
||
method = "DELETE" | ||
model_uri = f"/_plugins/_knn/models/{model_id}" | ||
|
||
request_context_holder.on_client_request_start() | ||
|
||
# 404 indicates the model has not been created. In that case, the runner's response depends on ignore_if_model_does_not_exist. | ||
response = await opensearch.transport.perform_request( | ||
method, | ||
model_uri, | ||
params={"ignore": [self.MODEL_DOES_NOT_EXIST_STATUS_CODE]}, | ||
) | ||
|
||
request_context_holder.on_client_request_end() | ||
|
||
# success condition. | ||
if "result" in response.keys() and response["result"] == "deleted": | ||
self.logger.debug("Model [%s] deleted successfully.", model_id) | ||
return {"weight": 1, "unit": "ops", "success": True} | ||
|
||
if "error" not in response.keys(): | ||
self.logger.warning( | ||
"Request to delete model [%s] failed but no error, response: [%s]", | ||
model_id, | ||
response, | ||
) | ||
return {"weight": 1, "unit": "ops", "success": False} | ||
|
||
if response["status"] != self.MODEL_DOES_NOT_EXIST_STATUS_CODE: | ||
self.logger.warning( | ||
"Request to delete model [%s] failed with status [%s] and response: [%s]", | ||
model_id, | ||
response["status"], | ||
response, | ||
) | ||
return {"weight": 1, "unit": "ops", "success": False} | ||
|
||
if ignore_if_model_does_not_exist: | ||
self.logger.debug( | ||
( | ||
"Model [%s] does not exist so it could not be deleted, " | ||
"however ignore-if-model-does-not-exist is True so the " | ||
"DeleteKnnModel operation succeeded." | ||
), | ||
model_id, | ||
) | ||
|
||
return {"weight": 1, "unit": "ops", "success": True} | ||
|
||
self.logger.warning( | ||
( | ||
"Request to delete model [%s] failed because the model does not exist " | ||
"and ignore-if-model-does-not-exist was set to False. Response: [%s]" | ||
), | ||
model_id, | ||
response, | ||
) | ||
return {"weight": 1, "unit": "ops", "success": False} | ||
|
||
def __repr__(self, *args, **kwargs): | ||
return self.NAME | ||
|
||
|
||
class TrainKnnModel(Runner): | ||
""" | ||
Trains model named model_id until training is complete or retries are exhausted. | ||
""" | ||
|
||
NAME = "train-knn-model" | ||
DEFAULT_RETRIES = 1000 | ||
DEFAULT_POLL_PERIOD = 0.5 | ||
|
||
async def __call__(self, opensearch, params): | ||
""" | ||
Create and train one model named model_id. | ||
|
||
:param opensearch: The OpenSearch client. | ||
:param params: A hash with all parameters. See below for details. | ||
:return: A hash with meta data for this bulk operation. See below for details. | ||
:raises: Exception if training fails, times out, or a different error occurs. | ||
It expects a parameter dict with the following mandatory keys: | ||
|
||
* ``body``: containing parameters to pass on to the train engine. | ||
See https://opensearch.org/docs/latest/search-plugins/knn/api/#train-a-model for information. | ||
* ``retries``: Maximum number of retries allowed for the training to complete (seconds). | ||
* ``polling-interval``: Polling interval to see if the model has been trained yet (seconds). | ||
* ``model_id``: ID of the model to train. | ||
""" | ||
body = params["body"] | ||
model_id = parse_string_parameter("model_id", params) | ||
max_retries = parse_int_parameter("retries", params, self.DEFAULT_RETRIES) | ||
poll_period = parse_float_parameter( | ||
"poll_period", params, self.DEFAULT_POLL_PERIOD | ||
) | ||
|
||
method = "POST" | ||
model_uri = f"/_plugins/_knn/models/{model_id}" | ||
request_context_holder.on_client_request_start() | ||
await opensearch.transport.perform_request( | ||
method, f"{model_uri}/_train", body=body | ||
) | ||
|
||
current_number_retries = 0 | ||
while True: | ||
model_response = await opensearch.transport.perform_request( | ||
"GET", model_uri | ||
) | ||
|
||
if "state" not in model_response.keys(): | ||
request_context_holder.on_client_request_end() | ||
self.logger.error( | ||
"Failed to create model [%s] with error response: [%s]", | ||
model_id, | ||
model_response, | ||
) | ||
raise Exception( | ||
f"Failed to create model {model_id} with error response: {model_response}" | ||
) | ||
|
||
if current_number_retries > max_retries: | ||
request_context_holder.on_client_request_end() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. request_context_holder.on_client_request_start() is missing |
||
self.logger.error( | ||
"Failed to create model [%s] within [%i] retries.", | ||
model_id, | ||
max_retries, | ||
) | ||
raise TimeoutError( | ||
f"Failed to create model: {model_id} within {max_retries} retries" | ||
) | ||
|
||
if model_response["state"] == "training": | ||
current_number_retries += 1 | ||
await asyncio.sleep(poll_period) | ||
continue | ||
|
||
# at this point, training either failed or finished. | ||
request_context_holder.on_client_request_end() | ||
if model_response["state"] == "created": | ||
self.logger.info( | ||
"Training model [%s] was completed successfully.", model_id | ||
) | ||
return | ||
|
||
if model_response["state"] == "failed": | ||
self.logger.error( | ||
"Training for model [%s] failed. Response: [%s]", | ||
model_id, | ||
model_response, | ||
) | ||
raise Exception(f"Failed to create model {model_id}: {model_response}") | ||
|
||
self.logger.error( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IMO, we should add model_response['state'] == 'failed' to validate failure condition, and, raise exception if state contains unexpected value. This will help us in case knn decided to add new status |
||
"Model [%s] in unknown state [%s], response: [%s]", | ||
model_id, | ||
model_response["state"], | ||
model_response, | ||
) | ||
raise Exception( | ||
f"Model {model_id} in unknown state {model_response['state']}, response: {model_response}" | ||
) | ||
|
||
def __repr__(self, *args, **kwargs): | ||
return self.NAME | ||
|
||
|
||
# TODO: Add retry logic to BulkIndex, so that we can remove BulkVectorDataSet and use BulkIndex. | ||
class BulkVectorDataSet(Runner): | ||
""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
did you check can we use mandatory?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, these helpers enforce mandatory if they don't have a third argument. So if model_id is not provided
parse_string_parameter
will throw an error.