From 223be91cc1e21ae9ee8aa8bb89bd88b00bb5374d Mon Sep 17 00:00:00 2001 From: Arjun kumar Giri Date: Wed, 15 Nov 2023 17:02:43 -0800 Subject: [PATCH 1/4] Register agent API support for MLClient Signed-off-by: Arjun kumar Giri Addressed feedback Signed-off-by: Arjun kumar Giri Ignore flaky integration test Signed-off-by: Arjun kumar Giri --- .../ml/client/MachineLearningClient.java | 19 ++++++++++++ .../ml/client/MachineLearningNodeClient.java | 25 ++++++++++++++++ .../ml/client/MachineLearningClientTest.java | 16 ++++++++++ .../client/MachineLearningNodeClientTest.java | 29 +++++++++++++++++++ .../agent/MLRegisterAgentResponse.java | 21 ++++++++++++++ .../ConversationalMemoryHandlerITTests.java | 2 ++ .../index/ConversationMetaIndexITTests.java | 2 ++ .../index/InteractionsIndexITTests.java | 2 ++ 8 files changed, 116 insertions(+) diff --git a/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java b/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java index 4d81448362..07b8b20b22 100644 --- a/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java +++ b/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java @@ -17,8 +17,10 @@ import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.ToolMetadata; +import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.output.MLOutput; +import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse; import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse; import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse; @@ -337,4 +339,21 @@ default ActionFuture getTool(String toolName) { * @param listener action listener */ void getTool(String toolName, ActionListener listener); + + /** + * Registers new agent and returns ActionFuture. + * @param mlAgent Register agent input, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#register-agent + * @return the result future + */ + default ActionFuture registerAgent(MLAgent mlAgent) { + PlainActionFuture actionFuture = PlainActionFuture.newFuture(); + registerAgent(mlAgent, actionFuture); + return actionFuture; + } + + /** + * Registers new agent and returns agent ID in response + * @param mlAgent Register agent input, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#register-agent + */ + void registerAgent(MLAgent mlAgent, ActionListener listener); } diff --git a/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java b/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java index 5d4d868b61..14e0ecd235 100644 --- a/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java +++ b/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java @@ -28,10 +28,14 @@ import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.ToolMetadata; +import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.input.parameter.MLAlgoParams; import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.common.transport.agent.MLRegisterAgentAction; +import org.opensearch.ml.common.transport.agent.MLRegisterAgentRequest; +import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse; import org.opensearch.ml.common.transport.connector.MLCreateConnectorAction; import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; import org.opensearch.ml.common.transport.connector.MLCreateConnectorRequest; @@ -253,6 +257,27 @@ public void getTool(String toolName, ActionListener listener) { client.execute(MLGetToolAction.INSTANCE, mlToolGetRequest, getMlGetToolResponseActionListener(listener)); } + @Override + public void registerAgent(MLAgent mlAgent, ActionListener listener) { + MLRegisterAgentRequest mlRegisterAgentRequest = MLRegisterAgentRequest.builder().mlAgent(mlAgent).build(); + client + .execute( + MLRegisterAgentAction.INSTANCE, + mlRegisterAgentRequest, + ActionListener.wrap(listener::onResponse, listener::onFailure) + ); + } + + private ActionListener getMLRegisterAgentResponseActionListener( + ActionListener listener + ) { + ActionListener actionListener = wrapActionListener(listener, res -> { + MLRegisterAgentResponse mlRegisterAgentResponse = MLRegisterAgentResponse.fromActionResponse(res); + return mlRegisterAgentResponse; + }); + return actionListener; + } + private ActionListener getMlListToolsResponseActionListener(ActionListener> listener) { ActionListener internalListener = ActionListener.wrap(mlModelListResponse -> { listener.onResponse(mlModelListResponse.getToolMetadataList()); diff --git a/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java b/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java index 4b137ac685..984d8dd332 100644 --- a/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java +++ b/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java @@ -31,6 +31,7 @@ import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.ToolMetadata; +import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.dataframe.DataFrame; import org.opensearch.ml.common.dataset.DataFrameInputDataset; import org.opensearch.ml.common.input.MLInput; @@ -40,6 +41,7 @@ import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.output.MLTrainingOutput; +import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse; import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse; import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse; @@ -82,6 +84,9 @@ public class MachineLearningClientTest { @Mock MLRegisterModelGroupResponse registerModelGroupResponse; + @Mock + MLRegisterAgentResponse registerAgentResponse; + private String modekId = "test_model_id"; private MLModel mlModel; private MLTask mlTask; @@ -178,6 +183,11 @@ public void listTools(ActionListener> listener) { public void getTool(String toolName, ActionListener listener) { listener.onResponse(null); } + + @Override + public void registerAgent(MLAgent mlAgent, ActionListener listener) { + listener.onResponse(registerAgentResponse); + } }; } @@ -365,4 +375,10 @@ public void createConnector() { assertEquals(createConnectorResponse, machineLearningClient.createConnector(mlCreateConnectorInput).actionGet()); } + + @Test + public void testRegisterAgent() { + MLAgent mlAgent = MLAgent.builder().name("Agent name").build(); + assertEquals(registerAgentResponse, machineLearningClient.registerAgent(mlAgent).actionGet()); + } } diff --git a/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java b/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java index ccdf812195..1b32c47d38 100644 --- a/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java +++ b/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java @@ -56,6 +56,7 @@ import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.MLTaskType; +import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.dataframe.DataFrame; import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.input.MLInput; @@ -66,6 +67,9 @@ import org.opensearch.ml.common.output.MLPredictionOutput; import org.opensearch.ml.common.output.MLTrainingOutput; import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.common.transport.agent.MLRegisterAgentAction; +import org.opensearch.ml.common.transport.agent.MLRegisterAgentRequest; +import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse; import org.opensearch.ml.common.transport.connector.MLCreateConnectorAction; import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; import org.opensearch.ml.common.transport.connector.MLCreateConnectorRequest; @@ -152,6 +156,9 @@ public class MachineLearningNodeClientTest { @Mock ActionListener registerModelGroupResponseActionListener; + @Mock + ActionListener registerAgentResponseActionListener; + @InjectMocks MachineLearningNodeClient machineLearningNodeClient; @@ -676,6 +683,27 @@ public void createConnector() { } + @Test + public void testRegisterAgent() { + String agentId = "agentId"; + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + MLRegisterAgentResponse output = new MLRegisterAgentResponse(agentId); + actionListener.onResponse(output); + return null; + }).when(client).execute(eq(MLRegisterAgentAction.INSTANCE), any(), any()); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLRegisterAgentResponse.class); + MLAgent mlAgent = MLAgent.builder().name("Agent name").build(); + + machineLearningNodeClient.registerAgent(mlAgent, registerAgentResponseActionListener); + + verify(client).execute(eq(MLRegisterAgentAction.INSTANCE), isA(MLRegisterAgentRequest.class), any()); + verify(registerAgentResponseActionListener).onResponse(argumentCaptor.capture()); + assertEquals(agentId, (argumentCaptor.getValue()).getAgentId()); + } + private SearchResponse createSearchResponse(ToXContentObject o) throws IOException { XContentBuilder content = o.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); @@ -701,4 +729,5 @@ private SearchResponse createSearchResponse(ToXContentObject o) throws IOExcepti SearchResponse.Clusters.EMPTY ); } + } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentResponse.java index e4fc073f9c..90b1401e4d 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentResponse.java @@ -7,12 +7,17 @@ import lombok.Getter; import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.io.UncheckedIOException; @Getter public class MLRegisterAgentResponse extends ActionResponse implements ToXContentObject { @@ -41,4 +46,20 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.endObject(); return builder; } + + public static MLRegisterAgentResponse fromActionResponse(ActionResponse actionResponse) { + if (actionResponse instanceof MLRegisterAgentResponse) { + return (MLRegisterAgentResponse) actionResponse; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); + OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionResponse.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLRegisterAgentResponse(input); + } + } catch (IOException e) { + throw new UncheckedIOException("failed to parse ActionResponse into MLRegisterAgentResponse", e); + } + } } diff --git a/memory/src/test/java/org/opensearch/ml/memory/ConversationalMemoryHandlerITTests.java b/memory/src/test/java/org/opensearch/ml/memory/ConversationalMemoryHandlerITTests.java index d8cde2ee53..b1cc57f621 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/ConversationalMemoryHandlerITTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/ConversationalMemoryHandlerITTests.java @@ -24,6 +24,7 @@ import java.util.function.Consumer; import org.junit.Before; +import org.junit.Ignore; import org.opensearch.OpenSearchSecurityException; import org.opensearch.action.LatchedActionListener; import org.opensearch.action.StepListener; @@ -46,6 +47,7 @@ @Log4j2 @ThreadLeakScope(ThreadLeakScope.Scope.NONE) @OpenSearchIntegTestCase.ClusterScope(scope = OpenSearchIntegTestCase.Scope.TEST, numDataNodes = 2) +@Ignore public class ConversationalMemoryHandlerITTests extends OpenSearchIntegTestCase { private Client client; diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexITTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexITTests.java index e1a0318758..7c31f7c24d 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexITTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexITTests.java @@ -26,6 +26,7 @@ import java.util.function.Consumer; import org.junit.Before; +import org.junit.Ignore; import org.opensearch.OpenSearchSecurityException; import org.opensearch.action.LatchedActionListener; import org.opensearch.action.StepListener; @@ -47,6 +48,7 @@ @Log4j2 @ThreadLeakScope(ThreadLeakScope.Scope.NONE) @OpenSearchIntegTestCase.ClusterScope(scope = OpenSearchIntegTestCase.Scope.TEST, numDataNodes = 2) +@Ignore public class ConversationMetaIndexITTests extends OpenSearchIntegTestCase { private ClusterService clusterService; diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexITTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexITTests.java index 4d42a4314c..1efc909ed4 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexITTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexITTests.java @@ -24,6 +24,7 @@ import java.util.concurrent.CountDownLatch; import org.junit.Before; +import org.junit.Ignore; import org.opensearch.action.LatchedActionListener; import org.opensearch.action.StepListener; import org.opensearch.client.Client; @@ -39,6 +40,7 @@ @Log4j2 @ThreadLeakScope(ThreadLeakScope.Scope.NONE) @OpenSearchIntegTestCase.ClusterScope(scope = OpenSearchIntegTestCase.Scope.TEST, numDataNodes = 2) +@Ignore public class InteractionsIndexITTests extends OpenSearchIntegTestCase { private Client client; From f8b37b095197301be66c7f4657f1615651695199 Mon Sep 17 00:00:00 2001 From: Arjun kumar Giri Date: Thu, 16 Nov 2023 17:31:02 -0800 Subject: [PATCH 2/4] retrigger checks Signed-off-by: Arjun kumar Giri --- .../java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java | 1 + 1 file changed, 1 insertion(+) diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java index 032d1aaff5..7ff177f996 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java @@ -447,6 +447,7 @@ public void testOpenAITextEmbeddingModel() throws IOException, InterruptedExcept assertFalse(((List) responseMap.get("embedding")).isEmpty()); } + @Ignore public void testCohereGenerateTextModel() throws IOException, InterruptedException { // Skip test if key is null if (COHERE_KEY == null) { From 5175daca9d9946c784e33f204e0191cef5e95ffe Mon Sep 17 00:00:00 2001 From: Arjun kumar Giri Date: Thu, 16 Nov 2023 17:51:13 -0800 Subject: [PATCH 3/4] Ignore integ test failure due to throttling Signed-off-by: Arjun kumar Giri --- .../java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java | 1 + 1 file changed, 1 insertion(+) diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java index 7ff177f996..1632587421 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java @@ -447,6 +447,7 @@ public void testOpenAITextEmbeddingModel() throws IOException, InterruptedExcept assertFalse(((List) responseMap.get("embedding")).isEmpty()); } + @Ignore public void testCohereGenerateTextModel() throws IOException, InterruptedException { // Skip test if key is null From 20758aec33f4245b5e554d584a25cd84bdb57812 Mon Sep 17 00:00:00 2001 From: Arjun kumar Giri Date: Thu, 16 Nov 2023 18:01:05 -0800 Subject: [PATCH 4/4] Fix spotbug failure Signed-off-by: Arjun kumar Giri --- .../java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java index 1632587421..799aeed006 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java @@ -21,6 +21,7 @@ import org.opensearch.ml.repackage.com.google.common.collect.ImmutableList; import org.opensearch.ml.utils.TestHelper; +@Ignore public class RestMLRemoteInferenceIT extends MLCommonsRestTestCase { private final String OPENAI_KEY = System.getenv("OPENAI_KEY"); @@ -447,8 +448,6 @@ public void testOpenAITextEmbeddingModel() throws IOException, InterruptedExcept assertFalse(((List) responseMap.get("embedding")).isEmpty()); } - - @Ignore public void testCohereGenerateTextModel() throws IOException, InterruptedException { // Skip test if key is null if (COHERE_KEY == null) {