Skip to content

Commit

Permalink
change the setting name to more appropriate
Browse files Browse the repository at this point in the history
Signed-off-by: Bhavana Ramaram <[email protected]>
  • Loading branch information
rbhavna committed Mar 20, 2024
1 parent e1a81e4 commit 8924e22
Show file tree
Hide file tree
Showing 14 changed files with 25 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLDepl
Boolean isHidden = mlModel.getIsHidden();
if (functionName == FunctionName.REMOTE && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) {
throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG);
} else if (FunctionName.isDLModel(functionName) && !mlFeatureEnabledSetting.isLocalModelInferenceEnabled()) {
} else if (FunctionName.isDLModel(functionName) && !mlFeatureEnabledSetting.isLocalModelEnabled()) {
throw new IllegalStateException(LOCAL_MODEL_DISABLED_ERR_MSG);
}
if (!isUserInitiatedDeployRequest) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ public void onResponse(MLModel mlModel) {
context.restore();
modelCacheHelper.setModelInfo(modelId, mlModel);
FunctionName functionName = mlModel.getAlgorithm();
if (FunctionName.isDLModel(functionName) && !mlFeatureEnabledSetting.isLocalModelInferenceEnabled()) {
if (FunctionName.isDLModel(functionName) && !mlFeatureEnabledSetting.isLocalModelEnabled()) {
throw new IllegalStateException(LOCAL_MODEL_DISABLED_ERR_MSG);
}
mlPredictionTaskRequest.getMlInput().setAlgorithm(functionName);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ public TransportRegisterModelAction(
protected void doExecute(Task task, ActionRequest request, ActionListener<MLRegisterModelResponse> listener) {
MLRegisterModelRequest registerModelRequest = MLRegisterModelRequest.fromActionRequest(request);
MLRegisterModelInput registerModelInput = registerModelRequest.getRegisterModelInput();
if (FunctionName.isDLModel(registerModelInput.getFunctionName()) && !mlFeatureEnabledSetting.isLocalModelInferenceEnabled()) {
if (FunctionName.isDLModel(registerModelInput.getFunctionName()) && !mlFeatureEnabledSetting.isLocalModelEnabled()) {
throw new IllegalStateException(LOCAL_MODEL_DISABLED_ERR_MSG);
}
if (registerModelInput.getUrl() != null && !isModelUrlAllowed) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -897,7 +897,7 @@ public List<Setting<?>> getSettings() {
MLCommonsSettings.ML_COMMONS_REMOTE_MODEL_ELIGIBLE_NODE_ROLES,
MLCommonsSettings.ML_COMMONS_LOCAL_MODEL_ELIGIBLE_NODE_ROLES,
MLCommonsSettings.ML_COMMONS_REMOTE_INFERENCE_ENABLED,
MLCommonsSettings.ML_COMMONS_LOCAL_MODEL_INFERENCE_ENABLED,
MLCommonsSettings.ML_COMMONS_LOCAL_MODEL_ENABLED,
MLCommonsSettings.ML_COMMONS_MEMORY_FEATURE_ENABLED,
MLCommonsSettings.ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED,
MLCommonsSettings.ML_COMMONS_AGENT_FRAMEWORK_ENABLED,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client
MLPredictionTaskRequest getRequest(String modelId, String algorithm, RestRequest request) throws IOException {
if (FunctionName.REMOTE.name().equals(algorithm) && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) {
throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG);
} else if (FunctionName.isDLModel(FunctionName.from(algorithm)) && !mlFeatureEnabledSetting.isLocalModelInferenceEnabled()) {
} else if (FunctionName.isDLModel(FunctionName.from(algorithm)) && !mlFeatureEnabledSetting.isLocalModelEnabled()) {
throw new IllegalStateException(LOCAL_MODEL_DISABLED_ERR_MSG);
}
XContentParser parser = request.contentParser();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ MLRegisterModelRequest getRequest(RestRequest request) throws IOException {
MLRegisterModelInput mlInput = MLRegisterModelInput.parse(parser, loadModel);
if (mlInput.getFunctionName() == FunctionName.REMOTE && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) {
throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG);
} else if (FunctionName.isDLModel(mlInput.getFunctionName()) && !mlFeatureEnabledSetting.isLocalModelInferenceEnabled()) {
} else if (FunctionName.isDLModel(mlInput.getFunctionName()) && !mlFeatureEnabledSetting.isLocalModelEnabled()) {
throw new IllegalStateException(LOCAL_MODEL_DISABLED_ERR_MSG);
}
return new MLRegisterModelRequest(mlInput);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,8 @@ private MLCommonsSettings() {}
public static final Setting<Boolean> ML_COMMONS_REMOTE_INFERENCE_ENABLED = Setting
.boolSetting("plugins.ml_commons.remote_inference.enabled", true, Setting.Property.NodeScope, Setting.Property.Dynamic);

public static final Setting<Boolean> ML_COMMONS_LOCAL_MODEL_INFERENCE_ENABLED = Setting
.boolSetting("plugins.ml_commons.local_model_inference_enabled", true, Setting.Property.NodeScope, Setting.Property.Dynamic);
public static final Setting<Boolean> ML_COMMONS_LOCAL_MODEL_ENABLED = Setting
.boolSetting("plugins.ml_commons.local_model.enabled", true, Setting.Property.NodeScope, Setting.Property.Dynamic);

public static final Setting<Boolean> ML_COMMONS_MODEL_ACCESS_CONTROL_ENABLED = Setting
.boolSetting("plugins.ml_commons.model_access_control_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
package org.opensearch.ml.settings;

import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_AGENT_FRAMEWORK_ENABLED;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_LOCAL_MODEL_INFERENCE_ENABLED;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_LOCAL_MODEL_ENABLED;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_INFERENCE_ENABLED;

import org.opensearch.cluster.service.ClusterService;
Expand All @@ -19,22 +19,20 @@ public class MLFeatureEnabledSetting {
private volatile Boolean isRemoteInferenceEnabled;
private volatile Boolean isAgentFrameworkEnabled;

private volatile Boolean isLocalModelInferenceEnabled;
private volatile Boolean isLocalModelEnabled;

public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings) {
isRemoteInferenceEnabled = ML_COMMONS_REMOTE_INFERENCE_ENABLED.get(settings);
isAgentFrameworkEnabled = ML_COMMONS_AGENT_FRAMEWORK_ENABLED.get(settings);
isLocalModelInferenceEnabled = ML_COMMONS_LOCAL_MODEL_INFERENCE_ENABLED.get(settings);
isLocalModelEnabled = ML_COMMONS_LOCAL_MODEL_ENABLED.get(settings);

clusterService
.getClusterSettings()
.addSettingsUpdateConsumer(ML_COMMONS_REMOTE_INFERENCE_ENABLED, it -> isRemoteInferenceEnabled = it);
clusterService
.getClusterSettings()
.addSettingsUpdateConsumer(ML_COMMONS_AGENT_FRAMEWORK_ENABLED, it -> isAgentFrameworkEnabled = it);
clusterService
.getClusterSettings()
.addSettingsUpdateConsumer(ML_COMMONS_LOCAL_MODEL_INFERENCE_ENABLED, it -> isLocalModelInferenceEnabled = it);
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_LOCAL_MODEL_ENABLED, it -> isLocalModelEnabled = it);
}

/**
Expand All @@ -57,8 +55,8 @@ public boolean isAgentFrameworkEnabled() {
* Whether the local model feature is enabled. If disabled, APIs in ml-commons will block local model inference.
* @return whether the local inference is enabled.
*/
public boolean isLocalModelInferenceEnabled() {
return isLocalModelInferenceEnabled;
public boolean isLocalModelEnabled() {
return isLocalModelEnabled;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ public class MLExceptionUtils {
public static final String REMOTE_INFERENCE_DISABLED_ERR_MSG =
"Remote Inference is currently disabled. To enable it, update the setting \"plugins.ml_commons.remote_inference_enabled\" to true.";
public static final String LOCAL_MODEL_DISABLED_ERR_MSG =
"Local Model Inference is currently disabled. To enable it, update the setting \"plugins.ml_commons.local_model_inference_enabled\" to true.";
"Local Model is currently disabled. To enable it, update the setting \"plugins.ml_commons.local_model.enabled\" to true.";
public static final String AGENT_FRAMEWORK_DISABLED_ERR_MSG =
"Agent Framework is currently disabled. To enable it, update the setting \"plugins.ml_commons.agent_framework_enabled\" to true.";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ public void setup() {
when(mlDeployModelRequest.isUserInitiatedDeployRequest()).thenReturn(true);

when(mlFeatureEnabledSetting.isRemoteInferenceEnabled()).thenReturn(true);
when(mlFeatureEnabledSetting.isLocalModelInferenceEnabled()).thenReturn(true);
when(mlFeatureEnabledSetting.isLocalModelEnabled()).thenReturn(true);

MLStat mlStat = mock(MLStat.class);
when(mlStats.getStat(eq(MLNodeLevelStat.ML_REQUEST_COUNT))).thenReturn(mlStat);
Expand Down Expand Up @@ -385,7 +385,7 @@ public void testDoExecuteLocalInferenceDisabled() {
return null;
}).when(mlModelManager).getModel(anyString(), isNull(), any(String[].class), Mockito.isA(ActionListener.class));

when(mlFeatureEnabledSetting.isLocalModelInferenceEnabled()).thenReturn(false);
when(mlFeatureEnabledSetting.isLocalModelEnabled()).thenReturn(false);
ActionListener<MLDeployModelResponse> deployModelResponseListener = mock(ActionListener.class);
transportDeployModelAction.doExecute(mock(Task.class), mlDeployModelRequest, deployModelResponseListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(IllegalStateException.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ public void setup() {
when(client.threadPool()).thenReturn(threadPool);
when(threadPool.getThreadContext()).thenReturn(threadContext);
when(clusterService.getClusterSettings()).thenReturn(clusterSettings);
when(mlFeatureEnabledSetting.isLocalModelInferenceEnabled()).thenReturn(true);
when(mlFeatureEnabledSetting.isLocalModelEnabled()).thenReturn(true);

transportPredictionTaskAction = spy(
new TransportPredictionTaskAction(
Expand Down Expand Up @@ -177,7 +177,7 @@ public void testPrediction_default_exception() {
public void testPrediction_local_model_not_exception() {
when(modelCacheHelper.getModelInfo(anyString())).thenReturn(model);
when(model.getAlgorithm()).thenReturn(FunctionName.TEXT_EMBEDDING);
when(mlFeatureEnabledSetting.isLocalModelInferenceEnabled()).thenReturn(false);
when(mlFeatureEnabledSetting.isLocalModelEnabled()).thenReturn(false);

IllegalStateException e = assertThrows(
IllegalStateException.class,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ public void setup() throws IOException {
return null;
}).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any());

when(mlFeatureEnabledSetting.isLocalModelInferenceEnabled()).thenReturn(true);
when(mlFeatureEnabledSetting.isLocalModelEnabled()).thenReturn(true);

when(clusterService.localNode()).thenReturn(node2);
when(node2.getId()).thenReturn("node2Id");
Expand All @@ -235,7 +235,7 @@ public void setup() throws IOException {
}

public void testDoExecute_LocalModelDisabledException() {
when(mlFeatureEnabledSetting.isLocalModelInferenceEnabled()).thenReturn(false);
when(mlFeatureEnabledSetting.isLocalModelEnabled()).thenReturn(false);

MLRegisterModelInput registerModelInput = MLRegisterModelInput
.builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ public void setup() {
MockitoAnnotations.openMocks(this);
when(modelManager.getOptionalModelFunctionName(anyString())).thenReturn(Optional.empty());
when(mlFeatureEnabledSetting.isRemoteInferenceEnabled()).thenReturn(true);
when(mlFeatureEnabledSetting.isLocalModelInferenceEnabled()).thenReturn(true);
when(mlFeatureEnabledSetting.isLocalModelEnabled()).thenReturn(true);
restMLPredictionAction = new RestMLPredictionAction(modelManager, mlFeatureEnabledSetting);

threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool");
Expand Down Expand Up @@ -128,7 +128,7 @@ public void testGetRequest_LocalModelInferenceDisabled() throws IOException {
thrown.expect(IllegalStateException.class);
thrown.expectMessage(LOCAL_MODEL_DISABLED_ERR_MSG);

when(mlFeatureEnabledSetting.isLocalModelInferenceEnabled()).thenReturn(false);
when(mlFeatureEnabledSetting.isLocalModelEnabled()).thenReturn(false);
RestRequest request = getRestRequest_PredictModel();
MLPredictionTaskRequest mlPredictionTaskRequest = restMLPredictionAction
.getRequest("modelId", FunctionName.TEXT_EMBEDDING.name(), request);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ public void setup() {
when(clusterService.getClusterSettings()).thenReturn(clusterSettings);
when(clusterService.getSettings()).thenReturn(settings);
when(mlFeatureEnabledSetting.isRemoteInferenceEnabled()).thenReturn(true);
when(mlFeatureEnabledSetting.isLocalModelInferenceEnabled()).thenReturn(true);
when(mlFeatureEnabledSetting.isLocalModelEnabled()).thenReturn(true);
restMLRegisterModelAction = new RestMLRegisterModelAction(clusterService, settings, mlFeatureEnabledSetting);
threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool");
client = spy(new NodeClient(Settings.EMPTY, threadPool));
Expand Down Expand Up @@ -169,7 +169,7 @@ public void testRegisterModelRequestLocalInferenceDisabled() throws Exception {
exceptionRule.expect(IllegalStateException.class);
exceptionRule.expectMessage(LOCAL_MODEL_DISABLED_ERR_MSG);

when(mlFeatureEnabledSetting.isLocalModelInferenceEnabled()).thenReturn(false);
when(mlFeatureEnabledSetting.isLocalModelEnabled()).thenReturn(false);
RestRequest request = getRestRequestWithNullModelId(FunctionName.TEXT_EMBEDDING);
restMLRegisterModelAction.handleRequest(request, channel, client);
}
Expand Down

0 comments on commit 8924e22

Please sign in to comment.