Skip to content

Commit

Permalink
updated testcases for MLEngine.java
Browse files Browse the repository at this point in the history
Signed-off-by: Divit Rawal <[email protected]>
  • Loading branch information
divitr committed Oct 23, 2023
1 parent e9e3834 commit 1d667fd
Showing 1 changed file with 46 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.mockito.Mock;
import org.mockito.MockedStatic;
import org.mockito.Mockito;
import org.opensearch.core.common.io.stream.StreamOutput;
Expand Down Expand Up @@ -41,6 +42,7 @@
import java.util.UUID;

import static org.junit.Assert.assertEquals;
import static org.mockito.Mockito.when;
import static org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionPredictionDataFrame;
import static org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionTrainDataFrame;
import static org.opensearch.ml.engine.helper.MLTestHelper.constructTestDataFrame;
Expand All @@ -54,7 +56,7 @@ public class MLEngineTest {
@Before
public void setUp() {
Encryptor encryptor = new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w=");
mlEngine = new MLEngine(Path.of("/tmp/test" + UUID.randomUUID()), encryptor);
MLEngine mlEngine = new MLEngine(Path.of("/tmp/test" + UUID.randomUUID()), encryptor);
}

@Test
Expand All @@ -68,6 +70,22 @@ public void testPrebuiltModelPath() {
assertEquals("https://artifacts.opensearch.org/models/ml-models/huggingface/sentence-transformers/msmarco-distilbert-base-tas-b/1.0.1/torch_script/config.json", prebuiltModelConfigPath);
}

@Test
public void testDeployModelZipPath() {
String modelId = "test_id";
String modelName = "huggingface/sentence-transformers/msmarco-distilbert-base-tas-b";
String modelZipPath = mlEngine.getDeployModelZipPath(modelId, modelName);
assertEquals(mlEngine.getMlCachePath() + "/models_cache/deploy/test_id/huggingface/sentence-transformers/msmarco-distilbert-base-tas-b.zip", modelZipPath);
}

@Test
public void testGetDeployModelChunkPath() {
String modelId = "test_id";
Integer chunkNum = 1;
Path chunkPath = mlEngine.getDeployModelChunkPath(modelId, chunkNum);
assertEquals(Path.of(mlEngine.getMlCachePath().toString() + "/models_cache/deploy/test_id/chunks/1"), chunkPath);
}

@Test
public void predictKMeans() {
MLModel model = trainKMeansModel();
Expand Down Expand Up @@ -142,6 +160,33 @@ public void train_NullInput() {
}
}

@Test
public void train_NullTrainable() {
exceptionRule.expect(IllegalArgumentException.class);
MLInput mlInput = Mockito.mock(MLInput.class);
when(mlInput.getAlgorithm()).thenReturn(FunctionName.LINEAR_REGRESSION);
when(MLEngineClassLoader.initInstance(mlInput.getAlgorithm(), mlInput.getParameters(), MLAlgoParams.class)).thenReturn(null);
mlEngine.train(mlInput);
}

@Test
public void predict_NullTrainAndPredictable() {
exceptionRule.expect(IllegalArgumentException.class);
MLInput mlInput = Mockito.mock(MLInput.class);
MLModel mlModel = Mockito.mock(MLModel.class);
when(mlInput.getAlgorithm()).thenReturn(FunctionName.LINEAR_REGRESSION);
when(MLEngineClassLoader.initInstance(mlInput.getAlgorithm(), mlInput.getParameters(), MLAlgoParams.class)).thenReturn(null);
mlEngine.predict(mlInput, mlModel);
}

@Test
public void trainAndPredict_NullTrainable() {
exceptionRule.expect(IllegalArgumentException.class);
MLInput mlInput = Mockito.mock(MLInput.class);
when(mlInput.getAlgorithm()).thenReturn(FunctionName.LINEAR_REGRESSION);
when(MLEngineClassLoader.initInstance(mlInput.getAlgorithm(), mlInput.getParameters(), MLAlgoParams.class)).thenReturn(null);
mlEngine.trainAndPredict(mlInput);
}
//TODO: fix mockito error
@Ignore
@Test
Expand Down

0 comments on commit 1d667fd

Please sign in to comment.