From 228b0daaa809f400731ea9165977cdb01307456c Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Sun, 17 Dec 2023 11:39:01 -0800 Subject: [PATCH] add more method to client Signed-off-by: Yaliang Wu --- .../ml/client/MachineLearningClient.java | 76 +++++-- .../ml/client/MachineLearningNodeClient.java | 76 +++++-- .../ml/client/MachineLearningClientTest.java | 100 ++++----- .../client/MachineLearningNodeClientTest.java | 189 ++++++++++-------- 4 files changed, 270 insertions(+), 171 deletions(-) diff --git a/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java b/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java index 55c3178ad8..da1498f7de 100644 --- a/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java +++ b/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java @@ -11,22 +11,23 @@ import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.PlainActionFuture; +import org.opensearch.common.Nullable; import org.opensearch.common.action.ActionFuture; import org.opensearch.core.action.ActionListener; -import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.MLTask; -import org.opensearch.ml.common.input.Input; +import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.output.MLOutput; +import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse; import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse; import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse; -import org.opensearch.ml.common.transport.execute.MLExecuteTaskResponse; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupResponse; import org.opensearch.ml.common.transport.register.MLRegisterModelInput; import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; +import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse; /** * A client to provide interfaces for machine learning jobs. This will be used by other plugins. @@ -254,7 +255,7 @@ default ActionFuture register(MLRegisterModelInput mlIn /** * Deploy model - * For additional info on deploy, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#deploying-a-model + * For additional info on deploy, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/model-apis/deploy-model/ * @param modelId the model id */ default ActionFuture deploy(String modelId) { @@ -265,12 +266,33 @@ default ActionFuture deploy(String modelId) { /** * Deploy model - * For additional info on deploy, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#deploying-a-model + * For additional info on deploy, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/model-apis/deploy-model/ * @param modelId the model id * @param listener a listener to be notified of the result */ void deploy(String modelId, ActionListener listener); + /** + * Undeploy models + * For additional info on undeploy, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/model-apis/undeploy-model/ + * @param modelIds the model ids + * @param nodeIds the node ids. May be null for all nodes. + */ + default ActionFuture undeploy(String[] modelIds, @Nullable String[] nodeIds) { + PlainActionFuture actionFuture = PlainActionFuture.newFuture(); + undeploy(modelIds, nodeIds, actionFuture); + return actionFuture; + } + + /** + * Undeploy model + * For additional info on deploy, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/model-apis/undeploy-model/ + * @param modelIds the model ids + * @param modelIds the node ids. May be null for all nodes. + * @param listener a listener to be notified of the result + */ + void undeploy(String[] modelIds, String[] nodeIds, ActionListener listener); + /** * Create connector for remote model * @param mlCreateConnectorInput Create Connector Input, refer: https://opensearch.org/docs/latest/ml-commons-plugin/extensibility/connectors/ @@ -284,6 +306,19 @@ default ActionFuture createConnector(MLCreateConnecto void createConnector(MLCreateConnectorInput mlCreateConnectorInput, ActionListener listener); + /** + * Delete connector for remote model + * @param connectorId The id of the connector to delete + * @return the result future + */ + default ActionFuture deleteConnector(String connectorId) { + PlainActionFuture actionFuture = PlainActionFuture.newFuture(); + deleteConnector(connectorId, actionFuture); + return actionFuture; + } + + void deleteConnector(String connectorId, ActionListener listener); + /** * Register model group * For additional info on model group, refer: https://opensearch.org/docs/latest/ml-commons-plugin/model-access-control#registering-a-model-group @@ -304,21 +339,32 @@ default ActionFuture registerModelGroup(MLRegister void registerModelGroup(MLRegisterModelGroupInput mlRegisterModelGroupInput, ActionListener listener); /** - * Execute an algorithm - * @param name algorithm function name - * @param input input + * Registers new agent and returns ActionFuture. + * @param mlAgent Register agent input, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#register-agent * @return the result future */ - default ActionFuture execute(FunctionName name, Input input) { - PlainActionFuture actionFuture = PlainActionFuture.newFuture(); - execute(name, input, actionFuture); + default ActionFuture registerAgent(MLAgent mlAgent) { + PlainActionFuture actionFuture = PlainActionFuture.newFuture(); + registerAgent(mlAgent, actionFuture); return actionFuture; } /** - * Execute an algorithm - * @param input an algorithm input - * @param listener a listener to be notified of the result + * Registers new agent and returns agent ID in response + * @param mlAgent Register agent input, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#register-agent */ - void execute(FunctionName name, Input input, ActionListener listener); + void registerAgent(MLAgent mlAgent, ActionListener listener); + + /** + * Delete agent + * @param agentId The id of the agent to delete + * @return the result future + */ + default ActionFuture deleteAgent(String agentId) { + PlainActionFuture actionFuture = PlainActionFuture.newFuture(); + deleteAgent(agentId, actionFuture); + return actionFuture; + } + + void deleteAgent(String agentId, ActionListener listener); } diff --git a/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java b/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java index 1c66a94403..91df04e4fc 100644 --- a/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java +++ b/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java @@ -26,11 +26,18 @@ import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.MLTask; -import org.opensearch.ml.common.input.Input; +import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.input.parameter.MLAlgoParams; import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.common.transport.agent.MLAgentDeleteAction; +import org.opensearch.ml.common.transport.agent.MLAgentDeleteRequest; +import org.opensearch.ml.common.transport.agent.MLRegisterAgentAction; +import org.opensearch.ml.common.transport.agent.MLRegisterAgentRequest; +import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse; +import org.opensearch.ml.common.transport.connector.MLConnectorDeleteAction; +import org.opensearch.ml.common.transport.connector.MLConnectorDeleteRequest; import org.opensearch.ml.common.transport.connector.MLCreateConnectorAction; import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; import org.opensearch.ml.common.transport.connector.MLCreateConnectorRequest; @@ -38,9 +45,6 @@ import org.opensearch.ml.common.transport.deploy.MLDeployModelAction; import org.opensearch.ml.common.transport.deploy.MLDeployModelRequest; import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse; -import org.opensearch.ml.common.transport.execute.MLExecuteTaskAction; -import org.opensearch.ml.common.transport.execute.MLExecuteTaskRequest; -import org.opensearch.ml.common.transport.execute.MLExecuteTaskResponse; import org.opensearch.ml.common.transport.model.MLModelDeleteAction; import org.opensearch.ml.common.transport.model.MLModelDeleteRequest; import org.opensearch.ml.common.transport.model.MLModelGetAction; @@ -66,6 +70,9 @@ import org.opensearch.ml.common.transport.training.MLTrainingTaskAction; import org.opensearch.ml.common.transport.training.MLTrainingTaskRequest; import org.opensearch.ml.common.transport.trainpredict.MLTrainAndPredictionTaskAction; +import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsAction; +import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsRequest; +import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse; import lombok.AccessLevel; import lombok.RequiredArgsConstructor; @@ -191,19 +198,6 @@ public void registerModelGroup( ); } - /** - * Execute an algorithm - * - * @param name function name - * @param input an algorithm input - * @param listener a listener to be notified of the result - */ - @Override - public void execute(FunctionName name, Input input, ActionListener listener) { - MLExecuteTaskRequest mlExecuteTaskRequest = new MLExecuteTaskRequest(name, input); - client.execute(MLExecuteTaskAction.INSTANCE, mlExecuteTaskRequest, listener); - } - @Override public void getTask(String taskId, ActionListener listener) { MLTaskGetRequest mlTaskGetRequest = MLTaskGetRequest.builder().taskId(taskId).build(); @@ -242,12 +236,50 @@ public void deploy(String modelId, ActionListener listene client.execute(MLDeployModelAction.INSTANCE, deployModelRequest, getMlDeployModelResponseActionListener(listener)); } + @Override + public void undeploy(String[] modelIds, String[] nodeIds, ActionListener listener) { + MLUndeployModelsRequest undeployModelRequest = new MLUndeployModelsRequest(modelIds, nodeIds); + client.execute(MLUndeployModelsAction.INSTANCE, undeployModelRequest, getMlUndeployModelsResponseActionListener(listener)); + } + @Override public void createConnector(MLCreateConnectorInput mlCreateConnectorInput, ActionListener listener) { MLCreateConnectorRequest createConnectorRequest = new MLCreateConnectorRequest(mlCreateConnectorInput); client.execute(MLCreateConnectorAction.INSTANCE, createConnectorRequest, getMlCreateConnectorResponseActionListener(listener)); } + @Override + public void deleteConnector(String connectorId, ActionListener listener) { + MLConnectorDeleteRequest connectorDeleteRequest = new MLConnectorDeleteRequest(connectorId); + client.execute(MLConnectorDeleteAction.INSTANCE, connectorDeleteRequest, ActionListener.wrap(deleteResponse -> { + listener.onResponse(deleteResponse); + }, listener::onFailure)); + } + + @Override + public void registerAgent(MLAgent mlAgent, ActionListener listener) { + MLRegisterAgentRequest mlRegisterAgentRequest = MLRegisterAgentRequest.builder().mlAgent(mlAgent).build(); + client.execute(MLRegisterAgentAction.INSTANCE, mlRegisterAgentRequest, getMLRegisterAgentResponseActionListener(listener)); + } + + @Override + public void deleteAgent(String agentId, ActionListener listener) { + MLAgentDeleteRequest agentDeleteRequest = new MLAgentDeleteRequest(agentId); + client.execute(MLAgentDeleteAction.INSTANCE, agentDeleteRequest, ActionListener.wrap(deleteResponse -> { + listener.onResponse(deleteResponse); + }, listener::onFailure)); + } + + private ActionListener getMLRegisterAgentResponseActionListener( + ActionListener listener + ) { + ActionListener actionListener = wrapActionListener(listener, res -> { + MLRegisterAgentResponse mlRegisterAgentResponse = MLRegisterAgentResponse.fromActionResponse(res); + return mlRegisterAgentResponse; + }); + return actionListener; + } + private ActionListener getMLTaskResponseActionListener(ActionListener listener) { ActionListener internalListener = ActionListener .wrap(getResponse -> { listener.onResponse(getResponse.getMlTask()); }, listener::onFailure); @@ -266,6 +298,16 @@ private ActionListener getMlDeployModelResponseActionList return actionListener; } + private ActionListener getMlUndeployModelsResponseActionListener( + ActionListener listener + ) { + ActionListener actionListener = wrapActionListener(listener, response -> { + MLUndeployModelsResponse deployModelResponse = MLUndeployModelsResponse.fromActionResponse(response); + return deployModelResponse; + }); + return actionListener; + } + private ActionListener getMlCreateConnectorResponseActionListener( ActionListener listener ) { diff --git a/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java b/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java index 2dda397016..c3933edda6 100644 --- a/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java +++ b/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java @@ -12,7 +12,6 @@ import static org.opensearch.ml.common.input.Constants.KMEANS; import static org.opensearch.ml.common.input.Constants.TRAIN; -import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.List; @@ -31,25 +30,25 @@ import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.MLTask; +import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.dataframe.DataFrame; import org.opensearch.ml.common.dataset.DataFrameInputDataset; -import org.opensearch.ml.common.input.Input; import org.opensearch.ml.common.input.MLInput; -import org.opensearch.ml.common.input.execute.metricscorrelation.MetricsCorrelationInput; import org.opensearch.ml.common.input.parameter.MLAlgoParams; import org.opensearch.ml.common.model.MLModelConfig; import org.opensearch.ml.common.model.MLModelFormat; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.output.MLTrainingOutput; +import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse; import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse; import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse; -import org.opensearch.ml.common.transport.execute.MLExecuteTaskResponse; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupResponse; import org.opensearch.ml.common.transport.register.MLRegisterModelInput; import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; +import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse; public class MachineLearningClientTest { @@ -79,6 +78,9 @@ public class MachineLearningClientTest { @Mock MLDeployModelResponse deployModelResponse; + @Mock + MLUndeployModelsResponse undeployModelsResponse; + @Mock MLCreateConnectorResponse createConnectorResponse; @@ -86,7 +88,7 @@ public class MachineLearningClientTest { MLRegisterModelGroupResponse registerModelGroupResponse; @Mock - MLExecuteTaskResponse mlExecuteTaskResponse; + MLRegisterAgentResponse registerAgentResponse; private String modekId = "test_model_id"; private MLModel mlModel; @@ -163,14 +165,19 @@ public void deploy(String modelId, ActionListener listene listener.onResponse(deployModelResponse); } + @Override + public void undeploy(String[] modelIds, String[] nodeIds, ActionListener listener) { + listener.onResponse(undeployModelsResponse); + } + @Override public void createConnector(MLCreateConnectorInput mlCreateConnectorInput, ActionListener listener) { listener.onResponse(createConnectorResponse); } @Override - public void execute(FunctionName name, Input input, ActionListener listener) { - listener.onResponse(mlExecuteTaskResponse); + public void deleteConnector(String connectorId, ActionListener listener) { + listener.onResponse(deleteResponse); } public void registerModelGroup( @@ -179,6 +186,16 @@ public void registerModelGroup( ) { listener.onResponse(registerModelGroupResponse); } + + @Override + public void registerAgent(MLAgent mlAgent, ActionListener listener) { + listener.onResponse(registerAgentResponse); + } + + @Override + public void deleteAgent(String agentId, ActionListener listener) { + listener.onResponse(deleteResponse); + } }; } @@ -344,6 +361,11 @@ public void deploy() { assertEquals(deployModelResponse, machineLearningClient.deploy("modelId").actionGet()); } + @Test + public void undeploy() { + assertEquals(undeployModelsResponse, machineLearningClient.undeploy(new String[] { "modelId" }, null).actionGet()); + } + @Test public void createConnector() { Map params = Map.ofEntries(Map.entry("endpoint", "endpoint"), Map.entry("temp", "7")); @@ -368,56 +390,18 @@ public void createConnector() { } @Test - public void executeMetricsCorrelation() { - List inputData = new ArrayList<>( - Arrays - .asList( - new float[] { - 0.89451003f, - 4.2006273f, - 0.3697659f, - 2.2458954f, - -4.671612f, - -1.5076426f, - 1.635445f, - -1.1394824f, - -0.7503817f, - 0.98424894f, - -0.38896716f, - 1.0328646f, - 1.9543738f, - -0.5236269f, - 0.14298044f, - 3.2963762f, - 8.1641035f, - 5.717064f, - 7.4869685f, - 2.5987444f, - 11.018798f, - 9.151356f, - 5.7354255f, - 6.862203f, - 3.0524514f, - 4.431755f, - 5.1481285f, - 7.9548607f, - 7.4519925f, - 6.09533f, - 7.634116f, - 8.898271f, - 3.898491f, - 9.447067f, - 8.197385f, - 5.8284273f, - 5.804283f, - 7.089733f, - 9.140584f } - ) - ); - Input metricsCorrelationInput = MetricsCorrelationInput.builder().inputData(inputData).build(); - assertEquals( - mlExecuteTaskResponse, - machineLearningClient.execute(FunctionName.METRICS_CORRELATION, metricsCorrelationInput).actionGet() - ); + public void deleteConnector() { + assertEquals(deleteResponse, machineLearningClient.deleteConnector("connectorId").actionGet()); + } + + @Test + public void testRegisterAgent() { + MLAgent mlAgent = MLAgent.builder().name("Agent name").build(); + assertEquals(registerAgentResponse, machineLearningClient.registerAgent(mlAgent).actionGet()); + } + + @Test + public void deleteAgent() { + assertEquals(deleteResponse, machineLearningClient.deleteAgent("agentId").actionGet()); } } diff --git a/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java b/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java index 591797b758..75a40e3c03 100644 --- a/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java +++ b/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java @@ -6,6 +6,7 @@ package org.opensearch.ml.client; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; import static org.mockito.Answers.RETURNS_DEEP_STUBS; import static org.mockito.ArgumentMatchers.any; @@ -23,7 +24,6 @@ import static org.opensearch.ml.common.input.Constants.TRAINANDPREDICT; import java.io.IOException; -import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; @@ -44,6 +44,7 @@ import org.opensearch.action.search.SearchResponse; import org.opensearch.action.search.ShardSearchFailure; import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.ClusterName; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.bytes.BytesReference; @@ -58,22 +59,24 @@ import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.MLTaskType; +import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.dataframe.DataFrame; import org.opensearch.ml.common.dataset.MLInputDataset; -import org.opensearch.ml.common.input.Input; import org.opensearch.ml.common.input.MLInput; -import org.opensearch.ml.common.input.execute.metricscorrelation.MetricsCorrelationInput; import org.opensearch.ml.common.model.MLModelConfig; import org.opensearch.ml.common.model.MLModelFormat; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.output.MLPredictionOutput; import org.opensearch.ml.common.output.MLTrainingOutput; -import org.opensearch.ml.common.output.Output; -import org.opensearch.ml.common.output.execute.metrics_correlation.MCorrModelTensor; -import org.opensearch.ml.common.output.execute.metrics_correlation.MCorrModelTensors; -import org.opensearch.ml.common.output.execute.metrics_correlation.MetricsCorrelationOutput; import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.common.transport.agent.MLAgentDeleteAction; +import org.opensearch.ml.common.transport.agent.MLAgentDeleteRequest; +import org.opensearch.ml.common.transport.agent.MLRegisterAgentAction; +import org.opensearch.ml.common.transport.agent.MLRegisterAgentRequest; +import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse; +import org.opensearch.ml.common.transport.connector.MLConnectorDeleteAction; +import org.opensearch.ml.common.transport.connector.MLConnectorDeleteRequest; import org.opensearch.ml.common.transport.connector.MLCreateConnectorAction; import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; import org.opensearch.ml.common.transport.connector.MLCreateConnectorRequest; @@ -81,9 +84,6 @@ import org.opensearch.ml.common.transport.deploy.MLDeployModelAction; import org.opensearch.ml.common.transport.deploy.MLDeployModelRequest; import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse; -import org.opensearch.ml.common.transport.execute.MLExecuteTaskAction; -import org.opensearch.ml.common.transport.execute.MLExecuteTaskRequest; -import org.opensearch.ml.common.transport.execute.MLExecuteTaskResponse; import org.opensearch.ml.common.transport.model.MLModelDeleteAction; import org.opensearch.ml.common.transport.model.MLModelDeleteRequest; import org.opensearch.ml.common.transport.model.MLModelGetAction; @@ -109,6 +109,10 @@ import org.opensearch.ml.common.transport.training.MLTrainingTaskAction; import org.opensearch.ml.common.transport.training.MLTrainingTaskRequest; import org.opensearch.ml.common.transport.trainpredict.MLTrainAndPredictionTaskAction; +import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesResponse; +import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsAction; +import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsRequest; +import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.search.aggregations.InternalAggregations; @@ -157,14 +161,23 @@ public class MachineLearningNodeClientTest { @Mock ActionListener deployModelActionListener; + @Mock + ActionListener undeployModelsActionListener; + @Mock ActionListener createConnectorActionListener; + @Mock + ActionListener deleteConnectorActionListener; + @Mock ActionListener registerModelGroupResponseActionListener; @Mock - ActionListener executeTaskResponseActionListener; + ActionListener registerAgentResponseActionListener; + + @Mock + ActionListener deleteAgentActionListener; @InjectMocks MachineLearningNodeClient machineLearningNodeClient; @@ -649,6 +662,33 @@ public void deploy() { assertEquals(status, (argumentCaptor.getValue()).getStatus()); } + @Test + public void undeploy() { + ClusterName clusterName = new ClusterName("clusterName"); + String[] modelIds = new String[] { "modelId" }; + String[] nodeIds = new String[] { "nodeId" }; + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + MLUndeployModelNodesResponse mlUndeployModelNodesResponse = new MLUndeployModelNodesResponse( + clusterName, + Collections.emptyList(), + Collections.emptyList() + ); + MLUndeployModelsResponse output = new MLUndeployModelsResponse(mlUndeployModelNodesResponse); + actionListener.onResponse(output); + return null; + }).when(client).execute(eq(MLUndeployModelsAction.INSTANCE), any(), any()); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLUndeployModelsResponse.class); + machineLearningNodeClient.undeploy(modelIds, nodeIds, undeployModelsActionListener); + + verify(client).execute(eq(MLUndeployModelsAction.INSTANCE), isA(MLUndeployModelsRequest.class), any()); + verify(undeployModelsActionListener).onResponse(argumentCaptor.capture()); + assertEquals(clusterName, (argumentCaptor.getValue()).getResponse().getClusterName()); + assertTrue((argumentCaptor.getValue()).getResponse().getNodes().isEmpty()); + assertFalse((argumentCaptor.getValue()).getResponse().hasFailures()); + } + @Test public void createConnector() { @@ -691,82 +731,68 @@ public void createConnector() { } @Test - public void executeMetricsCorrelation() { - Output metricsCorrelationOutput; - List outputs = new ArrayList<>(); - MCorrModelTensor mCorrModelTensor = MCorrModelTensor - .builder() - .event_pattern(new float[] { 1.0f, 2.0f, 3.0f }) - .event_window(new float[] { 4.0f, 5.0f, 6.0f }) - .suspected_metrics(new long[] { 1, 2 }) - .build(); - List mlModelTensors = Arrays.asList(mCorrModelTensor); - MCorrModelTensors modelTensors = MCorrModelTensors.builder().mCorrModelTensors(mlModelTensors).build(); - outputs.add(modelTensors); - metricsCorrelationOutput = MetricsCorrelationOutput.builder().modelOutput(outputs).build(); + public void deleteConnector() { + + String connectorId = "connectorId"; doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(2); - MLExecuteTaskResponse output = new MLExecuteTaskResponse(FunctionName.METRICS_CORRELATION, metricsCorrelationOutput); + ActionListener actionListener = invocation.getArgument(2); + ShardId shardId = new ShardId(new Index("indexName", "uuid"), 1); + DeleteResponse output = new DeleteResponse(shardId, connectorId, 1, 1, 1, true); actionListener.onResponse(output); return null; - }).when(client).execute(eq(MLExecuteTaskAction.INSTANCE), any(), any()); - - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLExecuteTaskResponse.class); - - List inputData = new ArrayList<>( - Arrays - .asList( - new float[] { - 0.89451003f, - 4.2006273f, - 0.3697659f, - 2.2458954f, - -4.671612f, - -1.5076426f, - 1.635445f, - -1.1394824f, - -0.7503817f, - 0.98424894f, - -0.38896716f, - 1.0328646f, - 1.9543738f, - -0.5236269f, - 0.14298044f, - 3.2963762f, - 8.1641035f, - 5.717064f, - 7.4869685f, - 2.5987444f, - 11.018798f, - 9.151356f, - 5.7354255f, - 6.862203f, - 3.0524514f, - 4.431755f, - 5.1481285f, - 7.9548607f, - 7.4519925f, - 6.09533f, - 7.634116f, - 8.898271f, - 3.898491f, - 9.447067f, - 8.197385f, - 5.8284273f, - 5.804283f, - 7.089733f, - 9.140584f } - ) - ); - Input metricsCorrelationInput = MetricsCorrelationInput.builder().inputData(inputData).build(); + }).when(client).execute(eq(MLConnectorDeleteAction.INSTANCE), any(), any()); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(DeleteResponse.class); + + machineLearningNodeClient.deleteConnector(connectorId, deleteConnectorActionListener); - machineLearningNodeClient.execute(FunctionName.METRICS_CORRELATION, metricsCorrelationInput, executeTaskResponseActionListener); + verify(client).execute(eq(MLConnectorDeleteAction.INSTANCE), isA(MLConnectorDeleteRequest.class), any()); + verify(deleteConnectorActionListener).onResponse(argumentCaptor.capture()); + assertEquals(connectorId, (argumentCaptor.getValue()).getId()); + } + + @Test + public void testRegisterAgent() { + String agentId = "agentId"; + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + MLRegisterAgentResponse output = new MLRegisterAgentResponse(agentId); + actionListener.onResponse(output); + return null; + }).when(client).execute(eq(MLRegisterAgentAction.INSTANCE), any(), any()); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLRegisterAgentResponse.class); + MLAgent mlAgent = MLAgent.builder().name("Agent name").build(); + + machineLearningNodeClient.registerAgent(mlAgent, registerAgentResponseActionListener); + + verify(client).execute(eq(MLRegisterAgentAction.INSTANCE), isA(MLRegisterAgentRequest.class), any()); + verify(registerAgentResponseActionListener).onResponse(argumentCaptor.capture()); + assertEquals(agentId, (argumentCaptor.getValue()).getAgentId()); + } + + @Test + public void deleteAgent() { + + String agentId = "agentId"; + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + ShardId shardId = new ShardId(new Index("indexName", "uuid"), 1); + DeleteResponse output = new DeleteResponse(shardId, agentId, 1, 1, 1, true); + actionListener.onResponse(output); + return null; + }).when(client).execute(eq(MLAgentDeleteAction.INSTANCE), any(), any()); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(DeleteResponse.class); + + machineLearningNodeClient.deleteAgent(agentId, deleteAgentActionListener); - verify(client).execute(eq(MLExecuteTaskAction.INSTANCE), isA(MLExecuteTaskRequest.class), any()); - verify(executeTaskResponseActionListener).onResponse(argumentCaptor.capture()); - assertEquals(FunctionName.METRICS_CORRELATION, argumentCaptor.getValue().getFunctionName()); - assertTrue(argumentCaptor.getValue().getOutput() instanceof MetricsCorrelationOutput); + verify(client).execute(eq(MLAgentDeleteAction.INSTANCE), isA(MLAgentDeleteRequest.class), any()); + verify(deleteAgentActionListener).onResponse(argumentCaptor.capture()); + assertEquals(agentId, (argumentCaptor.getValue()).getId()); } private SearchResponse createSearchResponse(ToXContentObject o) throws IOException { @@ -794,4 +820,5 @@ private SearchResponse createSearchResponse(ToXContentObject o) throws IOExcepti SearchResponse.Clusters.EMPTY ); } + }