Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TrainKNN Runner/Operation for Benchmarking Approximate KNN Algorithms #556

Merged
merged 9 commits into from
Jul 18, 2024
183 changes: 181 additions & 2 deletions osbenchmark/worker_coordinator/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from osbenchmark.utils import convert
from osbenchmark.client import RequestContextHolder
# Mapping from operation type to specific runner
from osbenchmark.utils.parse import parse_int_parameter, parse_string_parameter
from osbenchmark.utils.parse import parse_int_parameter, parse_string_parameter, parse_float_parameter

__RUNNERS = {}

Expand Down Expand Up @@ -105,7 +105,8 @@ def register_default_runners():
register_runner(workload.OperationType.DeleteMlModel, Retry(DeleteMlModel()), async_runner=True)
register_runner(workload.OperationType.RegisterMlModel, Retry(RegisterMlModel()), async_runner=True)
register_runner(workload.OperationType.DeployMlModel, Retry(DeployMlModel()), async_runner=True)

register_runner(workload.OperationType.TrainKnnModel, Retry(TrainKnnModel()), async_runner=True)
register_runner(workload.OperationType.DeleteKnnModel, Retry(DeleteKnnModel()), async_runner=True)

def runner_for(operation_type):
try:
Expand Down Expand Up @@ -652,6 +653,184 @@ def __repr__(self, *args, **kwargs):
return "bulk-index"


class DeleteKnnModel(Runner):
"""
Deletes the K-NN model named model_id.
"""

NAME = "delete-knn-model"
MODEL_DOES_NOT_EXIST_STATUS_CODE = 404

async def __call__(self, opensearch, params):
model_id = parse_string_parameter("model_id", params)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did you check can we use mandatory?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, these helpers enforce mandatory if they don't have a third argument. So if model_id is not provided parse_string_parameter will throw an error.

ignore_if_model_does_not_exist = params.get(
"ignore-if-model-does-not-exist", False
)

method = "DELETE"
model_uri = f"/_plugins/_knn/models/{model_id}"

request_context_holder.on_client_request_start()

# 404 indicates the model has not been created. In that case, the runner's response depends on ignore_if_model_does_not_exist.
response = await opensearch.transport.perform_request(
method,
model_uri,
params={"ignore": [self.MODEL_DOES_NOT_EXIST_STATUS_CODE]},
)

request_context_holder.on_client_request_end()

# success condition.
if "result" in response.keys() and response["result"] == "deleted":
self.logger.debug("Model [%s] deleted successfully.", model_id)
return {"weight": 1, "unit": "ops", "success": True}

if "error" not in response.keys():
self.logger.warning(
"Request to delete model [%s] failed but no error, response: [%s]",
model_id,
response,
)
return {"weight": 1, "unit": "ops", "success": False}

if response["status"] != self.MODEL_DOES_NOT_EXIST_STATUS_CODE:
self.logger.warning(
"Request to delete model [%s] failed with status [%s] and response: [%s]",
model_id,
response["status"],
response,
)
return {"weight": 1, "unit": "ops", "success": False}

if ignore_if_model_does_not_exist:
self.logger.debug(
(
"Model [%s] does not exist so it could not be deleted, "
"however ignore-if-model-does-not-exist is True so the "
"DeleteKnnModel operation succeeded."
),
model_id,
)

return {"weight": 1, "unit": "ops", "success": True}

self.logger.warning(
(
"Request to delete model [%s] failed because the model does not exist "
"and ignore-if-model-does-not-exist was set to False. Response: [%s]"
),
model_id,
response,
)
return {"weight": 1, "unit": "ops", "success": False}

def __repr__(self, *args, **kwargs):
return self.NAME


class TrainKnnModel(Runner):
"""
Trains model named model_id until training is complete or retries are exhausted.
"""

NAME = "train-knn-model"
DEFAULT_RETRIES = 1000
DEFAULT_POLL_PERIOD = 0.5

async def __call__(self, opensearch, params):
"""
Create and train one model named model_id.

:param opensearch: The OpenSearch client.
:param params: A hash with all parameters. See below for details.
:return: A hash with meta data for this bulk operation. See below for details.
:raises: Exception if training fails, times out, or a different error occurs.
It expects a parameter dict with the following mandatory keys:

* ``body``: containing parameters to pass on to the train engine.
See https://opensearch.org/docs/latest/search-plugins/knn/api/#train-a-model for information.
* ``retries``: Maximum number of retries allowed for the training to complete (seconds).
* ``polling-interval``: Polling interval to see if the model has been trained yet (seconds).
* ``model_id``: ID of the model to train.
"""
body = params["body"]
model_id = parse_string_parameter("model_id", params)
max_retries = parse_int_parameter("retries", params, self.DEFAULT_RETRIES)
poll_period = parse_float_parameter(
"poll_period", params, self.DEFAULT_POLL_PERIOD
)

method = "POST"
model_uri = f"/_plugins/_knn/models/{model_id}"
request_context_holder.on_client_request_start()
await opensearch.transport.perform_request(
method, f"{model_uri}/_train", body=body
)

current_number_retries = 0
while True:
model_response = await opensearch.transport.perform_request(
"GET", model_uri
)

if "state" not in model_response.keys():
request_context_holder.on_client_request_end()
self.logger.error(
"Failed to create model [%s] with error response: [%s]",
model_id,
model_response,
)
raise Exception(
f"Failed to create model {model_id} with error response: {model_response}"
)

if current_number_retries > max_retries:
request_context_holder.on_client_request_end()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

request_context_holder.on_client_request_start() is missing

self.logger.error(
"Failed to create model [%s] within [%i] retries.",
model_id,
max_retries,
)
raise TimeoutError(
f"Failed to create model: {model_id} within {max_retries} retries"
)

if model_response["state"] == "training":
current_number_retries += 1
await asyncio.sleep(poll_period)
continue

# at this point, training either failed or finished.
request_context_holder.on_client_request_end()
if model_response["state"] == "created":
self.logger.info(
"Training model [%s] was completed successfully.", model_id
)
return

if model_response["state"] == "failed":
self.logger.error(
"Training for model [%s] failed. Response: [%s]",
model_id,
model_response,
)
raise Exception(f"Failed to create model {model_id}: {model_response}")

self.logger.error(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO, we should add model_response['state'] == 'failed' to validate failure condition, and, raise exception if state contains unexpected value. This will help us in case knn decided to add new status

"Model [%s] in unknown state [%s], response: [%s]",
model_id,
model_response["state"],
model_response,
)
raise Exception(
f"Model {model_id} in unknown state {model_response['state']}, response: {model_response}"
)

def __repr__(self, *args, **kwargs):
return self.NAME


# TODO: Add retry logic to BulkIndex, so that we can remove BulkVectorDataSet and use BulkIndex.
class BulkVectorDataSet(Runner):
"""
Expand Down
6 changes: 6 additions & 0 deletions osbenchmark/workload/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,8 @@ class OperationType(Enum):
ListAllPointInTime = 16
VectorSearch = 17
BulkVectorDataSet = 18
TrainKnnModel = 19
DeleteKnnModel = 20

# administrative actions
ForceMerge = 1001
Expand Down Expand Up @@ -746,6 +748,10 @@ def from_hyphenated_string(cls, v):
return OperationType.RegisterMlModel
elif v == "deploy-ml-model":
return OperationType.DeployMlModel
elif v == "train-knn-model":
return OperationType.TrainKnnModel
elif v == "delete-knn-model":
return OperationType.DeleteKnnModel
else:
raise KeyError(f"No enum value for [{v}]")

Expand Down
Loading
Loading