Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add local inference enabling/disabling setting #2232

Merged
merged 3 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import static org.opensearch.ml.plugin.MachineLearningPlugin.DEPLOY_THREAD_POOL;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN;
import static org.opensearch.ml.task.MLTaskManager.TASK_SEMAPHORE_TIMEOUT;
import static org.opensearch.ml.utils.MLExceptionUtils.LOCAL_MODEL_DISABLED_ERR_MSG;
import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG;

import java.time.Instant;
Expand Down Expand Up @@ -143,6 +144,8 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLDepl
Boolean isHidden = mlModel.getIsHidden();
if (functionName == FunctionName.REMOTE && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) {
throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG);
} else if (FunctionName.isDLModel(functionName) && !mlFeatureEnabledSetting.isLocalModelEnabled()) {
throw new IllegalStateException(LOCAL_MODEL_DISABLED_ERR_MSG);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just wondering if we should throw IllegalStateException or Forbidden exception?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can maybe use the same one as remote inference and agent framework to stay consistent?

}
if (!isUserInitiatedDeployRequest) {
deployModel(deployModelRequest, mlModel, modelId, wrappedListener, listener);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.action.prediction;

import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE;
import static org.opensearch.ml.utils.MLExceptionUtils.LOCAL_MODEL_DISABLED_ERR_MSG;

import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.ActionRequest;
Expand All @@ -29,6 +30,7 @@
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.ml.model.MLModelCacheHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.task.MLPredictTaskRunner;
import org.opensearch.ml.task.MLTaskRunner;
import org.opensearch.ml.utils.RestActionUtils;
Expand Down Expand Up @@ -58,6 +60,8 @@ public class TransportPredictionTaskAction extends HandledTransportAction<Action

private volatile boolean enableAutomaticDeployment;

private MLFeatureEnabledSetting mlFeatureEnabledSetting;

@Inject
public TransportPredictionTaskAction(
TransportService transportService,
Expand All @@ -69,6 +73,7 @@ public TransportPredictionTaskAction(
NamedXContentRegistry xContentRegistry,
MLModelManager mlModelManager,
ModelAccessControlHelper modelAccessControlHelper,
MLFeatureEnabledSetting mlFeatureEnabledSetting,
Settings settings
) {
super(MLPredictionTaskAction.NAME, transportService, actionFilters, MLPredictionTaskRequest::new);
Expand All @@ -80,6 +85,7 @@ public TransportPredictionTaskAction(
this.xContentRegistry = xContentRegistry;
this.mlModelManager = mlModelManager;
this.modelAccessControlHelper = modelAccessControlHelper;
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
enableAutomaticDeployment = ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE.get(settings);
clusterService
.getClusterSettings()
Expand Down Expand Up @@ -107,6 +113,9 @@ public void onResponse(MLModel mlModel) {
context.restore();
modelCacheHelper.setModelInfo(modelId, mlModel);
FunctionName functionName = mlModel.getAlgorithm();
if (FunctionName.isDLModel(functionName) && !mlFeatureEnabledSetting.isLocalModelEnabled()) {
throw new IllegalStateException(LOCAL_MODEL_DISABLED_ERR_MSG);
}
mlPredictionTaskRequest.getMlInput().setAlgorithm(functionName);
modelAccessControlHelper
.validateModelGroupAccess(userInfo, mlModel.getModelGroupId(), client, ActionListener.wrap(access -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_URL_REGEX;
import static org.opensearch.ml.task.MLTaskManager.TASK_SEMAPHORE_TIMEOUT;
import static org.opensearch.ml.utils.MLExceptionUtils.LOCAL_MODEL_DISABLED_ERR_MSG;
import static org.opensearch.ml.utils.MLExceptionUtils.logException;

import java.time.Instant;
Expand Down Expand Up @@ -56,6 +57,7 @@
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.ml.model.MLModelGroupManager;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.stats.MLStats;
import org.opensearch.ml.task.MLTaskDispatcher;
import org.opensearch.ml.task.MLTaskManager;
Expand Down Expand Up @@ -94,6 +96,7 @@ public class TransportRegisterModelAction extends HandledTransportAction<ActionR

ConnectorAccessControlHelper connectorAccessControlHelper;
MLModelGroupManager mlModelGroupManager;
private MLFeatureEnabledSetting mlFeatureEnabledSetting;

@Inject
public TransportRegisterModelAction(
Expand All @@ -112,7 +115,8 @@ public TransportRegisterModelAction(
MLStats mlStats,
ModelAccessControlHelper modelAccessControlHelper,
ConnectorAccessControlHelper connectorAccessControlHelper,
MLModelGroupManager mlModelGroupManager
MLModelGroupManager mlModelGroupManager,
MLFeatureEnabledSetting mlFeatureEnabledSetting
) {
super(MLRegisterModelAction.NAME, transportService, actionFilters, MLRegisterModelRequest::new);
this.transportService = transportService;
Expand All @@ -129,6 +133,7 @@ public TransportRegisterModelAction(
this.modelAccessControlHelper = modelAccessControlHelper;
this.connectorAccessControlHelper = connectorAccessControlHelper;
this.mlModelGroupManager = mlModelGroupManager;
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
this.settings = settings;

trustedUrlRegex = ML_COMMONS_TRUSTED_URL_REGEX.get(settings);
Expand All @@ -147,6 +152,9 @@ public TransportRegisterModelAction(
protected void doExecute(Task task, ActionRequest request, ActionListener<MLRegisterModelResponse> listener) {
MLRegisterModelRequest registerModelRequest = MLRegisterModelRequest.fromActionRequest(request);
MLRegisterModelInput registerModelInput = registerModelRequest.getRegisterModelInput();
if (FunctionName.isDLModel(registerModelInput.getFunctionName()) && !mlFeatureEnabledSetting.isLocalModelEnabled()) {
throw new IllegalStateException(LOCAL_MODEL_DISABLED_ERR_MSG);
}
if (registerModelInput.getUrl() != null && !isModelUrlAllowed) {
throw new IllegalArgumentException(
"To upload custom model user needs to enable allow_registering_model_via_url settings. Otherwise please use OpenSearch pre-trained models."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -897,6 +897,7 @@ public List<Setting<?>> getSettings() {
MLCommonsSettings.ML_COMMONS_REMOTE_MODEL_ELIGIBLE_NODE_ROLES,
MLCommonsSettings.ML_COMMONS_LOCAL_MODEL_ELIGIBLE_NODE_ROLES,
MLCommonsSettings.ML_COMMONS_REMOTE_INFERENCE_ENABLED,
MLCommonsSettings.ML_COMMONS_LOCAL_MODEL_ENABLED,
MLCommonsSettings.ML_COMMONS_MEMORY_FEATURE_ENABLED,
MLCommonsSettings.ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED,
MLCommonsSettings.ML_COMMONS_AGENT_FRAMEWORK_ENABLED,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI;
import static org.opensearch.ml.utils.MLExceptionUtils.LOCAL_MODEL_DISABLED_ERR_MSG;
import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG;
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_ALGORITHM;
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID;
Expand Down Expand Up @@ -121,6 +122,8 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client
MLPredictionTaskRequest getRequest(String modelId, String algorithm, RestRequest request) throws IOException {
if (FunctionName.REMOTE.name().equals(algorithm) && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) {
throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG);
} else if (FunctionName.isDLModel(FunctionName.from(algorithm)) && !mlFeatureEnabledSetting.isLocalModelEnabled()) {
throw new IllegalStateException(LOCAL_MODEL_DISABLED_ERR_MSG);
}
XContentParser parser = request.contentParser();
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI;
import static org.opensearch.ml.utils.MLExceptionUtils.LOCAL_MODEL_DISABLED_ERR_MSG;
import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG;
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_DEPLOY_MODEL;
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID;
Expand Down Expand Up @@ -98,6 +99,8 @@ MLRegisterModelRequest getRequest(RestRequest request) throws IOException {
MLRegisterModelInput mlInput = MLRegisterModelInput.parse(parser, loadModel);
if (mlInput.getFunctionName() == FunctionName.REMOTE && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) {
throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG);
} else if (FunctionName.isDLModel(mlInput.getFunctionName()) && !mlFeatureEnabledSetting.isLocalModelEnabled()) {
throw new IllegalStateException(LOCAL_MODEL_DISABLED_ERR_MSG);
}
return new MLRegisterModelRequest(mlInput);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,9 @@ private MLCommonsSettings() {}
public static final Setting<Boolean> ML_COMMONS_REMOTE_INFERENCE_ENABLED = Setting
.boolSetting("plugins.ml_commons.remote_inference.enabled", true, Setting.Property.NodeScope, Setting.Property.Dynamic);

public static final Setting<Boolean> ML_COMMONS_LOCAL_MODEL_ENABLED = Setting
.boolSetting("plugins.ml_commons.local_model.enabled", true, Setting.Property.NodeScope, Setting.Property.Dynamic);

public static final Setting<Boolean> ML_COMMONS_MODEL_ACCESS_CONTROL_ENABLED = Setting
.boolSetting("plugins.ml_commons.model_access_control_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
package org.opensearch.ml.settings;

import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_AGENT_FRAMEWORK_ENABLED;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_LOCAL_MODEL_ENABLED;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_INFERENCE_ENABLED;

import org.opensearch.cluster.service.ClusterService;
Expand All @@ -18,16 +19,20 @@ public class MLFeatureEnabledSetting {
private volatile Boolean isRemoteInferenceEnabled;
private volatile Boolean isAgentFrameworkEnabled;

private volatile Boolean isLocalModelEnabled;

public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings) {
isRemoteInferenceEnabled = ML_COMMONS_REMOTE_INFERENCE_ENABLED.get(settings);
isAgentFrameworkEnabled = ML_COMMONS_AGENT_FRAMEWORK_ENABLED.get(settings);
isLocalModelEnabled = ML_COMMONS_LOCAL_MODEL_ENABLED.get(settings);

clusterService
.getClusterSettings()
.addSettingsUpdateConsumer(ML_COMMONS_REMOTE_INFERENCE_ENABLED, it -> isRemoteInferenceEnabled = it);
clusterService
.getClusterSettings()
.addSettingsUpdateConsumer(ML_COMMONS_AGENT_FRAMEWORK_ENABLED, it -> isAgentFrameworkEnabled = it);
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_LOCAL_MODEL_ENABLED, it -> isLocalModelEnabled = it);
}

/**
Expand All @@ -46,4 +51,12 @@ public boolean isAgentFrameworkEnabled() {
return isAgentFrameworkEnabled;
}

/**
* Whether the local model feature is enabled. If disabled, APIs in ml-commons will block local model inference.
* @return whether the local inference is enabled.
*/
public boolean isLocalModelEnabled() {
return isLocalModelEnabled;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ public class MLExceptionUtils {
public static final String NOT_SERIALIZABLE_EXCEPTION_WRAPPER = "NotSerializableExceptionWrapper: ";
public static final String REMOTE_INFERENCE_DISABLED_ERR_MSG =
"Remote Inference is currently disabled. To enable it, update the setting \"plugins.ml_commons.remote_inference_enabled\" to true.";
public static final String LOCAL_MODEL_DISABLED_ERR_MSG =
"Local Model is currently disabled. To enable it, update the setting \"plugins.ml_commons.local_model.enabled\" to true.";
public static final String AGENT_FRAMEWORK_DISABLED_ERR_MSG =
"Agent Framework is currently disabled. To enable it, update the setting \"plugins.ml_commons.agent_framework_enabled\" to true.";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN;
import static org.opensearch.ml.utils.MLExceptionUtils.LOCAL_MODEL_DISABLED_ERR_MSG;
import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG;

import java.lang.reflect.Field;
Expand Down Expand Up @@ -176,6 +177,7 @@ public void setup() {
when(mlDeployModelRequest.isUserInitiatedDeployRequest()).thenReturn(true);

when(mlFeatureEnabledSetting.isRemoteInferenceEnabled()).thenReturn(true);
when(mlFeatureEnabledSetting.isLocalModelEnabled()).thenReturn(true);

MLStat mlStat = mock(MLStat.class);
when(mlStats.getStat(eq(MLNodeLevelStat.ML_REQUEST_COUNT))).thenReturn(mlStat);
Expand Down Expand Up @@ -374,6 +376,23 @@ public void testDoExecuteRemoteInferenceDisabled() {
assertEquals(REMOTE_INFERENCE_DISABLED_ERR_MSG, argumentCaptor.getValue().getMessage());
}

public void testDoExecuteLocalInferenceDisabled() {
MLModel mlModel = mock(MLModel.class);
when(mlModel.getAlgorithm()).thenReturn(FunctionName.TEXT_EMBEDDING);
doAnswer(invocation -> {
ActionListener<MLModel> listener = invocation.getArgument(3);
listener.onResponse(mlModel);
return null;
}).when(mlModelManager).getModel(anyString(), isNull(), any(String[].class), Mockito.isA(ActionListener.class));

when(mlFeatureEnabledSetting.isLocalModelEnabled()).thenReturn(false);
ActionListener<MLDeployModelResponse> deployModelResponseListener = mock(ActionListener.class);
transportDeployModelAction.doExecute(mock(Task.class), mlDeployModelRequest, deployModelResponseListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(IllegalStateException.class);
verify(deployModelResponseListener).onFailure(argumentCaptor.capture());
assertEquals(LOCAL_MODEL_DISABLED_ERR_MSG, argumentCaptor.getValue().getMessage());
}

public void test_ValidationFailedException() {
MLModel mlModel = mock(MLModel.class);
when(mlModel.getAlgorithm()).thenReturn(FunctionName.ANOMALY_LOCALIZATION);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE;
import static org.opensearch.ml.utils.MLExceptionUtils.LOCAL_MODEL_DISABLED_ERR_MSG;

import java.util.Arrays;
import java.util.Collections;
Expand Down Expand Up @@ -48,6 +49,7 @@
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.ml.model.MLModelCacheHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.task.MLPredictTaskRunner;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.ThreadPool;
Expand Down Expand Up @@ -83,6 +85,8 @@ public class TransportPredictionTaskActionTests extends OpenSearchTestCase {

@Mock
private ModelAccessControlHelper modelAccessControlHelper;
@Mock
private MLFeatureEnabledSetting mlFeatureEnabledSetting;

@Mock
ActionFilters actionFilters;
Expand Down Expand Up @@ -129,6 +133,7 @@ public void setup() {
when(client.threadPool()).thenReturn(threadPool);
when(threadPool.getThreadContext()).thenReturn(threadContext);
when(clusterService.getClusterSettings()).thenReturn(clusterSettings);
when(mlFeatureEnabledSetting.isLocalModelEnabled()).thenReturn(true);

transportPredictionTaskAction = spy(
new TransportPredictionTaskAction(
Expand All @@ -141,6 +146,7 @@ public void setup() {
xContentRegistry,
mlModelManager,
modelAccessControlHelper,
mlFeatureEnabledSetting,
settings
)
);
Expand Down Expand Up @@ -168,6 +174,21 @@ public void testPrediction_default_exception() {
assertEquals("Failed to Validate Access for ModelId test_id", argumentCaptor.getValue().getMessage());
}

public void testPrediction_local_model_not_exception() {
when(modelCacheHelper.getModelInfo(anyString())).thenReturn(model);
when(model.getAlgorithm()).thenReturn(FunctionName.TEXT_EMBEDDING);
when(mlFeatureEnabledSetting.isLocalModelEnabled()).thenReturn(false);

IllegalStateException e = assertThrows(
IllegalStateException.class,
() -> transportPredictionTaskAction.doExecute(null, mlPredictionTaskRequest, actionListener)
);
assertEquals(
e.getMessage(),
LOCAL_MODEL_DISABLED_ERR_MSG
);
}

public void testPrediction_OpenSearchStatusException() {
when(modelCacheHelper.getModelInfo(anyString())).thenReturn(model);
when(model.getAlgorithm()).thenReturn(FunctionName.KMEANS);
Expand Down
Loading
Loading