Skip to content

Commit

Permalink
Issue 1878/1879/1880 fixing index not found for model group/model/tas…
Browse files Browse the repository at this point in the history
…ks (#1889) (#1895)

* adding tests for search of ML constructs



* improve test coverage on MLSearchHandler



* fix headers on test



---------


(cherry picked from commit 41fac82)

Signed-off-by: Samuel Herman <[email protected]>
Signed-off-by: samuel-oci <[email protected]>
  • Loading branch information
sam-herman authored Jan 22, 2024
1 parent 56a9ab1 commit c278cb1
Show file tree
Hide file tree
Showing 12 changed files with 328 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.opensearch.ml.action.connector;

import static org.opensearch.ml.utils.RestActionUtils.wrapListenerToHandleSearchIndexNotFound;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
Expand Down Expand Up @@ -85,8 +87,8 @@ private void search(SearchRequest request, ActionListener<SearchResponse> action
);
request.source().fetchSource(rebuiltFetchSourceContext);

ActionListener<SearchResponse> doubleWrappedListener = ActionListener
.wrap(wrappedListener::onResponse, e -> wrapListenerToHandleConnectorIndexNotFound(e, actionListener));
final ActionListener<SearchResponse> doubleWrappedListener = ActionListener
.wrap(wrappedListener::onResponse, e -> wrapListenerToHandleSearchIndexNotFound(e, wrappedListener));

if (connectorAccessControlHelper.skipConnectorAccessControl(user)) {
client.search(request, doubleWrappedListener);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import static org.opensearch.core.rest.RestStatus.BAD_REQUEST;
import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR;
import static org.opensearch.ml.utils.RestActionUtils.wrapListenerToHandleSearchIndexNotFound;

import java.util.ArrayList;
import java.util.Arrays;
Expand Down Expand Up @@ -45,6 +46,8 @@
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.fetch.subphase.FetchSourceContext;

import com.google.common.annotations.VisibleForTesting;

import lombok.extern.log4j.Log4j2;

/**
Expand Down Expand Up @@ -121,10 +124,12 @@ public void search(SearchRequest request, ActionListener<SearchResponse> actionL

request.source().query(queryBuilder);
request.source().fetchSource(rebuiltFetchSourceContext);
final ActionListener<SearchResponse> doubleWrapperListener = ActionListener
.wrap(wrappedListener::onResponse, e -> wrapListenerToHandleSearchIndexNotFound(e, wrappedListener));
if (modelAccessControlHelper.skipModelAccessControl(user)) {
client.search(request, wrappedListener);
client.search(request, doubleWrapperListener);
} else if (!clusterService.state().metadata().hasIndex(CommonValue.ML_MODEL_GROUP_INDEX)) {
client.search(request, wrappedListener);
client.search(request, doubleWrapperListener);
} else {
SearchSourceBuilder sourceBuilder = modelAccessControlHelper.createSearchSourceBuilder(user);
SearchRequest modelGroupSearchRequest = new SearchRequest();
Expand All @@ -143,11 +148,11 @@ public void search(SearchRequest request, ActionListener<SearchResponse> actionL
Arrays.stream(r.getHits().getHits()).forEach(hit -> { modelGroupIds.add(hit.getId()); });

request.source().query(rewriteQueryBuilder(request.source().query(), modelGroupIds));
client.search(request, wrappedListener);
client.search(request, doubleWrapperListener);
} else {
log.debug("No model group found");
request.source().query(rewriteQueryBuilder(request.source().query(), null));
client.search(request, wrappedListener);
client.search(request, doubleWrapperListener);
}
}, e -> {
log.error("Fail to search model groups!", e);
Expand All @@ -161,7 +166,8 @@ public void search(SearchRequest request, ActionListener<SearchResponse> actionL
}
}

private QueryBuilder rewriteQueryBuilder(QueryBuilder queryBuilder, List<String> modelGroupIds) {
@VisibleForTesting
static QueryBuilder rewriteQueryBuilder(QueryBuilder queryBuilder, List<String> modelGroupIds) {
ExistsQueryBuilder existsQueryBuilder = new ExistsQueryBuilder(MLModelGroup.MODEL_GROUP_ID_FIELD);
BoolQueryBuilder modelGroupIdMustNotExistBoolQuery = new BoolQueryBuilder();
modelGroupIdMustNotExistBoolQuery.mustNot(existsQueryBuilder);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.action.model_group;

import static org.opensearch.ml.action.handler.MLSearchHandler.wrapRestActionListener;
import static org.opensearch.ml.utils.RestActionUtils.wrapListenerToHandleSearchIndexNotFound;

import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
Expand Down Expand Up @@ -58,13 +59,17 @@ protected void doExecute(Task task, SearchRequest request, ActionListener<Search
private void preProcessRoleAndPerformSearch(SearchRequest request, User user, ActionListener<SearchResponse> listener) {
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
ActionListener<SearchResponse> wrappedListener = ActionListener.runBefore(listener, () -> context.restore());

final ActionListener<SearchResponse> doubleWrappedListener = ActionListener
.wrap(wrappedListener::onResponse, e -> wrapListenerToHandleSearchIndexNotFound(e, wrappedListener));

if (modelAccessControlHelper.skipModelAccessControl(user)) {
client.search(request, wrappedListener);
client.search(request, doubleWrappedListener);
} else {
// Security is enabled, filter is enabled and user isn't admin
modelAccessControlHelper.addUserBackendRolesFilter(user, request.source());
log.debug("Filtering result by " + user.getBackendRoles());
client.search(request, wrappedListener);
client.search(request, doubleWrappedListener);
}
} catch (Exception e) {
log.error("Failed to search", e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

@Log4j2
public class SearchModelTransportAction extends HandledTransportAction<SearchRequest, SearchResponse> {
private MLSearchHandler mlSearchHandler;
private final MLSearchHandler mlSearchHandler;

@Inject
public SearchModelTransportAction(TransportService transportService, ActionFilters actionFilters, MLSearchHandler mlSearchHandler) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.opensearch.ml.action.tasks;

import static org.opensearch.ml.utils.RestActionUtils.wrapListenerToHandleSearchIndexNotFound;

import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.ActionFilters;
Expand Down Expand Up @@ -32,7 +34,9 @@ public SearchTaskTransportAction(TransportService transportService, ActionFilter
@Override
protected void doExecute(Task task, SearchRequest request, ActionListener<SearchResponse> actionListener) {
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
client.search(request, ActionListener.runBefore(actionListener, () -> context.restore()));
final ActionListener<SearchResponse> wrappedListener = ActionListener
.wrap(actionListener::onResponse, e -> wrapListenerToHandleSearchIndexNotFound(e, actionListener));
client.search(request, ActionListener.runBefore(wrappedListener, () -> context.restore()));
} catch (Exception e) {
log.error(e.getMessage(), e);
actionListener.onFailure(e);
Expand Down
42 changes: 42 additions & 0 deletions plugin/src/main/java/org/opensearch/ml/utils/RestActionUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,23 +23,32 @@
import org.apache.commons.lang3.ArrayUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ExceptionsHelper;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.search.ShardSearchFailure;
import org.opensearch.client.Client;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.Nullable;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.commons.ConfigConstants;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.Strings;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.rest.BytesRestResponse;
import org.opensearch.rest.RestChannel;
import org.opensearch.rest.RestRequest;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.fetch.subphase.FetchSourceContext;
import org.opensearch.search.internal.InternalSearchResponse;

import com.google.common.annotations.VisibleForTesting;

import lombok.extern.log4j.Log4j2;

@Log4j2
public class RestActionUtils {

private static final Logger logger = LogManager.getLogger(RestActionUtils.class);
Expand Down Expand Up @@ -246,4 +255,37 @@ private static boolean isAdminDN(LdapName dn) {
return isAdmin;
}

/**
* Utility to wrap over an action listener to handle index not found error to return empty results instead of failing.
* This is important when the user is performing a search request against connectors/models/model groups/tasks or other constructs that
* do not imply an index error but rather imply no items found.
* @see <a href=https://github.com/opensearch-project/ml-commons/issues/1787>Issue 1787</a>
* @see <a href=https://github.com/opensearch-project/ml-commons/issues/1778>Issue 1878</a>
* @see <a href=https://github.com/opensearch-project/ml-commons/issues/1879>Issue 1879</a>
* @see <a href=https://github.com/opensearch-project/ml-commons/issues/1880>Issue 1880</a>
* @param e Exception to wrap
* @param listener ActionListener for a search response to wrap
*/
public static void wrapListenerToHandleSearchIndexNotFound(Exception e, ActionListener<SearchResponse> listener) {
if (ExceptionsHelper.unwrapCause(e) instanceof IndexNotFoundException) {
log.debug("Connectors index not created yet, therefore we will swallow the exception and return an empty search result");
final InternalSearchResponse internalSearchResponse = InternalSearchResponse.empty();
final SearchResponse emptySearchResponse = new SearchResponse(
internalSearchResponse,
null,
0,
0,
0,
0,
null,
new ShardSearchFailure[] {},
SearchResponse.Clusters.EMPTY,
null
);
listener.onResponse(emptySearchResponse);
} else {
listener.onFailure(e);
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.action.handler;

import java.util.List;

import org.junit.Assert;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryStringQueryBuilder;
import org.opensearch.test.OpenSearchTestCase;

public class MLSearchHandlerTests extends OpenSearchTestCase {

public void testRewriteQueryBuilder_accessControlIgnored() {
final String expectedQueryString = "{\n"
+ " \"bool\" : {\n"
+ " \"must\" : [\n"
+ " {\n"
+ " \"query_string\" : {\n"
+ " \"query\" : \"\",\n"
+ " \"fields\" : [ ],\n"
+ " \"type\" : \"best_fields\",\n"
+ " \"default_operator\" : \"or\",\n"
+ " \"max_determinized_states\" : 10000,\n"
+ " \"enable_position_increments\" : true,\n"
+ " \"fuzziness\" : \"AUTO\",\n"
+ " \"fuzzy_prefix_length\" : 0,\n"
+ " \"fuzzy_max_expansions\" : 50,\n"
+ " \"phrase_slop\" : 0,\n"
+ " \"escape\" : false,\n"
+ " \"auto_generate_synonyms_phrase_query\" : true,\n"
+ " \"fuzzy_transpositions\" : true,\n"
+ " \"boost\" : 1.0\n"
+ " }\n"
+ " },\n"
+ " {\n"
+ " \"bool\" : {\n"
+ " \"must_not\" : [\n"
+ " {\n"
+ " \"exists\" : {\n"
+ " \"field\" : \"model_group_id\",\n"
+ " \"boost\" : 1.0\n"
+ " }\n"
+ " }\n"
+ " ],\n"
+ " \"adjust_pure_negative\" : true,\n"
+ " \"boost\" : 1.0\n"
+ " }\n"
+ " }\n"
+ " ],\n"
+ " \"adjust_pure_negative\" : true,\n"
+ " \"boost\" : 1.0\n"
+ " }\n"
+ "}";
final QueryBuilder queryBuilder = MLSearchHandler.rewriteQueryBuilder(new QueryStringQueryBuilder(""), List.of("group1", "group2"));
final String queryString = queryBuilder.toString();
Assert.assertEquals(expectedQueryString, queryString);
}

public void testRewriteQueryBuilder_accessControlUsed_withNullQuery() {
final String expectedQueryString = "{\n"
+ " \"bool\" : {\n"
+ " \"should\" : [\n"
+ " {\n"
+ " \"terms\" : {\n"
+ " \"model_group_id\" : [\n"
+ " \"group1\",\n"
+ " \"group2\"\n"
+ " ],\n"
+ " \"boost\" : 1.0\n"
+ " }\n"
+ " },\n"
+ " {\n"
+ " \"bool\" : {\n"
+ " \"must_not\" : [\n"
+ " {\n"
+ " \"exists\" : {\n"
+ " \"field\" : \"model_group_id\",\n"
+ " \"boost\" : 1.0\n"
+ " }\n"
+ " }\n"
+ " ],\n"
+ " \"adjust_pure_negative\" : true,\n"
+ " \"boost\" : 1.0\n"
+ " }\n"
+ " }\n"
+ " ],\n"
+ " \"adjust_pure_negative\" : true,\n"
+ " \"boost\" : 1.0\n"
+ " }\n"
+ "}";
final QueryBuilder queryBuilder = MLSearchHandler.rewriteQueryBuilder(null, List.of("group1", "group2"));
final String queryString = queryBuilder.toString();
Assert.assertEquals(expectedQueryString, queryString);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.search.SearchResponseSections;
import org.opensearch.action.search.ShardSearchFailure;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.client.Client;
import org.opensearch.cluster.ClusterName;
Expand All @@ -45,6 +47,7 @@
import org.opensearch.ml.utils.TestHelper;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
import org.opensearch.search.aggregations.InternalAggregations;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.fetch.subphase.FetchSourceContext;
import org.opensearch.test.OpenSearchTestCase;
Expand Down Expand Up @@ -161,7 +164,7 @@ public void test_DoExecute_addBackendRoles_exception() {
verify(client, times(1)).search(any(), any());
}

public void test_DoExecute_searchModel_indexNotFound_exception() {
public void test_DoExecute_searchModel_before_model_creation_no_exception() {
doAnswer(invocation -> {
ActionListener<SearchResponse> listener = invocation.getArgument(1);
listener.onFailure(new IndexNotFoundException("index not found exception"));
Expand All @@ -171,7 +174,41 @@ public void test_DoExecute_searchModel_indexNotFound_exception() {
searchModelTransportAction.doExecute(null, searchRequest, actionListener);
verify(mlSearchHandler).search(searchRequest, actionListener);
verify(client, times(1)).search(any(), any());
verify(actionListener, times(1)).onFailure(any(IndexNotFoundException.class));
verify(actionListener, times(0)).onFailure(any(IndexNotFoundException.class));
}

public void test_DoExecute_searchModel_before_model_creation_empty_search() {
doAnswer(invocation -> {
ActionListener<SearchResponse> listener = invocation.getArgument(1);
SearchHits hits = new SearchHits(new SearchHit[] {}, new TotalHits(0, TotalHits.Relation.EQUAL_TO), Float.NaN);
SearchResponseSections searchSections = new SearchResponseSections(
hits,
InternalAggregations.EMPTY,
null,
false,
false,
null,
1
);
final SearchResponse searchResponse = new SearchResponse(
searchSections,
null,
1,
1,
0,
11,
ShardSearchFailure.EMPTY_ARRAY,
SearchResponse.Clusters.EMPTY
);
listener.onResponse(searchResponse);
return null;
}).when(client).search(any(), isA(ActionListener.class));
when(modelAccessControlHelper.skipModelAccessControl(any())).thenReturn(true);
searchModelTransportAction.doExecute(null, searchRequest, actionListener);
verify(mlSearchHandler).search(searchRequest, actionListener);
verify(client, times(1)).search(any(), any());
verify(actionListener, times(0)).onFailure(any(IndexNotFoundException.class));
verify(actionListener, times(1)).onResponse(any(SearchResponse.class));
}

public void test_DoExecute_searchModel_MLResourceNotFoundException_exception() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -765,7 +765,7 @@ private String parseTaskIdFromResponse(Response response) throws IOException {
return taskId;
}

private Map parseResponseToMap(Response response) throws IOException {
Map parseResponseToMap(Response response) throws IOException {
HttpEntity entity = response.getEntity();
assertNotNull(response);
String entityString = TestHelper.httpEntityToString(entity);
Expand Down
Loading

0 comments on commit c278cb1

Please sign in to comment.