From 76f0f3bc9a08516adefa18feb146a3a69e93aace Mon Sep 17 00:00:00 2001 From: "opensearch-trigger-bot[bot]" <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com> Date: Fri, 24 Jan 2025 09:37:20 -0800 Subject: [PATCH] adding tenantID to the request + undeploy request (#3425) (#3429) Signed-off-by: Dhrubo Saha (cherry picked from commit af96fe0a7a2f8e0571b0ee98bbd4782a1cecfb69) Co-authored-by: Dhrubo Saha --- .../ml/client/MachineLearningClient.java | 28 ++- .../ml/client/MachineLearningNodeClient.java | 7 +- .../ml/client/MachineLearningClientTest.java | 12 +- .../undeploy/MLUndeployModelNodesRequest.java | 13 ++ .../undeploy/MLUndeployModelsRequest.java | 27 ++- .../undeploy/MLUndeployModelsRequestTest.java | 160 ++++++++++++++++++ .../ml/engine/indices/MLIndicesHandler.java | 2 +- .../engine/indices/MLInputDatasetHandler.java | 2 +- .../TransportPredictionTaskAction.java | 150 ++++++++-------- .../TransportRegisterModelAction.java | 2 + .../tasks/DeleteTaskTransportAction.java | 9 +- .../TransportUndeployModelsAction.java | 85 ++++++++-- .../MLCommonsClusterManagerEventListener.java | 12 +- .../opensearch/ml/cluster/MLSyncUpCron.java | 31 ++-- .../ml/plugin/MachineLearningPlugin.java | 10 +- .../ml/rest/RestMLUndeployModelAction.java | 9 +- .../TransportPredictionTaskActionTests.java | 30 +++- .../TransportUndeployModelsActionTests.java | 60 ++++--- .../ml/cluster/MLSyncUpCronTests.java | 12 +- .../rest/RestMLUndeployModelActionTests.java | 8 +- 20 files changed, 517 insertions(+), 152 deletions(-) create mode 100644 common/src/test/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelsRequestTest.java diff --git a/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java b/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java index 32f8fe02fb..6a58cb04f4 100644 --- a/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java +++ b/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java @@ -60,7 +60,19 @@ default ActionFuture predict(String modelId, MLInput mlInput) { * @param mlInput ML input * @param listener a listener to be notified of the result */ - void predict(String modelId, MLInput mlInput, ActionListener listener); + default void predict(String modelId, MLInput mlInput, ActionListener listener) { + predict(modelId, null, mlInput, listener); + } + + /** + * Do prediction machine learning job + * For additional info on Predict, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#predict + * @param modelId the trained model id + * @param tenantId tenant id + * @param mlInput ML input + * @param listener a listener to be notified of the result + */ + void predict(String modelId, String tenantId, MLInput mlInput, ActionListener listener); /** * Train model then predict with the same data set. @@ -352,7 +364,19 @@ default ActionFuture undeploy(String[] modelIds, @Null * @param modelIds the node ids. May be null for all nodes. * @param listener a listener to be notified of the result */ - void undeploy(String[] modelIds, String[] nodeIds, ActionListener listener); + default void undeploy(String[] modelIds, String[] nodeIds, ActionListener listener) { + undeploy(modelIds, nodeIds, null, listener); + } + + /** + * Undeploy model + * For additional info on deploy, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/model-apis/undeploy-model/ + * @param modelIds the model ids + * @param modelIds the node ids. May be null for all nodes. + * @param tenantId the tenant id. This is necessary for multi-tenancy. + * @param listener a listener to be notified of the result + */ + void undeploy(String[] modelIds, String[] nodeIds, String tenantId, ActionListener listener); /** * Create connector for remote model diff --git a/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java b/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java index 68bb5d2875..e86cf5acae 100644 --- a/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java +++ b/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java @@ -101,7 +101,7 @@ public class MachineLearningNodeClient implements MachineLearningClient { Client client; @Override - public void predict(String modelId, MLInput mlInput, ActionListener listener) { + public void predict(String modelId, String tenantId, MLInput mlInput, ActionListener listener) { validateMLInput(mlInput, true); MLPredictionTaskRequest predictionRequest = MLPredictionTaskRequest @@ -109,6 +109,7 @@ public void predict(String modelId, MLInput mlInput, ActionListener li .mlInput(mlInput) .modelId(modelId) .dispatchTask(true) + .tenantId(tenantId) .build(); client.execute(MLPredictionTaskAction.INSTANCE, predictionRequest, getMlPredictionTaskResponseActionListener(listener)); } @@ -262,8 +263,8 @@ public void deploy(String modelId, String tenantId, ActionListener listener) { - MLUndeployModelsRequest undeployModelRequest = new MLUndeployModelsRequest(modelIds, nodeIds); + public void undeploy(String[] modelIds, String[] nodeIds, String tenantId, ActionListener listener) { + MLUndeployModelsRequest undeployModelRequest = new MLUndeployModelsRequest(modelIds, nodeIds, tenantId); client.execute(MLUndeployModelsAction.INSTANCE, undeployModelRequest, getMlUndeployModelsResponseActionListener(listener)); } diff --git a/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java b/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java index 74bc74a58c..d09cc7a95b 100644 --- a/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java +++ b/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java @@ -149,6 +149,11 @@ public void predict(String modelId, MLInput mlInput, ActionListener li listener.onResponse(output); } + @Override + public void predict(String modelId, String tenantId, MLInput mlInput, ActionListener listener) { + listener.onResponse(output); + } + @Override public void trainAndPredict(MLInput mlInput, ActionListener listener) { listener.onResponse(output); @@ -234,6 +239,11 @@ public void undeploy(String[] modelIds, String[] nodeIds, ActionListener listener) { + listener.onResponse(undeployModelsResponse); + } + @Override public void createConnector(MLCreateConnectorInput mlCreateConnectorInput, ActionListener listener) { listener.onResponse(createConnectorResponse); @@ -320,7 +330,7 @@ public void predict_WithAlgoAndParametersAndInputDataAndModelId() { public void predict_WithAlgoAndInputDataAndListener() { MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(new DataFrameInputDataset(input)).build(); ArgumentCaptor dataFrameArgumentCaptor = ArgumentCaptor.forClass(MLOutput.class); - machineLearningClient.predict(null, mlInput, dataFrameActionListener); + machineLearningClient.predict(null, null, mlInput, dataFrameActionListener); verify(dataFrameActionListener).onResponse(dataFrameArgumentCaptor.capture()); assertEquals(output, dataFrameArgumentCaptor.getValue()); } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodesRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodesRequest.java index 48b2bf7c5c..4ef4084a27 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodesRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodesRequest.java @@ -5,23 +5,32 @@ package org.opensearch.ml.common.transport.undeploy; +import static org.opensearch.ml.common.CommonValue.VERSION_2_19_0; + import java.io.IOException; +import org.opensearch.Version; import org.opensearch.action.support.nodes.BaseNodesRequest; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import lombok.Getter; +import lombok.Setter; public class MLUndeployModelNodesRequest extends BaseNodesRequest { @Getter private String[] modelIds; + @Getter + @Setter + private String tenantId; public MLUndeployModelNodesRequest(StreamInput in) throws IOException { super(in); + Version streamInputVersion = in.getVersion(); this.modelIds = in.readOptionalStringArray(); + this.tenantId = streamInputVersion.onOrAfter(VERSION_2_19_0) ? in.readOptionalString() : null; } public MLUndeployModelNodesRequest(String[] nodeIds, String[] modelIds) { @@ -36,7 +45,11 @@ public MLUndeployModelNodesRequest(DiscoveryNode... nodes) { @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); + Version streamOutputVersion = out.getVersion(); out.writeOptionalStringArray(modelIds); + if (streamOutputVersion.onOrAfter(VERSION_2_19_0)) { + out.writeOptionalString(tenantId); + } } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelsRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelsRequest.java index 32fdfced27..65de301edd 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelsRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelsRequest.java @@ -6,6 +6,8 @@ package org.opensearch.ml.common.transport.undeploy; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD; +import static org.opensearch.ml.common.CommonValue.VERSION_2_19_0; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; @@ -14,6 +16,7 @@ import java.util.ArrayList; import java.util.List; +import org.opensearch.Version; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.InputStreamStreamInput; @@ -39,24 +42,28 @@ public class MLUndeployModelsRequest extends MLTaskRequest { private String[] modelIds; private String[] nodeIds; boolean async; + private String tenantId; @Builder - public MLUndeployModelsRequest(String[] modelIds, String[] nodeIds, boolean async, boolean dispatchTask) { + public MLUndeployModelsRequest(String[] modelIds, String[] nodeIds, boolean async, boolean dispatchTask, String tenantId) { super(dispatchTask); this.modelIds = modelIds; this.nodeIds = nodeIds; this.async = async; + this.tenantId = tenantId; } - public MLUndeployModelsRequest(String[] modelIds, String[] nodeIds) { - this(modelIds, nodeIds, false, false); + public MLUndeployModelsRequest(String[] modelIds, String[] nodeIds, String tenantId) { + this(modelIds, nodeIds, false, false, tenantId); } public MLUndeployModelsRequest(StreamInput in) throws IOException { super(in); + Version streamInputVersion = in.getVersion(); this.modelIds = in.readOptionalStringArray(); this.nodeIds = in.readOptionalStringArray(); this.async = in.readBoolean(); + this.tenantId = streamInputVersion.onOrAfter(VERSION_2_19_0) ? in.readOptionalString() : null; } @Override @@ -68,15 +75,20 @@ public ActionRequestValidationException validate() { @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); + Version streamOutputVersion = out.getVersion(); out.writeOptionalStringArray(modelIds); out.writeOptionalStringArray(nodeIds); out.writeBoolean(async); + if (streamOutputVersion.onOrAfter(VERSION_2_19_0)) { + out.writeOptionalString(tenantId); + } } public static MLUndeployModelsRequest parse(XContentParser parser, String modelId) throws IOException { ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); List modelIdList = new ArrayList<>(); List nodeIdList = new ArrayList<>(); + String tenantId = null; while (parser.nextToken() != XContentParser.Token.END_OBJECT) { String fieldName = parser.currentName(); parser.nextToken(); @@ -94,14 +106,17 @@ public static MLUndeployModelsRequest parse(XContentParser parser, String modelI nodeIdList.add(parser.text()); } break; + case TENANT_ID_FIELD: + tenantId = parser.textOrNull(); + break; default: parser.skipChildren(); break; } } - String[] modelIds = modelIdList == null ? null : modelIdList.toArray(new String[0]); - String[] nodeIds = nodeIdList == null ? null : nodeIdList.toArray(new String[0]); - return new MLUndeployModelsRequest(modelIds, nodeIds, false, true); + String[] modelIds = modelIdList.toArray(new String[0]); + String[] nodeIds = nodeIdList.toArray(new String[0]); + return new MLUndeployModelsRequest(modelIds, nodeIds, false, true, tenantId); } public static MLUndeployModelsRequest fromActionRequest(ActionRequest actionRequest) { diff --git a/common/src/test/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelsRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelsRequestTest.java new file mode 100644 index 0000000000..6efce9661e --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelsRequestTest.java @@ -0,0 +1,160 @@ +package org.opensearch.ml.common.transport.undeploy; + +import static org.junit.Assert.*; +import static org.opensearch.ml.common.CommonValue.VERSION_2_18_0; +import static org.opensearch.ml.common.CommonValue.VERSION_2_19_0; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Collections; +import java.util.function.Consumer; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.search.SearchModule; + +public class MLUndeployModelsRequestTest { + + private MLUndeployModelsRequest mlUndeployModelsRequest; + + @Before + public void setUp() { + mlUndeployModelsRequest = MLUndeployModelsRequest + .builder() + .modelIds(new String[] { "model1", "model2" }) + .nodeIds(new String[] { "node1", "node2" }) + .async(true) + .dispatchTask(true) + .tenantId("tenant1") + .build(); + } + + @Test + public void testValidate() { + MLUndeployModelsRequest request = MLUndeployModelsRequest.builder().modelIds(new String[] { "model1" }).build(); + assertNull(request.validate()); + } + + @Test + public void testStreamInputVersionBefore_2_19_0() throws IOException { + BytesStreamOutput out = new BytesStreamOutput(); + out.setVersion(VERSION_2_18_0); + mlUndeployModelsRequest.writeTo(out); + + StreamInput in = out.bytes().streamInput(); + in.setVersion(VERSION_2_18_0); + MLUndeployModelsRequest request = new MLUndeployModelsRequest(in); + + assertArrayEquals(mlUndeployModelsRequest.getModelIds(), request.getModelIds()); + assertArrayEquals(mlUndeployModelsRequest.getNodeIds(), request.getNodeIds()); + assertEquals(mlUndeployModelsRequest.isAsync(), request.isAsync()); + assertEquals(mlUndeployModelsRequest.isDispatchTask(), request.isDispatchTask()); + assertNull(request.getTenantId()); + } + + @Test + public void testStreamInputVersionAfter_2_19_0() throws IOException { + BytesStreamOutput out = new BytesStreamOutput(); + out.setVersion(VERSION_2_19_0); + mlUndeployModelsRequest.writeTo(out); + + StreamInput in = out.bytes().streamInput(); + in.setVersion(VERSION_2_19_0); + MLUndeployModelsRequest request = new MLUndeployModelsRequest(in); + + assertArrayEquals(mlUndeployModelsRequest.getModelIds(), request.getModelIds()); + assertArrayEquals(mlUndeployModelsRequest.getNodeIds(), request.getNodeIds()); + assertEquals(mlUndeployModelsRequest.isAsync(), request.isAsync()); + assertEquals(mlUndeployModelsRequest.isDispatchTask(), request.isDispatchTask()); + assertEquals(mlUndeployModelsRequest.getTenantId(), request.getTenantId()); + } + + @Test + public void testWriteToWithNullFields() throws IOException { + MLUndeployModelsRequest request = MLUndeployModelsRequest + .builder() + .modelIds(null) + .nodeIds(null) + .async(true) + .dispatchTask(true) + .build(); + + BytesStreamOutput out = new BytesStreamOutput(); + out.setVersion(VERSION_2_19_0); + request.writeTo(out); + + StreamInput in = out.bytes().streamInput(); + in.setVersion(VERSION_2_19_0); + MLUndeployModelsRequest result = new MLUndeployModelsRequest(in); + + assertNull(result.getModelIds()); + assertNull(result.getNodeIds()); + assertEquals(request.isAsync(), result.isAsync()); + assertEquals(request.isDispatchTask(), result.isDispatchTask()); + } + + @Test(expected = UncheckedIOException.class) + public void fromActionRequest_IOException() { + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IOException("test"); + } + }; + MLUndeployModelsRequest.fromActionRequest(actionRequest); + } + + @Test + public void fromActionRequest_Success_WithMLUndeployModelsRequest() { + MLUndeployModelsRequest request = MLUndeployModelsRequest.builder().modelIds(new String[] { "model1" }).build(); + assertSame(MLUndeployModelsRequest.fromActionRequest(request), request); + } + + @Test + public void testParse() throws Exception { + String expectedInputStr = "{\"model_ids\":[\"model1\"],\"node_ids\":[\"node1\"]}"; + parseFromJsonString(expectedInputStr, parsedInput -> { + assertArrayEquals(new String[] { "model1" }, parsedInput.getModelIds()); + assertArrayEquals(new String[] { "node1" }, parsedInput.getNodeIds()); + assertFalse(parsedInput.isAsync()); + assertTrue(parsedInput.isDispatchTask()); + }); + } + + @Test + public void testParseWithInvalidField() throws Exception { + String withInvalidFieldInputStr = "{\"invalid_field\":\"void\",\"model_ids\":[\"model1\"],\"node_ids\":[\"node1\"]}"; + parseFromJsonString(withInvalidFieldInputStr, parsedInput -> { + assertArrayEquals(new String[] { "model1" }, parsedInput.getModelIds()); + assertArrayEquals(new String[] { "node1" }, parsedInput.getNodeIds()); + }); + } + + private void parseFromJsonString(String expectedInputStr, Consumer verify) throws Exception { + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + LoggingDeprecationHandler.INSTANCE, + expectedInputStr + ); + parser.nextToken(); + MLUndeployModelsRequest parsedInput = MLUndeployModelsRequest.parse(parser, null); + verify.accept(parsedInput); + } +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndicesHandler.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndicesHandler.java index 9dd4cb25b3..520423ec41 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndicesHandler.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndicesHandler.java @@ -190,7 +190,7 @@ public void initMLIndexIfAbsent(MLIndex index, ActionListener listener) */ public void shouldUpdateIndex(String indexName, Integer newVersion, ActionListener listener) { IndexMetadata indexMetaData = clusterService.state().getMetadata().indices().get(indexName); - if (indexMetaData == null) { + if (indexMetaData == null || indexMetaData.mapping() == null) { listener.onResponse(Boolean.FALSE); return; } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLInputDatasetHandler.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLInputDatasetHandler.java index 452f836357..e94fd20cd0 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLInputDatasetHandler.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLInputDatasetHandler.java @@ -68,7 +68,7 @@ public void parseSearchQueryInput(MLInputDataset mlInputDataset, ActionListener< listener.onResponse(dfInputDataset); return; }, e -> { - log.error("Failed to search" + e); + log.error("Failed to search{}", e); listener.onFailure(e); })); return; diff --git a/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java b/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java index a0e5018ad4..59a8c6aeea 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java @@ -39,6 +39,8 @@ import org.opensearch.ml.task.MLTaskRunner; import org.opensearch.ml.utils.MLNodeUtils; import org.opensearch.ml.utils.RestActionUtils; +import org.opensearch.ml.utils.TenantAwareHelper; +import org.opensearch.remote.metadata.client.SdkClient; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -54,6 +56,7 @@ public class TransportPredictionTaskAction extends HandledTransportAction listener) { MLPredictionTaskRequest mlPredictionTaskRequest = MLPredictionTaskRequest.fromActionRequest(request); String modelId = mlPredictionTaskRequest.getModelId(); - + String tenantId = mlPredictionTaskRequest.getTenantId(); + if (!TenantAwareHelper.validateTenantId(mlFeatureEnabledSetting, tenantId, listener)) { + return; + } User user = mlPredictionTaskRequest.getUser(); if (user == null) { user = RestActionUtils.getUserContext(client); @@ -110,7 +118,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener wrappedListener = ActionListener.runBefore(listener, () -> context.restore()); + ActionListener wrappedListener = ActionListener.runBefore(listener, context::restore); MLModel cachedMlModel = modelCacheHelper.getModelInfo(modelId); ActionListener modelActionListener = new ActionListener<>() { @Override @@ -123,78 +131,88 @@ public void onResponse(MLModel mlModel) { } mlPredictionTaskRequest.getMlInput().setAlgorithm(functionName); modelAccessControlHelper - .validateModelGroupAccess(userInfo, mlModel.getModelGroupId(), client, ActionListener.wrap(access -> { - if (!access) { - wrappedListener - .onFailure( - new OpenSearchStatusException( - "User Doesn't have privilege to perform this operation on this model", - RestStatus.FORBIDDEN - ) - ); - } else { - if (modelCacheHelper.getIsModelEnabled(modelId) != null && !modelCacheHelper.getIsModelEnabled(modelId)) { - wrappedListener.onFailure(new OpenSearchStatusException("Model is disabled.", RestStatus.FORBIDDEN)); + .validateModelGroupAccess( + userInfo, + mlFeatureEnabledSetting, + tenantId, + mlModel.getModelGroupId(), + client, + sdkClient, + ActionListener.wrap(access -> { + if (!access) { + wrappedListener + .onFailure( + new OpenSearchStatusException( + "User Doesn't have privilege to perform this operation on this model", + RestStatus.FORBIDDEN + ) + ); } else { - if (FunctionName.isDLModel(functionName)) { - if (modelCacheHelper.getRateLimiter(modelId) != null - && !modelCacheHelper.getRateLimiter(modelId).request()) { - wrappedListener - .onFailure( - new OpenSearchStatusException( - "Request is throttled at model level.", - RestStatus.TOO_MANY_REQUESTS - ) - ); - } else if (userInfo != null - && modelCacheHelper.getUserRateLimiter(modelId, userInfo.getName()) != null - && !modelCacheHelper.getUserRateLimiter(modelId, userInfo.getName()).request()) { - wrappedListener - .onFailure( - new OpenSearchStatusException( - "Request is throttled at user level. If you think there's an issue, please contact your cluster admin.", - RestStatus.TOO_MANY_REQUESTS - ) - ); + if (modelCacheHelper.getIsModelEnabled(modelId) != null + && !modelCacheHelper.getIsModelEnabled(modelId)) { + wrappedListener + .onFailure(new OpenSearchStatusException("Model is disabled.", RestStatus.FORBIDDEN)); + } else { + if (FunctionName.isDLModel(functionName)) { + if (modelCacheHelper.getRateLimiter(modelId) != null + && !modelCacheHelper.getRateLimiter(modelId).request()) { + wrappedListener + .onFailure( + new OpenSearchStatusException( + "Request is throttled at model level.", + RestStatus.TOO_MANY_REQUESTS + ) + ); + } else if (userInfo != null + && modelCacheHelper.getUserRateLimiter(modelId, userInfo.getName()) != null + && !modelCacheHelper.getUserRateLimiter(modelId, userInfo.getName()).request()) { + wrappedListener + .onFailure( + new OpenSearchStatusException( + "Request is throttled at user level. If you think there's an issue, please contact your cluster admin.", + RestStatus.TOO_MANY_REQUESTS + ) + ); + } else { + validateInputSchema(modelId, mlPredictionTaskRequest.getMlInput()); + executePredict(mlPredictionTaskRequest, wrappedListener, modelId); + } } else { validateInputSchema(modelId, mlPredictionTaskRequest.getMlInput()); executePredict(mlPredictionTaskRequest, wrappedListener, modelId); } - } else { - validateInputSchema(modelId, mlPredictionTaskRequest.getMlInput()); - executePredict(mlPredictionTaskRequest, wrappedListener, modelId); } } - } - }, e -> { - log.error("Failed to Validate Access for ModelId " + modelId, e); - if (e instanceof OpenSearchStatusException) { - wrappedListener - .onFailure( - new OpenSearchStatusException( - e.getMessage(), - RestStatus.fromCode(((OpenSearchStatusException) e).status().getStatus()) - ) - ); - } else if (e instanceof MLResourceNotFoundException) { - wrappedListener.onFailure(new OpenSearchStatusException(e.getMessage(), RestStatus.NOT_FOUND)); - } else if (e instanceof CircuitBreakingException) { - wrappedListener.onFailure(e); - } else { - wrappedListener - .onFailure( - new OpenSearchStatusException( - "Failed to Validate Access for ModelId " + modelId, - RestStatus.FORBIDDEN - ) - ); - } - })); + }, e -> { + log.error("Failed to Validate Access for ModelId {}", modelId, e); + if (e instanceof OpenSearchStatusException) { + wrappedListener + .onFailure( + new OpenSearchStatusException( + e.getMessage(), + RestStatus.fromCode(((OpenSearchStatusException) e).status().getStatus()) + ) + ); + } else if (e instanceof MLResourceNotFoundException) { + wrappedListener.onFailure(new OpenSearchStatusException(e.getMessage(), RestStatus.NOT_FOUND)); + } else if (e instanceof CircuitBreakingException) { + wrappedListener.onFailure(e); + } else { + wrappedListener + .onFailure( + new OpenSearchStatusException( + "Failed to Validate Access for ModelId " + modelId, + RestStatus.FORBIDDEN + ) + ); + } + }) + ); } @Override public void onFailure(Exception e) { - log.error("Failed to find model " + modelId, e); + log.error("Failed to find model {}", modelId, e); wrappedListener.onFailure(e); } }; @@ -203,7 +221,7 @@ public void onFailure(Exception e) { modelActionListener.onResponse(cachedMlModel); } else { // For multi-node cluster, the function name is null in cache, so should always get model first. - mlModelManager.getModel(modelId, modelActionListener); + mlModelManager.getModel(modelId, tenantId, modelActionListener); } } } @@ -214,7 +232,7 @@ private void executePredict( String modelId ) { String requestId = mlPredictionTaskRequest.getRequestID(); - log.debug("receive predict request " + requestId + " for model " + mlPredictionTaskRequest.getModelId()); + log.debug("receive predict request {} for model {}", requestId, mlPredictionTaskRequest.getModelId()); long startTime = System.nanoTime(); // For remote text embedding model, neural search will set mlPredictionTaskRequest.getMlInput().getAlgorithm() as // TEXT_EMBEDDING. In ml-commons we should always use the real function name of model: REMOTE. So we try to get @@ -233,7 +251,7 @@ private void executePredict( double durationInMs = (endTime - startTime) / 1e6; modelCacheHelper.addPredictRequestDuration(modelId, durationInMs); modelCacheHelper.refreshLastAccessTime(modelId); - log.debug("completed predict request " + requestId + " for model " + modelId); + log.debug("completed predict request {} for model {}", requestId, modelId); }) ); } diff --git a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java index 0e2b02865c..77feac69b9 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java @@ -366,6 +366,7 @@ private void registerModel(MLRegisterModelInput registerModelInput, ActionListen .lastUpdateTime(Instant.now()) .state(MLTaskState.CREATED) .workerNodes(ImmutableList.of(clusterService.localNode().getId())) + .tenantId(registerModelInput.getTenantId()) .build(); if (!isAsync) { @@ -441,6 +442,7 @@ private MLRegisterModelGroupInput createRegisterModelGroupRequest(MLRegisterMode .backendRoles(registerModelInput.getBackendRoles()) .modelAccessMode(registerModelInput.getAccessMode()) .isAddAllBackendRoles(registerModelInput.getAddAllBackendRoles()) + .tenantId(registerModelInput.getTenantId()) .build(); } } diff --git a/plugin/src/main/java/org/opensearch/ml/action/tasks/DeleteTaskTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/tasks/DeleteTaskTransportAction.java index 5128d0b04c..a7d3545f20 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/tasks/DeleteTaskTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/tasks/DeleteTaskTransportAction.java @@ -87,6 +87,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener actionListener) { + private void executeDelete(String taskId, String tenantId, ActionListener actionListener) { DeleteRequest deleteRequest = new DeleteRequest(ML_TASK_INDEX, taskId); try { sdkClient - .deleteDataObjectAsync(DeleteDataObjectRequest.builder().index(deleteRequest.index()).id(deleteRequest.id()).build()) + .deleteDataObjectAsync( + DeleteDataObjectRequest.builder().index(deleteRequest.index()).id(deleteRequest.id()).tenantId(tenantId).build() + ) .whenComplete((deleteDataObjectResponse, throwable) -> { if (throwable != null) { handleDeleteError(throwable, taskId, actionListener); diff --git a/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java b/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java index 579fa51a38..d64f817fa4 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java @@ -11,6 +11,7 @@ import java.util.List; import java.util.stream.Collectors; +import org.opensearch.ExceptionsHelper; import org.opensearch.OpenSearchStatusException; import org.opensearch.action.ActionRequest; import org.opensearch.action.search.SearchRequest; @@ -41,9 +42,14 @@ import org.opensearch.ml.engine.ModelHelper; import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.task.MLTaskDispatcher; import org.opensearch.ml.task.MLTaskManager; import org.opensearch.ml.utils.RestActionUtils; +import org.opensearch.ml.utils.TenantAwareHelper; +import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.remote.metadata.client.SearchDataObjectRequest; +import org.opensearch.remote.metadata.common.SdkClientUtils; import org.opensearch.search.SearchHit; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.tasks.Task; @@ -62,6 +68,7 @@ public class TransportUndeployModelsAction extends HandledTransportAction listener) { MLUndeployModelsRequest undeployModelsRequest = MLUndeployModelsRequest.fromActionRequest(request); String[] modelIds = undeployModelsRequest.getModelIds(); + String tenantId = undeployModelsRequest.getTenantId(); String[] targetNodeIds = undeployModelsRequest.getNodeIds(); + if (!TenantAwareHelper.validateTenantId(mlFeatureEnabledSetting, tenantId, listener)) { + return; + } + if (modelIds == null) { listener.onFailure(new IllegalArgumentException("Must set specific model ids to undeploy")); return; } if (modelIds.length == 1) { String modelId = modelIds[0]; - validateAccess(modelId, ActionListener.wrap(hasPermissionToUndeploy -> { + validateAccess(modelId, tenantId, ActionListener.wrap(hasPermissionToUndeploy -> { if (hasPermissionToUndeploy) { - undeployModels(targetNodeIds, modelIds, listener); + undeployModels(targetNodeIds, modelIds, tenantId, listener); } else { listener.onFailure(new IllegalArgumentException("No permission to undeploy model " + modelId)); } @@ -141,9 +158,9 @@ protected void doExecute(Task task, ActionRequest request, ActionListener !hiddenModelIds.contains(modelId)) .toArray(String[]::new); - undeployModels(targetNodeIds, modelsIDsToUndeploy, listener); + undeployModels(targetNodeIds, modelsIDsToUndeploy, tenantId, listener); } else { - undeployModels(targetNodeIds, modelIds, listener); + undeployModels(targetNodeIds, modelIds, tenantId, listener); } }, e -> { log.error("Failed to search model index", e); @@ -153,20 +170,29 @@ protected void doExecute(Task task, ActionRequest request, ActionListener listener) { + private void undeployModels( + String[] targetNodeIds, + String[] modelIds, + String tenantId, + ActionListener listener + ) { MLUndeployModelNodesRequest mlUndeployModelNodesRequest = new MLUndeployModelNodesRequest(targetNodeIds, modelIds); + mlUndeployModelNodesRequest.setTenantId(tenantId); client.execute(MLUndeployModelAction.INSTANCE, mlUndeployModelNodesRequest, ActionListener.wrap(r -> { listener.onResponse(new MLUndeployModelsResponse(r)); }, listener::onFailure)); } - private void validateAccess(String modelId, ActionListener listener) { + private void validateAccess(String modelId, String tenantId, ActionListener listener) { User user = RestActionUtils.getUserContext(client); boolean isSuperAdmin = isSuperAdminUserWrapper(clusterService, client); String[] excludes = new String[] { MLModel.MODEL_CONTENT_FIELD, MLModel.OLD_MODEL_CONTENT_FIELD }; try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - mlModelManager.getModel(modelId, null, excludes, ActionListener.runBefore(ActionListener.wrap(mlModel -> { + mlModelManager.getModel(modelId, tenantId, null, excludes, ActionListener.runBefore(ActionListener.wrap(mlModel -> { + if (!TenantAwareHelper.validateTenantResource(mlFeatureEnabledSetting, tenantId, mlModel.getTenantId(), listener)) { + return; + } Boolean isHidden = mlModel.getIsHidden(); if (isHidden != null && isHidden) { if (isSuperAdmin) { @@ -181,7 +207,16 @@ private void validateAccess(String modelId, ActionListener listener) { ); } } else { - modelAccessControlHelper.validateModelGroupAccess(user, mlModel.getModelGroupId(), client, listener); + modelAccessControlHelper + .validateModelGroupAccess( + user, + mlFeatureEnabledSetting, + tenantId, + mlModel.getModelGroupId(), + client, + sdkClient, + listener + ); } }, e -> { log.error("Failed to find Model", e); @@ -215,14 +250,34 @@ public void searchHiddenModels(String[] modelIds, ActionListener SearchRequest searchRequest = new SearchRequest(ML_MODEL_INDEX).source(searchSourceBuilder); - client.search(searchRequest, ActionListener.runBefore(ActionListener.wrap(models -> { listener.onResponse(models); }, e -> { - if (e instanceof IndexNotFoundException) { - listener.onResponse(null); + SearchDataObjectRequest searchDataObjectRequest = SearchDataObjectRequest + .builder() + .indices(searchRequest.indices()) + .searchSourceBuilder(searchRequest.source()) + .build(); + + sdkClient.searchDataObjectAsync(searchDataObjectRequest).whenComplete((r, throwable) -> { + context.restore(); + if (throwable != null) { + Exception cause = SdkClientUtils.unwrapAndConvertToException(throwable); + log.error("Failed to search model index", cause); + if (ExceptionsHelper.unwrap(cause, IndexNotFoundException.class) != null) { + listener.onResponse(null); + } else { + listener.onFailure(cause); + } } else { - log.error("Failed to search model index", e); - listener.onFailure(e); + try { + SearchResponse searchResponse = SearchResponse.fromXContent(r.parser()); + log.info("Model Index search complete: {}", searchResponse.getHits().getTotalHits()); + listener.onResponse(searchResponse); + } catch (Exception e) { + log.error("Failed to parse search response", e); + listener + .onFailure(new OpenSearchStatusException("Failed to parse search response", RestStatus.INTERNAL_SERVER_ERROR)); + } } - }), () -> context.restore())); + }); } catch (Exception e) { log.error("Failed to search model index", e); listener.onFailure(e); diff --git a/plugin/src/main/java/org/opensearch/ml/cluster/MLCommonsClusterManagerEventListener.java b/plugin/src/main/java/org/opensearch/ml/cluster/MLCommonsClusterManagerEventListener.java index 0cf4215a23..a14d49dd59 100644 --- a/plugin/src/main/java/org/opensearch/ml/cluster/MLCommonsClusterManagerEventListener.java +++ b/plugin/src/main/java/org/opensearch/ml/cluster/MLCommonsClusterManagerEventListener.java @@ -20,6 +20,8 @@ import org.opensearch.ml.autoredeploy.MLModelAutoReDeployer; import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.indices.MLIndicesHandler; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; +import org.opensearch.remote.metadata.client.SdkClient; import org.opensearch.threadpool.Scheduler; import org.opensearch.threadpool.ThreadPool; @@ -30,6 +32,7 @@ public class MLCommonsClusterManagerEventListener implements LocalNodeClusterMan private final ClusterService clusterService; private Client client; + private final SdkClient sdkClient; private ThreadPool threadPool; private Scheduler.Cancellable syncModelRoutingCron; @@ -40,25 +43,30 @@ public class MLCommonsClusterManagerEventListener implements LocalNodeClusterMan private volatile Integer jobInterval; private final MLModelAutoReDeployer mlModelAutoReDeployer; + private final MLFeatureEnabledSetting mlFeatureEnabledSetting; public MLCommonsClusterManagerEventListener( ClusterService clusterService, Client client, + SdkClient sdkClient, Settings settings, ThreadPool threadPool, DiscoveryNodeHelper nodeHelper, MLIndicesHandler mlIndicesHandler, Encryptor encryptor, - MLModelAutoReDeployer modelAutoReDeployer + MLModelAutoReDeployer modelAutoReDeployer, + MLFeatureEnabledSetting mlFeatureEnabledSetting ) { this.clusterService = clusterService; this.client = client; + this.sdkClient = sdkClient; this.threadPool = threadPool; this.clusterService.addListener(this); this.nodeHelper = nodeHelper; this.mlIndicesHandler = mlIndicesHandler; this.encryptor = encryptor; this.mlModelAutoReDeployer = modelAutoReDeployer; + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; this.jobInterval = ML_COMMONS_SYNC_UP_JOB_INTERVAL_IN_SECONDS.get(settings); clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_SYNC_UP_JOB_INTERVAL_IN_SECONDS, it -> { @@ -94,7 +102,7 @@ private void startSyncModelRoutingCron() { log.info("Starting ML sync up job..."); syncModelRoutingCron = threadPool .scheduleWithFixedDelay( - new MLSyncUpCron(client, clusterService, nodeHelper, mlIndicesHandler, encryptor), + new MLSyncUpCron(client, sdkClient, clusterService, nodeHelper, mlIndicesHandler, encryptor, mlFeatureEnabledSetting), TimeValue.timeValueSeconds(jobInterval), GENERAL_THREAD_POOL ); diff --git a/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java b/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java index a165dda429..533ae18899 100644 --- a/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java +++ b/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java @@ -48,6 +48,8 @@ import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsRequest; import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.indices.MLIndicesHandler; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; +import org.opensearch.remote.metadata.client.SdkClient; import org.opensearch.search.SearchHit; import org.opensearch.search.builder.SearchSourceBuilder; @@ -61,28 +63,34 @@ public class MLSyncUpCron implements Runnable { public static final int DEPLOY_MODEL_TASK_GRACE_TIME_IN_MS = 20_000; private Client client; + private final SdkClient sdkClient; private ClusterService clusterService; private DiscoveryNodeHelper nodeHelper; private MLIndicesHandler mlIndicesHandler; private Encryptor encryptor; private volatile Boolean mlConfigInited; + private final MLFeatureEnabledSetting mlFeatureEnabledSetting; @VisibleForTesting Semaphore updateModelStateSemaphore; public MLSyncUpCron( Client client, + SdkClient sdkClient, ClusterService clusterService, DiscoveryNodeHelper nodeHelper, MLIndicesHandler mlIndicesHandler, - Encryptor encryptor + Encryptor encryptor, + MLFeatureEnabledSetting mlFeatureEnabledSetting ) { this.client = client; + this.sdkClient = sdkClient; this.clusterService = clusterService; this.nodeHelper = nodeHelper; this.mlIndicesHandler = mlIndicesHandler; this.updateModelStateSemaphore = new Semaphore(1); this.mlConfigInited = false; this.encryptor = encryptor; + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; } @Override @@ -100,7 +108,7 @@ public void run() { // gather running model/tasks on nodes client.execute(MLSyncUpAction.INSTANCE, gatherInfoRequest, ActionListener.wrap(r -> { List responses = r.getNodes(); - if (r.failures() != null && r.failures().size() != 0) { + if (r.failures() != null && !r.failures().isEmpty()) { log .debug( "Received {} failures in the sync up response on nodes. Error messages are {}", @@ -126,14 +134,14 @@ public void run() { } String[] deployedModelIds = response.getDeployedModelIds(); - if (deployedModelIds != null && deployedModelIds.length > 0) { + if (deployedModelIds != null) { for (String modelId : deployedModelIds) { Set workerNodes = modelWorkerNodes.computeIfAbsent(modelId, it -> new HashSet<>()); workerNodes.add(nodeId); } } String[] runningModelIds = response.getRunningDeployModelIds(); - if (runningModelIds != null && runningModelIds.length > 0) { + if (runningModelIds != null) { for (String modelId : runningModelIds) { Set workerNodes = deployingModels.computeIfAbsent(modelId, it -> new HashSet<>()); workerNodes.add(nodeId); @@ -141,7 +149,7 @@ public void run() { } String[] runningDeployModelTaskIds = response.getRunningDeployModelTaskIds(); - if (runningDeployModelTaskIds != null && runningDeployModelTaskIds.length > 0) { + if (runningDeployModelTaskIds != null) { for (String taskId : runningDeployModelTaskIds) { Set workerNodes = runningDeployModelTasks.computeIfAbsent(taskId, it -> new HashSet<>()); workerNodes.add(nodeId); @@ -169,7 +177,7 @@ public void run() { .builder() .syncRunningDeployModelTasks(true) .runningDeployModelTasks(runningDeployModelTasks); - if (modelWorkerNodes.size() == 0) { + if (modelWorkerNodes.isEmpty()) { log.debug("No deployed model found. Will clear model routing on all nodes"); inputBuilder.clearRoutingTable(true); } else { @@ -205,12 +213,13 @@ private void undeployExpiredModels( String[] targetNodeIds = getAllNodes(clusterService); MLUndeployModelsRequest mlUndeployModelsRequest = new MLUndeployModelsRequest( expiredModels.toArray(new String[expiredModels.size()]), - targetNodeIds + targetNodeIds, + null ); client.execute(MLUndeployModelsAction.INSTANCE, mlUndeployModelsRequest, ActionListener.wrap(r -> { MLUndeployModelNodesResponse mlUndeployModelNodesResponse = r.getResponse(); - if (mlUndeployModelNodesResponse.failures() != null && mlUndeployModelNodesResponse.failures().size() != 0) { + if (mlUndeployModelNodesResponse.failures() != null && !mlUndeployModelNodesResponse.failures().isEmpty()) { log.debug("Received failures in undeploying expired models", mlUndeployModelNodesResponse.failures()); } @@ -226,7 +235,7 @@ private void undeployExpiredModels( @VisibleForTesting void initMLConfig() { - if (mlConfigInited) { + if (mlConfigInited || mlFeatureEnabledSetting.isMultiTenancyEnabled()) { return; } mlIndicesHandler.initMLConfigIndex(ActionListener.wrap(r -> { @@ -370,7 +379,7 @@ private MLModelState getNewModelState( int currentWorkerNodeCountInIndex ) { Set deployModelTaskNodes = deployingModels.get(modelId); - if (deployModelTaskNodes != null && deployModelTaskNodes.size() > 0 && state != MLModelState.DEPLOYING) { + if (deployModelTaskNodes != null && !deployModelTaskNodes.isEmpty() && state != MLModelState.DEPLOYING) { // If some node/nodes are deploying the model and model state is not DEPLOYING, then set model state as DEPLOYING. return MLModelState.DEPLOYING; } @@ -418,7 +427,7 @@ private void bulkUpdateModelState( updatedModelIds.addAll(newModelStates.keySet()); updatedModelIds.addAll(newPlanningWorkNodes.keySet()); - if (updatedModelIds.size() > 0) { + if (!updatedModelIds.isEmpty()) { BulkRequest bulkUpdateRequest = new BulkRequest(); for (String modelId : updatedModelIds) { UpdateRequest updateRequest = new UpdateRequest(); diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index df3d52232e..89cc9b07ab 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -702,12 +702,14 @@ public Collection createComponents( MLCommonsClusterManagerEventListener clusterManagerEventListener = new MLCommonsClusterManagerEventListener( clusterService, client, + sdkClient, settings, threadPool, nodeHelper, mlIndicesHandler, encryptor, - mlModelAutoRedeployer + mlModelAutoRedeployer, + mlFeatureEnabledSetting ); // TODO move this into MLFeatureEnabledSetting @@ -778,7 +780,11 @@ public List getRestHandlers( ); RestMLRegisterAgentAction restMLRegisterAgentAction = new RestMLRegisterAgentAction(mlFeatureEnabledSetting); RestMLDeployModelAction restMLDeployModelAction = new RestMLDeployModelAction(mlFeatureEnabledSetting); - RestMLUndeployModelAction restMLUndeployModelAction = new RestMLUndeployModelAction(clusterService, settings); + RestMLUndeployModelAction restMLUndeployModelAction = new RestMLUndeployModelAction( + clusterService, + settings, + mlFeatureEnabledSetting + ); RestMLRegisterModelMetaAction restMLRegisterModelMetaAction = new RestMLRegisterModelMetaAction(clusterService, settings); RestMLUploadModelChunkAction restMLUploadModelChunkAction = new RestMLUploadModelChunkAction(clusterService, settings); RestMLRegisterModelGroupAction restMLCreateModelGroupAction = new RestMLRegisterModelGroupAction(mlFeatureEnabledSetting); diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLUndeployModelAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLUndeployModelAction.java index 0cc30752df..c66865e484 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLUndeployModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLUndeployModelAction.java @@ -10,6 +10,7 @@ import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID; import static org.opensearch.ml.utils.RestActionUtils.getAllNodes; +import static org.opensearch.ml.utils.TenantAwareHelper.getTenantID; import java.io.IOException; import java.util.List; @@ -23,6 +24,7 @@ import org.opensearch.ml.common.transport.undeploy.MLUndeployModelInput; import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsAction; import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsRequest; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.RestRequest; import org.opensearch.rest.action.RestToXContentListener; @@ -36,14 +38,16 @@ public class RestMLUndeployModelAction extends BaseRestHandler { private Settings settings; private boolean allowCustomDeploymentPlan; + private final MLFeatureEnabledSetting mlFeatureEnabledSetting; /** * Constructor */ - public RestMLUndeployModelAction(ClusterService clusterService, Settings settings) { + public RestMLUndeployModelAction(ClusterService clusterService, Settings settings, MLFeatureEnabledSetting mlFeatureEnabledSetting) { this.clusterService = clusterService; this.settings = settings; this.allowCustomDeploymentPlan = ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN.get(settings); + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; clusterService .getClusterSettings() @@ -82,6 +86,7 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client MLUndeployModelsRequest getRequest(RestRequest request) throws IOException { String modelId = request.param(PARAMETER_MODEL_ID); + String tenantId = getTenantID(mlFeatureEnabledSetting.isMultiTenancyEnabled(), request); String[] targetModelIds = null; if (modelId != null) { targetModelIds = new String[] { modelId }; @@ -109,6 +114,6 @@ MLUndeployModelsRequest getRequest(RestRequest request) throws IOException { targetNodeIds = getAllNodes(clusterService); } - return new MLUndeployModelsRequest(targetModelIds, targetNodeIds); + return new MLUndeployModelsRequest(targetModelIds, targetNodeIds, tenantId); } } diff --git a/plugin/src/test/java/org/opensearch/ml/action/prediction/TransportPredictionTaskActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/prediction/TransportPredictionTaskActionTests.java index baaf2cec05..d7f371ca88 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/prediction/TransportPredictionTaskActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/prediction/TransportPredictionTaskActionTests.java @@ -22,6 +22,7 @@ import org.junit.Before; import org.junit.Rule; +import org.junit.Test; import org.junit.rules.ExpectedException; import org.mockito.ArgumentCaptor; import org.mockito.Mock; @@ -55,6 +56,8 @@ import org.opensearch.ml.model.MLModelManager; import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.task.MLPredictTaskRunner; +import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.remote.metadata.client.impl.SdkClientFactory; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; @@ -71,6 +74,7 @@ public class TransportPredictionTaskActionTests extends OpenSearchTestCase { @Mock private Client client; + SdkClient sdkClient; @Mock private ClusterService clusterService; @@ -128,7 +132,7 @@ public void setup() { .build(); mlPredictionTaskRequest = MLPredictionTaskRequest.builder().modelId("test_id").mlInput(mlInput).user(user).build(); - + sdkClient = SdkClientFactory.createSdkClient(client, NamedXContentRegistry.EMPTY, Collections.emptyMap()); Settings settings = Settings.builder().put(ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE.getKey(), true).build(); ClusterSettings clusterSettings = new ClusterSettings(settings, new HashSet<>(Arrays.asList(ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE))); @@ -147,6 +151,7 @@ public void setup() { mlPredictTaskRunner, clusterService, client, + sdkClient, xContentRegistry, mlModelManager, modelAccessControlHelper, @@ -156,15 +161,16 @@ public void setup() { ); } + @Test public void testPrediction_default_exception() { when(modelCacheHelper.getModelInfo(anyString())).thenReturn(model); when(model.getAlgorithm()).thenReturn(FunctionName.KMEANS); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(6); listener.onFailure(new RuntimeException("Exception occurred. Please check log for more details.")); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); doAnswer(invocation -> { ((ActionListener) invocation.getArguments()[3]).onResponse(null); @@ -178,6 +184,7 @@ public void testPrediction_default_exception() { assertEquals("Failed to Validate Access for ModelId test_id", argumentCaptor.getValue().getMessage()); } + @Test public void testPrediction_local_model_not_exception() { when(modelCacheHelper.getModelInfo(anyString())).thenReturn(model); when(model.getAlgorithm()).thenReturn(FunctionName.TEXT_EMBEDDING); @@ -193,15 +200,16 @@ public void testPrediction_local_model_not_exception() { ); } + @Test public void testPrediction_OpenSearchStatusException() { when(modelCacheHelper.getModelInfo(anyString())).thenReturn(model); when(model.getAlgorithm()).thenReturn(FunctionName.KMEANS); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(6); listener.onFailure(new OpenSearchStatusException("Testing OpenSearchStatusException", RestStatus.BAD_REQUEST)); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); doAnswer(invocation -> { ((ActionListener) invocation.getArguments()[3]).onResponse(null); @@ -215,15 +223,16 @@ public void testPrediction_OpenSearchStatusException() { assertEquals("Testing OpenSearchStatusException", argumentCaptor.getValue().getMessage()); } + @Test public void testPrediction_MLResourceNotFoundException() { when(modelCacheHelper.getModelInfo(anyString())).thenReturn(model); when(model.getAlgorithm()).thenReturn(FunctionName.KMEANS); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(6); listener.onFailure(new MLResourceNotFoundException("Testing MLResourceNotFoundException")); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); doAnswer(invocation -> { ((ActionListener) invocation.getArguments()[3]).onResponse(null); @@ -237,15 +246,16 @@ public void testPrediction_MLResourceNotFoundException() { assertEquals("Testing MLResourceNotFoundException", argumentCaptor.getValue().getMessage()); } + @Test public void testPrediction_MLLimitExceededException() { when(modelCacheHelper.getModelInfo(anyString())).thenReturn(model); when(model.getAlgorithm()).thenReturn(FunctionName.TEXT_EMBEDDING); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(6); listener.onFailure(new CircuitBreakingException("Memory Circuit Breaker is open, please check your resources!", CircuitBreaker.Durability.TRANSIENT)); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); doAnswer(invocation -> { ((ActionListener) invocation.getArguments()[3]).onResponse(null); @@ -259,6 +269,7 @@ public void testPrediction_MLLimitExceededException() { assertEquals("Memory Circuit Breaker is open, please check your resources!", argumentCaptor.getValue().getMessage()); } + @Test public void testValidateInputSchemaSuccess() { RemoteInferenceInputDataSet remoteInferenceInputDataSet = RemoteInferenceInputDataSet .builder() @@ -282,6 +293,7 @@ public void testValidateInputSchemaSuccess() { transportPredictionTaskAction.validateInputSchema("testId", mlInput); } + @Test public void testValidateInputSchemaFailed() { exceptionRule.expect(OpenSearchStatusException.class); RemoteInferenceInputDataSet remoteInferenceInputDataSet = RemoteInferenceInputDataSet diff --git a/plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsActionTests.java index 42152f473d..2964ed583b 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsActionTests.java @@ -49,8 +49,10 @@ import org.opensearch.ml.engine.ModelHelper; import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.task.MLTaskDispatcher; import org.opensearch.ml.task.MLTaskManager; +import org.opensearch.remote.metadata.client.SdkClient; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; @@ -78,6 +80,7 @@ public class TransportUndeployModelsActionTests extends OpenSearchTestCase { @Mock Client client; + SdkClient sdkClient; @Mock NamedXContentRegistry xContentRegistry; @@ -103,6 +106,9 @@ public class TransportUndeployModelsActionTests extends OpenSearchTestCase { @Mock Task task; + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + TransportUndeployModelsAction transportUndeployModelsAction; private String[] modelIds = { "modelId1" }; @@ -129,12 +135,14 @@ public void setup() throws IOException { clusterService, threadPool, client, + sdkClient, settings, xContentRegistry, nodeFilter, mlTaskDispatcher, mlModelManager, - modelAccessControlHelper + modelAccessControlHelper, + mlFeatureEnabledSetting ) ); when(modelAccessControlHelper.isModelAccessControlEnabled()).thenReturn(true); @@ -158,10 +166,10 @@ public void setup() throws IOException { .isHidden(false) .build(); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(mlModel); return null; - }).when(mlModelManager).getModel(any(), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(any(), any(), any(), any(), isA(ActionListener.class)); } public void testHiddenModelSuccess() { @@ -178,10 +186,10 @@ public void testHiddenModelSuccess() { .isHidden(true) .build(); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(mlModel); return null; - }).when(mlModelManager).getModel(any(), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(any(), any(), any(), any(), isA(ActionListener.class)); List responseList = new ArrayList<>(); List failuresList = new ArrayList<>(); @@ -193,7 +201,7 @@ public void testHiddenModelSuccess() { }).when(client).execute(any(), any(), isA(ActionListener.class)); doReturn(true).when(transportUndeployModelsAction).isSuperAdminUserWrapper(clusterService, client); - MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds); + MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds, null); transportUndeployModelsAction.doExecute(task, request, actionListener); verify(actionListener).onResponse(any(MLUndeployModelsResponse.class)); } @@ -212,10 +220,10 @@ public void testHiddenModelPermissionError() { .isHidden(true) .build(); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(mlModel); return null; - }).when(mlModelManager).getModel(any(), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(any(), any(), any(), any(), isA(ActionListener.class)); List responseList = new ArrayList<>(); List failuresList = new ArrayList<>(); @@ -227,7 +235,7 @@ public void testHiddenModelPermissionError() { }).when(client).execute(any(), any(), isA(ActionListener.class)); doReturn(false).when(transportUndeployModelsAction).isSuperAdminUserWrapper(clusterService, client); - MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds); + MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds, null); transportUndeployModelsAction.doExecute(task, request, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); @@ -236,10 +244,10 @@ public void testHiddenModelPermissionError() { public void testDoExecute() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(6); listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), isA(ActionListener.class)); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); List responseList = new ArrayList<>(); List failuresList = new ArrayList<>(); @@ -249,7 +257,7 @@ public void testDoExecute() { listener.onResponse(response); return null; }).when(client).execute(any(), any(), isA(ActionListener.class)); - MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds); + MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds, null); transportUndeployModelsAction.doExecute(task, request, actionListener); verify(actionListener).onResponse(any(MLUndeployModelsResponse.class)); } @@ -257,10 +265,10 @@ public void testDoExecute() { public void testDoExecute_modelAccessControl_notEnabled() { when(modelAccessControlHelper.isModelAccessControlEnabled()).thenReturn(false); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(6); listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), isA(ActionListener.class)); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); MLUndeployModelsResponse mlUndeployModelsResponse = new MLUndeployModelsResponse(mock(MLUndeployModelNodesResponse.class)); doAnswer(invocation -> { @@ -268,49 +276,51 @@ public void testDoExecute_modelAccessControl_notEnabled() { listener.onResponse(mlUndeployModelsResponse); return null; }).when(client).execute(any(), any(), isA(ActionListener.class)); - MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds); + MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds, null); transportUndeployModelsAction.doExecute(task, request, actionListener); verify(actionListener).onFailure(isA(Exception.class)); } public void testDoExecute_validate_false() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(false); + ActionListener listener = invocation.getArgument(6); + listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), isA(ActionListener.class)); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); listener.onFailure(new IllegalArgumentException()); return null; }).when(client).execute(any(), any(), isA(ActionListener.class)); - MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds); + MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds, null); transportUndeployModelsAction.doExecute(task, request, actionListener); verify(actionListener).onFailure(isA(IllegalArgumentException.class)); } public void testDoExecute_getModel_exception() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onFailure(new RuntimeException("runtime exception")); return null; - }).when(mlModelManager).getModel(any(), any(), any(), isA(ActionListener.class)); - MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds); + }).when(mlModelManager).getModel(any(), any(), any(), any(), isA(ActionListener.class)); + MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds, null); transportUndeployModelsAction.doExecute(task, request, actionListener); verify(actionListener).onFailure(isA(RuntimeException.class)); } public void testDoExecute_validateAccess_exception() { - doThrow(new RuntimeException("runtime exception")).when(mlModelManager).getModel(any(), any(), any(), isA(ActionListener.class)); - MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds); + doThrow(new RuntimeException("runtime exception")) + .when(mlModelManager) + .getModel(any(), any(), any(), any(), isA(ActionListener.class)); + MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds, null); transportUndeployModelsAction.doExecute(task, request, actionListener); verify(actionListener).onFailure(isA(RuntimeException.class)); } public void testDoExecute_modelIds_moreThan1() { expectedException.expect(IllegalArgumentException.class); - MLUndeployModelsRequest request = new MLUndeployModelsRequest(new String[] { "modelId1", "modelId2" }, nodeIds); + MLUndeployModelsRequest request = new MLUndeployModelsRequest(new String[] { "modelId1", "modelId2" }, nodeIds, null); transportUndeployModelsAction.doExecute(task, request, actionListener); } } diff --git a/plugin/src/test/java/org/opensearch/ml/cluster/MLSyncUpCronTests.java b/plugin/src/test/java/org/opensearch/ml/cluster/MLSyncUpCronTests.java index 523b8fee36..696faa7432 100644 --- a/plugin/src/test/java/org/opensearch/ml/cluster/MLSyncUpCronTests.java +++ b/plugin/src/test/java/org/opensearch/ml/cluster/MLSyncUpCronTests.java @@ -38,6 +38,7 @@ import org.junit.Before; import org.mockito.ArgumentCaptor; import org.mockito.Mock; +import org.mockito.Mockito; import org.mockito.MockitoAnnotations; import org.opensearch.Version; import org.opensearch.action.bulk.BulkRequest; @@ -61,6 +62,7 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.common.transport.TransportAddress; +import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; @@ -71,7 +73,10 @@ import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.encryptor.EncryptorImpl; import org.opensearch.ml.engine.indices.MLIndicesHandler; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.utils.TestHelper; +import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.remote.metadata.client.impl.SdkClientFactory; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.search.aggregations.InternalAggregations; @@ -88,6 +93,7 @@ public class MLSyncUpCronTests extends OpenSearchTestCase { @Mock private Client client; + private SdkClient sdkClient; @Mock private ClusterService clusterService; @Mock @@ -95,6 +101,9 @@ public class MLSyncUpCronTests extends OpenSearchTestCase { @Mock private MLIndicesHandler mlIndicesHandler; + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + private DiscoveryNode mlNode1; private DiscoveryNode mlNode2; private MLSyncUpCron syncUpCron; @@ -116,7 +125,6 @@ public void setup() throws IOException { mlNode1 = new DiscoveryNode(mlNode1Id, buildNewFakeTransportAddress(), emptyMap(), ImmutableSet.of(ML_ROLE), Version.CURRENT); mlNode2 = new DiscoveryNode(mlNode2Id, buildNewFakeTransportAddress(), emptyMap(), ImmutableSet.of(ML_ROLE), Version.CURRENT); encryptor = spy(new EncryptorImpl(null)); - syncUpCron = new MLSyncUpCron(client, clusterService, nodeHelper, mlIndicesHandler, encryptor); testState = setupTestClusterState("node"); when(clusterService.state()).thenReturn(testState); @@ -128,10 +136,12 @@ public void setup() throws IOException { }).when(mlIndicesHandler).initMLConfigIndex(any()); Settings settings = Settings.builder().build(); + sdkClient = Mockito.spy(SdkClientFactory.createSdkClient(client, NamedXContentRegistry.EMPTY, Collections.emptyMap())); threadContext = new ThreadContext(settings); threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, USER_STRING); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); + syncUpCron = new MLSyncUpCron(client, sdkClient, clusterService, nodeHelper, mlIndicesHandler, encryptor, mlFeatureEnabledSetting); } public void testInitMlConfig_MasterKeyNotExist() { diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLUndeployModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLUndeployModelActionTests.java index e98d1a5f31..34d85b6e10 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLUndeployModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLUndeployModelActionTests.java @@ -37,6 +37,7 @@ import org.opensearch.ml.common.transport.model.MLModelGetResponse; import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsAction; import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsRequest; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestHandler; import org.opensearch.rest.RestRequest; @@ -57,6 +58,9 @@ public class RestMLUndeployModelActionTests extends OpenSearchTestCase { @Mock ClusterState testState; + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + @Mock RestChannel channel; @@ -70,7 +74,7 @@ public void setup() { testState = setupTestClusterState("node"); when(clusterService.state()).thenReturn(testState); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); - restMLUndeployModelAction = new RestMLUndeployModelAction(clusterService, settings); + restMLUndeployModelAction = new RestMLUndeployModelAction(clusterService, settings, mlFeatureEnabledSetting); threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); client = spy(new NodeClient(Settings.EMPTY, threadPool)); doAnswer(invocation -> { @@ -88,7 +92,7 @@ public void tearDown() throws Exception { } public void testConstructor() { - RestMLUndeployModelAction undeployModel = new RestMLUndeployModelAction(clusterService, settings); + RestMLUndeployModelAction undeployModel = new RestMLUndeployModelAction(clusterService, settings, mlFeatureEnabledSetting); assertNotNull(undeployModel); }