diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java index cba0d2ee6a..8a33edca7d 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java @@ -34,6 +34,9 @@ @ToolAnnotation(MLModelTool.TYPE) public class MLModelTool implements Tool { public static final String TYPE = "MLModelTool"; + public static final String RESPONSE_FIELD = "response_field"; + public static final String MODEL_ID_FIELD = "model_id"; + public static final String DEFAULT_RESPONSE_FIELD = "response"; @Setter @Getter @@ -52,14 +55,18 @@ public class MLModelTool implements Tool { @Setter @Getter private Parser outputParser; + @Setter + @Getter + private String responseField; - public MLModelTool(Client client, String modelId) { + public MLModelTool(Client client, String modelId, String responseField) { this.client = client; this.modelId = modelId; + this.responseField = responseField; outputParser = o -> { List mlModelOutputs = (List) o; - return mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap().get("response"); + return mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap().get(responseField); }; } @@ -132,7 +139,11 @@ public void init(Client client) { @Override public MLModelTool create(Map map) { - return new MLModelTool(client, (String) map.get("model_id")); + return new MLModelTool( + client, + (String) map.get(MODEL_ID_FIELD), + (String) map.getOrDefault(RESPONSE_FIELD, DEFAULT_RESPONSE_FIELD) + ); } @Override diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/MLModelToolTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/MLModelToolTests.java index e4bcb9db5d..f6b54b56be 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/MLModelToolTests.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/MLModelToolTests.java @@ -5,13 +5,22 @@ package org.opensearch.ml.engine.tools; -import static org.junit.Assert.*; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.*; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.verify; import static org.opensearch.ml.engine.tools.MLModelTool.DEFAULT_DESCRIPTION; -import java.util.*; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; import org.junit.Before; import org.junit.Test; @@ -53,7 +62,73 @@ public void setup() { } @Test - public void testMLModelsWithOutputParser() { + public void testMLModelsWithDefaultOutputParserAndDefaultResponseField() throws ExecutionException, InterruptedException { + ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", "response 1", "action", "action1")).build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + doAnswer(invocation -> { + + ActionListener actionListener = invocation.getArgument(2); + + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(), any()); + + Tool tool = MLModelTool.Factory.getInstance().create(Map.of("model_id", "modelId")); + final CompletableFuture future = new CompletableFuture<>(); + ActionListener listener = ActionListener.wrap(r -> { future.complete(r); }, e -> { future.completeExceptionally(e); }); + tool.run(null, listener); + + future.join(); + assertEquals("response 1", future.get()); + } + + @Test + public void testMLModelsWithDefaultOutputParserAndCustomizedResponseField() throws ExecutionException, InterruptedException { + ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", "response 1", "action", "action1")).build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + doAnswer(invocation -> { + + ActionListener actionListener = invocation.getArgument(2); + + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(), any()); + + Tool tool = MLModelTool.Factory.getInstance().create(Map.of("model_id", "modelId", "response_field", "action")); + final CompletableFuture future = new CompletableFuture<>(); + ActionListener listener = ActionListener.wrap(r -> { future.complete(r); }, e -> { future.completeExceptionally(e); }); + tool.run(null, listener); + + future.join(); + assertEquals("action1", future.get()); + } + + @Test + public void testMLModelsWithDefaultOutputParserAndMalformedResponseField() throws ExecutionException, InterruptedException { + ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", "response 1", "action", "action1")).build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + doAnswer(invocation -> { + + ActionListener actionListener = invocation.getArgument(2); + + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(), any()); + + Tool tool = MLModelTool.Factory.getInstance().create(Map.of("model_id", "modelId", "response_field", "malformed field")); + final CompletableFuture future = new CompletableFuture<>(); + ActionListener listener = ActionListener.wrap(r -> { future.complete(r); }, e -> { future.completeExceptionally(e); }); + tool.run(null, listener); + + future.join(); + assertEquals(null, future.get()); + } + + @Test + public void testMLModelsWithCustomizedOutputParser() { ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("thought", "thought 1", "action", "action1")).build(); ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build();