diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java index 3d209888f5..cb192e83f9 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java @@ -7,6 +7,7 @@ import static org.junit.Assert.assertEquals; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.mock; @@ -260,7 +261,7 @@ public void executePredict_TextDocsInferenceInput_withStepSize() { ); Mockito.verify(actionListener, times(0)).onFailure(any()); - Mockito.verify(executor, times(3)).preparePayloadAndInvoke(PREDICT.name(), any(), any(), any()); + Mockito.verify(executor, times(3)).preparePayloadAndInvoke(anyString(), any(), any(), any()); } @Test @@ -295,8 +296,8 @@ public void executePredict_TextDocsInferenceInput_withStepSize_returnOrderedResu when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); doAnswer(invocation -> { - MLInput mlInput = invocation.getArgument(0); - ActionListener> actionListener = invocation.getArgument(4); + MLInput mlInput = invocation.getArgument(1); + ActionListener> actionListener = invocation.getArgument(5); String doc = ((TextDocsInputDataSet) mlInput.getInputDataset()).getDocs().get(0); Integer idx = Integer.parseInt(doc.substring(doc.length() - 1)); actionListener.onResponse(new Tuple<>(3 - idx, new ModelTensors(modelTensors.subList(3 - idx, 4 - idx)))); @@ -355,8 +356,8 @@ public void executePredict_TextDocsInferenceInput_withStepSize_partiallyFailed_t when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); doAnswer(invocation -> { - MLInput mlInput = invocation.getArgument(0); - ActionListener> actionListener = invocation.getArgument(4); + MLInput mlInput = invocation.getArgument(1); + ActionListener> actionListener = invocation.getArgument(5); String doc = ((TextDocsInputDataSet) mlInput.getInputDataset()).getDocs().get(0); if (doc.endsWith("1")) { actionListener.onFailure(new OpenSearchStatusException("test failure", RestStatus.BAD_REQUEST)); @@ -412,8 +413,8 @@ public void executePredict_TextDocsInferenceInput_withStepSize_failWithMultipleF when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); doAnswer(invocation -> { - MLInput mlInput = invocation.getArgument(0); - ActionListener> actionListener = invocation.getArgument(4); + MLInput mlInput = invocation.getArgument(1); + ActionListener> actionListener = invocation.getArgument(5); String doc = ((TextDocsInputDataSet) mlInput.getInputDataset()).getDocs().get(0); if (!doc.endsWith("1")) { actionListener.onFailure(new OpenSearchStatusException("test failure", RestStatus.BAD_REQUEST)); @@ -566,7 +567,7 @@ public void executePredict_TextDocsInferenceInput_withoutStepSize_emptyPredictio ArgumentCaptor exceptionArgumentCaptor = ArgumentCaptor.forClass(Exception.class); Mockito.verify(actionListener, times(1)).onFailure(exceptionArgumentCaptor.capture()); assert exceptionArgumentCaptor.getValue() instanceof IllegalArgumentException; - assert "no predict action found".equals(exceptionArgumentCaptor.getValue().getMessage()); + assert "no PREDICT action found".equals(exceptionArgumentCaptor.getValue().getMessage()); } @Test @@ -729,7 +730,7 @@ public void invokeRemoteServiceWithRetry_whenRetryableException_thenRetryUntilSu @Override public Void answer(InvocationOnMock invocation) { - ActionListener> actionListener = invocation.getArgument(4); + ActionListener> actionListener = invocation.getArgument(5); // fail the first 10 invocation, then success if (countOfInvocation++ < 10) { actionListener.onFailure(new RemoteConnectorThrottlingException("test failure retryable", RestStatus.BAD_REQUEST)); @@ -776,7 +777,7 @@ public void invokeRemoteServiceWithRetry_whenRetryExceedMaxRetryTimes_thenCallOn @Override public Void answer(InvocationOnMock invocation) { - ActionListener> actionListener = invocation.getArgument(4); + ActionListener> actionListener = invocation.getArgument(5); // fail the first 10 invocation, then success if (countOfInvocation++ < 10) { actionListener.onFailure(new RemoteConnectorThrottlingException("test failure retryable", RestStatus.BAD_REQUEST)); @@ -823,7 +824,7 @@ public void invokeRemoteServiceWithRetry_whenNonRetryableException_thenCallOnFai @Override public Void answer(InvocationOnMock invocation) { - ActionListener> actionListener = invocation.getArgument(4); + ActionListener> actionListener = invocation.getArgument(5); // fail the first 2 invocation with retryable exception, then fail with non-retryable exception if (countOfInvocation++ < 2) { actionListener.onFailure(new RemoteConnectorThrottlingException("test failure retryable", RestStatus.BAD_REQUEST)); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandlerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandlerTest.java index ed8f007fc7..f6c9b76071 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandlerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandlerTest.java @@ -422,6 +422,6 @@ public void test_onComplete_processOutputFail_onFailure() { ArgumentCaptor captor = ArgumentCaptor.forClass(MLException.class); verify(actionListener, times(1)).onFailure(captor.capture()); - assert captor.getValue().getMessage().equals("Fail to execute predict in aws connector"); + assert captor.getValue().getMessage().equals("Fail to execute PREDICT in aws connector"); } } diff --git a/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java index dc46b06110..d936f199a2 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java @@ -14,7 +14,6 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_MODEL_URL; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_URL_REGEX; @@ -523,7 +522,7 @@ public void test_execute_registerRemoteModel_withInternalConnector_success() { when(input.getFunctionName()).thenReturn(FunctionName.REMOTE); Connector connector = mock(Connector.class); when(input.getConnector()).thenReturn(connector); - when(connector.getActionEndpoint(PREDICT.name(), any(Map.class))).thenReturn("https://api.openai.com"); + when(connector.getActionEndpoint(anyString(), any(Map.class))).thenReturn("https://api.openai.com"); MLCreateConnectorResponse mlCreateConnectorResponse = mock(MLCreateConnectorResponse.class); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -557,7 +556,7 @@ public void test_execute_registerRemoteModel_withInternalConnector_predictEndpoi when(request.getRegisterModelInput()).thenReturn(input); when(input.getFunctionName()).thenReturn(FunctionName.REMOTE); Connector connector = mock(Connector.class); - when(connector.getActionEndpoint(PREDICT.name(), any(Map.class))).thenReturn(null); + when(connector.getActionEndpoint(anyString(), any(Map.class))).thenReturn(null); when(input.getConnector()).thenReturn(connector); transportRegisterModelAction.doExecute(task, request, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class);