Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

restrict stash context only for stop words system index #2283

Merged
merged 1 commit into from
Mar 27, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 33 additions & 14 deletions common/src/main/java/org/opensearch/ml/common/model/MLGuard.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@

package org.opensearch.ml.common.model;

import com.google.common.collect.ImmutableSet;
import lombok.Getter;
import lombok.NonNull;
import lombok.extern.log4j.Log4j2;
import org.opensearch.ResourceNotFoundException;
import org.opensearch.action.LatchedActionListener;
import org.opensearch.action.get.GetResponse;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.client.Client;
Expand All @@ -20,7 +19,6 @@
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.search.SearchHit;
import org.opensearch.search.builder.SearchSourceBuilder;

import java.security.AccessController;
Expand All @@ -30,15 +28,14 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;

import static java.util.concurrent.TimeUnit.SECONDS;
import static org.opensearch.ml.common.CommonValue.MASTER_KEY;
import static org.opensearch.ml.common.utils.StringUtils.gson;

@Log4j2
Expand All @@ -52,6 +49,7 @@ public class MLGuard {
private List<Pattern> outputRegexPattern;
private NamedXContentRegistry xContentRegistry;
private Client client;
private Set<String> stopWordsIndices = ImmutableSet.of(".plugins-ml-stop-words");

public MLGuard(Guardrails guardrails, NamedXContentRegistry xContentRegistry, Client client) {
this.xContentRegistry = xContentRegistry;
Expand Down Expand Up @@ -128,27 +126,44 @@ public Boolean validateStopWordsSingleIndex(String input, String indexName, List
Map<String, Object> queryBodyMap = Map
.of("query", Map.of("percolate", Map.of("field", "query", "document", documentMap)));
CountDownLatch latch = new CountDownLatch(1);
ThreadContext.StoredContext context = null;

try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
try {
queryBody = AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> gson.toJson(queryBodyMap));
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
XContentParser queryParser = XContentType.JSON.xContent().createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, queryBody);
searchSourceBuilder.parseXContent(queryParser);
searchSourceBuilder.size(1); //Only need 1 doc returned, if hit.
searchRequest = new SearchRequest().source(searchSourceBuilder).indices(indexName);
context.restore();
client.search(searchRequest, ActionListener.runBefore(new LatchedActionListener(ActionListener.<SearchResponse>wrap(r -> {
if (r == null || r.getHits() == null || r.getHits().getTotalHits() == null || r.getHits().getTotalHits().value == 0) {
if (isStopWordsSystemIndex(indexName)) {
context = client.threadPool().getThreadContext().stashContext();
ThreadContext.StoredContext finalContext = context;
client.search(searchRequest, ActionListener.runBefore(new LatchedActionListener(ActionListener.<SearchResponse>wrap(r -> {
if (r == null || r.getHits() == null || r.getHits().getTotalHits() == null || r.getHits().getTotalHits().value == 0) {
hitStopWords.set(true);
}
}, e -> {
log.error("Failed to search stop words index {}", indexName, e);
hitStopWords.set(true);
}), latch), () -> finalContext.restore()));
} else {
client.search(searchRequest, new LatchedActionListener(ActionListener.<SearchResponse>wrap(r -> {
if (r == null || r.getHits() == null || r.getHits().getTotalHits() == null || r.getHits().getTotalHits().value == 0) {
hitStopWords.set(true);
}
}, e -> {
log.error("Failed to search stop words index {}", indexName, e);
hitStopWords.set(true);
}
}, e -> {
log.error("Failed to search stop words index {}", indexName, e);
hitStopWords.set(true);
}), latch), () -> context.restore()));
}), latch));
}
} catch (Exception e) {
log.error("[validateStopWords] Searching stop words index failed.", e);
latch.countDown();
hitStopWords.set(true);
} finally {
if (context != null) {
context.close();
}
}

try {
Expand All @@ -160,6 +175,10 @@ public Boolean validateStopWordsSingleIndex(String input, String indexName, List
return hitStopWords.get();
}

private boolean isStopWordsSystemIndex(String index) {
return stopWordsIndices.contains(index);
}

public enum Type {
INPUT,
OUTPUT
Expand Down
Loading