From deb51f6163d884bff5db3082ab661bc53b426677 Mon Sep 17 00:00:00 2001 From: Jackie Han Date: Mon, 8 Jan 2024 17:33:26 -0800 Subject: [PATCH] Add GetTool API and ListTools API (#1818) * Add GetTool API and ListTools API Signed-off-by: Jackie Han * rename externalTools parameter as toolMetadataList Signed-off-by: Jackie Han * spotless apply Signed-off-by: Jackie Han * add more unit tests Signed-off-by: Jackie Han * tweak unit test cases Signed-off-by: Jackie Han * increase test coverage Signed-off-by: Jackie Han * increase test coverage Signed-off-by: Jackie Han * add more tests Signed-off-by: Jackie Han * Include Type and Version in GetTool and ListTools API responses Signed-off-by: Jackie Han * tweak ListTools result format Signed-off-by: Jackie Han * change term no version found to undefined Signed-off-by: Jackie Han --------- Signed-off-by: Jackie Han --- .../ml/client/MachineLearningClient.java | 38 ++++++ .../ml/client/MachineLearningNodeClient.java | 44 +++++++ .../ml/client/MachineLearningClientTest.java | 32 +++++ .../client/MachineLearningNodeClientTest.java | 63 ++++++++++ .../opensearch/ml/common/ToolMetadata.java | 118 ++++++++++++++++++ .../transport/tools/MLGetToolAction.java | 16 +++ .../transport/tools/MLListToolsAction.java | 16 +++ .../transport/tools/MLToolGetRequest.java | 84 +++++++++++++ .../transport/tools/MLToolGetResponse.java | 67 ++++++++++ .../transport/tools/MLToolsListRequest.java | 72 +++++++++++ .../transport/tools/MLToolsListResponse.java | 75 +++++++++++ .../ml/common/ToolMetadataTests.java | 92 ++++++++++++++ .../tools/MLToolGetRequestTests.java | 98 +++++++++++++++ .../tools/MLToolGetResponseTests.java | 92 ++++++++++++++ .../tools/MLToolsListRequestTests.java | 111 ++++++++++++++++ .../tools/MLToolsListResponseTests.java | 102 +++++++++++++++ .../opensearch/ml/engine/tools/AgentTool.java | 10 ++ .../ml/engine/tools/CatIndexTool.java | 10 ++ .../ml/engine/tools/MLModelTool.java | 10 ++ .../ml/plugin/MachineLearningPlugin.java | 16 ++- .../ml/rest/RestMLGetToolAction.java | 88 +++++++++++++ .../ml/rest/RestMLListToolsAction.java | 85 +++++++++++++ .../ml/tools/GetToolTransportAction.java | 56 +++++++++ .../ml/tools/ListToolsTransportAction.java | 47 +++++++ .../opensearch/ml/plugin/DummyWrongTool.java | 10 ++ .../ml/rest/RestMLGetToolActionTests.java | 118 ++++++++++++++++++ .../ml/rest/RestMLListToolsActionTests.java | 113 +++++++++++++++++ .../ml/tools/GetToolTransportActionTests.java | 73 +++++++++++ .../tools/ListToolsTransportActionTests.java | 77 ++++++++++++ .../opensearch/ml/common/spi/tools/Tool.java | 12 ++ 30 files changed, 1843 insertions(+), 2 deletions(-) create mode 100644 common/src/main/java/org/opensearch/ml/common/ToolMetadata.java create mode 100644 common/src/main/java/org/opensearch/ml/common/transport/tools/MLGetToolAction.java create mode 100644 common/src/main/java/org/opensearch/ml/common/transport/tools/MLListToolsAction.java create mode 100644 common/src/main/java/org/opensearch/ml/common/transport/tools/MLToolGetRequest.java create mode 100644 common/src/main/java/org/opensearch/ml/common/transport/tools/MLToolGetResponse.java create mode 100644 common/src/main/java/org/opensearch/ml/common/transport/tools/MLToolsListRequest.java create mode 100644 common/src/main/java/org/opensearch/ml/common/transport/tools/MLToolsListResponse.java create mode 100644 common/src/test/java/org/opensearch/ml/common/ToolMetadataTests.java create mode 100644 common/src/test/java/org/opensearch/ml/common/transport/tools/MLToolGetRequestTests.java create mode 100644 common/src/test/java/org/opensearch/ml/common/transport/tools/MLToolGetResponseTests.java create mode 100644 common/src/test/java/org/opensearch/ml/common/transport/tools/MLToolsListRequestTests.java create mode 100644 common/src/test/java/org/opensearch/ml/common/transport/tools/MLToolsListResponseTests.java create mode 100644 plugin/src/main/java/org/opensearch/ml/rest/RestMLGetToolAction.java create mode 100644 plugin/src/main/java/org/opensearch/ml/rest/RestMLListToolsAction.java create mode 100644 plugin/src/main/java/org/opensearch/ml/tools/GetToolTransportAction.java create mode 100644 plugin/src/main/java/org/opensearch/ml/tools/ListToolsTransportAction.java create mode 100644 plugin/src/test/java/org/opensearch/ml/rest/RestMLGetToolActionTests.java create mode 100644 plugin/src/test/java/org/opensearch/ml/rest/RestMLListToolsActionTests.java create mode 100644 plugin/src/test/java/org/opensearch/ml/tools/GetToolTransportActionTests.java create mode 100644 plugin/src/test/java/org/opensearch/ml/tools/ListToolsTransportActionTests.java diff --git a/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java b/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java index 678d6d0f14..b115eb91c9 100644 --- a/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java +++ b/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java @@ -5,6 +5,7 @@ package org.opensearch.ml.client; +import java.util.List; import java.util.Map; import org.opensearch.action.delete.DeleteResponse; @@ -17,6 +18,7 @@ import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.MLTask; +import org.opensearch.ml.common.ToolMetadata; import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.input.Input; import org.opensearch.ml.common.input.MLInput; @@ -390,4 +392,40 @@ default ActionFuture deleteAgent(String agentId) { void deleteAgent(String agentId, ActionListener listener); + /** + * Get a list of ToolMetadata and return ActionFuture. + * For more info on list tools, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#list-tools + * @return ActionFuture of a list of tool metadata + */ + default ActionFuture> listTools() { + PlainActionFuture> actionFuture = PlainActionFuture.newFuture(); + listTools(actionFuture); + return actionFuture; + } + + /** + * List ToolMetadata and return a list of ToolMetadata in listener + * For more info on get tools, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#list-tools + * @param listener action listener + */ + void listTools(ActionListener> listener); + + /** + * Get ToolMetadata and return ActionFuture. + * For more info on get tool, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#get-tool + * @return ActionFuture of tool metadata + */ + default ActionFuture getTool(String toolName) { + PlainActionFuture actionFuture = PlainActionFuture.newFuture(); + getTool(toolName, actionFuture); + return actionFuture; + } + + /** + * Get ToolMetadata and return ToolMetadata in listener + * For more info on get tool, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#get-tool + * @param listener action listener + */ + void getTool(String toolName, ActionListener listener); + } diff --git a/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java b/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java index 5c550ef9d5..acf171872d 100644 --- a/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java +++ b/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java @@ -14,6 +14,7 @@ import static org.opensearch.ml.common.input.InputHelper.getAction; import static org.opensearch.ml.common.input.InputHelper.getFunctionName; +import java.util.List; import java.util.Map; import java.util.function.Function; @@ -26,6 +27,7 @@ import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.MLTask; +import org.opensearch.ml.common.ToolMetadata; import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.input.Input; import org.opensearch.ml.common.input.MLInput; @@ -71,6 +73,12 @@ import org.opensearch.ml.common.transport.task.MLTaskGetRequest; import org.opensearch.ml.common.transport.task.MLTaskGetResponse; import org.opensearch.ml.common.transport.task.MLTaskSearchAction; +import org.opensearch.ml.common.transport.tools.MLGetToolAction; +import org.opensearch.ml.common.transport.tools.MLListToolsAction; +import org.opensearch.ml.common.transport.tools.MLToolGetRequest; +import org.opensearch.ml.common.transport.tools.MLToolGetResponse; +import org.opensearch.ml.common.transport.tools.MLToolsListRequest; +import org.opensearch.ml.common.transport.tools.MLToolsListResponse; import org.opensearch.ml.common.transport.training.MLTrainingTaskAction; import org.opensearch.ml.common.transport.training.MLTrainingTaskRequest; import org.opensearch.ml.common.transport.trainpredict.MLTrainAndPredictionTaskAction; @@ -287,6 +295,42 @@ public void deleteAgent(String agentId, ActionListener listener) }, listener::onFailure)); } + @Override + public void listTools(ActionListener> listener) { + MLToolsListRequest mlToolsListRequest = MLToolsListRequest.builder().build(); + + client.execute(MLListToolsAction.INSTANCE, mlToolsListRequest, getMlListToolsResponseActionListener(listener)); + } + + @Override + public void getTool(String toolName, ActionListener listener) { + MLToolGetRequest mlToolGetRequest = MLToolGetRequest.builder().toolName(toolName).build(); + + client.execute(MLGetToolAction.INSTANCE, mlToolGetRequest, getMlGetToolResponseActionListener(listener)); + } + + private ActionListener getMlListToolsResponseActionListener(ActionListener> listener) { + ActionListener internalListener = ActionListener.wrap(mlModelListResponse -> { + listener.onResponse(mlModelListResponse.getToolMetadataList()); + }, listener::onFailure); + ActionListener actionListener = wrapActionListener(internalListener, res -> { + MLToolsListResponse getResponse = MLToolsListResponse.fromActionResponse(res); + return getResponse; + }); + return actionListener; + } + + private ActionListener getMlGetToolResponseActionListener(ActionListener listener) { + ActionListener internalListener = ActionListener.wrap(mlModelGetResponse -> { + listener.onResponse(mlModelGetResponse.getToolMetadata()); + }, listener::onFailure); + ActionListener actionListener = wrapActionListener(internalListener, res -> { + MLToolGetResponse getResponse = MLToolGetResponse.fromActionResponse(res); + return getResponse; + }); + return actionListener; + } + private ActionListener getMLRegisterAgentResponseActionListener( ActionListener listener ) { diff --git a/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java b/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java index 52ce5a2ef2..ccc0e050e9 100644 --- a/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java +++ b/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java @@ -31,6 +31,7 @@ import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.MLTask; +import org.opensearch.ml.common.ToolMetadata; import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.dataframe.DataFrame; import org.opensearch.ml.common.dataset.DataFrameInputDataset; @@ -100,6 +101,8 @@ public class MachineLearningClientTest { private String modekId = "test_model_id"; private MLModel mlModel; private MLTask mlTask; + private ToolMetadata toolMetadata; + private List toolsList = new ArrayList<>(); @Before public void setUp() { @@ -111,6 +114,15 @@ public void setUp() { String modelContent = "test content"; mlModel = MLModel.builder().algorithm(FunctionName.KMEANS).name("test").content(modelContent).build(); + toolMetadata = ToolMetadata + .builder() + .name("MathTool") + .description("Use this tool to calculate any math problem.") + .type("MathTool") + .version(null) + .build(); + toolsList.add(toolMetadata); + machineLearningClient = new MachineLearningClient() { @Override public void predict(String modelId, MLInput mlInput, ActionListener listener) { @@ -192,6 +204,16 @@ public void deleteConnector(String connectorId, ActionListener l listener.onResponse(deleteResponse); } + @Override + public void listTools(ActionListener> listener) { + listener.onResponse(toolsList); + } + + @Override + public void getTool(String toolName, ActionListener listener) { + listener.onResponse(toolMetadata); + } + public void registerModelGroup( MLRegisterModelGroupInput mlRegisterModelGroupInput, ActionListener listener @@ -470,4 +492,14 @@ public void testRegisterAgent() { public void deleteAgent() { assertEquals(deleteResponse, machineLearningClient.deleteAgent("agentId").actionGet()); } + + @Test + public void getTool() { + assertEquals(toolMetadata, machineLearningClient.getTool("MathTool").actionGet()); + } + + @Test + public void listTools() { + assertEquals(toolMetadata, machineLearningClient.listTools().actionGet().get(0)); + } } diff --git a/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java b/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java index 34c13f16e0..f81b20747f 100644 --- a/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java +++ b/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java @@ -60,6 +60,7 @@ import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.MLTaskType; +import org.opensearch.ml.common.ToolMetadata; import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.dataframe.DataFrame; import org.opensearch.ml.common.dataset.MLInputDataset; @@ -116,6 +117,12 @@ import org.opensearch.ml.common.transport.task.MLTaskGetRequest; import org.opensearch.ml.common.transport.task.MLTaskGetResponse; import org.opensearch.ml.common.transport.task.MLTaskSearchAction; +import org.opensearch.ml.common.transport.tools.MLGetToolAction; +import org.opensearch.ml.common.transport.tools.MLListToolsAction; +import org.opensearch.ml.common.transport.tools.MLToolGetRequest; +import org.opensearch.ml.common.transport.tools.MLToolGetResponse; +import org.opensearch.ml.common.transport.tools.MLToolsListRequest; +import org.opensearch.ml.common.transport.tools.MLToolsListResponse; import org.opensearch.ml.common.transport.training.MLTrainingTaskAction; import org.opensearch.ml.common.transport.training.MLTrainingTaskRequest; import org.opensearch.ml.common.transport.trainpredict.MLTrainAndPredictionTaskAction; @@ -192,6 +199,12 @@ public class MachineLearningNodeClientTest { @Mock ActionListener deleteAgentActionListener; + @Mock + ActionListener> listToolsActionListener; + + @Mock + ActionListener getToolActionListener; + @InjectMocks MachineLearningNodeClient machineLearningNodeClient; @@ -887,6 +900,56 @@ public void deleteAgent() { assertEquals(agentId, (argumentCaptor.getValue()).getId()); } + @Test + public void getTool() { + ToolMetadata toolMetadata = ToolMetadata + .builder() + .name("MathTool") + .description("Use this tool to calculate any math problem.") + .build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + MLToolGetResponse output = MLToolGetResponse.builder().toolMetadata(toolMetadata).build(); + actionListener.onResponse(output); + return null; + }).when(client).execute(eq(MLGetToolAction.INSTANCE), any(), any()); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(ToolMetadata.class); + machineLearningNodeClient.getTool("MathTool", getToolActionListener); + + verify(client).execute(eq(MLGetToolAction.INSTANCE), isA(MLToolGetRequest.class), any()); + verify(getToolActionListener).onResponse(argumentCaptor.capture()); + assertEquals("MathTool", argumentCaptor.getValue().getName()); + assertEquals("Use this tool to calculate any math problem.", argumentCaptor.getValue().getDescription()); + } + + @Test + public void listTools() { + List toolMetadataList = new ArrayList<>(); + ToolMetadata wikipediaTool = ToolMetadata + .builder() + .name("WikipediaTool") + .description("Use this tool to search general knowledge on wikipedia.") + .build(); + toolMetadataList.add(wikipediaTool); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + MLToolsListResponse output = MLToolsListResponse.builder().toolMetadata(toolMetadataList).build(); + actionListener.onResponse(output); + return null; + }).when(client).execute(eq(MLListToolsAction.INSTANCE), any(), any()); + + ArgumentCaptor> argumentCaptor = ArgumentCaptor.forClass(List.class); + machineLearningNodeClient.listTools(listToolsActionListener); + + verify(client).execute(eq(MLListToolsAction.INSTANCE), isA(MLToolsListRequest.class), any()); + verify(listToolsActionListener).onResponse(argumentCaptor.capture()); + assertEquals("WikipediaTool", argumentCaptor.getValue().get(0).getName()); + assertEquals("Use this tool to search general knowledge on wikipedia.", argumentCaptor.getValue().get(0).getDescription()); + } + private SearchResponse createSearchResponse(ToXContentObject o) throws IOException { XContentBuilder content = o.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); diff --git a/common/src/main/java/org/opensearch/ml/common/ToolMetadata.java b/common/src/main/java/org/opensearch/ml/common/ToolMetadata.java new file mode 100644 index 0000000000..fa9c29ead5 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/ToolMetadata.java @@ -0,0 +1,118 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.ml.common; + +import lombok.Builder; +import lombok.Getter; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; + +import java.io.IOException; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + + +public class ToolMetadata implements ToXContentObject, Writeable { + + public static final String TOOL_NAME_FIELD = "name"; + public static final String TOOL_DESCRIPTION_FIELD = "description"; + public static final String TOOL_TYPE_FIELD = "type"; + public static final String TOOL_VERSION_FIELD = "version"; + + + @Getter + private String name; + @Getter + private String description; + @Getter + private String type; + @Getter + private String version; + + @Builder(toBuilder = true) + public ToolMetadata(String name, String description, String type, String version) { + this.name = name; + this.description = description; + this.type = type; + this.version = version; + } + + public ToolMetadata(StreamInput input) throws IOException { + name = input.readString(); + description = input.readString(); + type = input.readString(); + version = input.readOptionalString(); + } + + public void writeTo(StreamOutput output) throws IOException { + output.writeString(name); + output.writeString(description); + output.writeString(type); + output.writeOptionalString(version); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.startObject(); + if (name != null) { + builder.field(TOOL_NAME_FIELD, name); + } + if (description != null) { + builder.field(TOOL_DESCRIPTION_FIELD, description); + } + if (type != null) { + builder.field(TOOL_TYPE_FIELD, type); + } + builder.field(TOOL_VERSION_FIELD, version != null ? version : "undefined"); + builder.endObject(); + return builder; + } + + public static ToolMetadata parse(XContentParser parser) throws IOException { + String name = null; + String description = null; + String type = null; + String version = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case TOOL_NAME_FIELD: + name = parser.text(); + break; + case TOOL_DESCRIPTION_FIELD: + description = parser.text(); + break; + case TOOL_TYPE_FIELD: + type = parser.text(); + break; + case TOOL_VERSION_FIELD: + version = parser.text(); + default: + parser.skipChildren(); + break; + } + } + return ToolMetadata.builder() + .name(name) + .description(description) + .type(type) + .version(version) + .build(); + } + + public static ToolMetadata fromStream(StreamInput in) throws IOException { + ToolMetadata toolMetadata = new ToolMetadata(in); + return toolMetadata; + } +} \ No newline at end of file diff --git a/common/src/main/java/org/opensearch/ml/common/transport/tools/MLGetToolAction.java b/common/src/main/java/org/opensearch/ml/common/transport/tools/MLGetToolAction.java new file mode 100644 index 0000000000..468d53d34a --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/tools/MLGetToolAction.java @@ -0,0 +1,16 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.ml.common.transport.tools; + +import org.opensearch.action.ActionType; + +public class MLGetToolAction extends ActionType { + public static final MLGetToolAction INSTANCE = new MLGetToolAction(); + public static final String NAME = "cluster:admin/opensearch/ml/tools/get"; + + public MLGetToolAction() { + super(NAME, MLToolGetResponse::new); + } +} \ No newline at end of file diff --git a/common/src/main/java/org/opensearch/ml/common/transport/tools/MLListToolsAction.java b/common/src/main/java/org/opensearch/ml/common/transport/tools/MLListToolsAction.java new file mode 100644 index 0000000000..3ec6b4c99e --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/tools/MLListToolsAction.java @@ -0,0 +1,16 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.ml.common.transport.tools; + +import org.opensearch.action.ActionType; + +public class MLListToolsAction extends ActionType { + public static final MLListToolsAction INSTANCE = new MLListToolsAction(); + public static final String NAME = "cluster:admin/opensearch/ml/tools/list"; + + public MLListToolsAction() { + super(NAME, MLToolsListResponse::new); + } +} \ No newline at end of file diff --git a/common/src/main/java/org/opensearch/ml/common/transport/tools/MLToolGetRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/tools/MLToolGetRequest.java new file mode 100644 index 0000000000..e89e506fe3 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/tools/MLToolGetRequest.java @@ -0,0 +1,84 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.ml.common.transport.tools; + +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; +import lombok.experimental.FieldDefaults; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.ml.common.ToolMetadata; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.List; + +import static org.opensearch.action.ValidateActions.addValidationError; + +@Getter +@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) +@ToString +public class MLToolGetRequest extends ActionRequest { + + String toolName; + + List toolMetadataList; + + @Builder + public MLToolGetRequest(String toolName, List toolMetadataList) { + this.toolName = toolName; + this.toolMetadataList = toolMetadataList; + } + + public MLToolGetRequest(StreamInput in) throws IOException { + super(in); + this.toolName = in.readString(); + this.toolMetadataList = in.readList(ToolMetadata::new); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(this.toolName); + out.writeList(this.toolMetadataList); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException exception = null; + + if (this.toolName == null) { + exception = addValidationError("Tool name can't be null", exception); + } + + return exception; + } + + public static MLToolGetRequest fromActionRequest(ActionRequest actionRequest) { + if (actionRequest instanceof MLToolGetRequest) { + return (MLToolGetRequest)actionRequest; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); + OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionRequest.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLToolGetRequest(input); + } + } catch (IOException e) { + throw new UncheckedIOException("failed to parse ActionRequest into MLToolGetRequest", e); + } + } + + +} \ No newline at end of file diff --git a/common/src/main/java/org/opensearch/ml/common/transport/tools/MLToolGetResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/tools/MLToolGetResponse.java new file mode 100644 index 0000000000..d4623039c8 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/tools/MLToolGetResponse.java @@ -0,0 +1,67 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.tools; + +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.ToolMetadata; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + +@Getter +@ToString +public class MLToolGetResponse extends ActionResponse implements ToXContentObject { + + ToolMetadata toolMetadata; + + @Builder + public MLToolGetResponse(ToolMetadata toolMetadata) { + this.toolMetadata = toolMetadata; + } + + public MLToolGetResponse(StreamInput in) throws IOException { + super(in); + toolMetadata = toolMetadata.fromStream(in); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + toolMetadata.writeTo(out); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + return toolMetadata.toXContent(builder, params); + } + + public static MLToolGetResponse fromActionResponse(ActionResponse actionResponse) { + if (actionResponse instanceof MLToolGetResponse) { + return (MLToolGetResponse) actionResponse; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); + OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionResponse.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLToolGetResponse(input); + } + } catch (IOException e) { + throw new UncheckedIOException("failed to parse ActionResponse into MLToolGetResponse", e); + } + } +} \ No newline at end of file diff --git a/common/src/main/java/org/opensearch/ml/common/transport/tools/MLToolsListRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/tools/MLToolsListRequest.java new file mode 100644 index 0000000000..49575aaac2 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/tools/MLToolsListRequest.java @@ -0,0 +1,72 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.tools; + +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; +import lombok.experimental.FieldDefaults; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.ml.common.ToolMetadata; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.List; + +@Getter +@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) +@ToString +public class MLToolsListRequest extends ActionRequest { + + List toolMetadataList; + + @Builder + public MLToolsListRequest(List toolMetadataList) { + this.toolMetadataList = toolMetadataList; + } + + public MLToolsListRequest(StreamInput in) throws IOException { + super(in); + this.toolMetadataList = in.readList(ToolMetadata::new); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeList(this.toolMetadataList); + } + + @Override + public ActionRequestValidationException validate() { + return null; + } + + public static MLToolsListRequest fromActionRequest(ActionRequest actionRequest) { + if (actionRequest instanceof MLToolsListRequest) { + return (MLToolsListRequest)actionRequest; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); + OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionRequest.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLToolsListRequest(input); + } + } catch (IOException e) { + throw new UncheckedIOException("failed to parse ActionRequest into MLToolsListRequest", e); + } + } + +} \ No newline at end of file diff --git a/common/src/main/java/org/opensearch/ml/common/transport/tools/MLToolsListResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/tools/MLToolsListResponse.java new file mode 100644 index 0000000000..840981174e --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/tools/MLToolsListResponse.java @@ -0,0 +1,75 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.ml.common.transport.tools; + +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.ToolMetadata; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.List; + +@Getter +@ToString +public class MLToolsListResponse extends ActionResponse implements ToXContentObject { + + List toolMetadataList; + + @Builder + public MLToolsListResponse(List toolMetadata) { + this.toolMetadataList = toolMetadata; + } + public MLToolsListResponse(StreamInput in) throws IOException { + super(in); + this.toolMetadataList = in.readList(ToolMetadata::new); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeList(toolMetadataList); + } + + @Override + public XContentBuilder toXContent(XContentBuilder xContentBuilder, ToXContent.Params params) throws IOException { + for (ToolMetadata toolMetadata : toolMetadataList) { + xContentBuilder.startObject(); + xContentBuilder.field(ToolMetadata.TOOL_NAME_FIELD, toolMetadata.getName()); + xContentBuilder.field(ToolMetadata.TOOL_DESCRIPTION_FIELD, toolMetadata.getDescription()); + xContentBuilder.field(ToolMetadata.TOOL_TYPE_FIELD, toolMetadata.getType()); + xContentBuilder.field(ToolMetadata.TOOL_VERSION_FIELD, toolMetadata.getVersion() != null ? toolMetadata.getVersion() : "undefined"); + xContentBuilder.endObject(); + } + return xContentBuilder; + } + + public static MLToolsListResponse fromActionResponse(ActionResponse actionResponse) { + if (actionResponse instanceof MLToolsListResponse) { + return (MLToolsListResponse) actionResponse; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); + OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionResponse.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLToolsListResponse(input); + } + } + catch (IOException e) { + throw new UncheckedIOException("failed to parse ActionResponse into MLToolsListResponse", e); + } + } +} \ No newline at end of file diff --git a/common/src/test/java/org/opensearch/ml/common/ToolMetadataTests.java b/common/src/test/java/org/opensearch/ml/common/ToolMetadataTests.java new file mode 100644 index 0000000000..02234757b3 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/ToolMetadataTests.java @@ -0,0 +1,92 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.ml.common; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; + +import java.io.IOException; +import java.util.function.Function; + +import static org.junit.Assert.assertEquals; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; + +public class ToolMetadataTests { + ToolMetadata toolMetadata; + + Function function; + + @Before + public void setUp() { + toolMetadata = ToolMetadata.builder() + .name("MathTool") + .description("Use this tool to calculate any math problem.") + .type("MathTool") + .version("test") + .build(); + + function = parser -> { + try { + return ToolMetadata.parse(parser); + } catch (IOException e) { + throw new RuntimeException(e); + } + }; + } + + @Test + public void toXContent() throws IOException { + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + toolMetadata.toXContent(builder, EMPTY_PARAMS); + String toolMetadataString = TestHelper.xContentBuilderToString(builder); + assertEquals(toolMetadataString, "{\"name\":\"MathTool\",\"description\":\"Use this tool to calculate any math problem.\",\"type\":\"MathTool\",\"version\":\"test\"}"); + } + + @Test + public void toXContent_nullValue() throws IOException { + ToolMetadata emptyToolMetadata = ToolMetadata.builder().build(); + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + emptyToolMetadata.toXContent(builder, ToXContent.EMPTY_PARAMS); + String toolMetadataString = TestHelper.xContentBuilderToString(builder); + assertEquals("{\"version\":\"undefined\"}", toolMetadataString); + } + + @Test + public void parse() throws IOException { + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + toolMetadata.toXContent(builder, EMPTY_PARAMS); + String toolMetadataString = TestHelper.xContentBuilderToString(builder); + XContentParser parser = XContentType.JSON.xContent().createParser(NamedXContentRegistry.EMPTY, + LoggingDeprecationHandler.INSTANCE, toolMetadataString); + parser.nextToken(); + toolMetadata.equals(function.apply(parser)); + } + + + @Test + public void readInputStream_Success() throws IOException { + readInputStream(toolMetadata); + } + + private void readInputStream(ToolMetadata toolMetadata) throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + toolMetadata.writeTo(bytesStreamOutput); + + StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); + ToolMetadata parsedToolMetadata = new ToolMetadata(streamInput); + assertEquals(toolMetadata.getName(), parsedToolMetadata.getName()); + assertEquals(toolMetadata.getDescription(), parsedToolMetadata.getDescription()); + assertEquals(toolMetadata.getType(), parsedToolMetadata.getType()); + assertEquals(toolMetadata.getVersion(), parsedToolMetadata.getVersion()); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/tools/MLToolGetRequestTests.java b/common/src/test/java/org/opensearch/ml/common/transport/tools/MLToolGetRequestTests.java new file mode 100644 index 0000000000..6e62d99507 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/tools/MLToolGetRequestTests.java @@ -0,0 +1,98 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.ml.common.transport.tools; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.ml.common.ToolMetadata; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.ArrayList; +import java.util.List; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; + +public class MLToolGetRequestTests { + private List toolMetadataList; + + @Before + public void setUp() { + toolMetadataList = new ArrayList<>(); + ToolMetadata wikipediaTool = ToolMetadata.builder() + .name("MathTool") + .description("Use this tool to search general knowledge on wikipedia.") + .type("MathTool") + .version("test") + .build(); + toolMetadataList.add(wikipediaTool); + } + + @Test + public void writeTo_success() throws IOException { + MLToolGetRequest mlToolGetRequest = MLToolGetRequest.builder() + .toolName("MathTool") + .toolMetadataList(toolMetadataList) + .build(); + + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + mlToolGetRequest.writeTo(bytesStreamOutput); + MLToolGetRequest parsedToolMetadata = new MLToolGetRequest(bytesStreamOutput.bytes().streamInput()); + assertEquals(parsedToolMetadata.getToolName(), "MathTool"); + assertEquals(parsedToolMetadata.getToolMetadataList().get(0).getName(), toolMetadataList.get(0).getName()); + } + + @Test + public void fromActionRequest_success() { + MLToolGetRequest mlToolGetRequest = MLToolGetRequest.builder() + .toolName("MathTool") + .toolMetadataList(toolMetadataList) + .build(); + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + mlToolGetRequest.writeTo(out); + } + }; + MLToolGetRequest result = MLToolGetRequest.fromActionRequest(actionRequest); + assertNotSame(result, mlToolGetRequest); + assertEquals(result.getToolName(), "MathTool"); + assertEquals(result.getToolMetadataList().get(0).getName(), mlToolGetRequest.getToolMetadataList().get(0).getName()); + } + + @Test(expected = UncheckedIOException.class) + public void fromActionRequest_IOException() { + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IOException("test"); + } + }; + MLToolGetRequest.fromActionRequest(actionRequest); + } + + @Test + public void validate_Exception_NullToolName() { + MLToolGetRequest mlToolGetRequest = MLToolGetRequest.builder().build(); + ActionRequestValidationException exception = mlToolGetRequest.validate(); + assertEquals("Validation Failed: 1: Tool name can't be null;", exception.getMessage()); + + } +} \ No newline at end of file diff --git a/common/src/test/java/org/opensearch/ml/common/transport/tools/MLToolGetResponseTests.java b/common/src/test/java/org/opensearch/ml/common/transport/tools/MLToolGetResponseTests.java new file mode 100644 index 0000000000..6ec682dcc2 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/tools/MLToolGetResponseTests.java @@ -0,0 +1,92 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.ml.common.transport.tools; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.Strings; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.Strings; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.ToolMetadata; +import org.opensearch.ml.common.transport.model.MLModelGetResponse; + +import java.io.IOException; +import java.io.UncheckedIOException; + +import static org.junit.Assert.*; + +public class MLToolGetResponseTests { + ToolMetadata toolMetadata; + + MLToolGetResponse mlToolGetResponse; + + @Before + public void setUp() { + toolMetadata = ToolMetadata.builder() + .name("MathTool") + .description("Use this tool to calculate any math problem.") + .type("MathTool") + .version(null) + .build(); + + mlToolGetResponse = MLToolGetResponse.builder().toolMetadata(toolMetadata).build(); + } + + @Test + public void writeTo_success() throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + MLToolGetResponse response = MLToolGetResponse.builder().toolMetadata(toolMetadata).build(); + response.writeTo(bytesStreamOutput); + MLToolGetResponse parsedResponse = new MLToolGetResponse(bytesStreamOutput.bytes().streamInput()); + assertNotEquals(response.toolMetadata, parsedResponse.toolMetadata); + assertEquals(response.toolMetadata.getName(), parsedResponse.getToolMetadata().getName()); + } + + @Test + public void toXContentTest() throws IOException { + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + mlToolGetResponse.toXContent(builder, ToXContent.EMPTY_PARAMS); + assertNotNull(builder); + String jsonStr = builder.toString(); + assertEquals("{\"name\":\"MathTool\",\"description\":\"Use this tool to calculate any math problem.\",\"type\":\"MathTool\",\"version\":\"undefined\"}", jsonStr); + } + + @Test + public void fromActionResponseWithMLToolGetResponse_Success() { + MLToolGetResponse mlToolGetResponseFromActionResponse = MLToolGetResponse.fromActionResponse(mlToolGetResponse); + assertSame(mlToolGetResponse, mlToolGetResponseFromActionResponse); + assertEquals(mlToolGetResponse.getToolMetadata().getName(), mlToolGetResponseFromActionResponse.getToolMetadata().getName()); + } + + @Test + public void fromActionResponse_Success() { + ActionResponse actionResponse = new ActionResponse() { + @Override + public void writeTo(StreamOutput out) throws IOException { + mlToolGetResponse.writeTo(out); + } + }; + MLToolGetResponse mlToolGetResponseFromActionResponse = MLToolGetResponse.fromActionResponse(actionResponse); + assertEquals(mlToolGetResponse.getToolMetadata().getName(), mlToolGetResponseFromActionResponse.getToolMetadata().getName()); + } + + @Test(expected = UncheckedIOException.class) + public void fromActionResponse_IOException() { + ActionResponse actionResponse = new ActionResponse() { + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IOException(); + } + }; + MLToolGetResponse.fromActionResponse(actionResponse); + } +} \ No newline at end of file diff --git a/common/src/test/java/org/opensearch/ml/common/transport/tools/MLToolsListRequestTests.java b/common/src/test/java/org/opensearch/ml/common/transport/tools/MLToolsListRequestTests.java new file mode 100644 index 0000000000..8aedf99970 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/tools/MLToolsListRequestTests.java @@ -0,0 +1,111 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.tools; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.ml.common.ToolMetadata; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.ArrayList; +import java.util.List; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertNull; + +public class MLToolsListRequestTests { + private List toolMetadataList; + + @Before + public void setUp() { + toolMetadataList = new ArrayList<>(); + ToolMetadata wikipediaTool = ToolMetadata.builder() + .name("WikipediaTool") + .description("Use this tool to search general knowledge on wikipedia.") + .type("WikipediaTool") + .version(null) + .build(); + toolMetadataList.add(wikipediaTool); + } + @Test + public void writeTo_success() throws IOException { + + MLToolsListRequest mlToolsListRequest = MLToolsListRequest.builder() + .toolMetadataList(toolMetadataList) + .build(); + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + mlToolsListRequest.writeTo(bytesStreamOutput); + MLToolsListRequest parsedToolMetadata = new MLToolsListRequest(bytesStreamOutput.bytes().streamInput()); + assertEquals(parsedToolMetadata.getToolMetadataList().get(0).getName(), toolMetadataList.get(0).getName()); + assertEquals(parsedToolMetadata.getToolMetadataList().get(0).getDescription(), toolMetadataList.get(0).getDescription()); + } + + @Test + public void fromActionRequest_success() { + MLToolsListRequest mlToolsListRequest = MLToolsListRequest.builder().toolMetadataList(toolMetadataList).build(); + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + mlToolsListRequest.writeTo(out); + } + }; + MLToolsListRequest result = MLToolsListRequest.fromActionRequest(actionRequest); + assertNotSame(result, mlToolsListRequest); + assertEquals(result.getToolMetadataList().get(0).getName(), mlToolsListRequest.getToolMetadataList().get(0).getName()); + } + + @Test(expected = UncheckedIOException.class) + public void fromActionRequest_IOException() { + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IOException("test"); + } + }; + MLToolsListRequest.fromActionRequest(actionRequest); + } + + @Test + public void fromActionRequest_Success() { + MLToolsListRequest mlToolsListRequest = MLToolsListRequest.builder() + .toolMetadataList(toolMetadataList).build(); + + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput output) throws IOException { + mlToolsListRequest.writeTo(output); + } + }; + + MLToolsListRequest result = MLToolsListRequest.fromActionRequest(actionRequest); + assertNotSame(result, mlToolsListRequest); + assertEquals(result.getToolMetadataList().get(0).getName(), mlToolsListRequest.getToolMetadataList().get(0).getName()); + } + + @Test + public void testValidate() { + MLToolsListRequest request = MLToolsListRequest.builder().build(); + assertNull(request.validate()); + } +} \ No newline at end of file diff --git a/common/src/test/java/org/opensearch/ml/common/transport/tools/MLToolsListResponseTests.java b/common/src/test/java/org/opensearch/ml/common/transport/tools/MLToolsListResponseTests.java new file mode 100644 index 0000000000..aa41c53ffb --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/tools/MLToolsListResponseTests.java @@ -0,0 +1,102 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.tools; + +import org.junit.Before; +import org.junit.Test; +// import org.opensearch.common.Strings; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.ToolMetadata; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.ArrayList; +import java.util.List; + +import static org.junit.Assert.*; + +public class MLToolsListResponseTests { + List toolMetadataList; + + MLToolsListResponse mlToolsListResponse; + + @Before + public void setUp() { + toolMetadataList = new ArrayList<>(); + ToolMetadata searchWikipediaTool = ToolMetadata.builder() + .name("SearchWikipediaTool") + .description("Useful when you need to use this tool to search general knowledge on wikipedia.") + .type("SearchWikipediaTool") + .version(null) + .build(); + ToolMetadata toolMetadata = ToolMetadata.builder() + .name("MathTool") + .description("Use this tool to calculate any math problem.") + .type("MathTool") + .version("test") + .build(); + + toolMetadataList.add(searchWikipediaTool); + toolMetadataList.add(toolMetadata); + mlToolsListResponse = MLToolsListResponse.builder().toolMetadata(toolMetadataList).build(); + } + + @Test + public void writeTo_success() throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + MLToolsListResponse response = MLToolsListResponse.builder().toolMetadata(toolMetadataList).build(); + response.writeTo(bytesStreamOutput); + MLToolsListResponse parsedResponse = new MLToolsListResponse(bytesStreamOutput.bytes().streamInput()); + assertNotEquals(response.toolMetadataList, parsedResponse.toolMetadataList); + assertEquals(response.toolMetadataList.get(0).getName(), parsedResponse.toolMetadataList.get(0).getName()); + assertEquals(response.toolMetadataList.get(0).getDescription(), parsedResponse.toolMetadataList.get(0).getDescription()); + } + + @Test + public void toXContentTest() throws IOException { + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + mlToolsListResponse.toXContent(builder, ToXContent.EMPTY_PARAMS); + assertNotNull(builder); + String jsonStr = builder.toString(); + assertEquals("{\"name\":\"SearchWikipediaTool\",\"description\":\"Useful when you need to use this tool to search general knowledge on wikipedia.\",\"type\":\"SearchWikipediaTool\",\"version\":\"undefined\"} {\"name\":\"MathTool\",\"description\":\"Use this tool to calculate any math problem.\",\"type\":\"MathTool\",\"version\":\"test\"}", jsonStr); + } + + @Test + public void fromActionResponseWithMLToolsListResponse_Success() { + MLToolsListResponse mlToolsListResponseFromActionResponse = MLToolsListResponse.fromActionResponse(mlToolsListResponse); + assertSame(mlToolsListResponse, mlToolsListResponseFromActionResponse); + assertEquals(mlToolsListResponse.getToolMetadataList().get(0).getName(), mlToolsListResponseFromActionResponse.getToolMetadataList().get(0).getName()); + } + + @Test + public void fromActionResponse_Success() { + ActionResponse actionResponse = new ActionResponse() { + @Override + public void writeTo(StreamOutput out) throws IOException { + mlToolsListResponse.writeTo(out); + } + }; + MLToolsListResponse mlToolsListResponseFromActionResponse = MLToolsListResponse.fromActionResponse(actionResponse); + assertEquals(mlToolsListResponse.getToolMetadataList().get(0).getName(), mlToolsListResponseFromActionResponse.getToolMetadataList().get(0).getName()); + } + + @Test(expected = UncheckedIOException.class) + public void fromActionResponse_IOException() { + ActionResponse actionResponse = new ActionResponse() { + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IOException(); + } + }; + MLToolsListResponse.fromActionResponse(actionResponse); + } +} \ No newline at end of file diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java index a4a3982505..f048c62dc8 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java @@ -124,5 +124,15 @@ public AgentTool create(Map map) { public String getDefaultDescription() { return DEFAULT_DESCRIPTION; } + + @Override + public String getDefaultType() { + return TYPE; + } + + @Override + public String getDefaultVersion() { + return null; + } } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/CatIndexTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/CatIndexTool.java index 16cec3870d..78ebd2f3bd 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/CatIndexTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/CatIndexTool.java @@ -347,6 +347,16 @@ public CatIndexTool create(Map map) { public String getDefaultDescription() { return DEFAULT_DESCRIPTION; } + + @Override + public String getDefaultType() { + return TYPE; + } + + @Override + public String getDefaultVersion() { + return null; + } } private Table getTableWithHeader() { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java index 4b941e6333..cba0d2ee6a 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java @@ -139,5 +139,15 @@ public MLModelTool create(Map map) { public String getDefaultDescription() { return DEFAULT_DESCRIPTION; } + + @Override + public String getDefaultType() { + return TYPE; + } + + @Override + public String getDefaultVersion() { + return null; + } } } diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index fb469b3354..78309ce0a1 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -140,6 +140,8 @@ import org.opensearch.ml.common.transport.task.MLTaskDeleteAction; import org.opensearch.ml.common.transport.task.MLTaskGetAction; import org.opensearch.ml.common.transport.task.MLTaskSearchAction; +import org.opensearch.ml.common.transport.tools.MLGetToolAction; +import org.opensearch.ml.common.transport.tools.MLListToolsAction; import org.opensearch.ml.common.transport.training.MLTrainingTaskAction; import org.opensearch.ml.common.transport.trainpredict.MLTrainAndPredictionTaskAction; import org.opensearch.ml.common.transport.undeploy.MLUndeployModelAction; @@ -211,6 +213,8 @@ import org.opensearch.ml.rest.RestMLGetModelControllerAction; import org.opensearch.ml.rest.RestMLGetModelGroupAction; import org.opensearch.ml.rest.RestMLGetTaskAction; +import org.opensearch.ml.rest.RestMLGetToolAction; +import org.opensearch.ml.rest.RestMLListToolsAction; import org.opensearch.ml.rest.RestMLPredictionAction; import org.opensearch.ml.rest.RestMLProfileAction; import org.opensearch.ml.rest.RestMLRegisterAgentAction; @@ -257,6 +261,8 @@ import org.opensearch.ml.task.MLTaskManager; import org.opensearch.ml.task.MLTrainAndPredictTaskRunner; import org.opensearch.ml.task.MLTrainingTaskRunner; +import org.opensearch.ml.tools.GetToolTransportAction; +import org.opensearch.ml.tools.ListToolsTransportAction; import org.opensearch.ml.utils.IndexUtils; import org.opensearch.monitor.jvm.JvmService; import org.opensearch.monitor.os.OsService; @@ -394,7 +400,9 @@ public class MachineLearningPlugin extends Plugin implements ActionPlugin, Searc new ActionHandler<>(MLAgentDeleteAction.INSTANCE, DeleteAgentTransportAction.class), new ActionHandler<>(UpdateConversationAction.INSTANCE, UpdateConversationTransportAction.class), new ActionHandler<>(UpdateInteractionAction.INSTANCE, UpdateInteractionTransportAction.class), - new ActionHandler<>(GetTracesAction.INSTANCE, GetTracesTransportAction.class) + new ActionHandler<>(GetTracesAction.INSTANCE, GetTracesTransportAction.class), + new ActionHandler<>(MLListToolsAction.INSTANCE, ListToolsTransportAction.class), + new ActionHandler<>(MLGetToolAction.INSTANCE, GetToolTransportAction.class) ); } @@ -698,6 +706,8 @@ public List getRestHandlers( RestMemoryUpdateInteractionAction restMemoryUpdateInteractionAction = new RestMemoryUpdateInteractionAction(); RestMemoryGetTracesAction restMemoryGetTracesAction = new RestMemoryGetTracesAction(); RestMLSearchAgentAction restMLSearchAgentAction = new RestMLSearchAgentAction(); + RestMLListToolsAction restMLListToolsAction = new RestMLListToolsAction(toolFactories); + RestMLGetToolAction restMLGetToolAction = new RestMLGetToolAction(toolFactories); return ImmutableList .of( restMLStatsAction, @@ -747,7 +757,9 @@ public List getRestHandlers( restMemoryUpdateConversationAction, restMemoryUpdateInteractionAction, restMemoryGetTracesAction, - restMLSearchAgentAction + restMLSearchAgentAction, + restMLListToolsAction, + restMLGetToolAction ); } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetToolAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetToolAction.java new file mode 100644 index 0000000000..6fe5a20328 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetToolAction.java @@ -0,0 +1,88 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.ml.rest; + +import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; +import static org.opensearch.ml.utils.RestActionUtils.*; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; +import java.util.Map; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.ml.common.ToolMetadata; +import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.ml.common.transport.tools.MLGetToolAction; +import org.opensearch.ml.common.transport.tools.MLToolGetRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; + +public class RestMLGetToolAction extends BaseRestHandler { + + private static final String ML_GET_TOOL_ACTION = "ml_get_tool_action"; + + private Map toolFactories; + + public RestMLGetToolAction(Map toolFactories) { + this.toolFactories = toolFactories; + } + + @Override + public String getName() { + return ML_GET_TOOL_ACTION; + } + + @Override + public List routes() { + return ImmutableList + .of(new Route(RestRequest.Method.GET, String.format(Locale.ROOT, "%s/tools/{%s}", ML_BASE_URI, PARAMETER_TOOL_NAME))); + } + + /** + * Prepare the request for execution. Implementations should consume all request params before + * returning the runnable for actual execution. Unconsumed params will immediately terminate + * execution of the request. However, some params are only used in processing the response; + * implementations can override {@link BaseRestHandler#responseParams()} to indicate such + * params. + * + * @param request the request to execute + * @param client client for executing actions on the local node + * @return the action to execute + * @throws IOException if an I/O exception occurred parsing the request and preparing for + * execution + */ + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + MLToolGetRequest mlToolGetRequest = getRequest(request); + return channel -> client.execute(MLGetToolAction.INSTANCE, mlToolGetRequest, new RestToXContentListener<>(channel)); + } + + @VisibleForTesting + MLToolGetRequest getRequest(RestRequest request) throws IOException { + List toolList = new ArrayList<>(); + toolFactories + .forEach( + (key, value) -> toolList + .add( + ToolMetadata + .builder() + .name(key) + .description(value.getDefaultDescription()) + .type(value.getDefaultType()) + .version(value.getDefaultVersion()) + .build() + ) + ); + String toolName = getParameterId(request, PARAMETER_TOOL_NAME); + MLToolGetRequest mlToolGetRequest = MLToolGetRequest.builder().toolName(toolName).toolMetadataList(toolList).build(); + return mlToolGetRequest; + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLListToolsAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLListToolsAction.java new file mode 100644 index 0000000000..f79da82b45 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLListToolsAction.java @@ -0,0 +1,85 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; +import java.util.Map; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.ml.common.ToolMetadata; +import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.ml.common.transport.tools.MLListToolsAction; +import org.opensearch.ml.common.transport.tools.MLToolsListRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; + +public class RestMLListToolsAction extends BaseRestHandler { + private static final String ML_LIST_TOOLS_ACTION = "ml_list_tools_action"; + + private Map toolFactories; + + public RestMLListToolsAction(Map toolFactories) { + this.toolFactories = toolFactories; + } + + @Override + public String getName() { + return ML_LIST_TOOLS_ACTION; + } + + @Override + public List routes() { + return ImmutableList.of(new Route(RestRequest.Method.GET, String.format(Locale.ROOT, "%s/tools", ML_BASE_URI))); + } + + /** + * Prepare the request for execution. Implementations should consume all request params before + * returning the runnable for actual execution. Unconsumed params will immediately terminate + * execution of the request. However, some params are only used in processing the response; + * implementations can override {@link BaseRestHandler#responseParams()} to indicate such + * params. + * + * @param request the request to execute + * @param client client for executing actions on the local node + * @return the action to execute + * @throws IOException if an I/O exception occurred parsing the request and preparing for + * execution + */ + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + MLToolsListRequest mlToolsListRequest = getRequest(request); + return channel -> client.execute(MLListToolsAction.INSTANCE, mlToolsListRequest, new RestToXContentListener<>(channel)); + } + + @VisibleForTesting + MLToolsListRequest getRequest(RestRequest request) throws IOException { + List toolList = new ArrayList<>(); + toolFactories + .forEach( + (key, value) -> toolList + .add( + ToolMetadata + .builder() + .name(key) + .description(value.getDefaultDescription()) + .type(value.getDefaultType()) + .version(value.getDefaultVersion()) + .build() + ) + ); + MLToolsListRequest mlToolsGetRequest = MLToolsListRequest.builder().toolMetadataList(toolList).build(); + return mlToolsGetRequest; + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/tools/GetToolTransportAction.java b/plugin/src/main/java/org/opensearch/ml/tools/GetToolTransportAction.java new file mode 100644 index 0000000000..39a400007d --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/tools/GetToolTransportAction.java @@ -0,0 +1,56 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.tools; + +import java.util.List; +import java.util.NoSuchElementException; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.common.inject.Inject; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.ToolMetadata; +import org.opensearch.ml.common.transport.tools.*; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +import lombok.AccessLevel; +import lombok.experimental.FieldDefaults; +import lombok.extern.log4j.Log4j2; + +@Log4j2 +@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) +public class GetToolTransportAction extends HandledTransportAction { + @Inject + public GetToolTransportAction(TransportService transportService, ActionFilters actionFilters) { + super(MLGetToolAction.NAME, transportService, actionFilters, MLToolGetRequest::new); + } + + /** + * @param task the Task + * @param request the MLToolGetRequest request + * @param listener action listener + */ + @Override + protected void doExecute(Task task, ActionRequest request, ActionListener listener) { + MLToolGetRequest mlToolGetRequest = MLToolGetRequest.fromActionRequest(request); + String toolName = mlToolGetRequest.getToolName(); + try { + List toolsList = mlToolGetRequest.getToolMetadataList(); + ToolMetadata theTool = toolsList + .stream() + .filter(tool -> tool.getName().equals(toolName)) + .findFirst() + .orElseThrow(NoSuchElementException::new); + listener.onResponse(MLToolGetResponse.builder().toolMetadata(theTool).build()); + } catch (Exception e) { + log.error("Failed to get tool", e); + listener.onFailure(e); + } + + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/tools/ListToolsTransportAction.java b/plugin/src/main/java/org/opensearch/ml/tools/ListToolsTransportAction.java new file mode 100644 index 0000000000..9ff6de0978 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/tools/ListToolsTransportAction.java @@ -0,0 +1,47 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.tools; + +import java.util.List; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.common.inject.Inject; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.ToolMetadata; +import org.opensearch.ml.common.transport.tools.MLListToolsAction; +import org.opensearch.ml.common.transport.tools.MLToolsListRequest; +import org.opensearch.ml.common.transport.tools.MLToolsListResponse; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +import lombok.AccessLevel; +import lombok.experimental.FieldDefaults; +import lombok.extern.log4j.Log4j2; + +@Log4j2 +@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) +public class ListToolsTransportAction extends HandledTransportAction { + @Inject + public ListToolsTransportAction(TransportService transportService, ActionFilters actionFilters) { + super(MLListToolsAction.NAME, transportService, actionFilters, MLToolsListRequest::new); + } + + @Override + protected void doExecute(Task task, ActionRequest request, ActionListener listener) { + MLToolsListRequest mlToolsGetRequest = MLToolsListRequest.fromActionRequest(request); + + List toolsList = mlToolsGetRequest.getToolMetadataList(); + + try { + listener.onResponse(MLToolsListResponse.builder().toolMetadata(toolsList).build()); + } catch (Exception e) { + log.error("Failed to get tools list", e); + listener.onFailure(e); + } + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/plugin/DummyWrongTool.java b/plugin/src/test/java/org/opensearch/ml/plugin/DummyWrongTool.java index 09efe7537d..2f21c532a7 100644 --- a/plugin/src/test/java/org/opensearch/ml/plugin/DummyWrongTool.java +++ b/plugin/src/test/java/org/opensearch/ml/plugin/DummyWrongTool.java @@ -94,5 +94,15 @@ public DummyWrongTool create(Map map) { public String getDefaultDescription() { return DEFAULT_DESCRIPTION; } + + @Override + public String getDefaultType() { + return TYPE; + } + + @Override + public String getDefaultVersion() { + return null; + } } } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetToolActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetToolActionTests.java new file mode 100644 index 0000000000..137df48c7a --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetToolActionTests.java @@ -0,0 +1,118 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.ml.rest; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_TOOL_NAME; + +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.rules.ExpectedException; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.opensearch.client.node.NodeClient; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.Strings; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.ml.common.transport.tools.MLGetToolAction; +import org.opensearch.ml.common.transport.tools.MLToolGetRequest; +import org.opensearch.ml.common.transport.tools.MLToolGetResponse; +import org.opensearch.ml.engine.tools.CatIndexTool; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; + +public class RestMLGetToolActionTests extends OpenSearchTestCase { + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + @Mock + private RestChannel channel; + + private RestMLGetToolAction restMLGetToolAction; + private NodeClient nodeClient; + private ThreadPool threadPool; + private Map toolFactories = new HashMap<>(); + private Tool.Factory mockFactory = Mockito.mock(Tool.Factory.class); + + @Before + public void setup() { + Mockito.when(mockFactory.getDefaultDescription()).thenReturn("Mocked Description"); + Mockito.when(mockFactory.getDefaultType()).thenReturn("Mocked type"); + Mockito.when(mockFactory.getDefaultVersion()).thenReturn("Mocked version"); + + Tool tool = CatIndexTool.Factory.getInstance().create(Collections.emptyMap()); + Mockito.when(mockFactory.create(Mockito.any())).thenReturn(tool); + toolFactories.put("mockTool", mockFactory); + + restMLGetToolAction = new RestMLGetToolAction(toolFactories); + + threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); + nodeClient = spy(new NodeClient(Settings.EMPTY, threadPool)); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + return null; + }).when(nodeClient).execute(eq(MLGetToolAction.INSTANCE), any(), any()); + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + threadPool.shutdown(); + nodeClient.close(); + } + + public void testGetName() { + String actionName = restMLGetToolAction.getName(); + assertFalse(Strings.isNullOrEmpty(actionName)); + assertEquals("ml_get_tool_action", actionName); + } + + public void testRoutes() { + List routes = restMLGetToolAction.routes(); + assertNotNull(routes); + assertFalse(routes.isEmpty()); + RestHandler.Route route = routes.get(0); + assertEquals(RestRequest.Method.GET, route.getMethod()); + assertEquals("/_plugins/_ml/tools/{tool_name}", route.getPath()); + } + + public void test_PrepareRequest() throws Exception { + RestRequest request = getRestRequest(); + restMLGetToolAction.handleRequest(request, channel, nodeClient); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLToolGetRequest.class); + verify(nodeClient, times(1)).execute(eq(MLGetToolAction.INSTANCE), argumentCaptor.capture(), any()); + String name = argumentCaptor.getValue().getToolName(); + assertEquals(name, "name"); + } + + private RestRequest getRestRequest() { + Map params = new HashMap<>(); + params.put(PARAMETER_TOOL_NAME, "name"); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); + return request; + } + +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLListToolsActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLListToolsActionTests.java new file mode 100644 index 0000000000..7ce9f8dedf --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLListToolsActionTests.java @@ -0,0 +1,113 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.ml.rest; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.rules.ExpectedException; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.opensearch.client.node.NodeClient; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.Strings; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.ml.common.transport.tools.MLListToolsAction; +import org.opensearch.ml.common.transport.tools.MLToolsListRequest; +import org.opensearch.ml.common.transport.tools.MLToolsListResponse; +import org.opensearch.ml.engine.tools.CatIndexTool; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; + +public class RestMLListToolsActionTests extends OpenSearchTestCase { + @Rule + public ExpectedException thrown = ExpectedException.none(); + + @Mock + private RestChannel channel; + + private RestMLListToolsAction restMLListToolsAction; + private NodeClient nodeClient; + private ThreadPool threadPool; + private Map toolFactories = new HashMap<>(); + private Tool.Factory mockFactory = Mockito.mock(Tool.Factory.class); + + @Before + public void setup() { + Mockito.when(mockFactory.getDefaultDescription()).thenReturn("Mocked Description"); + Mockito.when(mockFactory.getDefaultType()).thenReturn("Mocked type"); + Mockito.when(mockFactory.getDefaultVersion()).thenReturn("Mocked version"); + + Tool tool = CatIndexTool.Factory.getInstance().create(Collections.emptyMap()); + Mockito.when(mockFactory.create(Mockito.any())).thenReturn(tool); + toolFactories.put("mockTool", mockFactory); + restMLListToolsAction = new RestMLListToolsAction(toolFactories); + + threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); + nodeClient = spy(new NodeClient(Settings.EMPTY, threadPool)); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + return null; + }).when(nodeClient).execute(eq(MLListToolsAction.INSTANCE), any(), any()); + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + threadPool.shutdown(); + nodeClient.close(); + } + + public void testGetName() { + String actionName = restMLListToolsAction.getName(); + assertFalse(Strings.isNullOrEmpty(actionName)); + assertEquals("ml_list_tools_action", actionName); + } + + public void testRoutes() { + List routes = restMLListToolsAction.routes(); + assertNotNull(routes); + assertFalse(routes.isEmpty()); + RestHandler.Route route = routes.get(0); + assertEquals(RestRequest.Method.GET, route.getMethod()); + assertEquals("/_plugins/_ml/tools", route.getPath()); + } + + public void test_PrepareRequest() throws Exception { + RestRequest request = getRestRequest(); + restMLListToolsAction.handleRequest(request, channel, nodeClient); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLToolsListRequest.class); + verify(nodeClient, times(1)).execute(eq(MLListToolsAction.INSTANCE), argumentCaptor.capture(), any()); + String name = argumentCaptor.getValue().getToolMetadataList().get(0).getName(); + assertEquals(name, "mockTool"); + } + + private RestRequest getRestRequest() { + Map params = new HashMap<>(); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).build(); + return request; + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/tools/GetToolTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/tools/GetToolTransportActionTests.java new file mode 100644 index 0000000000..b80727f01b --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/tools/GetToolTransportActionTests.java @@ -0,0 +1,73 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.tools; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import org.junit.Before; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.ToolMetadata; +import org.opensearch.ml.common.transport.tools.MLToolGetRequest; +import org.opensearch.ml.common.transport.tools.MLToolGetResponse; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.transport.TransportService; + +public class GetToolTransportActionTests extends OpenSearchTestCase { + @Mock + TransportService transportService; + + @Mock + ActionFilters actionFilters; + + @Mock + ActionListener actionListener; + + GetToolTransportAction getToolTransportAction; + MLToolGetRequest mlToolGetRequest; + private List toolMetadataList; + private RuntimeException exceptionToThrow; + + @Before + public void setup() throws IOException { + MockitoAnnotations.openMocks(this); + toolMetadataList = new ArrayList<>(); + ToolMetadata wikipediaTool = ToolMetadata + .builder() + .name("WikipediaTool") + .description("Use this tool to search general knowledge on wikipedia.") + .build(); + toolMetadataList.add(wikipediaTool); + mlToolGetRequest = MLToolGetRequest.builder().toolMetadataList(toolMetadataList).toolName("WikipediaTool").build(); + exceptionToThrow = new RuntimeException("Failed to get tool"); + + getToolTransportAction = spy(new GetToolTransportAction(transportService, actionFilters)); + } + + public void testGetTool_Success() { + getToolTransportAction.doExecute(null, mlToolGetRequest, actionListener); + verify(actionListener, times(1)).onResponse(any()); + } + + public void testListTools_Failure() { + doThrow(exceptionToThrow).when(actionListener).onResponse(any(MLToolGetResponse.class)); + + getToolTransportAction.doExecute(null, mlToolGetRequest, actionListener); + + verify(actionListener, times(1)).onFailure(exceptionToThrow); + } + +} diff --git a/plugin/src/test/java/org/opensearch/ml/tools/ListToolsTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/tools/ListToolsTransportActionTests.java new file mode 100644 index 0000000000..c25ee54adc --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/tools/ListToolsTransportActionTests.java @@ -0,0 +1,77 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.tools; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import org.junit.Before; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.ToolMetadata; +import org.opensearch.ml.common.transport.tools.MLToolsListRequest; +import org.opensearch.ml.common.transport.tools.MLToolsListResponse; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.transport.TransportService; + +public class ListToolsTransportActionTests extends OpenSearchTestCase { + @Mock + TransportService transportService; + + @Mock + ActionFilters actionFilters; + + @Mock + ActionListener actionListener; + + ListToolsTransportAction listToolsTransportAction; + MLToolsListRequest mlToolsListRequest; + private List toolMetadataList; + + private RuntimeException exceptionToThrow; + + @Before + public void setup() throws IOException { + MockitoAnnotations.openMocks(this); + toolMetadataList = new ArrayList<>(); + ToolMetadata wikipediaTool = ToolMetadata + .builder() + .name("WikipediaTool") + .description("Use this tool to search general knowledge on wikipedia.") + .type("forTestingPurpose") + .version("test") + .build(); + toolMetadataList.add(wikipediaTool); + mlToolsListRequest = MLToolsListRequest.builder().toolMetadataList(toolMetadataList).build(); + + exceptionToThrow = new RuntimeException("Failed to get tools list"); + + listToolsTransportAction = spy(new ListToolsTransportAction(transportService, actionFilters)); + } + + public void testListTools_Success() { + listToolsTransportAction.doExecute(null, mlToolsListRequest, actionListener); + verify(actionListener, times(1)).onResponse(any()); + } + + public void testListTools_Failure() { + doThrow(exceptionToThrow).when(actionListener).onResponse(any(MLToolsListResponse.class)); + + listToolsTransportAction.doExecute(null, mlToolsListRequest, actionListener); + + verify(actionListener, times(1)).onFailure(exceptionToThrow); + } + +} diff --git a/spi/src/main/java/org/opensearch/ml/common/spi/tools/Tool.java b/spi/src/main/java/org/opensearch/ml/common/spi/tools/Tool.java index c271ce2050..ce30d384aa 100644 --- a/spi/src/main/java/org/opensearch/ml/common/spi/tools/Tool.java +++ b/spi/src/main/java/org/opensearch/ml/common/spi/tools/Tool.java @@ -115,5 +115,17 @@ interface Factory { * @return the default description */ String getDefaultDescription(); + + /** + * Get the default type of this tool. + * @return the default tool type + */ + String getDefaultType(); + + /** + * Get the default version of this tool + * @return the default tool version + */ + String getDefaultVersion(); } }