Skip to content

Commit

Permalink
Fix failure UTs and add more UTs
Browse files Browse the repository at this point in the history
Signed-off-by: zane-neo <[email protected]>
  • Loading branch information
zane-neo committed Feb 4, 2024
1 parent 726dd25 commit 4b2f28b
Show file tree
Hide file tree
Showing 9 changed files with 35 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.transport.MLTaskResponse;
Expand Down Expand Up @@ -38,8 +37,8 @@ default MLOutput predict(MLInput mlInput) {
throw new IllegalStateException("Method is not implemented");
}

default void predict(MLInput mlInput, MLTask mlTask, ActionListener<MLTaskResponse> actionListener) {

default void asyncPredict(MLInput mlInput, ActionListener<MLTaskResponse> actionListener) {
throw new IllegalStateException("Method is not implemented");
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import org.opensearch.client.Client;
import org.opensearch.common.util.TokenBucket;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.input.MLInput;
Expand Down Expand Up @@ -56,12 +55,13 @@ public MLOutput predict(MLInput mlInput, MLModel model) {
}

@Override
public void predict(MLInput mlInput, MLTask mlTask, ActionListener<MLTaskResponse> actionListener) {
public void asyncPredict(MLInput mlInput, ActionListener<MLTaskResponse> actionListener) {
if (!isModelReady()) {
actionListener
.onFailure(
new IllegalArgumentException("Model not ready yet. Please run this first: POST /_plugins/_ml/models/<model_id>/_deploy")
);
return;
}
try {
connectorExecutor.executePredict(mlInput, actionListener);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
import java.util.Locale;
import java.util.Optional;

import lombok.extern.log4j.Log4j2;
import org.apache.commons.lang3.math.NumberUtils;

import lombok.extern.log4j.Log4j2;
import software.amazon.awssdk.http.async.SdkAsyncHttpClient;
import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,12 @@ public void executePredict_RemoteInferenceInput_invalidIp() {
}

@Test
public void executePredict_RemoteInferenceInput_unformattedIp() {
public void executePredict_RemoteInferenceInput_EmptyIpAddress() {
ConnectorAction predictAction = ConnectorAction
.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.method("POST")
.url("http://0177.1/mock")
.url("http:///mock")
.requestBody("{\"input\": \"${parameters.input}\"}")
.build();
Map<String, String> credential = ImmutableMap
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@

package org.opensearch.ml.engine.algorithms.remote;

import static org.junit.Assert.assertEquals;

import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch;

import org.junit.Before;
Expand All @@ -21,7 +22,6 @@
import org.mockito.MockitoAnnotations;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.connector.AwsConnector;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.connector.ConnectorAction;
import org.opensearch.ml.common.connector.HttpConnector;
Expand All @@ -32,13 +32,6 @@

import com.google.common.collect.ImmutableMap;

import static org.junit.Assert.assertEquals;
import static org.mockito.Mockito.spy;
import static org.opensearch.ml.common.connector.AbstractConnector.ACCESS_KEY_FIELD;
import static org.opensearch.ml.common.connector.AbstractConnector.SECRET_KEY_FIELD;
import static org.opensearch.ml.common.connector.HttpConnector.REGION_FIELD;
import static org.opensearch.ml.common.connector.HttpConnector.SERVICE_NAME_FIELD;

public class HttpJsonConnectorExecutorTest {
@Rule
public ExpectedException exceptionRule = ExpectedException.none();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@

package org.opensearch.ml.engine.algorithms.remote;

import static org.junit.Assert.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import java.util.Arrays;
Expand All @@ -18,14 +21,17 @@
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.connector.ConnectorAction;
import org.opensearch.ml.common.connector.ConnectorProtocols;
import org.opensearch.ml.common.connector.HttpConnector;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.engine.encryptor.Encryptor;
import org.opensearch.ml.engine.encryptor.EncryptorImpl;

Expand Down Expand Up @@ -61,19 +67,28 @@ public void predict_ModelNotDeployed() {

@Test
public void predict_NullConnectorExecutor() {
exceptionRule.expect(RuntimeException.class);
exceptionRule.expectMessage("Model not ready yet");
remoteModel.predict(mlInput);
ActionListener<MLTaskResponse> actionListener = mock(ActionListener.class);
remoteModel.asyncPredict(mlInput, actionListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assert argumentCaptor.getValue() instanceof RuntimeException;
assertEquals(
"Model not ready yet. Please run this first: POST /_plugins/_ml/models/<model_id>/_deploy",
argumentCaptor.getValue().getMessage()
);
}

@Test
public void predict_ModelDeployed_WrongInput() {
exceptionRule.expect(RuntimeException.class);
exceptionRule.expectMessage("pre_process_function not defined in connector");
Connector connector = createConnector(ImmutableMap.of("Authorization", "Bearer ${credential.key}"));
when(mlModel.getConnector()).thenReturn(connector);
remoteModel.initModel(mlModel, ImmutableMap.of(), encryptor);
remoteModel.predict(mlInput);
ActionListener<MLTaskResponse> actionListener = mock(ActionListener.class);
remoteModel.asyncPredict(mlInput, actionListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assert argumentCaptor.getValue() instanceof RuntimeException;
assertEquals("pre_process_function not defined in connector", argumentCaptor.getValue().getMessage());
}

@Test
Expand Down Expand Up @@ -105,8 +120,8 @@ public void initModel_WithHeader() {
Assert.assertNotNull(executor);
Assert.assertNull(decryptedHeaders);
Assert.assertNotNull(executor.getConnector().getDecryptedHeaders());
Assert.assertEquals(1, executor.getConnector().getDecryptedHeaders().size());
Assert.assertEquals("Bearer test_api_key", executor.getConnector().getDecryptedHeaders().get("Authorization"));
assertEquals(1, executor.getConnector().getDecryptedHeaders().size());
assertEquals("Bearer test_api_key", executor.getConnector().getDecryptedHeaders().get("Authorization"));

remoteModel.close();
Assert.assertNull(remoteModel.getConnectorExecutor());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,19 @@

import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static software.amazon.awssdk.http.SdkHttpMethod.POST;

import java.net.URI;
import java.net.UnknownHostException;
import java.nio.charset.Charset;
import java.util.Arrays;
import java.util.Map;

import com.google.common.collect.ImmutableMap;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;

import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.connector.ConnectorAction;
import org.opensearch.ml.common.connector.HttpConnector;
import org.opensearch.ml.engine.algorithms.remote.ConnectorUtils;
import software.amazon.awssdk.core.sync.RequestBody;

import software.amazon.awssdk.http.SdkHttpFullRequest;
import software.amazon.awssdk.http.SdkHttpMethod;
import software.amazon.awssdk.http.async.SdkAsyncHttpClient;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ private void predict(String modelId, MLTask mlTask, MLInput mlInput, ActionListe
mlModelManager.trackPredictDuration(modelId, startTime);
internalListener.onResponse(output);
}, internalListener::onFailure);
predictor.predict(mlInput, mlTask, trackPredictDurationListener);
predictor.asyncPredict(mlInput, trackPredictDurationListener);
} else {
MLOutput output = mlModelManager.trackPredictDuration(modelId, () -> predictor.predict(mlInput));
if (output instanceof MLPredictionOutput) {
Expand Down

0 comments on commit 4b2f28b

Please sign in to comment.