diff --git a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLExecuteConnectorRequestTests.java b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLExecuteConnectorRequestTests.java index bdbe0bebfc..f95b236259 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLExecuteConnectorRequestTests.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLExecuteConnectorRequestTests.java @@ -31,13 +31,11 @@ public class MLExecuteConnectorRequestTests { private MLExecuteConnectorRequest mlExecuteConnectorRequest; private MLInput mlInput; private String connectorId; - private String action; @Before public void setUp(){ MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(Map.of("input", "hello")).build(); connectorId = "test_connector"; - action = "execute"; mlInput = RemoteInferenceMLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.CONNECTOR).build(); mlExecuteConnectorRequest = MLExecuteConnectorRequest.builder().connectorId(connectorId).mlInput(mlInput).build(); } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java index cb7f8a4fe8..31b0f5e420 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java @@ -7,6 +7,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.when; +import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT; import static org.opensearch.ml.common.utils.StringUtils.gson; import java.io.IOException; @@ -56,7 +57,7 @@ public void setUp() { public void processInput_NullInput() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("Input is null"); - ConnectorUtils.processInput(null, null, new HashMap<>(), null); + ConnectorUtils.processInput(PREDICT.name(), null, null, new HashMap<>(), null); } @Test @@ -66,7 +67,7 @@ public void processInput_TextDocsInputDataSet_NoPreprocessFunction() { ConnectorAction predictAction = ConnectorAction .builder() - .actionType(ConnectorAction.ActionType.PREDICT) + .actionType(PREDICT) .method("POST") .url("http://test.com/mock") .requestBody("{\"input\": \"${parameters.input}\"}") @@ -78,7 +79,7 @@ public void processInput_TextDocsInputDataSet_NoPreprocessFunction() { .protocol("http") .actions(Arrays.asList(predictAction)) .build(); - ConnectorUtils.processInput(mlInput, connector, new HashMap<>(), scriptService); + ConnectorUtils.processInput(PREDICT.name(), mlInput, connector, new HashMap<>(), scriptService); } @Test @@ -120,7 +121,7 @@ private void processInput_RemoteInferenceInputDataSet(String input, String expec ConnectorAction predictAction = ConnectorAction .builder() - .actionType(ConnectorAction.ActionType.PREDICT) + .actionType(PREDICT) .method("POST") .url("http://test.com/mock") .requestBody("{\"input\": \"${parameters.input}\"}") @@ -132,7 +133,7 @@ private void processInput_RemoteInferenceInputDataSet(String input, String expec .protocol("http") .actions(Arrays.asList(predictAction)) .build(); - ConnectorUtils.processInput(mlInput, connector, new HashMap<>(), scriptService); + ConnectorUtils.processInput(PREDICT.name(), mlInput, connector, new HashMap<>(), scriptService); Assert.assertEquals(expectedInput, ((RemoteInferenceInputDataSet) mlInput.getInputDataset()).getParameters().get("input")); } @@ -168,14 +169,14 @@ public void processInput_TextDocsInputDataSet_PreprocessFunction_MultiTextDoc() public void processOutput_NullResponse() throws IOException { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("model response is null"); - ConnectorUtils.processOutput(null, null, null, null, null); + ConnectorUtils.processOutput(PREDICT.name(), null, null, null, null, null); } @Test public void processOutput_NoPostprocessFunction_jsonResponse() throws IOException { ConnectorAction predictAction = ConnectorAction .builder() - .actionType(ConnectorAction.ActionType.PREDICT) + .actionType(PREDICT) .method("POST") .url("http://test.com/mock") .requestBody("{\"input\": \"${parameters.input}\"}") @@ -192,7 +193,8 @@ public void processOutput_NoPostprocessFunction_jsonResponse() throws IOExceptio .build(); String modelResponse = "{\"object\":\"list\",\"data\":[{\"object\":\"embedding\",\"index\":0,\"embedding\":[-0.014555434,-0.0002135904,0.0035105038]}],\"model\":\"text-embedding-ada-002-v2\",\"usage\":{\"prompt_tokens\":5,\"total_tokens\":5}}"; - ModelTensors tensors = ConnectorUtils.processOutput(modelResponse, connector, scriptService, ImmutableMap.of(), null); + ModelTensors tensors = ConnectorUtils + .processOutput(PREDICT.name(), modelResponse, connector, scriptService, ImmutableMap.of(), null); Assert.assertEquals(1, tensors.getMlModelTensors().size()); Assert.assertEquals("response", tensors.getMlModelTensors().get(0).getName()); Assert.assertEquals(4, tensors.getMlModelTensors().get(0).getDataAsMap().size()); @@ -206,7 +208,7 @@ public void processOutput_PostprocessFunction() throws IOException { ConnectorAction predictAction = ConnectorAction .builder() - .actionType(ConnectorAction.ActionType.PREDICT) + .actionType(PREDICT) .method("POST") .url("http://test.com/mock") .requestBody("{\"input\": \"${parameters.input}\"}") @@ -224,7 +226,8 @@ public void processOutput_PostprocessFunction() throws IOException { .build(); String modelResponse = "{\"object\":\"list\",\"data\":[{\"object\":\"embedding\",\"index\":0,\"embedding\":[-0.014555434,-0.0002135904,0.0035105038]}],\"model\":\"text-embedding-ada-002-v2\",\"usage\":{\"prompt_tokens\":5,\"total_tokens\":5}}"; - ModelTensors tensors = ConnectorUtils.processOutput(modelResponse, connector, scriptService, ImmutableMap.of(), null); + ModelTensors tensors = ConnectorUtils + .processOutput(PREDICT.name(), modelResponse, connector, scriptService, ImmutableMap.of(), null); Assert.assertEquals(1, tensors.getMlModelTensors().size()); Assert.assertEquals("sentence_embedding", tensors.getMlModelTensors().get(0).getName()); Assert.assertNull(tensors.getMlModelTensors().get(0).getDataAsMap()); @@ -246,7 +249,7 @@ private void processInput_TextDocsInputDataSet_PreprocessFunction( ConnectorAction predictAction = ConnectorAction .builder() - .actionType(ConnectorAction.ActionType.PREDICT) + .actionType(PREDICT) .method("POST") .url("http://test.com/mock") .requestBody(requestBody) @@ -263,7 +266,7 @@ private void processInput_TextDocsInputDataSet_PreprocessFunction( .actions(Arrays.asList(predictAction)) .build(); RemoteInferenceInputDataSet remoteInferenceInputDataSet = ConnectorUtils - .processInput(mlInput, connector, new HashMap<>(), scriptService); + .processInput(PREDICT.name(), mlInput, connector, new HashMap<>(), scriptService); Assert.assertNotNull(remoteInferenceInputDataSet.getParameters()); Assert.assertEquals(1, remoteInferenceInputDataSet.getParameters().size()); Assert.assertEquals(expectedProcessedInput, remoteInferenceInputDataSet.getParameters().get(resultKey)); diff --git a/plugin/src/test/java/org/opensearch/ml/action/connector/ExecuteConnectorTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/connector/ExecuteConnectorTransportActionTests.java index 0753eac8a8..bdc267ccc8 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/connector/ExecuteConnectorTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/connector/ExecuteConnectorTransportActionTests.java @@ -95,7 +95,6 @@ public void setup() { when(clusterService.state()).thenReturn(testState); when(request.getConnectorId()).thenReturn("test_connector_id"); - when(request.getConnectorAction()).thenReturn("execute"); Settings settings = Settings.builder().build(); ThreadContext threadContext = new ThreadContext(settings);