diff --git a/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java index 3911ed5f1f..ef04453b8d 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java @@ -11,6 +11,7 @@ import static org.opensearch.ml.plugin.MachineLearningPlugin.DEPLOY_THREAD_POOL; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN; import static org.opensearch.ml.task.MLTaskManager.TASK_SEMAPHORE_TIMEOUT; +import static org.opensearch.ml.utils.MLExceptionUtils.LOCAL_MODEL_DISABLED_ERR_MSG; import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG; import java.time.Instant; @@ -143,6 +144,8 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { diff --git a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java index e13ea03173..6c119d46d2 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java @@ -11,6 +11,7 @@ import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_URL_REGEX; import static org.opensearch.ml.task.MLTaskManager.TASK_SEMAPHORE_TIMEOUT; +import static org.opensearch.ml.utils.MLExceptionUtils.LOCAL_MODEL_DISABLED_ERR_MSG; import static org.opensearch.ml.utils.MLExceptionUtils.logException; import java.time.Instant; @@ -56,6 +57,7 @@ import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.model.MLModelGroupManager; import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.stats.MLStats; import org.opensearch.ml.task.MLTaskDispatcher; import org.opensearch.ml.task.MLTaskManager; @@ -94,6 +96,7 @@ public class TransportRegisterModelAction extends HandledTransportAction listener) { MLRegisterModelRequest registerModelRequest = MLRegisterModelRequest.fromActionRequest(request); MLRegisterModelInput registerModelInput = registerModelRequest.getRegisterModelInput(); + if (FunctionName.isDLModel(registerModelInput.getFunctionName()) && !mlFeatureEnabledSetting.isLocalModelEnabled()) { + throw new IllegalStateException(LOCAL_MODEL_DISABLED_ERR_MSG); + } if (registerModelInput.getUrl() != null && !isModelUrlAllowed) { throw new IllegalArgumentException( "To upload custom model user needs to enable allow_registering_model_via_url settings. Otherwise please use OpenSearch pre-trained models." diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 934f2da210..e25c762912 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -897,6 +897,7 @@ public List> getSettings() { MLCommonsSettings.ML_COMMONS_REMOTE_MODEL_ELIGIBLE_NODE_ROLES, MLCommonsSettings.ML_COMMONS_LOCAL_MODEL_ELIGIBLE_NODE_ROLES, MLCommonsSettings.ML_COMMONS_REMOTE_INFERENCE_ENABLED, + MLCommonsSettings.ML_COMMONS_LOCAL_MODEL_ENABLED, MLCommonsSettings.ML_COMMONS_MEMORY_FEATURE_ENABLED, MLCommonsSettings.ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED, MLCommonsSettings.ML_COMMONS_AGENT_FRAMEWORK_ENABLED, diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java index 60bfe07984..5943a2b489 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java @@ -7,6 +7,7 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; +import static org.opensearch.ml.utils.MLExceptionUtils.LOCAL_MODEL_DISABLED_ERR_MSG; import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_ALGORITHM; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID; @@ -121,6 +122,8 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client MLPredictionTaskRequest getRequest(String modelId, String algorithm, RestRequest request) throws IOException { if (FunctionName.REMOTE.name().equals(algorithm) && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) { throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG); + } else if (FunctionName.isDLModel(FunctionName.from(algorithm)) && !mlFeatureEnabledSetting.isLocalModelEnabled()) { + throw new IllegalStateException(LOCAL_MODEL_DISABLED_ERR_MSG); } XContentParser parser = request.contentParser(); ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLRegisterModelAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLRegisterModelAction.java index 631462e773..68fd73b20a 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLRegisterModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLRegisterModelAction.java @@ -7,6 +7,7 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; +import static org.opensearch.ml.utils.MLExceptionUtils.LOCAL_MODEL_DISABLED_ERR_MSG; import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_DEPLOY_MODEL; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID; @@ -98,6 +99,8 @@ MLRegisterModelRequest getRequest(RestRequest request) throws IOException { MLRegisterModelInput mlInput = MLRegisterModelInput.parse(parser, loadModel); if (mlInput.getFunctionName() == FunctionName.REMOTE && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) { throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG); + } else if (FunctionName.isDLModel(mlInput.getFunctionName()) && !mlFeatureEnabledSetting.isLocalModelEnabled()) { + throw new IllegalStateException(LOCAL_MODEL_DISABLED_ERR_MSG); } return new MLRegisterModelRequest(mlInput); } diff --git a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java index f96eb37e02..7dfc1c98ac 100644 --- a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java +++ b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java @@ -117,6 +117,9 @@ private MLCommonsSettings() {} public static final Setting ML_COMMONS_REMOTE_INFERENCE_ENABLED = Setting .boolSetting("plugins.ml_commons.remote_inference.enabled", true, Setting.Property.NodeScope, Setting.Property.Dynamic); + public static final Setting ML_COMMONS_LOCAL_MODEL_ENABLED = Setting + .boolSetting("plugins.ml_commons.local_model.enabled", true, Setting.Property.NodeScope, Setting.Property.Dynamic); + public static final Setting ML_COMMONS_MODEL_ACCESS_CONTROL_ENABLED = Setting .boolSetting("plugins.ml_commons.model_access_control_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic); diff --git a/plugin/src/main/java/org/opensearch/ml/settings/MLFeatureEnabledSetting.java b/plugin/src/main/java/org/opensearch/ml/settings/MLFeatureEnabledSetting.java index ad3337aa0b..f636f33722 100644 --- a/plugin/src/main/java/org/opensearch/ml/settings/MLFeatureEnabledSetting.java +++ b/plugin/src/main/java/org/opensearch/ml/settings/MLFeatureEnabledSetting.java @@ -8,6 +8,7 @@ package org.opensearch.ml.settings; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_AGENT_FRAMEWORK_ENABLED; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_LOCAL_MODEL_ENABLED; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_INFERENCE_ENABLED; import org.opensearch.cluster.service.ClusterService; @@ -18,9 +19,12 @@ public class MLFeatureEnabledSetting { private volatile Boolean isRemoteInferenceEnabled; private volatile Boolean isAgentFrameworkEnabled; + private volatile Boolean isLocalModelEnabled; + public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings) { isRemoteInferenceEnabled = ML_COMMONS_REMOTE_INFERENCE_ENABLED.get(settings); isAgentFrameworkEnabled = ML_COMMONS_AGENT_FRAMEWORK_ENABLED.get(settings); + isLocalModelEnabled = ML_COMMONS_LOCAL_MODEL_ENABLED.get(settings); clusterService .getClusterSettings() @@ -28,6 +32,7 @@ public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings) clusterService .getClusterSettings() .addSettingsUpdateConsumer(ML_COMMONS_AGENT_FRAMEWORK_ENABLED, it -> isAgentFrameworkEnabled = it); + clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_LOCAL_MODEL_ENABLED, it -> isLocalModelEnabled = it); } /** @@ -46,4 +51,12 @@ public boolean isAgentFrameworkEnabled() { return isAgentFrameworkEnabled; } + /** + * Whether the local model feature is enabled. If disabled, APIs in ml-commons will block local model inference. + * @return whether the local inference is enabled. + */ + public boolean isLocalModelEnabled() { + return isLocalModelEnabled; + } + } diff --git a/plugin/src/main/java/org/opensearch/ml/utils/MLExceptionUtils.java b/plugin/src/main/java/org/opensearch/ml/utils/MLExceptionUtils.java index dc29656f7c..6838a9ff79 100644 --- a/plugin/src/main/java/org/opensearch/ml/utils/MLExceptionUtils.java +++ b/plugin/src/main/java/org/opensearch/ml/utils/MLExceptionUtils.java @@ -20,6 +20,8 @@ public class MLExceptionUtils { public static final String NOT_SERIALIZABLE_EXCEPTION_WRAPPER = "NotSerializableExceptionWrapper: "; public static final String REMOTE_INFERENCE_DISABLED_ERR_MSG = "Remote Inference is currently disabled. To enable it, update the setting \"plugins.ml_commons.remote_inference_enabled\" to true."; + public static final String LOCAL_MODEL_DISABLED_ERR_MSG = + "Local Model is currently disabled. To enable it, update the setting \"plugins.ml_commons.local_model.enabled\" to true."; public static final String AGENT_FRAMEWORK_DISABLED_ERR_MSG = "Agent Framework is currently disabled. To enable it, update the setting \"plugins.ml_commons.agent_framework_enabled\" to true."; diff --git a/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelActionTests.java index e544f34e40..24c40a950e 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelActionTests.java @@ -22,6 +22,7 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN; +import static org.opensearch.ml.utils.MLExceptionUtils.LOCAL_MODEL_DISABLED_ERR_MSG; import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG; import java.lang.reflect.Field; @@ -176,6 +177,7 @@ public void setup() { when(mlDeployModelRequest.isUserInitiatedDeployRequest()).thenReturn(true); when(mlFeatureEnabledSetting.isRemoteInferenceEnabled()).thenReturn(true); + when(mlFeatureEnabledSetting.isLocalModelEnabled()).thenReturn(true); MLStat mlStat = mock(MLStat.class); when(mlStats.getStat(eq(MLNodeLevelStat.ML_REQUEST_COUNT))).thenReturn(mlStat); @@ -374,6 +376,23 @@ public void testDoExecuteRemoteInferenceDisabled() { assertEquals(REMOTE_INFERENCE_DISABLED_ERR_MSG, argumentCaptor.getValue().getMessage()); } + public void testDoExecuteLocalInferenceDisabled() { + MLModel mlModel = mock(MLModel.class); + when(mlModel.getAlgorithm()).thenReturn(FunctionName.TEXT_EMBEDDING); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(mlModel); + return null; + }).when(mlModelManager).getModel(anyString(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); + + when(mlFeatureEnabledSetting.isLocalModelEnabled()).thenReturn(false); + ActionListener deployModelResponseListener = mock(ActionListener.class); + transportDeployModelAction.doExecute(mock(Task.class), mlDeployModelRequest, deployModelResponseListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(IllegalStateException.class); + verify(deployModelResponseListener).onFailure(argumentCaptor.capture()); + assertEquals(LOCAL_MODEL_DISABLED_ERR_MSG, argumentCaptor.getValue().getMessage()); + } + public void test_ValidationFailedException() { MLModel mlModel = mock(MLModel.class); when(mlModel.getAlgorithm()).thenReturn(FunctionName.ANOMALY_LOCALIZATION); diff --git a/plugin/src/test/java/org/opensearch/ml/action/prediction/TransportPredictionTaskActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/prediction/TransportPredictionTaskActionTests.java index 461f2b834b..a1832dcd62 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/prediction/TransportPredictionTaskActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/prediction/TransportPredictionTaskActionTests.java @@ -12,6 +12,7 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE; +import static org.opensearch.ml.utils.MLExceptionUtils.LOCAL_MODEL_DISABLED_ERR_MSG; import java.util.Arrays; import java.util.Collections; @@ -48,6 +49,7 @@ import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.model.MLModelCacheHelper; import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.task.MLPredictTaskRunner; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; @@ -83,6 +85,8 @@ public class TransportPredictionTaskActionTests extends OpenSearchTestCase { @Mock private ModelAccessControlHelper modelAccessControlHelper; + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; @Mock ActionFilters actionFilters; @@ -129,6 +133,7 @@ public void setup() { when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + when(mlFeatureEnabledSetting.isLocalModelEnabled()).thenReturn(true); transportPredictionTaskAction = spy( new TransportPredictionTaskAction( @@ -141,6 +146,7 @@ public void setup() { xContentRegistry, mlModelManager, modelAccessControlHelper, + mlFeatureEnabledSetting, settings ) ); @@ -168,6 +174,21 @@ public void testPrediction_default_exception() { assertEquals("Failed to Validate Access for ModelId test_id", argumentCaptor.getValue().getMessage()); } + public void testPrediction_local_model_not_exception() { + when(modelCacheHelper.getModelInfo(anyString())).thenReturn(model); + when(model.getAlgorithm()).thenReturn(FunctionName.TEXT_EMBEDDING); + when(mlFeatureEnabledSetting.isLocalModelEnabled()).thenReturn(false); + + IllegalStateException e = assertThrows( + IllegalStateException.class, + () -> transportPredictionTaskAction.doExecute(null, mlPredictionTaskRequest, actionListener) + ); + assertEquals( + e.getMessage(), + LOCAL_MODEL_DISABLED_ERR_MSG + ); + } + public void testPrediction_OpenSearchStatusException() { when(modelCacheHelper.getModelInfo(anyString())).thenReturn(model); when(model.getAlgorithm()).thenReturn(FunctionName.KMEANS); diff --git a/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java index 83ecd01069..d30ef15a5a 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java @@ -17,6 +17,7 @@ import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_MODEL_URL; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_URL_REGEX; +import static org.opensearch.ml.utils.MLExceptionUtils.LOCAL_MODEL_DISABLED_ERR_MSG; import static org.opensearch.ml.utils.TestHelper.clusterSetting; import java.io.IOException; @@ -60,6 +61,7 @@ import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.model.MLModelGroupManager; import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.stats.MLNodeLevelStat; import org.opensearch.ml.stats.MLStat; import org.opensearch.ml.stats.MLStats; @@ -150,6 +152,9 @@ public class TransportRegisterModelActionTests extends OpenSearchTestCase { @Mock private ConnectorAccessControlHelper connectorAccessControlHelper; + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + @Before public void setup() throws IOException { MockitoAnnotations.openMocks(this); @@ -184,7 +189,8 @@ public void setup() throws IOException { mlStats, modelAccessControlHelper, connectorAccessControlHelper, - mlModelGroupManager + mlModelGroupManager, + mlFeatureEnabledSetting ); assertNotNull(transportRegisterModelAction); @@ -216,6 +222,8 @@ public void setup() throws IOException { return null; }).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any()); + when(mlFeatureEnabledSetting.isLocalModelEnabled()).thenReturn(true); + when(clusterService.localNode()).thenReturn(node2); when(node2.getId()).thenReturn("node2Id"); @@ -226,6 +234,42 @@ public void setup() throws IOException { when(threadPool.getThreadContext()).thenReturn(threadContext); } + public void testDoExecute_LocalModelDisabledException() { + when(mlFeatureEnabledSetting.isLocalModelEnabled()).thenReturn(false); + + MLRegisterModelInput registerModelInput = MLRegisterModelInput + .builder() + .functionName(FunctionName.TEXT_EMBEDDING) + .deployModel(true) + .modelGroupId("modelGroupID") + .modelName("Test Model") + .modelConfig( + new TextEmbeddingModelConfig( + "CUSTOM", + 123, + TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS, + "all config", + TextEmbeddingModelConfig.PoolingMode.MEAN, + true, + 512 + ) + ) + .modelFormat(MLModelFormat.TORCH_SCRIPT) + .url("http://test_url") + .build(); + + MLRegisterModelRequest mlRegisterModelRequest = new MLRegisterModelRequest(registerModelInput); + + IllegalStateException e = assertThrows( + IllegalStateException.class, + () -> transportRegisterModelAction.doExecute(task, mlRegisterModelRequest, actionListener) + ); + assertEquals( + e.getMessage(), + LOCAL_MODEL_DISABLED_ERR_MSG + ); + } + public void testDoExecute_userHasNoAccessException() { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); @@ -328,7 +372,8 @@ public void testRegisterModelUrlNotAllowed() throws Exception { mlStats, modelAccessControlHelper, connectorAccessControlHelper, - mlModelGroupManager + mlModelGroupManager, + mlFeatureEnabledSetting ); IllegalArgumentException e = assertThrows( diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLPredictionActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLPredictionActionTests.java index ceeda75277..d34e0fd00e 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLPredictionActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLPredictionActionTests.java @@ -8,6 +8,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.*; +import static org.opensearch.ml.utils.MLExceptionUtils.LOCAL_MODEL_DISABLED_ERR_MSG; import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID; import static org.opensearch.ml.utils.TestHelper.getKMeansRestRequest; @@ -67,6 +68,7 @@ public void setup() { MockitoAnnotations.openMocks(this); when(modelManager.getOptionalModelFunctionName(anyString())).thenReturn(Optional.empty()); when(mlFeatureEnabledSetting.isRemoteInferenceEnabled()).thenReturn(true); + when(mlFeatureEnabledSetting.isLocalModelEnabled()).thenReturn(true); restMLPredictionAction = new RestMLPredictionAction(modelManager, mlFeatureEnabledSetting); threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); @@ -122,6 +124,16 @@ public void testGetRequest_RemoteInferenceDisabled() throws IOException { MLPredictionTaskRequest mlPredictionTaskRequest = restMLPredictionAction.getRequest("modelId", FunctionName.REMOTE.name(), request); } + public void testGetRequest_LocalModelInferenceDisabled() throws IOException { + thrown.expect(IllegalStateException.class); + thrown.expectMessage(LOCAL_MODEL_DISABLED_ERR_MSG); + + when(mlFeatureEnabledSetting.isLocalModelEnabled()).thenReturn(false); + RestRequest request = getRestRequest_PredictModel(); + MLPredictionTaskRequest mlPredictionTaskRequest = restMLPredictionAction + .getRequest("modelId", FunctionName.TEXT_EMBEDDING.name(), request); + } + public void testPrepareRequest() throws Exception { RestRequest request = getRestRequest_PredictModel(); restMLPredictionAction.handleRequest(request, channel, client); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRegisterModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRegisterModelActionTests.java index 2ade96cab2..70f2d642ae 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRegisterModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRegisterModelActionTests.java @@ -13,6 +13,7 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_MODEL_URL; +import static org.opensearch.ml.utils.MLExceptionUtils.LOCAL_MODEL_DISABLED_ERR_MSG; import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG; import static org.opensearch.ml.utils.TestHelper.clusterSetting; @@ -79,6 +80,7 @@ public void setup() { when(clusterService.getClusterSettings()).thenReturn(clusterSettings); when(clusterService.getSettings()).thenReturn(settings); when(mlFeatureEnabledSetting.isRemoteInferenceEnabled()).thenReturn(true); + when(mlFeatureEnabledSetting.isLocalModelEnabled()).thenReturn(true); restMLRegisterModelAction = new RestMLRegisterModelAction(clusterService, settings, mlFeatureEnabledSetting); threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); client = spy(new NodeClient(Settings.EMPTY, threadPool)); @@ -159,7 +161,16 @@ public void testRegisterModelRequestRemoteInferenceDisabled() throws Exception { exceptionRule.expectMessage(REMOTE_INFERENCE_DISABLED_ERR_MSG); when(mlFeatureEnabledSetting.isRemoteInferenceEnabled()).thenReturn(false); - RestRequest request = getRestRequestWithNullModelId(); + RestRequest request = getRestRequestWithNullModelId(FunctionName.REMOTE); + restMLRegisterModelAction.handleRequest(request, channel, client); + } + + public void testRegisterModelRequestLocalInferenceDisabled() throws Exception { + exceptionRule.expect(IllegalStateException.class); + exceptionRule.expectMessage(LOCAL_MODEL_DISABLED_ERR_MSG); + + when(mlFeatureEnabledSetting.isLocalModelEnabled()).thenReturn(false); + RestRequest request = getRestRequestWithNullModelId(FunctionName.TEXT_EMBEDDING); restMLRegisterModelAction.handleRequest(request, channel, client); } @@ -189,7 +200,7 @@ public void testRegisterModelRequestWithNullUrl() throws Exception { } public void testRegisterModelRequestWithNullModelID() throws Exception { - RestRequest request = getRestRequestWithNullModelId(); + RestRequest request = getRestRequestWithNullModelId(FunctionName.REMOTE); restMLRegisterModelAction.handleRequest(request, channel, client); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLRegisterModelRequest.class); verify(client, times(1)).execute(eq(MLRegisterModelAction.INSTANCE), argumentCaptor.capture(), any()); @@ -271,7 +282,7 @@ private RestRequest getRestRequestAsymmetricModel() { return request; } - private RestRequest getRestRequestWithNullModelId() { + private RestRequest getRestRequestWithNullModelId(FunctionName functionName) { RestRequest.Method method = RestRequest.Method.POST; final Map modelConfig = Map .of("model_type", "bert", "embedding_dimension", 384, "framework_type", "sentence_transformers", "all_config", "All Config"); @@ -290,7 +301,7 @@ private RestRequest getRestRequestWithNullModelId() { "model_config", modelConfig, "function_name", - FunctionName.REMOTE + functionName ); String requestContent = new Gson().toJson(model).toString(); RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY)