Skip to content

Commit

Permalink
add local model enabling/disabling setting (#2232)
Browse files Browse the repository at this point in the history
add local model enabling setting

Signed-off-by: Bhavana Ramaram <[email protected]>
(cherry picked from commit b27673b)
  • Loading branch information
rbhavna authored and github-actions[bot] committed Mar 20, 2024
1 parent bda2995 commit 083e84a
Show file tree
Hide file tree
Showing 14 changed files with 160 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import static org.opensearch.ml.plugin.MachineLearningPlugin.DEPLOY_THREAD_POOL;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN;
import static org.opensearch.ml.task.MLTaskManager.TASK_SEMAPHORE_TIMEOUT;
import static org.opensearch.ml.utils.MLExceptionUtils.LOCAL_MODEL_DISABLED_ERR_MSG;
import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG;

import java.time.Instant;
Expand Down Expand Up @@ -143,6 +144,8 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLDepl
Boolean isHidden = mlModel.getIsHidden();
if (functionName == FunctionName.REMOTE && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) {
throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG);
} else if (FunctionName.isDLModel(functionName) && !mlFeatureEnabledSetting.isLocalModelEnabled()) {
throw new IllegalStateException(LOCAL_MODEL_DISABLED_ERR_MSG);
}
if (!isUserInitiatedDeployRequest) {
deployModel(deployModelRequest, mlModel, modelId, wrappedListener, listener);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.action.prediction;

import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE;
import static org.opensearch.ml.utils.MLExceptionUtils.LOCAL_MODEL_DISABLED_ERR_MSG;

import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.ActionRequest;
Expand All @@ -29,6 +30,7 @@
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.ml.model.MLModelCacheHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.task.MLPredictTaskRunner;
import org.opensearch.ml.task.MLTaskRunner;
import org.opensearch.ml.utils.RestActionUtils;
Expand Down Expand Up @@ -58,6 +60,8 @@ public class TransportPredictionTaskAction extends HandledTransportAction<Action

private volatile boolean enableAutomaticDeployment;

private MLFeatureEnabledSetting mlFeatureEnabledSetting;

@Inject
public TransportPredictionTaskAction(
TransportService transportService,
Expand All @@ -69,6 +73,7 @@ public TransportPredictionTaskAction(
NamedXContentRegistry xContentRegistry,
MLModelManager mlModelManager,
ModelAccessControlHelper modelAccessControlHelper,
MLFeatureEnabledSetting mlFeatureEnabledSetting,
Settings settings
) {
super(MLPredictionTaskAction.NAME, transportService, actionFilters, MLPredictionTaskRequest::new);
Expand All @@ -80,6 +85,7 @@ public TransportPredictionTaskAction(
this.xContentRegistry = xContentRegistry;
this.mlModelManager = mlModelManager;
this.modelAccessControlHelper = modelAccessControlHelper;
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
enableAutomaticDeployment = ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE.get(settings);
clusterService
.getClusterSettings()
Expand Down Expand Up @@ -107,6 +113,9 @@ public void onResponse(MLModel mlModel) {
context.restore();
modelCacheHelper.setModelInfo(modelId, mlModel);
FunctionName functionName = mlModel.getAlgorithm();
if (FunctionName.isDLModel(functionName) && !mlFeatureEnabledSetting.isLocalModelEnabled()) {
throw new IllegalStateException(LOCAL_MODEL_DISABLED_ERR_MSG);
}
mlPredictionTaskRequest.getMlInput().setAlgorithm(functionName);
modelAccessControlHelper
.validateModelGroupAccess(userInfo, mlModel.getModelGroupId(), client, ActionListener.wrap(access -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_URL_REGEX;
import static org.opensearch.ml.task.MLTaskManager.TASK_SEMAPHORE_TIMEOUT;
import static org.opensearch.ml.utils.MLExceptionUtils.LOCAL_MODEL_DISABLED_ERR_MSG;
import static org.opensearch.ml.utils.MLExceptionUtils.logException;

import java.time.Instant;
Expand Down Expand Up @@ -56,6 +57,7 @@
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.ml.model.MLModelGroupManager;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.stats.MLStats;
import org.opensearch.ml.task.MLTaskDispatcher;
import org.opensearch.ml.task.MLTaskManager;
Expand Down Expand Up @@ -94,6 +96,7 @@ public class TransportRegisterModelAction extends HandledTransportAction<ActionR

ConnectorAccessControlHelper connectorAccessControlHelper;
MLModelGroupManager mlModelGroupManager;
private MLFeatureEnabledSetting mlFeatureEnabledSetting;

@Inject
public TransportRegisterModelAction(
Expand All @@ -112,7 +115,8 @@ public TransportRegisterModelAction(
MLStats mlStats,
ModelAccessControlHelper modelAccessControlHelper,
ConnectorAccessControlHelper connectorAccessControlHelper,
MLModelGroupManager mlModelGroupManager
MLModelGroupManager mlModelGroupManager,
MLFeatureEnabledSetting mlFeatureEnabledSetting
) {
super(MLRegisterModelAction.NAME, transportService, actionFilters, MLRegisterModelRequest::new);
this.transportService = transportService;
Expand All @@ -129,6 +133,7 @@ public TransportRegisterModelAction(
this.modelAccessControlHelper = modelAccessControlHelper;
this.connectorAccessControlHelper = connectorAccessControlHelper;
this.mlModelGroupManager = mlModelGroupManager;
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
this.settings = settings;

trustedUrlRegex = ML_COMMONS_TRUSTED_URL_REGEX.get(settings);
Expand All @@ -147,6 +152,9 @@ public TransportRegisterModelAction(
protected void doExecute(Task task, ActionRequest request, ActionListener<MLRegisterModelResponse> listener) {
MLRegisterModelRequest registerModelRequest = MLRegisterModelRequest.fromActionRequest(request);
MLRegisterModelInput registerModelInput = registerModelRequest.getRegisterModelInput();
if (FunctionName.isDLModel(registerModelInput.getFunctionName()) && !mlFeatureEnabledSetting.isLocalModelEnabled()) {
throw new IllegalStateException(LOCAL_MODEL_DISABLED_ERR_MSG);
}
if (registerModelInput.getUrl() != null && !isModelUrlAllowed) {
throw new IllegalArgumentException(
"To upload custom model user needs to enable allow_registering_model_via_url settings. Otherwise please use OpenSearch pre-trained models."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -898,6 +898,7 @@ public List<Setting<?>> getSettings() {
MLCommonsSettings.ML_COMMONS_REMOTE_MODEL_ELIGIBLE_NODE_ROLES,
MLCommonsSettings.ML_COMMONS_LOCAL_MODEL_ELIGIBLE_NODE_ROLES,
MLCommonsSettings.ML_COMMONS_REMOTE_INFERENCE_ENABLED,
MLCommonsSettings.ML_COMMONS_LOCAL_MODEL_ENABLED,
MLCommonsSettings.ML_COMMONS_MEMORY_FEATURE_ENABLED,
MLCommonsSettings.ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED,
MLCommonsSettings.ML_COMMONS_AGENT_FRAMEWORK_ENABLED,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI;
import static org.opensearch.ml.utils.MLExceptionUtils.LOCAL_MODEL_DISABLED_ERR_MSG;
import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG;
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_ALGORITHM;
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID;
Expand Down Expand Up @@ -121,6 +122,8 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client
MLPredictionTaskRequest getRequest(String modelId, String algorithm, RestRequest request) throws IOException {
if (FunctionName.REMOTE.name().equals(algorithm) && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) {
throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG);
} else if (FunctionName.isDLModel(FunctionName.from(algorithm)) && !mlFeatureEnabledSetting.isLocalModelEnabled()) {
throw new IllegalStateException(LOCAL_MODEL_DISABLED_ERR_MSG);
}
XContentParser parser = request.contentParser();
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI;
import static org.opensearch.ml.utils.MLExceptionUtils.LOCAL_MODEL_DISABLED_ERR_MSG;
import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG;
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_DEPLOY_MODEL;
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID;
Expand Down Expand Up @@ -98,6 +99,8 @@ MLRegisterModelRequest getRequest(RestRequest request) throws IOException {
MLRegisterModelInput mlInput = MLRegisterModelInput.parse(parser, loadModel);
if (mlInput.getFunctionName() == FunctionName.REMOTE && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) {
throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG);
} else if (FunctionName.isDLModel(mlInput.getFunctionName()) && !mlFeatureEnabledSetting.isLocalModelEnabled()) {
throw new IllegalStateException(LOCAL_MODEL_DISABLED_ERR_MSG);
}
return new MLRegisterModelRequest(mlInput);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,9 @@ private MLCommonsSettings() {}
public static final Setting<Boolean> ML_COMMONS_REMOTE_INFERENCE_ENABLED = Setting
.boolSetting("plugins.ml_commons.remote_inference.enabled", true, Setting.Property.NodeScope, Setting.Property.Dynamic);

public static final Setting<Boolean> ML_COMMONS_LOCAL_MODEL_ENABLED = Setting
.boolSetting("plugins.ml_commons.local_model.enabled", true, Setting.Property.NodeScope, Setting.Property.Dynamic);

public static final Setting<Boolean> ML_COMMONS_MODEL_ACCESS_CONTROL_ENABLED = Setting
.boolSetting("plugins.ml_commons.model_access_control_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
package org.opensearch.ml.settings;

import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_AGENT_FRAMEWORK_ENABLED;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_LOCAL_MODEL_ENABLED;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_INFERENCE_ENABLED;

import org.opensearch.cluster.service.ClusterService;
Expand All @@ -18,16 +19,20 @@ public class MLFeatureEnabledSetting {
private volatile Boolean isRemoteInferenceEnabled;
private volatile Boolean isAgentFrameworkEnabled;

private volatile Boolean isLocalModelEnabled;

public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings) {
isRemoteInferenceEnabled = ML_COMMONS_REMOTE_INFERENCE_ENABLED.get(settings);
isAgentFrameworkEnabled = ML_COMMONS_AGENT_FRAMEWORK_ENABLED.get(settings);
isLocalModelEnabled = ML_COMMONS_LOCAL_MODEL_ENABLED.get(settings);

clusterService
.getClusterSettings()
.addSettingsUpdateConsumer(ML_COMMONS_REMOTE_INFERENCE_ENABLED, it -> isRemoteInferenceEnabled = it);
clusterService
.getClusterSettings()
.addSettingsUpdateConsumer(ML_COMMONS_AGENT_FRAMEWORK_ENABLED, it -> isAgentFrameworkEnabled = it);
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_LOCAL_MODEL_ENABLED, it -> isLocalModelEnabled = it);
}

/**
Expand All @@ -46,4 +51,12 @@ public boolean isAgentFrameworkEnabled() {
return isAgentFrameworkEnabled;
}

/**
* Whether the local model feature is enabled. If disabled, APIs in ml-commons will block local model inference.
* @return whether the local inference is enabled.
*/
public boolean isLocalModelEnabled() {
return isLocalModelEnabled;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ public class MLExceptionUtils {
public static final String NOT_SERIALIZABLE_EXCEPTION_WRAPPER = "NotSerializableExceptionWrapper: ";
public static final String REMOTE_INFERENCE_DISABLED_ERR_MSG =
"Remote Inference is currently disabled. To enable it, update the setting \"plugins.ml_commons.remote_inference_enabled\" to true.";
public static final String LOCAL_MODEL_DISABLED_ERR_MSG =
"Local Model is currently disabled. To enable it, update the setting \"plugins.ml_commons.local_model.enabled\" to true.";
public static final String AGENT_FRAMEWORK_DISABLED_ERR_MSG =
"Agent Framework is currently disabled. To enable it, update the setting \"plugins.ml_commons.agent_framework_enabled\" to true.";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN;
import static org.opensearch.ml.utils.MLExceptionUtils.LOCAL_MODEL_DISABLED_ERR_MSG;
import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG;

import java.lang.reflect.Field;
Expand Down Expand Up @@ -176,6 +177,7 @@ public void setup() {
when(mlDeployModelRequest.isUserInitiatedDeployRequest()).thenReturn(true);

when(mlFeatureEnabledSetting.isRemoteInferenceEnabled()).thenReturn(true);
when(mlFeatureEnabledSetting.isLocalModelEnabled()).thenReturn(true);

MLStat mlStat = mock(MLStat.class);
when(mlStats.getStat(eq(MLNodeLevelStat.ML_REQUEST_COUNT))).thenReturn(mlStat);
Expand Down Expand Up @@ -374,6 +376,23 @@ public void testDoExecuteRemoteInferenceDisabled() {
assertEquals(REMOTE_INFERENCE_DISABLED_ERR_MSG, argumentCaptor.getValue().getMessage());
}

public void testDoExecuteLocalInferenceDisabled() {
MLModel mlModel = mock(MLModel.class);
when(mlModel.getAlgorithm()).thenReturn(FunctionName.TEXT_EMBEDDING);
doAnswer(invocation -> {
ActionListener<MLModel> listener = invocation.getArgument(3);
listener.onResponse(mlModel);
return null;
}).when(mlModelManager).getModel(anyString(), isNull(), any(String[].class), Mockito.isA(ActionListener.class));

when(mlFeatureEnabledSetting.isLocalModelEnabled()).thenReturn(false);
ActionListener<MLDeployModelResponse> deployModelResponseListener = mock(ActionListener.class);
transportDeployModelAction.doExecute(mock(Task.class), mlDeployModelRequest, deployModelResponseListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(IllegalStateException.class);
verify(deployModelResponseListener).onFailure(argumentCaptor.capture());
assertEquals(LOCAL_MODEL_DISABLED_ERR_MSG, argumentCaptor.getValue().getMessage());
}

public void test_ValidationFailedException() {
MLModel mlModel = mock(MLModel.class);
when(mlModel.getAlgorithm()).thenReturn(FunctionName.ANOMALY_LOCALIZATION);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE;
import static org.opensearch.ml.utils.MLExceptionUtils.LOCAL_MODEL_DISABLED_ERR_MSG;

import java.util.Arrays;
import java.util.Collections;
Expand Down Expand Up @@ -48,6 +49,7 @@
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.ml.model.MLModelCacheHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.task.MLPredictTaskRunner;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.ThreadPool;
Expand Down Expand Up @@ -83,6 +85,8 @@ public class TransportPredictionTaskActionTests extends OpenSearchTestCase {

@Mock
private ModelAccessControlHelper modelAccessControlHelper;
@Mock
private MLFeatureEnabledSetting mlFeatureEnabledSetting;

@Mock
ActionFilters actionFilters;
Expand Down Expand Up @@ -129,6 +133,7 @@ public void setup() {
when(client.threadPool()).thenReturn(threadPool);
when(threadPool.getThreadContext()).thenReturn(threadContext);
when(clusterService.getClusterSettings()).thenReturn(clusterSettings);
when(mlFeatureEnabledSetting.isLocalModelEnabled()).thenReturn(true);

transportPredictionTaskAction = spy(
new TransportPredictionTaskAction(
Expand All @@ -141,6 +146,7 @@ public void setup() {
xContentRegistry,
mlModelManager,
modelAccessControlHelper,
mlFeatureEnabledSetting,
settings
)
);
Expand Down Expand Up @@ -168,6 +174,21 @@ public void testPrediction_default_exception() {
assertEquals("Failed to Validate Access for ModelId test_id", argumentCaptor.getValue().getMessage());
}

public void testPrediction_local_model_not_exception() {
when(modelCacheHelper.getModelInfo(anyString())).thenReturn(model);
when(model.getAlgorithm()).thenReturn(FunctionName.TEXT_EMBEDDING);
when(mlFeatureEnabledSetting.isLocalModelEnabled()).thenReturn(false);

IllegalStateException e = assertThrows(
IllegalStateException.class,
() -> transportPredictionTaskAction.doExecute(null, mlPredictionTaskRequest, actionListener)
);
assertEquals(
e.getMessage(),
LOCAL_MODEL_DISABLED_ERR_MSG
);
}

public void testPrediction_OpenSearchStatusException() {
when(modelCacheHelper.getModelInfo(anyString())).thenReturn(model);
when(model.getAlgorithm()).thenReturn(FunctionName.KMEANS);
Expand Down
Loading

0 comments on commit 083e84a

Please sign in to comment.