Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn committed Jun 6, 2024
1 parent 39ab97a commit f21261b
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,11 @@ public class MLExecuteConnectorRequestTests {
private MLExecuteConnectorRequest mlExecuteConnectorRequest;
private MLInput mlInput;
private String connectorId;
private String action;

@Before
public void setUp(){
MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(Map.of("input", "hello")).build();
connectorId = "test_connector";
action = "execute";
mlInput = RemoteInferenceMLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.CONNECTOR).build();
mlExecuteConnectorRequest = MLExecuteConnectorRequest.builder().connectorId(connectorId).mlInput(mlInput).build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.when;
import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT;
import static org.opensearch.ml.common.utils.StringUtils.gson;

import java.io.IOException;
Expand Down Expand Up @@ -56,7 +57,7 @@ public void setUp() {
public void processInput_NullInput() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Input is null");
ConnectorUtils.processInput(null, null, new HashMap<>(), null);
ConnectorUtils.processInput(PREDICT.name(), null, null, new HashMap<>(), null);
}

@Test
Expand All @@ -66,7 +67,7 @@ public void processInput_TextDocsInputDataSet_NoPreprocessFunction() {

ConnectorAction predictAction = ConnectorAction
.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.actionType(PREDICT)
.method("POST")
.url("http://test.com/mock")
.requestBody("{\"input\": \"${parameters.input}\"}")
Expand All @@ -78,7 +79,7 @@ public void processInput_TextDocsInputDataSet_NoPreprocessFunction() {
.protocol("http")
.actions(Arrays.asList(predictAction))
.build();
ConnectorUtils.processInput(mlInput, connector, new HashMap<>(), scriptService);
ConnectorUtils.processInput(PREDICT.name(), mlInput, connector, new HashMap<>(), scriptService);
}

@Test
Expand Down Expand Up @@ -120,7 +121,7 @@ private void processInput_RemoteInferenceInputDataSet(String input, String expec

ConnectorAction predictAction = ConnectorAction
.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.actionType(PREDICT)
.method("POST")
.url("http://test.com/mock")
.requestBody("{\"input\": \"${parameters.input}\"}")
Expand All @@ -132,7 +133,7 @@ private void processInput_RemoteInferenceInputDataSet(String input, String expec
.protocol("http")
.actions(Arrays.asList(predictAction))
.build();
ConnectorUtils.processInput(mlInput, connector, new HashMap<>(), scriptService);
ConnectorUtils.processInput(PREDICT.name(), mlInput, connector, new HashMap<>(), scriptService);
Assert.assertEquals(expectedInput, ((RemoteInferenceInputDataSet) mlInput.getInputDataset()).getParameters().get("input"));
}

Expand Down Expand Up @@ -168,14 +169,14 @@ public void processInput_TextDocsInputDataSet_PreprocessFunction_MultiTextDoc()
public void processOutput_NullResponse() throws IOException {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("model response is null");
ConnectorUtils.processOutput(null, null, null, null, null);
ConnectorUtils.processOutput(PREDICT.name(), null, null, null, null, null);
}

@Test
public void processOutput_NoPostprocessFunction_jsonResponse() throws IOException {
ConnectorAction predictAction = ConnectorAction
.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.actionType(PREDICT)
.method("POST")
.url("http://test.com/mock")
.requestBody("{\"input\": \"${parameters.input}\"}")
Expand All @@ -192,7 +193,8 @@ public void processOutput_NoPostprocessFunction_jsonResponse() throws IOExceptio
.build();
String modelResponse =
"{\"object\":\"list\",\"data\":[{\"object\":\"embedding\",\"index\":0,\"embedding\":[-0.014555434,-0.0002135904,0.0035105038]}],\"model\":\"text-embedding-ada-002-v2\",\"usage\":{\"prompt_tokens\":5,\"total_tokens\":5}}";
ModelTensors tensors = ConnectorUtils.processOutput(modelResponse, connector, scriptService, ImmutableMap.of(), null);
ModelTensors tensors = ConnectorUtils
.processOutput(PREDICT.name(), modelResponse, connector, scriptService, ImmutableMap.of(), null);
Assert.assertEquals(1, tensors.getMlModelTensors().size());
Assert.assertEquals("response", tensors.getMlModelTensors().get(0).getName());
Assert.assertEquals(4, tensors.getMlModelTensors().get(0).getDataAsMap().size());
Expand All @@ -206,7 +208,7 @@ public void processOutput_PostprocessFunction() throws IOException {

ConnectorAction predictAction = ConnectorAction
.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.actionType(PREDICT)
.method("POST")
.url("http://test.com/mock")
.requestBody("{\"input\": \"${parameters.input}\"}")
Expand All @@ -224,7 +226,8 @@ public void processOutput_PostprocessFunction() throws IOException {
.build();
String modelResponse =
"{\"object\":\"list\",\"data\":[{\"object\":\"embedding\",\"index\":0,\"embedding\":[-0.014555434,-0.0002135904,0.0035105038]}],\"model\":\"text-embedding-ada-002-v2\",\"usage\":{\"prompt_tokens\":5,\"total_tokens\":5}}";
ModelTensors tensors = ConnectorUtils.processOutput(modelResponse, connector, scriptService, ImmutableMap.of(), null);
ModelTensors tensors = ConnectorUtils
.processOutput(PREDICT.name(), modelResponse, connector, scriptService, ImmutableMap.of(), null);
Assert.assertEquals(1, tensors.getMlModelTensors().size());
Assert.assertEquals("sentence_embedding", tensors.getMlModelTensors().get(0).getName());
Assert.assertNull(tensors.getMlModelTensors().get(0).getDataAsMap());
Expand All @@ -246,7 +249,7 @@ private void processInput_TextDocsInputDataSet_PreprocessFunction(

ConnectorAction predictAction = ConnectorAction
.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.actionType(PREDICT)
.method("POST")
.url("http://test.com/mock")
.requestBody(requestBody)
Expand All @@ -263,7 +266,7 @@ private void processInput_TextDocsInputDataSet_PreprocessFunction(
.actions(Arrays.asList(predictAction))
.build();
RemoteInferenceInputDataSet remoteInferenceInputDataSet = ConnectorUtils
.processInput(mlInput, connector, new HashMap<>(), scriptService);
.processInput(PREDICT.name(), mlInput, connector, new HashMap<>(), scriptService);
Assert.assertNotNull(remoteInferenceInputDataSet.getParameters());
Assert.assertEquals(1, remoteInferenceInputDataSet.getParameters().size());
Assert.assertEquals(expectedProcessedInput, remoteInferenceInputDataSet.getParameters().get(resultKey));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ public void setup() {
when(clusterService.state()).thenReturn(testState);

when(request.getConnectorId()).thenReturn("test_connector_id");
when(request.getConnectorAction()).thenReturn("execute");

Settings settings = Settings.builder().build();
ThreadContext threadContext = new ThreadContext(settings);
Expand Down

0 comments on commit f21261b

Please sign in to comment.