From 2c1cd6d41d8d6b34e8fda123aff345dd66d89778 Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Wed, 27 Mar 2024 12:48:30 -0700 Subject: [PATCH] restrict stash context only for stop words system index Signed-off-by: Jing Zhang --- .../opensearch/ml/common/model/MLGuard.java | 47 +++++++++++++------ 1 file changed, 33 insertions(+), 14 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/model/MLGuard.java b/common/src/main/java/org/opensearch/ml/common/model/MLGuard.java index f51615551d..99999db182 100644 --- a/common/src/main/java/org/opensearch/ml/common/model/MLGuard.java +++ b/common/src/main/java/org/opensearch/ml/common/model/MLGuard.java @@ -5,12 +5,11 @@ package org.opensearch.ml.common.model; +import com.google.common.collect.ImmutableSet; import lombok.Getter; import lombok.NonNull; import lombok.extern.log4j.Log4j2; -import org.opensearch.ResourceNotFoundException; import org.opensearch.action.LatchedActionListener; -import org.opensearch.action.get.GetResponse; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.client.Client; @@ -20,7 +19,6 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.search.SearchHit; import org.opensearch.search.builder.SearchSourceBuilder; import java.security.AccessController; @@ -30,15 +28,14 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicReference; import java.util.regex.Matcher; import java.util.regex.Pattern; import java.util.stream.Collectors; import static java.util.concurrent.TimeUnit.SECONDS; -import static org.opensearch.ml.common.CommonValue.MASTER_KEY; import static org.opensearch.ml.common.utils.StringUtils.gson; @Log4j2 @@ -52,6 +49,7 @@ public class MLGuard { private List outputRegexPattern; private NamedXContentRegistry xContentRegistry; private Client client; + private Set stopWordsIndices = ImmutableSet.of(".stop-words"); public MLGuard(Guardrails guardrails, NamedXContentRegistry xContentRegistry, Client client) { this.xContentRegistry = xContentRegistry; @@ -128,27 +126,44 @@ public Boolean validateStopWordsSingleIndex(String input, String indexName, List Map queryBodyMap = Map .of("query", Map.of("percolate", Map.of("field", "query", "document", documentMap))); CountDownLatch latch = new CountDownLatch(1); + ThreadContext.StoredContext context = null; - try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + try { queryBody = AccessController.doPrivileged((PrivilegedExceptionAction) () -> gson.toJson(queryBodyMap)); SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); XContentParser queryParser = XContentType.JSON.xContent().createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, queryBody); searchSourceBuilder.parseXContent(queryParser); searchSourceBuilder.size(1); //Only need 1 doc returned, if hit. searchRequest = new SearchRequest().source(searchSourceBuilder).indices(indexName); - context.restore(); - client.search(searchRequest, ActionListener.runBefore(new LatchedActionListener(ActionListener.wrap(r -> { - if (r == null || r.getHits() == null || r.getHits().getTotalHits() == null || r.getHits().getTotalHits().value == 0) { + if (isStopWordsSystemIndex(indexName)) { + context = client.threadPool().getThreadContext().stashContext(); + ThreadContext.StoredContext finalContext = context; + client.search(searchRequest, ActionListener.runBefore(new LatchedActionListener(ActionListener.wrap(r -> { + if (r == null || r.getHits() == null || r.getHits().getTotalHits() == null || r.getHits().getTotalHits().value == 0) { + hitStopWords.set(true); + } + }, e -> { + log.error("Failed to search stop words index {}", indexName, e); + hitStopWords.set(true); + }), latch), () -> finalContext.restore())); + } else { + client.search(searchRequest, new LatchedActionListener(ActionListener.wrap(r -> { + if (r == null || r.getHits() == null || r.getHits().getTotalHits() == null || r.getHits().getTotalHits().value == 0) { + hitStopWords.set(true); + } + }, e -> { + log.error("Failed to search stop words index {}", indexName, e); hitStopWords.set(true); - } - }, e -> { - log.error("Failed to search stop words index {}", indexName, e); - hitStopWords.set(true); - }), latch), () -> context.restore())); + }), latch)); + } } catch (Exception e) { log.error("[validateStopWords] Searching stop words index failed.", e); latch.countDown(); hitStopWords.set(true); + } finally { + if (context != null) { + context.close(); + } } try { @@ -160,6 +175,10 @@ public Boolean validateStopWordsSingleIndex(String input, String indexName, List return hitStopWords.get(); } + private boolean isStopWordsSystemIndex(String index) { + return stopWordsIndices.contains(index); + } + public enum Type { INPUT, OUTPUT