Skip to content

Commit

Permalink
fix ut
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn committed Jun 6, 2024
1 parent 2d25764 commit 942d0fe
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import static org.junit.Assert.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.mock;
Expand Down Expand Up @@ -260,7 +261,7 @@ public void executePredict_TextDocsInferenceInput_withStepSize() {
);

Mockito.verify(actionListener, times(0)).onFailure(any());
Mockito.verify(executor, times(3)).preparePayloadAndInvoke(PREDICT.name(), any(), any(), any());
Mockito.verify(executor, times(3)).preparePayloadAndInvoke(anyString(), any(), any(), any());
}

@Test
Expand Down Expand Up @@ -295,8 +296,8 @@ public void executePredict_TextDocsInferenceInput_withStepSize_returnOrderedResu
when(client.threadPool()).thenReturn(threadPool);
when(threadPool.getThreadContext()).thenReturn(threadContext);
doAnswer(invocation -> {
MLInput mlInput = invocation.getArgument(0);
ActionListener<Tuple<Integer, ModelTensors>> actionListener = invocation.getArgument(4);
MLInput mlInput = invocation.getArgument(1);
ActionListener<Tuple<Integer, ModelTensors>> actionListener = invocation.getArgument(5);
String doc = ((TextDocsInputDataSet) mlInput.getInputDataset()).getDocs().get(0);
Integer idx = Integer.parseInt(doc.substring(doc.length() - 1));
actionListener.onResponse(new Tuple<>(3 - idx, new ModelTensors(modelTensors.subList(3 - idx, 4 - idx))));
Expand Down Expand Up @@ -355,8 +356,8 @@ public void executePredict_TextDocsInferenceInput_withStepSize_partiallyFailed_t
when(client.threadPool()).thenReturn(threadPool);
when(threadPool.getThreadContext()).thenReturn(threadContext);
doAnswer(invocation -> {
MLInput mlInput = invocation.getArgument(0);
ActionListener<Tuple<Integer, ModelTensors>> actionListener = invocation.getArgument(4);
MLInput mlInput = invocation.getArgument(1);
ActionListener<Tuple<Integer, ModelTensors>> actionListener = invocation.getArgument(5);
String doc = ((TextDocsInputDataSet) mlInput.getInputDataset()).getDocs().get(0);
if (doc.endsWith("1")) {
actionListener.onFailure(new OpenSearchStatusException("test failure", RestStatus.BAD_REQUEST));
Expand Down Expand Up @@ -412,8 +413,8 @@ public void executePredict_TextDocsInferenceInput_withStepSize_failWithMultipleF
when(client.threadPool()).thenReturn(threadPool);
when(threadPool.getThreadContext()).thenReturn(threadContext);
doAnswer(invocation -> {
MLInput mlInput = invocation.getArgument(0);
ActionListener<Tuple<Integer, ModelTensors>> actionListener = invocation.getArgument(4);
MLInput mlInput = invocation.getArgument(1);
ActionListener<Tuple<Integer, ModelTensors>> actionListener = invocation.getArgument(5);
String doc = ((TextDocsInputDataSet) mlInput.getInputDataset()).getDocs().get(0);
if (!doc.endsWith("1")) {
actionListener.onFailure(new OpenSearchStatusException("test failure", RestStatus.BAD_REQUEST));
Expand Down Expand Up @@ -566,7 +567,7 @@ public void executePredict_TextDocsInferenceInput_withoutStepSize_emptyPredictio
ArgumentCaptor<Exception> exceptionArgumentCaptor = ArgumentCaptor.forClass(Exception.class);
Mockito.verify(actionListener, times(1)).onFailure(exceptionArgumentCaptor.capture());
assert exceptionArgumentCaptor.getValue() instanceof IllegalArgumentException;
assert "no predict action found".equals(exceptionArgumentCaptor.getValue().getMessage());
assert "no PREDICT action found".equals(exceptionArgumentCaptor.getValue().getMessage());
}

@Test
Expand Down Expand Up @@ -729,7 +730,7 @@ public void invokeRemoteServiceWithRetry_whenRetryableException_thenRetryUntilSu

@Override
public Void answer(InvocationOnMock invocation) {
ActionListener<Tuple<Integer, ModelTensors>> actionListener = invocation.getArgument(4);
ActionListener<Tuple<Integer, ModelTensors>> actionListener = invocation.getArgument(5);
// fail the first 10 invocation, then success
if (countOfInvocation++ < 10) {
actionListener.onFailure(new RemoteConnectorThrottlingException("test failure retryable", RestStatus.BAD_REQUEST));
Expand Down Expand Up @@ -776,7 +777,7 @@ public void invokeRemoteServiceWithRetry_whenRetryExceedMaxRetryTimes_thenCallOn

@Override
public Void answer(InvocationOnMock invocation) {
ActionListener<Tuple<Integer, ModelTensors>> actionListener = invocation.getArgument(4);
ActionListener<Tuple<Integer, ModelTensors>> actionListener = invocation.getArgument(5);
// fail the first 10 invocation, then success
if (countOfInvocation++ < 10) {
actionListener.onFailure(new RemoteConnectorThrottlingException("test failure retryable", RestStatus.BAD_REQUEST));
Expand Down Expand Up @@ -823,7 +824,7 @@ public void invokeRemoteServiceWithRetry_whenNonRetryableException_thenCallOnFai

@Override
public Void answer(InvocationOnMock invocation) {
ActionListener<Tuple<Integer, ModelTensors>> actionListener = invocation.getArgument(4);
ActionListener<Tuple<Integer, ModelTensors>> actionListener = invocation.getArgument(5);
// fail the first 2 invocation with retryable exception, then fail with non-retryable exception
if (countOfInvocation++ < 2) {
actionListener.onFailure(new RemoteConnectorThrottlingException("test failure retryable", RestStatus.BAD_REQUEST));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,6 @@ public void test_onComplete_processOutputFail_onFailure() {

ArgumentCaptor<MLException> captor = ArgumentCaptor.forClass(MLException.class);
verify(actionListener, times(1)).onFailure(captor.capture());
assert captor.getValue().getMessage().equals("Fail to execute predict in aws connector");
assert captor.getValue().getMessage().equals("Fail to execute PREDICT in aws connector");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_MODEL_URL;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_URL_REGEX;
Expand Down Expand Up @@ -523,7 +522,7 @@ public void test_execute_registerRemoteModel_withInternalConnector_success() {
when(input.getFunctionName()).thenReturn(FunctionName.REMOTE);
Connector connector = mock(Connector.class);
when(input.getConnector()).thenReturn(connector);
when(connector.getActionEndpoint(PREDICT.name(), any(Map.class))).thenReturn("https://api.openai.com");
when(connector.getActionEndpoint(anyString(), any(Map.class))).thenReturn("https://api.openai.com");
MLCreateConnectorResponse mlCreateConnectorResponse = mock(MLCreateConnectorResponse.class);
doAnswer(invocation -> {
ActionListener<MLCreateConnectorResponse> listener = invocation.getArgument(2);
Expand Down Expand Up @@ -557,7 +556,7 @@ public void test_execute_registerRemoteModel_withInternalConnector_predictEndpoi
when(request.getRegisterModelInput()).thenReturn(input);
when(input.getFunctionName()).thenReturn(FunctionName.REMOTE);
Connector connector = mock(Connector.class);
when(connector.getActionEndpoint(PREDICT.name(), any(Map.class))).thenReturn(null);
when(connector.getActionEndpoint(anyString(), any(Map.class))).thenReturn(null);
when(input.getConnector()).thenReturn(connector);
transportRegisterModelAction.doExecute(task, request, actionListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
Expand Down

0 comments on commit 942d0fe

Please sign in to comment.