Skip to content

Commit

Permalink
agent framework disable/enable flag (#1949)
Browse files Browse the repository at this point in the history
Signed-off-by: Jing Zhang <[email protected]>
(cherry picked from commit f70b433)
  • Loading branch information
jngz-es authored and github-actions[bot] committed Feb 2, 2024
1 parent 939bbee commit d098f78
Show file tree
Hide file tree
Showing 14 changed files with 199 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,7 @@ public List<RestHandler> getRestHandlers(
RestMLTrainingAction restMLTrainingAction = new RestMLTrainingAction();
RestMLTrainAndPredictAction restMLTrainAndPredictAction = new RestMLTrainAndPredictAction();
RestMLPredictionAction restMLPredictionAction = new RestMLPredictionAction(mlModelManager, mlFeatureEnabledSetting);
RestMLExecuteAction restMLExecuteAction = new RestMLExecuteAction();
RestMLExecuteAction restMLExecuteAction = new RestMLExecuteAction(mlFeatureEnabledSetting);
RestMLGetModelAction restMLGetModelAction = new RestMLGetModelAction();
RestMLDeleteModelAction restMLDeleteModelAction = new RestMLDeleteModelAction();
RestMLSearchModelAction restMLSearchModelAction = new RestMLSearchModelAction();
Expand All @@ -676,7 +676,7 @@ public List<RestHandler> getRestHandlers(
settings,
mlFeatureEnabledSetting
);
RestMLRegisterAgentAction restMLRegisterAgentAction = new RestMLRegisterAgentAction();
RestMLRegisterAgentAction restMLRegisterAgentAction = new RestMLRegisterAgentAction(mlFeatureEnabledSetting);
RestMLDeployModelAction restMLDeployModelAction = new RestMLDeployModelAction();
RestMLUndeployModelAction restMLUndeployModelAction = new RestMLUndeployModelAction(clusterService, settings);
RestMLRegisterModelMetaAction restMLRegisterModelMetaAction = new RestMLRegisterModelMetaAction(clusterService, settings);
Expand Down Expand Up @@ -705,12 +705,12 @@ public List<RestHandler> getRestHandlers(
RestMLGetControllerAction restMLGetControllerAction = new RestMLGetControllerAction();
RestMLUpdateControllerAction restMLUpdateControllerAction = new RestMLUpdateControllerAction();
RestMLDeleteControllerAction restMLDeleteControllerAction = new RestMLDeleteControllerAction();
RestMLGetAgentAction restMLGetAgentAction = new RestMLGetAgentAction();
RestMLDeleteAgentAction restMLDeleteAgentAction = new RestMLDeleteAgentAction();
RestMLGetAgentAction restMLGetAgentAction = new RestMLGetAgentAction(mlFeatureEnabledSetting);
RestMLDeleteAgentAction restMLDeleteAgentAction = new RestMLDeleteAgentAction(mlFeatureEnabledSetting);
RestMemoryUpdateConversationAction restMemoryUpdateConversationAction = new RestMemoryUpdateConversationAction();
RestMemoryUpdateInteractionAction restMemoryUpdateInteractionAction = new RestMemoryUpdateInteractionAction();
RestMemoryGetTracesAction restMemoryGetTracesAction = new RestMemoryGetTracesAction();
RestMLSearchAgentAction restMLSearchAgentAction = new RestMLSearchAgentAction();
RestMLSearchAgentAction restMLSearchAgentAction = new RestMLSearchAgentAction(mlFeatureEnabledSetting);
RestMLListToolsAction restMLListToolsAction = new RestMLListToolsAction(toolFactories);
RestMLGetToolAction restMLGetToolAction = new RestMLGetToolAction(toolFactories);
return ImmutableList
Expand Down Expand Up @@ -873,7 +873,8 @@ public List<Setting<?>> getSettings() {
MLCommonsSettings.ML_COMMONS_LOCAL_MODEL_ELIGIBLE_NODE_ROLES,
MLCommonsSettings.ML_COMMONS_REMOTE_INFERENCE_ENABLED,
MLCommonsSettings.ML_COMMONS_MEMORY_FEATURE_ENABLED,
MLCommonsSettings.ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED
MLCommonsSettings.ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED,
MLCommonsSettings.ML_COMMONS_AGENT_FRAMEWORK_ENABLED
);
return settings;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.rest;

import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI;
import static org.opensearch.ml.utils.MLExceptionUtils.AGENT_FRAMEWORK_DISABLED_ERR_MSG;
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_AGENT_ID;

import java.io.IOException;
Expand All @@ -15,6 +16,7 @@
import org.opensearch.client.node.NodeClient;
import org.opensearch.ml.common.transport.agent.MLAgentDeleteAction;
import org.opensearch.ml.common.transport.agent.MLAgentDeleteRequest;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.action.RestToXContentListener;
Expand All @@ -26,8 +28,11 @@
*/
public class RestMLDeleteAgentAction extends BaseRestHandler {
private static final String ML_DELETE_AGENT_ACTION = "ml_delete_agent_action";
private final MLFeatureEnabledSetting mlFeatureEnabledSetting;

public void RestMLDeleteAgentAction() {}
public RestMLDeleteAgentAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) {
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
}

@Override
public String getName() {
Expand All @@ -42,6 +47,9 @@ public List<Route> routes() {

@Override
protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException {
if (!mlFeatureEnabledSetting.isAgentFrameworkEnabled()) {
throw new IllegalStateException(AGENT_FRAMEWORK_DISABLED_ERR_MSG);
}
String agentId = request.param(PARAMETER_AGENT_ID);

MLAgentDeleteRequest mlAgentDeleteRequest = new MLAgentDeleteRequest(agentId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI;
import static org.opensearch.ml.utils.MLExceptionUtils.AGENT_FRAMEWORK_DISABLED_ERR_MSG;
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_AGENT_ID;
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_ALGORITHM;
import static org.opensearch.ml.utils.RestActionUtils.getAlgorithm;
Expand All @@ -25,6 +26,7 @@
import org.opensearch.ml.common.transport.execute.MLExecuteTaskRequest;
import org.opensearch.ml.repackage.com.google.common.annotations.VisibleForTesting;
import org.opensearch.ml.repackage.com.google.common.collect.ImmutableList;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.action.RestToXContentListener;
Expand All @@ -34,11 +36,14 @@
@Log4j2
public class RestMLExecuteAction extends BaseRestHandler {
private static final String ML_EXECUTE_ACTION = "ml_execute_action";
private final MLFeatureEnabledSetting mlFeatureEnabledSetting;

/**
* Constructor
*/
public RestMLExecuteAction() {}
public RestMLExecuteAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) {
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
}

@Override
public String getName() {
Expand Down Expand Up @@ -75,6 +80,9 @@ MLExecuteTaskRequest getRequest(RestRequest request) throws IOException {
FunctionName functionName = null;
Input input = null;
if (uri.startsWith(ML_BASE_URI + "/agents/")) {
if (!mlFeatureEnabledSetting.isAgentFrameworkEnabled()) {
throw new IllegalStateException(AGENT_FRAMEWORK_DISABLED_ERR_MSG);
}
String agentId = request.param(PARAMETER_AGENT_ID);
functionName = FunctionName.AGENT;
input = MLInput.parse(parser, functionName.name());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.rest;

import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI;
import static org.opensearch.ml.utils.MLExceptionUtils.AGENT_FRAMEWORK_DISABLED_ERR_MSG;
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_AGENT_ID;
import static org.opensearch.ml.utils.RestActionUtils.getParameterId;

Expand All @@ -16,6 +17,7 @@
import org.opensearch.client.node.NodeClient;
import org.opensearch.ml.common.transport.agent.MLAgentGetAction;
import org.opensearch.ml.common.transport.agent.MLAgentGetRequest;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.action.RestToXContentListener;
Expand All @@ -25,11 +27,14 @@

public class RestMLGetAgentAction extends BaseRestHandler {
private static final String ML_GET_Agent_ACTION = "ml_get_agent_action";
private final MLFeatureEnabledSetting mlFeatureEnabledSetting;

/**
* Constructor
*/
public RestMLGetAgentAction() {}
public RestMLGetAgentAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) {
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
}

@Override
public String getName() {
Expand All @@ -56,6 +61,9 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client
*/
@VisibleForTesting
MLAgentGetRequest getRequest(RestRequest request) throws IOException {
if (!mlFeatureEnabledSetting.isAgentFrameworkEnabled()) {
throw new IllegalStateException(AGENT_FRAMEWORK_DISABLED_ERR_MSG);
}
String agentId = getParameterId(request, PARAMETER_AGENT_ID);

return new MLAgentGetRequest(agentId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI;
import static org.opensearch.ml.utils.MLExceptionUtils.AGENT_FRAMEWORK_DISABLED_ERR_MSG;

import java.io.IOException;
import java.util.List;
Expand All @@ -17,6 +18,7 @@
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentAction;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentRequest;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.action.RestToXContentListener;
Expand All @@ -26,11 +28,14 @@

public class RestMLRegisterAgentAction extends BaseRestHandler {
private static final String ML_REGISTER_AGENT_ACTION = "ml_register_agent_action";
private final MLFeatureEnabledSetting mlFeatureEnabledSetting;

/**
* Constructor
*/
public RestMLRegisterAgentAction() {}
public RestMLRegisterAgentAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) {
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
}

@Override
public String getName() {
Expand All @@ -56,6 +61,9 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client
*/
@VisibleForTesting
MLRegisterAgentRequest getRequest(RestRequest request) throws IOException {
if (!mlFeatureEnabledSetting.isAgentFrameworkEnabled()) {
throw new IllegalStateException(AGENT_FRAMEWORK_DISABLED_ERR_MSG);
}
XContentParser parser = request.contentParser();
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
MLAgent mlAgent = MLAgent.parse(parser);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,41 @@

import static org.opensearch.ml.common.CommonValue.ML_AGENT_INDEX;
import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI;
import static org.opensearch.ml.utils.MLExceptionUtils.AGENT_FRAMEWORK_DISABLED_ERR_MSG;

import java.io.IOException;

import org.opensearch.client.node.NodeClient;
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.transport.agent.MLSearchAgentAction;
import org.opensearch.ml.repackage.com.google.common.collect.ImmutableList;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.rest.RestRequest;

/**
* This class consists of the REST handler to search ML Agents.
*/
public class RestMLSearchAgentAction extends AbstractMLSearchAction<MLAgent> {
private static final String ML_SEARCH_AGENT_ACTION = "ml_search_agent_action";
private static final String SEARCH_AGENT_PATH = ML_BASE_URI + "/agents/_search";
private final MLFeatureEnabledSetting mlFeatureEnabledSetting;

public RestMLSearchAgentAction() {
public RestMLSearchAgentAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) {
super(ImmutableList.of(SEARCH_AGENT_PATH), ML_AGENT_INDEX, MLAgent.class, MLSearchAgentAction.INSTANCE);
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
}

@Override
public String getName() {
return ML_SEARCH_AGENT_ACTION;
}

@Override
protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException {
if (!mlFeatureEnabledSetting.isAgentFrameworkEnabled()) {
throw new IllegalStateException(AGENT_FRAMEWORK_DISABLED_ERR_MSG);
}

return super.prepareRequest(request, client);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -174,4 +174,8 @@ private MLCommonsSettings() {}
// Feature flag for enabling search processors for Retrieval Augmented Generation using OpenSearch and Remote Inference.
public static final Setting<Boolean> ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED =
GenerativeQAProcessorConstants.RAG_PIPELINE_FEATURE_ENABLED;

// This setting is to enable/disable agent related API register/execute/delete/get/search agent.
public static final Setting<Boolean> ML_COMMONS_AGENT_FRAMEWORK_ENABLED = Setting
.boolSetting("plugins.ml_commons.agent_framework_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic);
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

package org.opensearch.ml.settings;

import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_AGENT_FRAMEWORK_ENABLED;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_INFERENCE_ENABLED;

import org.opensearch.cluster.service.ClusterService;
Expand All @@ -15,13 +16,18 @@
public class MLFeatureEnabledSetting {

private volatile Boolean isRemoteInferenceEnabled;
private volatile Boolean isAgentFrameworkEnabled;

public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings) {
isRemoteInferenceEnabled = ML_COMMONS_REMOTE_INFERENCE_ENABLED.get(settings);
isAgentFrameworkEnabled = ML_COMMONS_AGENT_FRAMEWORK_ENABLED.get(settings);

clusterService
.getClusterSettings()
.addSettingsUpdateConsumer(ML_COMMONS_REMOTE_INFERENCE_ENABLED, it -> isRemoteInferenceEnabled = it);
clusterService
.getClusterSettings()
.addSettingsUpdateConsumer(ML_COMMONS_AGENT_FRAMEWORK_ENABLED, it -> isAgentFrameworkEnabled = it);
}

/**
Expand All @@ -32,4 +38,12 @@ public boolean isRemoteInferenceEnabled() {
return isRemoteInferenceEnabled;
}

/**
* Whether the agent framework feature is enabled. If disabled, APIs in ml-commons will block agent framework.
* @return whether the agent framework is enabled.
*/
public boolean isAgentFrameworkEnabled() {
return isAgentFrameworkEnabled;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ public class MLExceptionUtils {
public static final String NOT_SERIALIZABLE_EXCEPTION_WRAPPER = "NotSerializableExceptionWrapper: ";
public static final String REMOTE_INFERENCE_DISABLED_ERR_MSG =
"Remote Inference is currently disabled. To enable it, update the setting \"plugins.ml_commons.remote_inference_enabled\" to true.";
public static final String AGENT_FRAMEWORK_DISABLED_ERR_MSG =
"Agent Framework is currently disabled. To enable it, update the setting \"plugins.ml_commons.agent_framework_enabled\" to true.";

public static String getRootCauseMessage(final Throwable throwable) {
String message = ExceptionUtils.getRootCauseMessage(throwable);
Expand Down
Loading

0 comments on commit d098f78

Please sign in to comment.