From 1d667fd4a6ba741afd09106d9d9100db956b045e Mon Sep 17 00:00:00 2001 From: Divit Rawal Date: Sun, 22 Oct 2023 18:15:53 -0700 Subject: [PATCH 1/4] updated testcases for MLEngine.java Signed-off-by: Divit Rawal --- .../opensearch/ml/engine/MLEngineTest.java | 47 ++++++++++++++++++- 1 file changed, 46 insertions(+), 1 deletion(-) diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java index 85951820e4..5e62502bb5 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java @@ -11,6 +11,7 @@ import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; +import org.mockito.Mock; import org.mockito.MockedStatic; import org.mockito.Mockito; import org.opensearch.core.common.io.stream.StreamOutput; @@ -41,6 +42,7 @@ import java.util.UUID; import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.when; import static org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionPredictionDataFrame; import static org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionTrainDataFrame; import static org.opensearch.ml.engine.helper.MLTestHelper.constructTestDataFrame; @@ -54,7 +56,7 @@ public class MLEngineTest { @Before public void setUp() { Encryptor encryptor = new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="); - mlEngine = new MLEngine(Path.of("/tmp/test" + UUID.randomUUID()), encryptor); + MLEngine mlEngine = new MLEngine(Path.of("/tmp/test" + UUID.randomUUID()), encryptor); } @Test @@ -68,6 +70,22 @@ public void testPrebuiltModelPath() { assertEquals("https://artifacts.opensearch.org/models/ml-models/huggingface/sentence-transformers/msmarco-distilbert-base-tas-b/1.0.1/torch_script/config.json", prebuiltModelConfigPath); } + @Test + public void testDeployModelZipPath() { + String modelId = "test_id"; + String modelName = "huggingface/sentence-transformers/msmarco-distilbert-base-tas-b"; + String modelZipPath = mlEngine.getDeployModelZipPath(modelId, modelName); + assertEquals(mlEngine.getMlCachePath() + "/models_cache/deploy/test_id/huggingface/sentence-transformers/msmarco-distilbert-base-tas-b.zip", modelZipPath); + } + + @Test + public void testGetDeployModelChunkPath() { + String modelId = "test_id"; + Integer chunkNum = 1; + Path chunkPath = mlEngine.getDeployModelChunkPath(modelId, chunkNum); + assertEquals(Path.of(mlEngine.getMlCachePath().toString() + "/models_cache/deploy/test_id/chunks/1"), chunkPath); + } + @Test public void predictKMeans() { MLModel model = trainKMeansModel(); @@ -142,6 +160,33 @@ public void train_NullInput() { } } + @Test + public void train_NullTrainable() { + exceptionRule.expect(IllegalArgumentException.class); + MLInput mlInput = Mockito.mock(MLInput.class); + when(mlInput.getAlgorithm()).thenReturn(FunctionName.LINEAR_REGRESSION); + when(MLEngineClassLoader.initInstance(mlInput.getAlgorithm(), mlInput.getParameters(), MLAlgoParams.class)).thenReturn(null); + mlEngine.train(mlInput); + } + + @Test + public void predict_NullTrainAndPredictable() { + exceptionRule.expect(IllegalArgumentException.class); + MLInput mlInput = Mockito.mock(MLInput.class); + MLModel mlModel = Mockito.mock(MLModel.class); + when(mlInput.getAlgorithm()).thenReturn(FunctionName.LINEAR_REGRESSION); + when(MLEngineClassLoader.initInstance(mlInput.getAlgorithm(), mlInput.getParameters(), MLAlgoParams.class)).thenReturn(null); + mlEngine.predict(mlInput, mlModel); + } + + @Test + public void trainAndPredict_NullTrainable() { + exceptionRule.expect(IllegalArgumentException.class); + MLInput mlInput = Mockito.mock(MLInput.class); + when(mlInput.getAlgorithm()).thenReturn(FunctionName.LINEAR_REGRESSION); + when(MLEngineClassLoader.initInstance(mlInput.getAlgorithm(), mlInput.getParameters(), MLAlgoParams.class)).thenReturn(null); + mlEngine.trainAndPredict(mlInput); + } //TODO: fix mockito error @Ignore @Test From 73b85458386cc128be84b69e5586f21c6eef6407 Mon Sep 17 00:00:00 2001 From: Divit Rawal Date: Mon, 23 Oct 2023 16:13:31 -0700 Subject: [PATCH 2/4] updated setUp Signed-off-by: Divit Rawal --- .../src/test/java/org/opensearch/ml/engine/MLEngineTest.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java index 5e62502bb5..04bca84bdb 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java @@ -56,7 +56,7 @@ public class MLEngineTest { @Before public void setUp() { Encryptor encryptor = new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="); - MLEngine mlEngine = new MLEngine(Path.of("/tmp/test" + UUID.randomUUID()), encryptor); + mlEngine = new MLEngine(Path.of("/tmp/test" + UUID.randomUUID()), encryptor); } @Test From d2b7758567819d2cf5dafc62ddddb97752576b45 Mon Sep 17 00:00:00 2001 From: Divit Rawal Date: Sun, 29 Oct 2023 13:50:12 -0700 Subject: [PATCH 3/4] updated chunkNum to loop Signed-off-by: Divit Rawal --- .../org/opensearch/ml/engine/MLEngineTest.java | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java index 04bca84bdb..a15c7796d9 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java @@ -71,19 +71,21 @@ public void testPrebuiltModelPath() { } @Test - public void testDeployModelZipPath() { + public void testGetDeployModelZipPath() { String modelId = "test_id"; String modelName = "huggingface/sentence-transformers/msmarco-distilbert-base-tas-b"; String modelZipPath = mlEngine.getDeployModelZipPath(modelId, modelName); - assertEquals(mlEngine.getMlCachePath() + "/models_cache/deploy/test_id/huggingface/sentence-transformers/msmarco-distilbert-base-tas-b.zip", modelZipPath); + assertEquals(mlEngine.getMlCachePath() + Path.of("/models_cache/deploy/test_id/huggingface/sentence-transformers/msmarco-distilbert-base-tas-b.zip").toString(), modelZipPath); } @Test public void testGetDeployModelChunkPath() { String modelId = "test_id"; - Integer chunkNum = 1; - Path chunkPath = mlEngine.getDeployModelChunkPath(modelId, chunkNum); - assertEquals(Path.of(mlEngine.getMlCachePath().toString() + "/models_cache/deploy/test_id/chunks/1"), chunkPath); + for (int i = 1; i <= 10; i++) { + Integer chunkNum = i; + Path chunkPath = mlEngine.getDeployModelChunkPath(modelId, chunkNum); + assertEquals(Path.of(mlEngine.getMlCachePath().toString() + "/models_cache/deploy/test_id/chunks/" + chunkNum.toString()), chunkPath); + } } @Test @@ -170,7 +172,7 @@ public void train_NullTrainable() { } @Test - public void predict_NullTrainAndPredictable() { + public void predictNullPredictable() { exceptionRule.expect(IllegalArgumentException.class); MLInput mlInput = Mockito.mock(MLInput.class); MLModel mlModel = Mockito.mock(MLModel.class); @@ -180,7 +182,7 @@ public void predict_NullTrainAndPredictable() { } @Test - public void trainAndPredict_NullTrainable() { + public void trainAndPredictNullTrainable() { exceptionRule.expect(IllegalArgumentException.class); MLInput mlInput = Mockito.mock(MLInput.class); when(mlInput.getAlgorithm()).thenReturn(FunctionName.LINEAR_REGRESSION); From ec647c5add353e509f66b311248127c0ebcfa282 Mon Sep 17 00:00:00 2001 From: Divit Rawal Date: Sun, 29 Oct 2023 13:53:00 -0700 Subject: [PATCH 4/4] updated testcase name Signed-off-by: Divit Rawal --- .../src/test/java/org/opensearch/ml/engine/MLEngineTest.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java index a15c7796d9..9d4710305f 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java @@ -163,7 +163,7 @@ public void train_NullInput() { } @Test - public void train_NullTrainable() { + public void testTrainNullTrainable() { exceptionRule.expect(IllegalArgumentException.class); MLInput mlInput = Mockito.mock(MLInput.class); when(mlInput.getAlgorithm()).thenReturn(FunctionName.LINEAR_REGRESSION);