Skip to content

Commit

Permalink
add more method to client
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn committed Dec 17, 2023
1 parent 0335a8f commit 228b0da
Show file tree
Hide file tree
Showing 4 changed files with 270 additions and 171 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,23 @@
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.PlainActionFuture;
import org.opensearch.common.Nullable;
import org.opensearch.common.action.ActionFuture;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.input.Input;
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse;
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;
import org.opensearch.ml.common.transport.execute.MLExecuteTaskResponse;
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput;
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupResponse;
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse;

/**
* A client to provide interfaces for machine learning jobs. This will be used by other plugins.
Expand Down Expand Up @@ -254,7 +255,7 @@ default ActionFuture<MLRegisterModelResponse> register(MLRegisterModelInput mlIn

/**
* Deploy model
* For additional info on deploy, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#deploying-a-model
* For additional info on deploy, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/model-apis/deploy-model/
* @param modelId the model id
*/
default ActionFuture<MLDeployModelResponse> deploy(String modelId) {
Expand All @@ -265,12 +266,33 @@ default ActionFuture<MLDeployModelResponse> deploy(String modelId) {

/**
* Deploy model
* For additional info on deploy, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#deploying-a-model
* For additional info on deploy, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/model-apis/deploy-model/
* @param modelId the model id
* @param listener a listener to be notified of the result
*/
void deploy(String modelId, ActionListener<MLDeployModelResponse> listener);

/**
* Undeploy models
* For additional info on undeploy, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/model-apis/undeploy-model/
* @param modelIds the model ids
* @param nodeIds the node ids. May be null for all nodes.
*/
default ActionFuture<MLUndeployModelsResponse> undeploy(String[] modelIds, @Nullable String[] nodeIds) {
PlainActionFuture<MLUndeployModelsResponse> actionFuture = PlainActionFuture.newFuture();
undeploy(modelIds, nodeIds, actionFuture);
return actionFuture;
}

/**
* Undeploy model
* For additional info on deploy, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/model-apis/undeploy-model/
* @param modelIds the model ids
* @param modelIds the node ids. May be null for all nodes.
* @param listener a listener to be notified of the result
*/
void undeploy(String[] modelIds, String[] nodeIds, ActionListener<MLUndeployModelsResponse> listener);

/**
* Create connector for remote model
* @param mlCreateConnectorInput Create Connector Input, refer: https://opensearch.org/docs/latest/ml-commons-plugin/extensibility/connectors/
Expand All @@ -284,6 +306,19 @@ default ActionFuture<MLCreateConnectorResponse> createConnector(MLCreateConnecto

void createConnector(MLCreateConnectorInput mlCreateConnectorInput, ActionListener<MLCreateConnectorResponse> listener);

/**
* Delete connector for remote model
* @param connectorId The id of the connector to delete
* @return the result future
*/
default ActionFuture<DeleteResponse> deleteConnector(String connectorId) {
PlainActionFuture<DeleteResponse> actionFuture = PlainActionFuture.newFuture();
deleteConnector(connectorId, actionFuture);
return actionFuture;
}

void deleteConnector(String connectorId, ActionListener<DeleteResponse> listener);

/**
* Register model group
* For additional info on model group, refer: https://opensearch.org/docs/latest/ml-commons-plugin/model-access-control#registering-a-model-group
Expand All @@ -304,21 +339,32 @@ default ActionFuture<MLRegisterModelGroupResponse> registerModelGroup(MLRegister
void registerModelGroup(MLRegisterModelGroupInput mlRegisterModelGroupInput, ActionListener<MLRegisterModelGroupResponse> listener);

/**
* Execute an algorithm
* @param name algorithm function name
* @param input input
* Registers new agent and returns ActionFuture.
* @param mlAgent Register agent input, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#register-agent
* @return the result future
*/
default ActionFuture<MLExecuteTaskResponse> execute(FunctionName name, Input input) {
PlainActionFuture<MLExecuteTaskResponse> actionFuture = PlainActionFuture.newFuture();
execute(name, input, actionFuture);
default ActionFuture<MLRegisterAgentResponse> registerAgent(MLAgent mlAgent) {
PlainActionFuture<MLRegisterAgentResponse> actionFuture = PlainActionFuture.newFuture();
registerAgent(mlAgent, actionFuture);
return actionFuture;
}

/**
* Execute an algorithm
* @param input an algorithm input
* @param listener a listener to be notified of the result
* Registers new agent and returns agent ID in response
* @param mlAgent Register agent input, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#register-agent
*/
void execute(FunctionName name, Input input, ActionListener<MLExecuteTaskResponse> listener);
void registerAgent(MLAgent mlAgent, ActionListener<MLRegisterAgentResponse> listener);

/**
* Delete agent
* @param agentId The id of the agent to delete
* @return the result future
*/
default ActionFuture<DeleteResponse> deleteAgent(String agentId) {
PlainActionFuture<DeleteResponse> actionFuture = PlainActionFuture.newFuture();
deleteAgent(agentId, actionFuture);
return actionFuture;
}

void deleteAgent(String agentId, ActionListener<DeleteResponse> listener);
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,25 @@
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.input.Input;
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.agent.MLAgentDeleteAction;
import org.opensearch.ml.common.transport.agent.MLAgentDeleteRequest;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentAction;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentRequest;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse;
import org.opensearch.ml.common.transport.connector.MLConnectorDeleteAction;
import org.opensearch.ml.common.transport.connector.MLConnectorDeleteRequest;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorAction;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorRequest;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse;
import org.opensearch.ml.common.transport.deploy.MLDeployModelAction;
import org.opensearch.ml.common.transport.deploy.MLDeployModelRequest;
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;
import org.opensearch.ml.common.transport.execute.MLExecuteTaskAction;
import org.opensearch.ml.common.transport.execute.MLExecuteTaskRequest;
import org.opensearch.ml.common.transport.execute.MLExecuteTaskResponse;
import org.opensearch.ml.common.transport.model.MLModelDeleteAction;
import org.opensearch.ml.common.transport.model.MLModelDeleteRequest;
import org.opensearch.ml.common.transport.model.MLModelGetAction;
Expand All @@ -66,6 +70,9 @@
import org.opensearch.ml.common.transport.training.MLTrainingTaskAction;
import org.opensearch.ml.common.transport.training.MLTrainingTaskRequest;
import org.opensearch.ml.common.transport.trainpredict.MLTrainAndPredictionTaskAction;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsAction;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsRequest;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse;

import lombok.AccessLevel;
import lombok.RequiredArgsConstructor;
Expand Down Expand Up @@ -191,19 +198,6 @@ public void registerModelGroup(
);
}

/**
* Execute an algorithm
*
* @param name function name
* @param input an algorithm input
* @param listener a listener to be notified of the result
*/
@Override
public void execute(FunctionName name, Input input, ActionListener<MLExecuteTaskResponse> listener) {
MLExecuteTaskRequest mlExecuteTaskRequest = new MLExecuteTaskRequest(name, input);
client.execute(MLExecuteTaskAction.INSTANCE, mlExecuteTaskRequest, listener);
}

@Override
public void getTask(String taskId, ActionListener<MLTask> listener) {
MLTaskGetRequest mlTaskGetRequest = MLTaskGetRequest.builder().taskId(taskId).build();
Expand Down Expand Up @@ -242,12 +236,50 @@ public void deploy(String modelId, ActionListener<MLDeployModelResponse> listene
client.execute(MLDeployModelAction.INSTANCE, deployModelRequest, getMlDeployModelResponseActionListener(listener));
}

@Override
public void undeploy(String[] modelIds, String[] nodeIds, ActionListener<MLUndeployModelsResponse> listener) {
MLUndeployModelsRequest undeployModelRequest = new MLUndeployModelsRequest(modelIds, nodeIds);
client.execute(MLUndeployModelsAction.INSTANCE, undeployModelRequest, getMlUndeployModelsResponseActionListener(listener));
}

@Override
public void createConnector(MLCreateConnectorInput mlCreateConnectorInput, ActionListener<MLCreateConnectorResponse> listener) {
MLCreateConnectorRequest createConnectorRequest = new MLCreateConnectorRequest(mlCreateConnectorInput);
client.execute(MLCreateConnectorAction.INSTANCE, createConnectorRequest, getMlCreateConnectorResponseActionListener(listener));
}

@Override
public void deleteConnector(String connectorId, ActionListener<DeleteResponse> listener) {
MLConnectorDeleteRequest connectorDeleteRequest = new MLConnectorDeleteRequest(connectorId);
client.execute(MLConnectorDeleteAction.INSTANCE, connectorDeleteRequest, ActionListener.wrap(deleteResponse -> {
listener.onResponse(deleteResponse);
}, listener::onFailure));
}

@Override
public void registerAgent(MLAgent mlAgent, ActionListener<MLRegisterAgentResponse> listener) {
MLRegisterAgentRequest mlRegisterAgentRequest = MLRegisterAgentRequest.builder().mlAgent(mlAgent).build();
client.execute(MLRegisterAgentAction.INSTANCE, mlRegisterAgentRequest, getMLRegisterAgentResponseActionListener(listener));
}

@Override
public void deleteAgent(String agentId, ActionListener<DeleteResponse> listener) {
MLAgentDeleteRequest agentDeleteRequest = new MLAgentDeleteRequest(agentId);
client.execute(MLAgentDeleteAction.INSTANCE, agentDeleteRequest, ActionListener.wrap(deleteResponse -> {
listener.onResponse(deleteResponse);
}, listener::onFailure));
}

private ActionListener<MLRegisterAgentResponse> getMLRegisterAgentResponseActionListener(
ActionListener<MLRegisterAgentResponse> listener
) {
ActionListener<MLRegisterAgentResponse> actionListener = wrapActionListener(listener, res -> {
MLRegisterAgentResponse mlRegisterAgentResponse = MLRegisterAgentResponse.fromActionResponse(res);
return mlRegisterAgentResponse;
});
return actionListener;
}

private ActionListener<MLTaskGetResponse> getMLTaskResponseActionListener(ActionListener<MLTask> listener) {
ActionListener<MLTaskGetResponse> internalListener = ActionListener
.wrap(getResponse -> { listener.onResponse(getResponse.getMlTask()); }, listener::onFailure);
Expand All @@ -266,6 +298,16 @@ private ActionListener<MLDeployModelResponse> getMlDeployModelResponseActionList
return actionListener;
}

private ActionListener<MLUndeployModelsResponse> getMlUndeployModelsResponseActionListener(
ActionListener<MLUndeployModelsResponse> listener
) {
ActionListener<MLUndeployModelsResponse> actionListener = wrapActionListener(listener, response -> {
MLUndeployModelsResponse deployModelResponse = MLUndeployModelsResponse.fromActionResponse(response);
return deployModelResponse;
});
return actionListener;
}

private ActionListener<MLCreateConnectorResponse> getMlCreateConnectorResponseActionListener(
ActionListener<MLCreateConnectorResponse> listener
) {
Expand Down
Loading

0 comments on commit 228b0da

Please sign in to comment.