diff --git a/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java b/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java index 4a3f77a838..df0e4aa0e4 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java @@ -6,6 +6,7 @@ package org.opensearch.ml.action.prediction; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE; +import static org.opensearch.ml.utils.MLExceptionUtils.LOCAL_MODEL_DISABLED_ERR_MSG; import org.opensearch.OpenSearchStatusException; import org.opensearch.action.ActionRequest; @@ -29,6 +30,7 @@ import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.model.MLModelCacheHelper; import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.task.MLPredictTaskRunner; import org.opensearch.ml.task.MLTaskRunner; import org.opensearch.ml.utils.RestActionUtils; @@ -58,6 +60,8 @@ public class TransportPredictionTaskAction extends HandledTransportAction { diff --git a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java index e13ea03173..c3b3e9c587 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java @@ -11,6 +11,7 @@ import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_URL_REGEX; import static org.opensearch.ml.task.MLTaskManager.TASK_SEMAPHORE_TIMEOUT; +import static org.opensearch.ml.utils.MLExceptionUtils.LOCAL_MODEL_DISABLED_ERR_MSG; import static org.opensearch.ml.utils.MLExceptionUtils.logException; import java.time.Instant; @@ -56,6 +57,7 @@ import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.model.MLModelGroupManager; import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.stats.MLStats; import org.opensearch.ml.task.MLTaskDispatcher; import org.opensearch.ml.task.MLTaskManager; @@ -94,6 +96,7 @@ public class TransportRegisterModelAction extends HandledTransportAction listener) { MLRegisterModelRequest registerModelRequest = MLRegisterModelRequest.fromActionRequest(request); MLRegisterModelInput registerModelInput = registerModelRequest.getRegisterModelInput(); + if (FunctionName.isDLModel(registerModelInput.getFunctionName()) && !mlFeatureEnabledSetting.isLocalModelInferenceEnabled()) { + throw new IllegalStateException(LOCAL_MODEL_DISABLED_ERR_MSG); + } if (registerModelInput.getUrl() != null && !isModelUrlAllowed) { throw new IllegalArgumentException( "To upload custom model user needs to enable allow_registering_model_via_url settings. Otherwise please use OpenSearch pre-trained models." diff --git a/plugin/src/test/java/org/opensearch/ml/action/prediction/TransportPredictionTaskActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/prediction/TransportPredictionTaskActionTests.java index 461f2b834b..5aaf984441 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/prediction/TransportPredictionTaskActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/prediction/TransportPredictionTaskActionTests.java @@ -12,6 +12,7 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE; +import static org.opensearch.ml.utils.MLExceptionUtils.LOCAL_MODEL_DISABLED_ERR_MSG; import java.util.Arrays; import java.util.Collections; @@ -48,6 +49,7 @@ import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.model.MLModelCacheHelper; import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.task.MLPredictTaskRunner; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; @@ -83,6 +85,8 @@ public class TransportPredictionTaskActionTests extends OpenSearchTestCase { @Mock private ModelAccessControlHelper modelAccessControlHelper; + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; @Mock ActionFilters actionFilters; @@ -129,6 +133,7 @@ public void setup() { when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + when(mlFeatureEnabledSetting.isLocalModelInferenceEnabled()).thenReturn(true); transportPredictionTaskAction = spy( new TransportPredictionTaskAction( @@ -141,6 +146,7 @@ public void setup() { xContentRegistry, mlModelManager, modelAccessControlHelper, + mlFeatureEnabledSetting, settings ) ); @@ -168,6 +174,21 @@ public void testPrediction_default_exception() { assertEquals("Failed to Validate Access for ModelId test_id", argumentCaptor.getValue().getMessage()); } + public void testPrediction_local_model_not_exception() { + when(modelCacheHelper.getModelInfo(anyString())).thenReturn(model); + when(model.getAlgorithm()).thenReturn(FunctionName.TEXT_EMBEDDING); + when(mlFeatureEnabledSetting.isLocalModelInferenceEnabled()).thenReturn(false); + + IllegalStateException e = assertThrows( + IllegalStateException.class, + () -> transportPredictionTaskAction.doExecute(null, mlPredictionTaskRequest, actionListener) + ); + assertEquals( + e.getMessage(), + LOCAL_MODEL_DISABLED_ERR_MSG + ); + } + public void testPrediction_OpenSearchStatusException() { when(modelCacheHelper.getModelInfo(anyString())).thenReturn(model); when(model.getAlgorithm()).thenReturn(FunctionName.KMEANS); diff --git a/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java index 83ecd01069..9a619f9d10 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java @@ -17,6 +17,7 @@ import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_MODEL_URL; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_URL_REGEX; +import static org.opensearch.ml.utils.MLExceptionUtils.LOCAL_MODEL_DISABLED_ERR_MSG; import static org.opensearch.ml.utils.TestHelper.clusterSetting; import java.io.IOException; @@ -60,6 +61,7 @@ import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.model.MLModelGroupManager; import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.stats.MLNodeLevelStat; import org.opensearch.ml.stats.MLStat; import org.opensearch.ml.stats.MLStats; @@ -150,6 +152,9 @@ public class TransportRegisterModelActionTests extends OpenSearchTestCase { @Mock private ConnectorAccessControlHelper connectorAccessControlHelper; + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + @Before public void setup() throws IOException { MockitoAnnotations.openMocks(this); @@ -184,7 +189,8 @@ public void setup() throws IOException { mlStats, modelAccessControlHelper, connectorAccessControlHelper, - mlModelGroupManager + mlModelGroupManager, + mlFeatureEnabledSetting ); assertNotNull(transportRegisterModelAction); @@ -216,6 +222,8 @@ public void setup() throws IOException { return null; }).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any()); + when(mlFeatureEnabledSetting.isLocalModelInferenceEnabled()).thenReturn(true); + when(clusterService.localNode()).thenReturn(node2); when(node2.getId()).thenReturn("node2Id"); @@ -226,6 +234,42 @@ public void setup() throws IOException { when(threadPool.getThreadContext()).thenReturn(threadContext); } + public void testDoExecute_LocalModelDisabledException() { + when(mlFeatureEnabledSetting.isLocalModelInferenceEnabled()).thenReturn(false); + + MLRegisterModelInput registerModelInput = MLRegisterModelInput + .builder() + .functionName(FunctionName.TEXT_EMBEDDING) + .deployModel(true) + .modelGroupId("modelGroupID") + .modelName("Test Model") + .modelConfig( + new TextEmbeddingModelConfig( + "CUSTOM", + 123, + TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS, + "all config", + TextEmbeddingModelConfig.PoolingMode.MEAN, + true, + 512 + ) + ) + .modelFormat(MLModelFormat.TORCH_SCRIPT) + .url("http://test_url") + .build(); + + MLRegisterModelRequest mlRegisterModelRequest = new MLRegisterModelRequest(registerModelInput); + + IllegalStateException e = assertThrows( + IllegalStateException.class, + () -> transportRegisterModelAction.doExecute(task, mlRegisterModelRequest, actionListener) + ); + assertEquals( + e.getMessage(), + LOCAL_MODEL_DISABLED_ERR_MSG + ); + } + public void testDoExecute_userHasNoAccessException() { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); @@ -328,7 +372,8 @@ public void testRegisterModelUrlNotAllowed() throws Exception { mlStats, modelAccessControlHelper, connectorAccessControlHelper, - mlModelGroupManager + mlModelGroupManager, + mlFeatureEnabledSetting ); IllegalArgumentException e = assertThrows(