diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java index 85951820e4..9d4710305f 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java @@ -11,6 +11,7 @@ import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; +import org.mockito.Mock; import org.mockito.MockedStatic; import org.mockito.Mockito; import org.opensearch.core.common.io.stream.StreamOutput; @@ -41,6 +42,7 @@ import java.util.UUID; import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.when; import static org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionPredictionDataFrame; import static org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionTrainDataFrame; import static org.opensearch.ml.engine.helper.MLTestHelper.constructTestDataFrame; @@ -68,6 +70,24 @@ public void testPrebuiltModelPath() { assertEquals("https://artifacts.opensearch.org/models/ml-models/huggingface/sentence-transformers/msmarco-distilbert-base-tas-b/1.0.1/torch_script/config.json", prebuiltModelConfigPath); } + @Test + public void testGetDeployModelZipPath() { + String modelId = "test_id"; + String modelName = "huggingface/sentence-transformers/msmarco-distilbert-base-tas-b"; + String modelZipPath = mlEngine.getDeployModelZipPath(modelId, modelName); + assertEquals(mlEngine.getMlCachePath() + Path.of("/models_cache/deploy/test_id/huggingface/sentence-transformers/msmarco-distilbert-base-tas-b.zip").toString(), modelZipPath); + } + + @Test + public void testGetDeployModelChunkPath() { + String modelId = "test_id"; + for (int i = 1; i <= 10; i++) { + Integer chunkNum = i; + Path chunkPath = mlEngine.getDeployModelChunkPath(modelId, chunkNum); + assertEquals(Path.of(mlEngine.getMlCachePath().toString() + "/models_cache/deploy/test_id/chunks/" + chunkNum.toString()), chunkPath); + } + } + @Test public void predictKMeans() { MLModel model = trainKMeansModel(); @@ -142,6 +162,33 @@ public void train_NullInput() { } } + @Test + public void testTrainNullTrainable() { + exceptionRule.expect(IllegalArgumentException.class); + MLInput mlInput = Mockito.mock(MLInput.class); + when(mlInput.getAlgorithm()).thenReturn(FunctionName.LINEAR_REGRESSION); + when(MLEngineClassLoader.initInstance(mlInput.getAlgorithm(), mlInput.getParameters(), MLAlgoParams.class)).thenReturn(null); + mlEngine.train(mlInput); + } + + @Test + public void predictNullPredictable() { + exceptionRule.expect(IllegalArgumentException.class); + MLInput mlInput = Mockito.mock(MLInput.class); + MLModel mlModel = Mockito.mock(MLModel.class); + when(mlInput.getAlgorithm()).thenReturn(FunctionName.LINEAR_REGRESSION); + when(MLEngineClassLoader.initInstance(mlInput.getAlgorithm(), mlInput.getParameters(), MLAlgoParams.class)).thenReturn(null); + mlEngine.predict(mlInput, mlModel); + } + + @Test + public void trainAndPredictNullTrainable() { + exceptionRule.expect(IllegalArgumentException.class); + MLInput mlInput = Mockito.mock(MLInput.class); + when(mlInput.getAlgorithm()).thenReturn(FunctionName.LINEAR_REGRESSION); + when(MLEngineClassLoader.initInstance(mlInput.getAlgorithm(), mlInput.getParameters(), MLAlgoParams.class)).thenReturn(null); + mlEngine.trainAndPredict(mlInput); + } //TODO: fix mockito error @Ignore @Test