From 753480ab03ba7d0d64ce68d1cde8534c01245a03 Mon Sep 17 00:00:00 2001 From: Dhrubo Saha Date: Fri, 15 Dec 2023 21:39:00 -0700 Subject: [PATCH] adding mlmodeltool and agent tool with tests (#1768) * adding mlmodeltool and agent tool with tests Signed-off-by: Dhrubo Saha * updating tests Signed-off-by: Dhrubo Saha * removed connector Signed-off-by: Dhrubo Saha --------- Signed-off-by: Dhrubo Saha --- .../opensearch/ml/common/FunctionName.java | 3 +- .../input/execute/agent/AgentMLInput.java | 75 +++++++++ .../prediction/MLPredictionTaskRequest.java | 4 + .../execute/agent/AgentMLInputTests.java | 112 ++++++++++++++ .../MLPredictionTaskRequestTest.java | 13 +- .../opensearch/ml/engine/tools/AgentTool.java | 128 ++++++++++++++++ .../ml/engine/tools/MLModelTool.java | 143 ++++++++++++++++++ .../ml/engine/tools/AgentToolTests.java | 127 ++++++++++++++++ .../ml/engine/tools/MLModelToolTests.java | 127 ++++++++++++++++ 9 files changed, 730 insertions(+), 2 deletions(-) create mode 100644 common/src/main/java/org/opensearch/ml/common/input/execute/agent/AgentMLInput.java create mode 100644 common/src/test/java/org/opensearch/ml/common/input/execute/agent/AgentMLInputTests.java create mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java create mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java create mode 100644 ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/AgentToolTests.java create mode 100644 ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/MLModelToolTests.java diff --git a/common/src/main/java/org/opensearch/ml/common/FunctionName.java b/common/src/main/java/org/opensearch/ml/common/FunctionName.java index 72810459a4..6eff55156d 100644 --- a/common/src/main/java/org/opensearch/ml/common/FunctionName.java +++ b/common/src/main/java/org/opensearch/ml/common/FunctionName.java @@ -24,7 +24,8 @@ public enum FunctionName { SPARSE_ENCODING, SPARSE_TOKENIZE, METRICS_CORRELATION, - REMOTE; + REMOTE, + AGENT; public static FunctionName from(String value) { try { diff --git a/common/src/main/java/org/opensearch/ml/common/input/execute/agent/AgentMLInput.java b/common/src/main/java/org/opensearch/ml/common/input/execute/agent/AgentMLInput.java new file mode 100644 index 0000000000..3aa3ac382b --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/input/execute/agent/AgentMLInput.java @@ -0,0 +1,75 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.input.execute.agent; + +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.dataset.MLInputDataset; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.utils.StringUtils; + +import java.io.IOException; +import java.util.Map; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + + +@org.opensearch.ml.common.annotation.MLInput(functionNames = {FunctionName.AGENT}) +public class AgentMLInput extends MLInput { + public static final String AGENT_ID_FIELD = "agent_id"; + public static final String PARAMETERS_FIELD = "parameters"; + + @Getter @Setter + private String agentId; + + @Builder(builderMethodName = "AgentMLInputBuilder") + public AgentMLInput(String agentId, FunctionName functionName, MLInputDataset inputDataset) { + this.agentId = agentId; + this.algorithm = functionName; + this.inputDataset = inputDataset; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(agentId); + } + + public AgentMLInput(StreamInput in) throws IOException { + super(in); + this.agentId = in.readString(); + } + + public AgentMLInput(XContentParser parser, FunctionName functionName) throws IOException { + super(); + this.algorithm = functionName; + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case AGENT_ID_FIELD: + agentId = parser.text(); + break; + case PARAMETERS_FIELD: + Map parameters = StringUtils.getParameterMap(parser.map()); + inputDataset = new RemoteInferenceInputDataSet(parameters); + break; + default: + parser.skipChildren(); + break; + } + } + } + +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/prediction/MLPredictionTaskRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/prediction/MLPredictionTaskRequest.java index 963892215f..8060b1c6af 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/prediction/MLPredictionTaskRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/prediction/MLPredictionTaskRequest.java @@ -47,6 +47,10 @@ public MLPredictionTaskRequest(String modelId, MLInput mlInput, boolean dispatch this.user = user; } + public MLPredictionTaskRequest(String modelId, MLInput mlInput) { + this(modelId, mlInput, true, null); + } + public MLPredictionTaskRequest(String modelId, MLInput mlInput, User user) { this(modelId, mlInput, true, user); } diff --git a/common/src/test/java/org/opensearch/ml/common/input/execute/agent/AgentMLInputTests.java b/common/src/test/java/org/opensearch/ml/common/input/execute/agent/AgentMLInputTests.java new file mode 100644 index 0000000000..36235adffe --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/input/execute/agent/AgentMLInputTests.java @@ -0,0 +1,112 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.input.execute.agent; + +import org.junit.Test; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.dataset.MLInputDataset; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class AgentMLInputTests { + + @Test + public void testConstructorWithAgentIdFunctionNameAndDataset() { + // Arrange + String agentId = "testAgentId"; + FunctionName functionName = FunctionName.AGENT; // Assuming FunctionName is an enum or similar + MLInputDataset dataset = mock(MLInputDataset.class); // Mock the MLInputDataset + + // Act + AgentMLInput input = new AgentMLInput(agentId, functionName, dataset); + + // Assert + assertEquals(agentId, input.getAgentId()); + assertEquals(functionName, input.getAlgorithm()); + assertEquals(dataset, input.getInputDataset()); + } + + @Test + public void testWriteTo() throws IOException { + // Arrange + String agentId = "testAgentId"; + AgentMLInput input = new AgentMLInput(agentId, FunctionName.AGENT, null); + StreamOutput out = mock(StreamOutput.class); + + // Act + input.writeTo(out); + + // Assert + verify(out).writeString(agentId); + } + + @Test + public void testConstructorWithStreamInput() throws IOException { + // Arrange + String agentId = "testAgentId"; + StreamInput in = mock(StreamInput.class); + when(in.readString()).thenReturn(agentId); + + // Act + AgentMLInput input = new AgentMLInput(in); + + // Assert + assertEquals(agentId, input.getAgentId()); + } + + @Test + public void testConstructorWithXContentParser() throws IOException { + // Arrange + XContentParser parser = mock(XContentParser.class); + + // Simulate parser behavior for START_OBJECT token + when(parser.currentToken()).thenReturn(XContentParser.Token.START_OBJECT); + when(parser.nextToken()).thenReturn(XContentParser.Token.FIELD_NAME) + .thenReturn(XContentParser.Token.VALUE_STRING) + .thenReturn(XContentParser.Token.FIELD_NAME) // For PARAMETERS_FIELD + .thenReturn(XContentParser.Token.START_OBJECT) // Start of PARAMETERS_FIELD map + .thenReturn(XContentParser.Token.FIELD_NAME) // Key in PARAMETERS_FIELD map + .thenReturn(XContentParser.Token.VALUE_STRING) // Value in PARAMETERS_FIELD map + .thenReturn(XContentParser.Token.END_OBJECT) // End of PARAMETERS_FIELD map + .thenReturn(XContentParser.Token.END_OBJECT); // End of the main object + + // Simulate parser behavior for agent_id + when(parser.currentName()).thenReturn("agent_id") + .thenReturn("parameters") + .thenReturn("paramKey"); + when(parser.text()).thenReturn("testAgentId") + .thenReturn("paramValue"); + + // Simulate parser behavior for parameters + Map paramMap = new HashMap<>(); + paramMap.put("paramKey", "paramValue"); + when(parser.map()).thenReturn(paramMap); + + // Act + AgentMLInput input = new AgentMLInput(parser, FunctionName.AGENT); + + // Assert + assertEquals("testAgentId", input.getAgentId()); + assertNotNull(input.getInputDataset()); + assertTrue(input.getInputDataset() instanceof RemoteInferenceInputDataSet); + // Additional assertions for RemoteInferenceInputDataSet + RemoteInferenceInputDataSet dataset = (RemoteInferenceInputDataSet) input.getInputDataset(); + assertEquals("paramValue", dataset.getParameters().get("paramKey")); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/prediction/MLPredictionTaskRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/prediction/MLPredictionTaskRequestTest.java index ce96aa56c1..b9cbe7d700 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/prediction/MLPredictionTaskRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/prediction/MLPredictionTaskRequestTest.java @@ -16,6 +16,7 @@ import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.commons.authuser.User; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.ml.common.dataframe.ColumnType; @@ -53,9 +54,11 @@ public void setUp() { @Test public void writeTo_Success() throws IOException { + User user = User.parse("admin|role-1|all_access"); MLPredictionTaskRequest request = MLPredictionTaskRequest.builder() .mlInput(mlInput) + .user(user) .build(); BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); request.writeTo(bytesStreamOutput); @@ -73,13 +76,18 @@ public void writeTo_Success() throws IOException { assertEquals(1, dataFrame.getRow(0).size()); assertEquals(2.00, dataFrame.getRow(0).getValue(0).getValue()); + User userExpect = request.getUser(); + assertEquals(user.getName(), userExpect.getName()); + assertNull(request.getModelId()); } @Test public void validate_Success() { + User user = User.parse("admin|role-1|all_access"); MLPredictionTaskRequest request = MLPredictionTaskRequest.builder() .mlInput(mlInput) + .user(user) .build(); assertNull(request.validate()); @@ -133,8 +141,10 @@ public void fromActionRequest_Success_WithNonMLPredictionTaskRequest_SearchQuery } private void fromActionRequest_Success_WithNonMLPredictionTaskRequest(MLInput mlInput) { + User user = User.parse("admin|role-1|all_access"); MLPredictionTaskRequest request = MLPredictionTaskRequest.builder() .mlInput(mlInput) + .user(user) .build(); ActionRequest actionRequest = new ActionRequest() { @Override @@ -151,6 +161,7 @@ public void writeTo(StreamOutput out) throws IOException { assertNotSame(result, request); assertEquals(request.getMlInput().getAlgorithm(), result.getMlInput().getAlgorithm()); assertEquals(request.getMlInput().getInputDataset().getInputDataType(), result.getMlInput().getInputDataset().getInputDataType()); + assertEquals(request.getUser().getName(), request.getUser().getName()); } @Test(expected = UncheckedIOException.class) @@ -168,4 +179,4 @@ public void writeTo(StreamOutput out) throws IOException { }; MLPredictionTaskRequest.fromActionRequest(actionRequest); } -} \ No newline at end of file +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java new file mode 100644 index 0000000000..a4a3982505 --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java @@ -0,0 +1,128 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.tools; + +import java.util.Map; + +import org.opensearch.action.ActionRequest; +import org.opensearch.client.Client; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.execute.agent.AgentMLInput; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.ml.common.spi.tools.ToolAnnotation; +import org.opensearch.ml.common.transport.execute.MLExecuteTaskAction; +import org.opensearch.ml.common.transport.execute.MLExecuteTaskRequest; +import org.opensearch.ml.repackage.com.google.common.annotations.VisibleForTesting; + +import lombok.Getter; +import lombok.Setter; +import lombok.extern.log4j.Log4j2; + +/** + * This tool supports running any Agent. + */ +@Log4j2 +@ToolAnnotation(AgentTool.TYPE) +public class AgentTool implements Tool { + public static final String TYPE = "AgentTool"; + private final Client client; + + private String agentId; + @Setter + @Getter + private String name = TYPE; + + @VisibleForTesting + static String DEFAULT_DESCRIPTION = "Use this tool to run any agent."; + @Getter + @Setter + private String description = DEFAULT_DESCRIPTION; + + public AgentTool(Client client, String agentId) { + this.client = client; + this.agentId = agentId; + } + + @Override + public void run(Map parameters, ActionListener listener) { + AgentMLInput agentMLInput = AgentMLInput + .AgentMLInputBuilder() + .agentId(agentId) + .functionName(FunctionName.AGENT) + .inputDataset(RemoteInferenceInputDataSet.builder().parameters(parameters).build()) + .build(); + ActionRequest request = new MLExecuteTaskRequest(FunctionName.AGENT, agentMLInput, false); + client.execute(MLExecuteTaskAction.INSTANCE, request, ActionListener.wrap(r -> { + ModelTensorOutput output = (ModelTensorOutput) r.getOutput(); + listener.onResponse((T) output); + }, e -> { + log.error("Failed to run agent " + agentId, e); + listener.onFailure(e); + })); + + } + + @Override + public String getType() { + return TYPE; + } + + @Override + public String getVersion() { + return null; + } + + @Override + public String getName() { + return this.name; + } + + @Override + public void setName(String s) { + this.name = s; + } + + @Override + public boolean validate(Map parameters) { + return true; + } + + public static class Factory implements Tool.Factory { + private Client client; + + private static Factory INSTANCE; + + public static Factory getInstance() { + if (INSTANCE != null) { + return INSTANCE; + } + synchronized (AgentTool.class) { + if (INSTANCE != null) { + return INSTANCE; + } + INSTANCE = new Factory(); + return INSTANCE; + } + } + + public void init(Client client) { + this.client = client; + } + + @Override + public AgentTool create(Map map) { + return new AgentTool(client, (String) map.get("agent_id")); + } + + @Override + public String getDefaultDescription() { + return DEFAULT_DESCRIPTION; + } + } +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java new file mode 100644 index 0000000000..4b941e6333 --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java @@ -0,0 +1,143 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.tools; + +import java.util.List; +import java.util.Map; + +import org.opensearch.action.ActionRequest; +import org.opensearch.client.Client; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.spi.tools.Parser; +import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.ml.common.spi.tools.ToolAnnotation; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; +import org.opensearch.ml.repackage.com.google.common.annotations.VisibleForTesting; + +import lombok.Getter; +import lombok.Setter; +import lombok.extern.log4j.Log4j2; + +/** + * This tool supports running any ml-commons model. + */ +@Log4j2 +@ToolAnnotation(MLModelTool.TYPE) +public class MLModelTool implements Tool { + public static final String TYPE = "MLModelTool"; + + @Setter + @Getter + private String name = TYPE; + @VisibleForTesting + static String DEFAULT_DESCRIPTION = "Use this tool to run any model."; + @Getter + @Setter + private String description = DEFAULT_DESCRIPTION; + @Getter + private Client client; + @Getter + private String modelId; + @Setter + private Parser inputParser; + @Setter + @Getter + private Parser outputParser; + + public MLModelTool(Client client, String modelId) { + this.client = client; + this.modelId = modelId; + + outputParser = o -> { + List mlModelOutputs = (List) o; + return mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap().get("response"); + }; + } + + @Override + public void run(Map parameters, ActionListener listener) { + RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(parameters).build(); + ActionRequest request = new MLPredictionTaskRequest( + modelId, + MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build() + ); + client.execute(MLPredictionTaskAction.INSTANCE, request, ActionListener.wrap(r -> { + ModelTensorOutput modelTensorOutput = (ModelTensorOutput) r.getOutput(); + modelTensorOutput.getMlModelOutputs(); + listener.onResponse((T) outputParser.parse(modelTensorOutput.getMlModelOutputs())); + }, e -> { + log.error("Failed to run model " + modelId, e); + listener.onFailure(e); + })); + } + + @Override + public String getType() { + return TYPE; + } + + @Override + public String getVersion() { + return null; + } + + @Override + public String getName() { + return this.name; + } + + @Override + public void setName(String s) { + this.name = s; + } + + @Override + public boolean validate(Map parameters) { + if (parameters == null || parameters.size() == 0) { + return false; + } + return true; + } + + public static class Factory implements Tool.Factory { + private Client client; + + private static Factory INSTANCE; + + public static Factory getInstance() { + if (INSTANCE != null) { + return INSTANCE; + } + synchronized (MLModelTool.class) { + if (INSTANCE != null) { + return INSTANCE; + } + INSTANCE = new Factory(); + return INSTANCE; + } + } + + public void init(Client client) { + this.client = client; + } + + @Override + public MLModelTool create(Map map) { + return new MLModelTool(client, (String) map.get("model_id")); + } + + @Override + public String getDefaultDescription() { + return DEFAULT_DESCRIPTION; + } + } +} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/AgentToolTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/AgentToolTests.java new file mode 100644 index 0000000000..431e609bba --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/AgentToolTests.java @@ -0,0 +1,127 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.tools; + +import static org.junit.Assert.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.verify; +import static org.opensearch.ml.engine.tools.AgentTool.DEFAULT_DESCRIPTION; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.client.Client; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.execute.agent.AgentMLInput; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.spi.tools.Parser; +import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.ml.common.transport.execute.MLExecuteTaskAction; +import org.opensearch.ml.common.transport.execute.MLExecuteTaskResponse; +import org.opensearch.ml.repackage.com.google.common.collect.ImmutableMap; + +public class AgentToolTests { + + @Mock + private Client client; + private Map indicesParams; + private Map otherParams; + private Map emptyParams; + @Mock + private Parser mockOutputParser; + + @Mock + private MLExecuteTaskResponse mockResponse; + + @Mock + private ActionListener listener; + + private AgentTool agentTool; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + AgentTool.Factory.getInstance().init(client); + + indicesParams = Map.of("index", "[\"foo\"]"); + otherParams = Map.of("other", "[\"bar\"]"); + emptyParams = Collections.emptyMap(); + } + + @Test + public void testAgenttestRunMethod() { + Map parameters = new HashMap<>(); + parameters.put("testKey", "testValue"); + AgentMLInput agentMLInput = AgentMLInput + .AgentMLInputBuilder() + .agentId("agentId") + .functionName(FunctionName.AGENT) + .inputDataset(RemoteInferenceInputDataSet.builder().parameters(parameters).build()) + .build(); + + ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("thought", "thought 1", "action", "action1")).build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + Tool tool = AgentTool.Factory.getInstance().create(Map.of("agent_id", "modelId")); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLExecuteTaskResponse.builder().functionName(FunctionName.AGENT).output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(eq(MLExecuteTaskAction.INSTANCE), any(), any()); + + tool.run(parameters, listener); + + // Verify interactions + verify(client).execute(any(), any(), any()); + verify(listener).onResponse(mlModelTensorOutput); + } + + @Test + public void testRunWithError() { + Map parameters = new HashMap<>(); + parameters.put("testKey", "testValue"); + + // Mocking the client.execute to simulate an error + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onFailure(new RuntimeException("Test Exception")); + return null; + }).when(client).execute(any(), any(), any()); + + // Running the test + Tool tool = AgentTool.Factory.getInstance().create(Map.of("agent_id", "modelId")); + tool.run(parameters, listener); + + // Verifying that onFailure was called + verify(listener).onFailure(any(RuntimeException.class)); + } + + @Test + public void testTool() { + Tool tool = AgentTool.Factory.getInstance().create(Collections.emptyMap()); + assertEquals(AgentTool.TYPE, tool.getName()); + assertEquals(AgentTool.TYPE, tool.getType()); + assertNull(tool.getVersion()); + assertTrue(tool.validate(indicesParams)); + assertTrue(tool.validate(otherParams)); + assertTrue(tool.validate(emptyParams)); + assertEquals(DEFAULT_DESCRIPTION, tool.getDescription()); + } +} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/MLModelToolTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/MLModelToolTests.java new file mode 100644 index 0000000000..e4bcb9db5d --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/MLModelToolTests.java @@ -0,0 +1,127 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.tools; + +import static org.junit.Assert.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; +import static org.opensearch.ml.engine.tools.MLModelTool.DEFAULT_DESCRIPTION; + +import java.util.*; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.client.Client; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.spi.tools.Parser; +import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; +import org.opensearch.ml.repackage.com.google.common.collect.ImmutableMap; + +public class MLModelToolTests { + + @Mock + private Client client; + private Map indicesParams; + private Map otherParams; + private Map emptyParams; + @Mock + private Parser mockOutputParser; + + @Mock + private ActionListener listener; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + MLModelTool.Factory.getInstance().init(client); + + indicesParams = Map.of("index", "[\"foo\"]"); + otherParams = Map.of("other", "[\"bar\"]"); + emptyParams = Collections.emptyMap(); + } + + @Test + public void testMLModelsWithOutputParser() { + ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("thought", "thought 1", "action", "action1")).build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + doAnswer(invocation -> { + + ActionListener actionListener = invocation.getArgument(2); + + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(), any()); + + Tool tool = MLModelTool.Factory.getInstance().create(Map.of("model_id", "modelId")); + tool.setOutputParser(mockOutputParser); + tool.run(otherParams, listener); + + verify(client).execute(any(), any(), any()); + verify(mockOutputParser).parse(any()); + ArgumentCaptor dataFrameArgumentCaptor = ArgumentCaptor.forClass(ModelTensorOutput.class); + verify(listener).onResponse(dataFrameArgumentCaptor.capture()); + } + + @Test + public void testOutputParserLambda() { + // Create a mock ModelTensors object + ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", "testResponse", "action", "action1")).build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + // Create the lambda expression for outputParser + Parser outputParser = o -> { + List outputs = (List) o; + return outputs.get(0).getMlModelTensors().get(0).getDataAsMap().get("response"); + }; + + // Invoke the lambda with the mock data + Object result = outputParser.parse(mlModelTensorOutput.getMlModelOutputs()); + + // Assert that the result matches the expected response + assertEquals("testResponse", result); + } + + @Test + public void testRunWithError() { + // Mocking the client.execute to simulate an error + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onFailure(new RuntimeException("Test Exception")); + return null; + }).when(client).execute(any(), any(), any()); + + // Running the test + Tool tool = MLModelTool.Factory.getInstance().create(Map.of("model_id", "modelId")); + tool.setOutputParser(mockOutputParser); + tool.run(otherParams, listener); + + // Verifying that onFailure was called + verify(listener).onFailure(any(RuntimeException.class)); + } + + @Test + public void testTool() { + Tool tool = MLModelTool.Factory.getInstance().create(Collections.emptyMap()); + assertEquals(MLModelTool.TYPE, tool.getName()); + assertEquals(MLModelTool.TYPE, tool.getType()); + assertNull(tool.getVersion()); + assertTrue(tool.validate(indicesParams)); + assertTrue(tool.validate(otherParams)); + assertFalse(tool.validate(emptyParams)); + assertEquals(DEFAULT_DESCRIPTION, tool.getDescription()); + } +}