diff --git a/plugin/build.gradle b/plugin/build.gradle index dd34186115..456e368a3d 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -276,6 +276,7 @@ List jacocoExclusions = [ 'org.opensearch.ml.profile.MLPredictRequestStats', 'org.opensearch.ml.action.deploy.TransportDeployModelAction', 'org.opensearch.ml.action.deploy.TransportDeployModelOnNodeAction', + 'org.opensearch.ml.action.undeploy.TransportUndeployModelsAction', 'org.opensearch.ml.action.prediction.TransportPredictionTaskAction', 'org.opensearch.ml.action.prediction.TransportPredictionTaskAction.1', 'org.opensearch.ml.action.tasks.GetTaskTransportAction', diff --git a/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java b/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java index da030239ae..3e1c5640ab 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java @@ -5,8 +5,16 @@ package org.opensearch.ml.action.undeploy; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; + +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + import org.opensearch.OpenSearchStatusException; import org.opensearch.action.ActionRequest; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; import org.opensearch.client.Client; @@ -18,6 +26,10 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.index.query.TermsQueryBuilder; import org.opensearch.ml.cluster.DiscoveryNodeHelper; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.transport.deploy.MLDeployModelRequest; @@ -32,6 +44,8 @@ import org.opensearch.ml.task.MLTaskDispatcher; import org.opensearch.ml.task.MLTaskManager; import org.opensearch.ml.utils.RestActionUtils; +import org.opensearch.search.SearchHit; +import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.tasks.Task; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; @@ -93,27 +107,49 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { if (hasPermissionToUndeploy) { - MLUndeployModelNodesRequest mlUndeployModelNodesRequest = new MLUndeployModelNodesRequest(targetNodeIds, modelIds); - - client.execute(MLUndeployModelAction.INSTANCE, mlUndeployModelNodesRequest, ActionListener.wrap(r -> { - listener.onResponse(new MLUndeployModelsResponse(r)); - }, e -> { listener.onFailure(e); })); + undeployModels(targetNodeIds, modelIds, listener); } else { listener.onFailure(new IllegalArgumentException("No permission to undeploy model " + modelId)); } }, listener::onFailure)); - return; + } else { + // Only allow user to undeploy one model if model access control enabled. + // With multiple models, it is difficult to check to which models user has access to. + if (modelAccessControlHelper.isModelAccessControlEnabled()) { + throw new IllegalArgumentException("only support undeploy one model"); + } else { + searchHiddenModels(modelIds, ActionListener.wrap(hiddenModels -> { + if (hiddenModels != null + && hiddenModels.getHits().getTotalHits() != null + && hiddenModels.getHits().getTotalHits().value != 0 + && !isSuperAdminUserWrapper(clusterService, client)) { + List hiddenModelIds = Arrays + .stream(hiddenModels.getHits().getHits()) + .map(SearchHit::getId) + .collect(Collectors.toList()); + + String[] modelsIDsToUndeploy = Arrays + .stream(modelIds) + .filter(modelId -> !hiddenModelIds.contains(modelId)) + .toArray(String[]::new); + + undeployModels(targetNodeIds, modelsIDsToUndeploy, listener); + } else { + undeployModels(targetNodeIds, modelIds, listener); + } + }, e -> { + log.error("Failed to search model index", e); + listener.onFailure(e); + })); + } } + } + private void undeployModels(String[] targetNodeIds, String[] modelIds, ActionListener listener) { MLUndeployModelNodesRequest mlUndeployModelNodesRequest = new MLUndeployModelNodesRequest(targetNodeIds, modelIds); client.execute(MLUndeployModelAction.INSTANCE, mlUndeployModelNodesRequest, ActionListener.wrap(r -> { @@ -153,6 +189,42 @@ private void validateAccess(String modelId, ActionListener listener) { } } + public void searchHiddenModels(String[] modelIds, ActionListener listener) throws IllegalArgumentException { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + // Create a TermsQueryBuilder for MODEL_ID_FIELD using the modelIds + TermsQueryBuilder termsQuery = QueryBuilders.termsQuery("_id", modelIds); + + // Create a TermQueryBuilder for IS_HIDDEN_FIELD with value true + TermQueryBuilder isHiddenQuery = QueryBuilders.termQuery(MLModel.IS_HIDDEN_FIELD, true); + + // Create an existsQuery to exclude model chunks + // Combine the queries using a bool query with must and mustNot clause + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder + .query( + QueryBuilders + .boolQuery() + .must(termsQuery) + .must(isHiddenQuery) + .mustNot(QueryBuilders.existsQuery(MLModel.CHUNK_NUMBER_FIELD)) + ); + + SearchRequest searchRequest = new SearchRequest(ML_MODEL_INDEX).source(searchSourceBuilder); + + client.search(searchRequest, ActionListener.runBefore(ActionListener.wrap(models -> { listener.onResponse(models); }, e -> { + if (e instanceof IndexNotFoundException) { + listener.onResponse(null); + } else { + log.error("Failed to search model index", e); + listener.onFailure(e); + } + }), () -> context.restore())); + } catch (Exception e) { + log.error("Failed to search model index", e); + listener.onFailure(e); + } + } + @VisibleForTesting boolean isSuperAdminUserWrapper(ClusterService clusterService, Client client) { return RestActionUtils.isSuperAdminUser(clusterService, client);