From d920a451ed24fd222e55ab4fad8b17bfa5abac69 Mon Sep 17 00:00:00 2001 From: Daniel Widdis Date: Sat, 2 Dec 2023 21:04:33 -0800 Subject: [PATCH] Add Delete Agent to MLClient Signed-off-by: Daniel Widdis --- .../ml/client/MachineLearningClient.java | 13 +++++++++ .../ml/client/MachineLearningNodeClient.java | 10 +++++++ .../ml/client/MachineLearningClientTest.java | 10 +++++++ .../client/MachineLearningNodeClientTest.java | 27 +++++++++++++++++++ 4 files changed, 60 insertions(+) diff --git a/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java b/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java index 201ada8804..252748e179 100644 --- a/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java +++ b/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java @@ -392,4 +392,17 @@ default ActionFuture registerAgent(MLAgent mlAgent) { * @param mlAgent Register agent input, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#register-agent */ void registerAgent(MLAgent mlAgent, ActionListener listener); + + /** + * Delete agent + * @param agentId The id of the agent to delete + * @return the result future + */ + default ActionFuture deleteAgent(String agentId) { + PlainActionFuture actionFuture = PlainActionFuture.newFuture(); + deleteAgent(agentId, actionFuture); + return actionFuture; + } + + void deleteAgent(String agentId, ActionListener listener); } diff --git a/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java b/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java index 9a0075dd4b..ea6fba7e9a 100644 --- a/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java +++ b/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java @@ -33,6 +33,8 @@ import org.opensearch.ml.common.input.parameter.MLAlgoParams; import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.common.transport.agent.MLAgentDeleteAction; +import org.opensearch.ml.common.transport.agent.MLAgentDeleteRequest; import org.opensearch.ml.common.transport.agent.MLRegisterAgentAction; import org.opensearch.ml.common.transport.agent.MLRegisterAgentRequest; import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse; @@ -282,6 +284,14 @@ public void registerAgent(MLAgent mlAgent, ActionListener listener) { + MLAgentDeleteRequest agentDeleteRequest = new MLAgentDeleteRequest(agentId); + client.execute(MLAgentDeleteAction.INSTANCE, agentDeleteRequest, ActionListener.wrap(deleteResponse -> { + listener.onResponse(deleteResponse); + }, listener::onFailure)); + } + private ActionListener getMLRegisterAgentResponseActionListener( ActionListener listener ) { diff --git a/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java b/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java index 52c83cdaad..d5a77f0f86 100644 --- a/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java +++ b/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java @@ -202,6 +202,11 @@ public void getTool(String toolName, ActionListener listener) { public void registerAgent(MLAgent mlAgent, ActionListener listener) { listener.onResponse(registerAgentResponse); } + + @Override + public void deleteAgent(String agentId, ActionListener listener) { + listener.onResponse(deleteResponse); + } }; } @@ -405,4 +410,9 @@ public void testRegisterAgent() { MLAgent mlAgent = MLAgent.builder().name("Agent name").build(); assertEquals(registerAgentResponse, machineLearningClient.registerAgent(mlAgent).actionGet()); } + + @Test + public void deleteAgent() { + assertEquals(deleteResponse, machineLearningClient.deleteAgent("agentId").actionGet()); + } } diff --git a/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java b/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java index faa427a3d0..75a40e3c03 100644 --- a/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java +++ b/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java @@ -70,6 +70,8 @@ import org.opensearch.ml.common.output.MLPredictionOutput; import org.opensearch.ml.common.output.MLTrainingOutput; import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.common.transport.agent.MLAgentDeleteAction; +import org.opensearch.ml.common.transport.agent.MLAgentDeleteRequest; import org.opensearch.ml.common.transport.agent.MLRegisterAgentAction; import org.opensearch.ml.common.transport.agent.MLRegisterAgentRequest; import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse; @@ -174,6 +176,9 @@ public class MachineLearningNodeClientTest { @Mock ActionListener registerAgentResponseActionListener; + @Mock + ActionListener deleteAgentActionListener; + @InjectMocks MachineLearningNodeClient machineLearningNodeClient; @@ -768,6 +773,28 @@ public void testRegisterAgent() { assertEquals(agentId, (argumentCaptor.getValue()).getAgentId()); } + @Test + public void deleteAgent() { + + String agentId = "agentId"; + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + ShardId shardId = new ShardId(new Index("indexName", "uuid"), 1); + DeleteResponse output = new DeleteResponse(shardId, agentId, 1, 1, 1, true); + actionListener.onResponse(output); + return null; + }).when(client).execute(eq(MLAgentDeleteAction.INSTANCE), any(), any()); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(DeleteResponse.class); + + machineLearningNodeClient.deleteAgent(agentId, deleteAgentActionListener); + + verify(client).execute(eq(MLAgentDeleteAction.INSTANCE), isA(MLAgentDeleteRequest.class), any()); + verify(deleteAgentActionListener).onResponse(argumentCaptor.capture()); + assertEquals(agentId, (argumentCaptor.getValue()).getId()); + } + private SearchResponse createSearchResponse(ToXContentObject o) throws IOException { XContentBuilder content = o.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS);