diff --git a/CHANGELOG.md b/CHANGELOG.md index cf648c641..12248ccad 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.18...2.x) ### Features +- Pagination in Hybrid query ([#1048](https://github.com/opensearch-project/neural-search/pull/1048)) ### Enhancements - Explainability in hybrid query ([#970](https://github.com/opensearch-project/neural-search/pull/970)) - Implement pruning for neural sparse ingestion pipeline and two phase search processor ([#988](https://github.com/opensearch-project/neural-search/pull/988)) diff --git a/src/main/java/org/opensearch/neuralsearch/common/MinClusterVersionUtil.java b/src/main/java/org/opensearch/neuralsearch/common/MinClusterVersionUtil.java index a17e138e2..13410d1c7 100644 --- a/src/main/java/org/opensearch/neuralsearch/common/MinClusterVersionUtil.java +++ b/src/main/java/org/opensearch/neuralsearch/common/MinClusterVersionUtil.java @@ -24,6 +24,7 @@ public final class MinClusterVersionUtil { private static final Version MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID = Version.V_2_11_0; private static final Version MINIMAL_SUPPORTED_VERSION_RADIAL_SEARCH = Version.V_2_14_0; private static final Version MINIMAL_SUPPORTED_VERSION_QUERY_IMAGE_FIX = Version.V_2_19_0; + private static final Version MINIMAL_SUPPORTED_VERSION_PAGINATION_IN_HYBRID_QUERY = Version.V_2_19_0; // Note this minimal version will act as a override private static final Map MINIMAL_VERSION_NEURAL = ImmutableMap.builder() @@ -41,6 +42,10 @@ public static boolean isClusterOnOrAfterMinReqVersionForRadialSearch() { return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_RADIAL_SEARCH); } + public static boolean isClusterOnOrAfterMinReqVersionForPaginationInHybridQuery() { + return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_PAGINATION_IN_HYBRID_QUERY); + } + public static boolean isClusterOnOrAfterMinReqVersion(String key) { Version version; if (MINIMAL_VERSION_NEURAL.containsKey(key)) { diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java index d2008ae97..d2fa03fde 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java @@ -93,6 +93,7 @@ private void prepareAndExecuteNormalizationWo .combinationTechnique(combinationTechnique) .explain(explain) .pipelineProcessingContext(requestContextOptional.orElse(null)) + .searchPhaseContext(searchPhaseContext) .build(); normalizationWorkflow.execute(request); } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java index f2699d967..db3747a13 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java @@ -19,6 +19,7 @@ import org.apache.lucene.search.Sort; import org.apache.lucene.search.TopFieldDocs; import org.apache.lucene.search.FieldDoc; +import org.opensearch.action.search.SearchPhaseContext; import org.opensearch.common.lucene.search.TopDocsAndMaxScore; import org.opensearch.neuralsearch.processor.combination.CombineScoresDto; import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique; @@ -64,7 +65,8 @@ public void execute( final List querySearchResults, final Optional fetchSearchResultOptional, final ScoreNormalizationTechnique normalizationTechnique, - final ScoreCombinationTechnique combinationTechnique + final ScoreCombinationTechnique combinationTechnique, + final SearchPhaseContext searchPhaseContext ) { NormalizationProcessorWorkflowExecuteRequest request = NormalizationProcessorWorkflowExecuteRequest.builder() .querySearchResults(querySearchResults) @@ -72,17 +74,21 @@ public void execute( .normalizationTechnique(normalizationTechnique) .combinationTechnique(combinationTechnique) .explain(false) + .searchPhaseContext(searchPhaseContext) .build(); execute(request); } public void execute(final NormalizationProcessorWorkflowExecuteRequest request) { + List querySearchResults = request.getQuerySearchResults(); + Optional fetchSearchResultOptional = request.getFetchSearchResultOptional(); + // save original state - List unprocessedDocIds = unprocessedDocIds(request.getQuerySearchResults()); + List unprocessedDocIds = unprocessedDocIds(querySearchResults); // pre-process data log.debug("Pre-process query results"); - List queryTopDocs = getQueryTopDocs(request.getQuerySearchResults()); + List queryTopDocs = getQueryTopDocs(querySearchResults); explain(request, queryTopDocs); @@ -93,8 +99,9 @@ public void execute(final NormalizationProcessorWorkflowExecuteRequest request) CombineScoresDto combineScoresDTO = CombineScoresDto.builder() .queryTopDocs(queryTopDocs) .scoreCombinationTechnique(request.getCombinationTechnique()) - .querySearchResults(request.getQuerySearchResults()) - .sort(evaluateSortCriteria(request.getQuerySearchResults(), queryTopDocs)) + .querySearchResults(querySearchResults) + .sort(evaluateSortCriteria(querySearchResults, queryTopDocs)) + .fromValueForSingleShard(getFromValueIfSingleShard(request)) .build(); // combine @@ -103,8 +110,26 @@ public void execute(final NormalizationProcessorWorkflowExecuteRequest request) // post-process data log.debug("Post-process query results after score normalization and combination"); - updateOriginalQueryResults(combineScoresDTO); - updateOriginalFetchResults(request.getQuerySearchResults(), request.getFetchSearchResultOptional(), unprocessedDocIds); + updateOriginalQueryResults(combineScoresDTO, fetchSearchResultOptional.isPresent()); + updateOriginalFetchResults( + querySearchResults, + fetchSearchResultOptional, + unprocessedDocIds, + combineScoresDTO.getFromValueForSingleShard() + ); + } + + /** + * Get value of from parameter when there is a single shard + * and fetch phase is already executed + * Ref https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/search/SearchService.java#L715 + */ + private int getFromValueIfSingleShard(final NormalizationProcessorWorkflowExecuteRequest request) { + final SearchPhaseContext searchPhaseContext = request.getSearchPhaseContext(); + if (searchPhaseContext.getNumShards() > 1 || request.fetchSearchResultOptional.isEmpty()) { + return -1; + } + return searchPhaseContext.getRequest().source().from(); } /** @@ -173,19 +198,33 @@ private List getQueryTopDocs(final List quer return queryTopDocs; } - private void updateOriginalQueryResults(final CombineScoresDto combineScoresDTO) { + private void updateOriginalQueryResults(final CombineScoresDto combineScoresDTO, final boolean isFetchPhaseExecuted) { final List querySearchResults = combineScoresDTO.getQuerySearchResults(); final List queryTopDocs = getCompoundTopDocs(combineScoresDTO, querySearchResults); final Sort sort = combineScoresDTO.getSort(); + int totalScoreDocsCount = 0; for (int index = 0; index < querySearchResults.size(); index++) { QuerySearchResult querySearchResult = querySearchResults.get(index); CompoundTopDocs updatedTopDocs = queryTopDocs.get(index); + totalScoreDocsCount += updatedTopDocs.getScoreDocs().size(); TopDocsAndMaxScore updatedTopDocsAndMaxScore = new TopDocsAndMaxScore( buildTopDocs(updatedTopDocs, sort), maxScoreForShard(updatedTopDocs, sort != null) ); + // Fetch Phase had ran before the normalization phase, therefore update the from value in result of each shard. + // This will ensure the trimming of the search results. + if (isFetchPhaseExecuted) { + querySearchResult.from(combineScoresDTO.getFromValueForSingleShard()); + } querySearchResult.topDocs(updatedTopDocsAndMaxScore, querySearchResult.sortValueFormats()); } + + final int from = querySearchResults.get(0).from(); + if (from > totalScoreDocsCount) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "Reached end of search result, increase pagination_depth value to see more results") + ); + } } private List getCompoundTopDocs(CombineScoresDto combineScoresDTO, List querySearchResults) { @@ -244,7 +283,8 @@ private TopDocs buildTopDocs(CompoundTopDocs updatedTopDocs, Sort sort) { private void updateOriginalFetchResults( final List querySearchResults, final Optional fetchSearchResultOptional, - final List docIds + final List docIds, + final int fromValueForSingleShard ) { if (fetchSearchResultOptional.isEmpty()) { return; @@ -276,14 +316,21 @@ private void updateOriginalFetchResults( QuerySearchResult querySearchResult = querySearchResults.get(0); TopDocs topDocs = querySearchResult.topDocs().topDocs; + // Scenario to handle when calculating the trimmed length of updated search hits + // When normalization process runs after fetch phase, then search hits already fetched. Therefore, use the from value sent in the + // search request to calculate the effective length of updated search hits array. + int trimmedLengthOfSearchHits = topDocs.scoreDocs.length - fromValueForSingleShard; // iterate over the normalized/combined scores, that solves (1) and (3) - SearchHit[] updatedSearchHitArray = Arrays.stream(topDocs.scoreDocs).map(scoreDoc -> { + SearchHit[] updatedSearchHitArray = new SearchHit[trimmedLengthOfSearchHits]; + for (int i = 0; i < trimmedLengthOfSearchHits; i++) { + // Read topDocs after the desired from length + ScoreDoc scoreDoc = topDocs.scoreDocs[i + fromValueForSingleShard]; // get fetched hit content by doc_id SearchHit searchHit = docIdToSearchHit.get(scoreDoc.doc); // update score to normalized/combined value (3) searchHit.score(scoreDoc.score); - return searchHit; - }).toArray(SearchHit[]::new); + updatedSearchHitArray[i] = searchHit; + } SearchHits updatedSearchHits = new SearchHits( updatedSearchHitArray, querySearchResult.getTotalHits(), diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowExecuteRequest.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowExecuteRequest.java index ea0b54b9c..e818c1b31 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowExecuteRequest.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowExecuteRequest.java @@ -7,6 +7,7 @@ import lombok.AllArgsConstructor; import lombok.Builder; import lombok.Getter; +import org.opensearch.action.search.SearchPhaseContext; import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; import org.opensearch.search.fetch.FetchSearchResult; @@ -29,4 +30,5 @@ public class NormalizationProcessorWorkflowExecuteRequest { final ScoreCombinationTechnique combinationTechnique; boolean explain; final PipelineProcessingContext pipelineProcessingContext; + final SearchPhaseContext searchPhaseContext; } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/CombineScoresDto.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/CombineScoresDto.java index c4783969b..fecf5ca09 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/CombineScoresDto.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/CombineScoresDto.java @@ -29,4 +29,5 @@ public class CombineScoresDto { private List querySearchResults; @Nullable private Sort sort; + private int fromValueForSingleShard; } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java index 1779f20f7..40625adfb 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java @@ -70,14 +70,10 @@ public class ScoreCombiner { public void combineScores(final CombineScoresDto combineScoresDTO) { // iterate over results from each shard. Every CompoundTopDocs object has results from // multiple sub queries, doc ids may repeat for each sub query results + ScoreCombinationTechnique scoreCombinationTechnique = combineScoresDTO.getScoreCombinationTechnique(); + Sort sort = combineScoresDTO.getSort(); combineScoresDTO.getQueryTopDocs() - .forEach( - compoundQueryTopDocs -> combineShardScores( - combineScoresDTO.getScoreCombinationTechnique(), - compoundQueryTopDocs, - combineScoresDTO.getSort() - ) - ); + .forEach(compoundQueryTopDocs -> combineShardScores(scoreCombinationTechnique, compoundQueryTopDocs, sort)); } private void combineShardScores( diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java index 60d5870da..d1e339bd5 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java @@ -34,17 +34,22 @@ public final class HybridQuery extends Query implements Iterable { private final List subQueries; + private final HybridQueryContext queryContext; /** * Create new instance of hybrid query object based on collection of sub queries and filter query * @param subQueries collection of queries that are executed individually and contribute to a final list of combined scores * @param filterQueries list of filters that will be applied to each sub query. Each filter from the list is added as bool "filter" clause. If this is null sub queries will be executed as is */ - public HybridQuery(final Collection subQueries, final List filterQueries) { + public HybridQuery(final Collection subQueries, final List filterQueries, final HybridQueryContext hybridQueryContext) { Objects.requireNonNull(subQueries, "collection of queries must not be null"); if (subQueries.isEmpty()) { throw new IllegalArgumentException("collection of queries must not be empty"); } + Integer paginationDepth = hybridQueryContext.getPaginationDepth(); + if (Objects.nonNull(paginationDepth) && paginationDepth == 0) { + throw new IllegalArgumentException("pagination_depth must not be zero"); + } if (Objects.isNull(filterQueries) || filterQueries.isEmpty()) { this.subQueries = new ArrayList<>(subQueries); } else { @@ -57,10 +62,11 @@ public HybridQuery(final Collection subQueries, final List filterQ } this.subQueries = modifiedSubQueries; } + this.queryContext = hybridQueryContext; } - public HybridQuery(final Collection subQueries) { - this(subQueries, List.of()); + public HybridQuery(final Collection subQueries, final HybridQueryContext hybridQueryContext) { + this(subQueries, List.of(), hybridQueryContext); } /** @@ -128,7 +134,7 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException { return super.rewrite(indexSearcher); } final List rewrittenSubQueries = manager.getQueriesAfterRewrite(collectors); - return new HybridQuery(rewrittenSubQueries); + return new HybridQuery(rewrittenSubQueries, queryContext); } private Void rewriteQuery(Query query, HybridQueryExecutorCollector> collector) { @@ -190,6 +196,10 @@ public Collection getSubQueries() { return Collections.unmodifiableCollection(subQueries); } + public HybridQueryContext getQueryContext() { + return queryContext; + } + /** * Create the Weight used to score this query * diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java index 338758802..bea94e603 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java @@ -22,6 +22,7 @@ import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.IndexSettings; import org.opensearch.index.query.AbstractQueryBuilder; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryRewriteContext; @@ -35,6 +36,8 @@ import lombok.experimental.Accessors; import lombok.extern.log4j.Log4j2; +import static org.opensearch.neuralsearch.common.MinClusterVersionUtil.isClusterOnOrAfterMinReqVersionForPaginationInHybridQuery; + /** * Class abstract creation of a Query type "hybrid". Hybrid query will allow execution of multiple sub-queries and * collects score for each of those sub-query. @@ -48,14 +51,22 @@ public final class HybridQueryBuilder extends AbstractQueryBuilder queries = new ArrayList<>(); + private Integer paginationDepth; + static final int MAX_NUMBER_OF_SUB_QUERIES = 5; + private final static int DEFAULT_PAGINATION_DEPTH = 10; + private static final int LOWER_BOUND_OF_PAGINATION_DEPTH = 0; public HybridQueryBuilder(StreamInput in) throws IOException { super(in); queries.addAll(readQueries(in)); + if (isClusterOnOrAfterMinReqVersionForPaginationInHybridQuery()) { + paginationDepth = in.readOptionalInt(); + } } /** @@ -66,6 +77,9 @@ public HybridQueryBuilder(StreamInput in) throws IOException { @Override protected void doWriteTo(StreamOutput out) throws IOException { writeQueries(out, queries); + if (isClusterOnOrAfterMinReqVersionForPaginationInHybridQuery()) { + out.writeOptionalInt(paginationDepth); + } } /** @@ -95,6 +109,10 @@ protected void doXContent(XContentBuilder builder, Params params) throws IOExcep queryBuilder.toXContent(builder, params); } builder.endArray(); + // TODO https://github.com/opensearch-project/neural-search/issues/1097 + if (Objects.nonNull(paginationDepth)) { + builder.field(PAGINATION_DEPTH_FIELD.getPreferredName(), paginationDepth); + } printBoostAndQueryName(builder); builder.endObject(); } @@ -111,7 +129,9 @@ protected Query doToQuery(QueryShardContext queryShardContext) throws IOExceptio if (queryCollection.isEmpty()) { return Queries.newMatchNoDocsQuery(String.format(Locale.ROOT, "no clauses for %s query", NAME)); } - return new HybridQuery(queryCollection); + validatePaginationDepth(paginationDepth, queryShardContext); + HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(paginationDepth).build(); + return new HybridQuery(queryCollection, hybridQueryContext); } /** @@ -147,6 +167,7 @@ protected Query doToQuery(QueryShardContext queryShardContext) throws IOExceptio public static HybridQueryBuilder fromXContent(XContentParser parser) throws IOException { float boost = AbstractQueryBuilder.DEFAULT_BOOST; + int paginationDepth = DEFAULT_PAGINATION_DEPTH; final List queries = new ArrayList<>(); String queryName = null; @@ -194,6 +215,8 @@ public static HybridQueryBuilder fromXContent(XContentParser parser) throws IOEx } } else if (AbstractQueryBuilder.NAME_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { queryName = parser.text(); + } else if (PAGINATION_DEPTH_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + paginationDepth = parser.intValue(); } else { log.error(String.format(Locale.ROOT, "[%s] query does not support [%s]", NAME, currentFieldName)); throw new ParsingException( @@ -214,6 +237,9 @@ public static HybridQueryBuilder fromXContent(XContentParser parser) throws IOEx HybridQueryBuilder compoundQueryBuilder = new HybridQueryBuilder(); compoundQueryBuilder.queryName(queryName); compoundQueryBuilder.boost(boost); + if (isClusterOnOrAfterMinReqVersionForPaginationInHybridQuery()) { + compoundQueryBuilder.paginationDepth(paginationDepth); + } for (QueryBuilder query : queries) { compoundQueryBuilder.add(query); } @@ -233,6 +259,9 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryShardContext) throws I if (changed) { newBuilder.queryName(queryName); newBuilder.boost(boost); + if (isClusterOnOrAfterMinReqVersionForPaginationInHybridQuery()) { + newBuilder.paginationDepth(paginationDepth); + } return newBuilder; } else { return this; @@ -254,6 +283,7 @@ protected boolean doEquals(HybridQueryBuilder obj) { } EqualsBuilder equalsBuilder = new EqualsBuilder(); equalsBuilder.append(queries, obj.queries); + equalsBuilder.append(paginationDepth, obj.paginationDepth); return equalsBuilder.isEquals(); } @@ -263,7 +293,7 @@ protected boolean doEquals(HybridQueryBuilder obj) { */ @Override protected int doHashCode() { - return Objects.hash(queries); + return Objects.hash(queries, paginationDepth); } /** @@ -294,6 +324,29 @@ private Collection toQueries(Collection queryBuilders, Quer return queries; } + private static void validatePaginationDepth(final int paginationDepth, final QueryShardContext queryShardContext) { + if (Objects.isNull(paginationDepth)) { + return; + } + if (paginationDepth < LOWER_BOUND_OF_PAGINATION_DEPTH) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "pagination_depth should be greater than %s", LOWER_BOUND_OF_PAGINATION_DEPTH) + ); + } + // compare pagination depth with OpenSearch setting index.max_result_window + // see https://opensearch.org/docs/latest/install-and-configure/configuring-opensearch/index-settings/ + int maxResultWindowIndexSetting = queryShardContext.getIndexSettings().getMaxResultWindow(); + if (paginationDepth > maxResultWindowIndexSetting) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "pagination_depth should be less than or equal to %s setting", + IndexSettings.MAX_RESULT_WINDOW_SETTING.getKey() + ) + ); + } + } + /** * visit method to parse the HybridQueryBuilder by a visitor */ diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryContext.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryContext.java new file mode 100644 index 000000000..34706e6e7 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryContext.java @@ -0,0 +1,17 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.query; + +import lombok.Builder; +import lombok.Getter; + +/** + * Class that holds the low level information of hybrid query in the form of context + */ +@Builder +@Getter +public class HybridQueryContext { + private Integer paginationDepth; +} diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java index f9457f6ca..3c6a7271f 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java @@ -8,6 +8,7 @@ import lombok.RequiredArgsConstructor; import lombok.extern.log4j.Log4j2; import org.apache.lucene.index.IndexReader; +import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.Collector; import org.apache.lucene.search.CollectorManager; import org.apache.lucene.search.Weight; @@ -22,6 +23,7 @@ import org.opensearch.common.Nullable; import org.opensearch.common.lucene.search.FilteredCollector; import org.opensearch.common.lucene.search.TopDocsAndMaxScore; +import org.opensearch.neuralsearch.query.HybridQuery; import org.opensearch.neuralsearch.search.HitsThresholdChecker; import org.opensearch.neuralsearch.search.collector.HybridSearchCollector; import org.opensearch.neuralsearch.search.collector.HybridTopFieldDocSortCollector; @@ -52,6 +54,7 @@ import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createFieldDocStartStopElementForHybridSearchResults; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createFieldDocDelimiterElementForHybridSearchResults; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createSortFieldsForDelimiterResults; +import static org.opensearch.neuralsearch.util.HybridQueryUtil.isHybridQueryWrappedInBooleanQuery; /** * Collector manager based on HybridTopScoreDocCollector that allows users to parallelize counting the number of hits. @@ -80,14 +83,28 @@ public abstract class HybridCollectorManager implements CollectorManager 0) { + searchContext.from(0); + } + Weight filteringWeight = null; // Check for post filter to create weight for filter query and later use that weight in the search workflow if (Objects.nonNull(searchContext.parsedPostFilter()) && Objects.nonNull(searchContext.parsedPostFilter().query())) { @@ -461,6 +478,39 @@ private ReduceableSearchResult reduceSearchResults(final List clauseQuery.getQuery() instanceof HybridQuery); - } - @VisibleForTesting protected Query extractHybridQuery(final SearchContext searchContext, final Query query) { - if ((hasAliasFilter(query, searchContext) || hasNestedFieldOrNestedDocs(query, searchContext)) - && isWrappedHybridQuery(query) - && !((BooleanQuery) query).clauses().isEmpty()) { + if (isHybridQueryWrappedInBooleanQuery(searchContext, query)) { List booleanClauses = ((BooleanQuery) query).clauses(); if (!(booleanClauses.get(0).getQuery() instanceof HybridQuery)) { throw new IllegalStateException("cannot process hybrid query due to incorrect structure of top level query"); @@ -97,7 +85,7 @@ && isWrappedHybridQuery(query) .filter(clause -> BooleanClause.Occur.FILTER == clause.getOccur()) .map(BooleanClause::getQuery) .collect(Collectors.toList()); - HybridQuery hybridQueryWithFilter = new HybridQuery(hybridQuery.getSubQueries(), filterQueries); + HybridQuery hybridQueryWithFilter = new HybridQuery(hybridQuery.getSubQueries(), filterQueries, hybridQuery.getQueryContext()); return hybridQueryWithFilter; } return query; diff --git a/src/main/java/org/opensearch/neuralsearch/util/HybridQueryUtil.java b/src/main/java/org/opensearch/neuralsearch/util/HybridQueryUtil.java index d19985f5c..e8794131f 100644 --- a/src/main/java/org/opensearch/neuralsearch/util/HybridQueryUtil.java +++ b/src/main/java/org/opensearch/neuralsearch/util/HybridQueryUtil.java @@ -20,6 +20,9 @@ @NoArgsConstructor(access = AccessLevel.PRIVATE) public class HybridQueryUtil { + /** + * This method validates whether the query object is an instance of hybrid query + */ public static boolean isHybridQuery(final Query query, final SearchContext searchContext) { if (query instanceof HybridQuery) { return true; @@ -52,7 +55,7 @@ public static boolean isHybridQuery(final Query query, final SearchContext searc return false; } - public static boolean hasNestedFieldOrNestedDocs(final Query query, final SearchContext searchContext) { + private static boolean hasNestedFieldOrNestedDocs(final Query query, final SearchContext searchContext) { return searchContext.mapperService().hasNested() && new NestedHelper(searchContext.mapperService()).mightMatchNestedDocs(query); } @@ -61,7 +64,16 @@ private static boolean isWrappedHybridQuery(final Query query) { && ((BooleanQuery) query).clauses().stream().anyMatch(clauseQuery -> clauseQuery.getQuery() instanceof HybridQuery); } - public static boolean hasAliasFilter(final Query query, final SearchContext searchContext) { + private static boolean hasAliasFilter(final Query query, final SearchContext searchContext) { return Objects.nonNull(searchContext.aliasFilter()); } + + /** + * This method checks whether hybrid query is wrapped under boolean query object + */ + public static boolean isHybridQueryWrappedInBooleanQuery(final SearchContext searchContext, final Query query) { + return ((hasAliasFilter(query, searchContext) || hasNestedFieldOrNestedDocs(query, searchContext)) + && isWrappedHybridQuery(query) + && !((BooleanQuery) query).clauses().isEmpty()); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java index 5f45b14fe..87dac8674 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java @@ -274,7 +274,7 @@ public void testEmptySearchResults_whenEmptySearchResults_thenDoNotExecuteWorkfl SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); normalizationProcessor.process(null, searchPhaseContext); - verify(normalizationProcessorWorkflow, never()).execute(any(), any(), any(), any()); + verify(normalizationProcessorWorkflow, never()).execute(any(), any(), any(), any(), any()); } public void testNotHybridSearchResult_whenResultsNotEmptyAndNotHybridSearchResult_thenDoNotExecuteWorkflow() { @@ -330,7 +330,7 @@ public void testNotHybridSearchResult_whenResultsNotEmptyAndNotHybridSearchResul when(searchPhaseContext.getNumShards()).thenReturn(numberOfShards); normalizationProcessor.process(queryPhaseResultConsumer, searchPhaseContext); - verify(normalizationProcessorWorkflow, never()).execute(any(), any(), any(), any()); + verify(normalizationProcessorWorkflow, never()).execute(any(), any(), any(), any(), any()); } public void testResultTypes_whenQueryAndFetchPresentAndSizeSame_thenCallNormalization() { @@ -346,6 +346,7 @@ public void testResultTypes_whenQueryAndFetchPresentAndSizeSame_thenCallNormaliz ); SearchRequest searchRequest = new SearchRequest(INDEX_NAME); + searchRequest.source().from(0); searchRequest.setBatchedReduceSize(4); AtomicReference onPartialMergeFailure = new AtomicReference<>(); QueryPhaseResultConsumer queryPhaseResultConsumer = new QueryPhaseResultConsumer( diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java index 59fb51563..9969081a6 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java @@ -12,6 +12,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Optional; @@ -19,6 +20,8 @@ import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TotalHits; import org.opensearch.action.OriginalIndices; +import org.opensearch.action.search.SearchPhaseContext; +import org.opensearch.action.search.SearchRequest; import org.opensearch.common.lucene.search.TopDocsAndMaxScore; import org.opensearch.core.index.shard.ShardId; import org.opensearch.neuralsearch.util.TestUtils; @@ -29,6 +32,7 @@ import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.search.SearchShardTarget; +import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.fetch.FetchSearchResult; import org.opensearch.search.internal.ShardSearchRequest; import org.opensearch.search.query.QuerySearchResult; @@ -71,12 +75,18 @@ public void testSearchResultTypes_whenResultsOfHybridSearch_thenDoNormalizationC querySearchResult.setShardIndex(shardId); querySearchResults.add(querySearchResult); } - + SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); + SearchRequest searchRequest = mock(SearchRequest.class); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.from(0); + when(searchPhaseContext.getRequest()).thenReturn(searchRequest); + when(searchRequest.source()).thenReturn(searchSourceBuilder); normalizationProcessorWorkflow.execute( querySearchResults, Optional.empty(), ScoreNormalizationFactory.DEFAULT_METHOD, - ScoreCombinationFactory.DEFAULT_METHOD + ScoreCombinationFactory.DEFAULT_METHOD, + searchPhaseContext ); TestUtils.assertQueryResultScores(querySearchResults); @@ -113,12 +123,18 @@ public void testSearchResultTypes_whenNoMatches_thenReturnZeroResults() { querySearchResult.setShardIndex(shardId); querySearchResults.add(querySearchResult); } - + SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); + SearchRequest searchRequest = mock(SearchRequest.class); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.from(0); + when(searchPhaseContext.getRequest()).thenReturn(searchRequest); + when(searchRequest.source()).thenReturn(searchSourceBuilder); normalizationProcessorWorkflow.execute( querySearchResults, Optional.empty(), ScoreNormalizationFactory.DEFAULT_METHOD, - ScoreCombinationFactory.DEFAULT_METHOD + ScoreCombinationFactory.DEFAULT_METHOD, + searchPhaseContext ); TestUtils.assertQueryResultScoresWithNoMatches(querySearchResults); @@ -172,12 +188,18 @@ public void testFetchResults_whenOneShardAndQueryAndFetchResultsPresent_thenDoNo new SearchHit(0, "10", Map.of(), Map.of()), }; SearchHits searchHits = new SearchHits(searchHitArray, new TotalHits(7, TotalHits.Relation.EQUAL_TO), 10); fetchSearchResult.hits(searchHits); - + SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); + SearchRequest searchRequest = mock(SearchRequest.class); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.from(0); + when(searchPhaseContext.getRequest()).thenReturn(searchRequest); + when(searchRequest.source()).thenReturn(searchSourceBuilder); normalizationProcessorWorkflow.execute( querySearchResults, Optional.of(fetchSearchResult), ScoreNormalizationFactory.DEFAULT_METHOD, - ScoreCombinationFactory.DEFAULT_METHOD + ScoreCombinationFactory.DEFAULT_METHOD, + searchPhaseContext ); TestUtils.assertQueryResultScores(querySearchResults); @@ -232,12 +254,18 @@ public void testFetchResults_whenOneShardAndMultipleNodes_thenDoNormalizationCom new SearchHit(-1, "10", Map.of(), Map.of()), }; SearchHits searchHits = new SearchHits(searchHitArray, new TotalHits(7, TotalHits.Relation.EQUAL_TO), 10); fetchSearchResult.hits(searchHits); - + SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); + SearchRequest searchRequest = mock(SearchRequest.class); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.from(0); + when(searchPhaseContext.getRequest()).thenReturn(searchRequest); + when(searchRequest.source()).thenReturn(searchSourceBuilder); normalizationProcessorWorkflow.execute( querySearchResults, Optional.of(fetchSearchResult), ScoreNormalizationFactory.DEFAULT_METHOD, - ScoreCombinationFactory.DEFAULT_METHOD + ScoreCombinationFactory.DEFAULT_METHOD, + searchPhaseContext ); TestUtils.assertQueryResultScores(querySearchResults); @@ -284,14 +312,20 @@ public void testFetchResultsAndNoCache_whenOneShardAndMultipleNodesAndMismatchRe querySearchResults.add(querySearchResult); SearchHits searchHits = getSearchHits(); fetchSearchResult.hits(searchHits); - + SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); + SearchRequest searchRequest = mock(SearchRequest.class); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.from(0); + when(searchPhaseContext.getRequest()).thenReturn(searchRequest); + when(searchRequest.source()).thenReturn(searchSourceBuilder); expectThrows( IllegalStateException.class, () -> normalizationProcessorWorkflow.execute( querySearchResults, Optional.of(fetchSearchResult), ScoreNormalizationFactory.DEFAULT_METHOD, - ScoreCombinationFactory.DEFAULT_METHOD + ScoreCombinationFactory.DEFAULT_METHOD, + searchPhaseContext ) ); } @@ -336,18 +370,88 @@ public void testFetchResultsAndCache_whenOneShardAndMultipleNodesAndMismatchResu querySearchResults.add(querySearchResult); SearchHits searchHits = getSearchHits(); fetchSearchResult.hits(searchHits); - + SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); + SearchRequest searchRequest = mock(SearchRequest.class); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.from(0); + when(searchPhaseContext.getRequest()).thenReturn(searchRequest); + when(searchRequest.source()).thenReturn(searchSourceBuilder); normalizationProcessorWorkflow.execute( querySearchResults, Optional.of(fetchSearchResult), ScoreNormalizationFactory.DEFAULT_METHOD, - ScoreCombinationFactory.DEFAULT_METHOD + ScoreCombinationFactory.DEFAULT_METHOD, + searchPhaseContext ); TestUtils.assertQueryResultScores(querySearchResults); TestUtils.assertFetchResultScores(fetchSearchResult, 4); } + public void testNormalization_whenFromIsGreaterThanResultsSize_thenFail() { + NormalizationProcessorWorkflow normalizationProcessorWorkflow = spy( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()) + ); + + List querySearchResults = new ArrayList<>(); + for (int shardId = 0; shardId < 4; shardId++) { + SearchShardTarget searchShardTarget = new SearchShardTarget( + "node", + new ShardId("index", "uuid", shardId), + null, + OriginalIndices.NONE + ); + QuerySearchResult querySearchResult = new QuerySearchResult(); + querySearchResult.topDocs( + new TopDocsAndMaxScore( + new TopDocs( + new TotalHits(4, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { + createStartStopElementForHybridSearchResults(0), + createDelimiterElementForHybridSearchResults(0), + new ScoreDoc(0, 0.5f), + new ScoreDoc(2, 0.3f), + new ScoreDoc(4, 0.25f), + new ScoreDoc(10, 0.2f), + createStartStopElementForHybridSearchResults(0) } + ), + 0.5f + ), + null + ); + querySearchResult.setSearchShardTarget(searchShardTarget); + querySearchResult.setShardIndex(shardId); + // requested page is out of bound for the total number of results + querySearchResult.from(17); + querySearchResults.add(querySearchResult); + } + + SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); + when(searchPhaseContext.getNumShards()).thenReturn(4); + SearchRequest searchRequest = mock(SearchRequest.class); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.from(17); + when(searchPhaseContext.getRequest()).thenReturn(searchRequest); + + NormalizationProcessorWorkflowExecuteRequest normalizationExecuteDto = NormalizationProcessorWorkflowExecuteRequest.builder() + .querySearchResults(querySearchResults) + .fetchSearchResultOptional(Optional.empty()) + .normalizationTechnique(ScoreNormalizationFactory.DEFAULT_METHOD) + .combinationTechnique(ScoreCombinationFactory.DEFAULT_METHOD) + .searchPhaseContext(searchPhaseContext) + .build(); + + IllegalArgumentException illegalArgumentException = assertThrows( + IllegalArgumentException.class, + () -> normalizationProcessorWorkflow.execute(normalizationExecuteDto) + ); + + assertEquals( + String.format(Locale.ROOT, "Reached end of search result, increase pagination_depth value to see more results"), + illegalArgumentException.getMessage() + ); + } + private static SearchHits getSearchHits() { SearchHit[] searchHitArray = new SearchHit[] { new SearchHit(-1, "10", Map.of(), Map.of()), diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java index 1640d8e02..a6cf4d29e 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java @@ -11,6 +11,7 @@ import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; import static org.opensearch.index.query.AbstractQueryBuilder.BOOST_FIELD; import static org.opensearch.index.query.AbstractQueryBuilder.DEFAULT_BOOST; +import static org.opensearch.index.remote.RemoteStoreEnums.PathType.HASHED_PREFIX; import static org.opensearch.knn.index.query.KNNQueryBuilder.FILTER_FIELD; import static org.opensearch.neuralsearch.util.TestUtils.xContentBuilderToMap; import static org.opensearch.neuralsearch.query.NeuralQueryBuilder.K_FIELD; @@ -33,7 +34,9 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.Version; +import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.UUIDs; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Setting; @@ -50,6 +53,7 @@ import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.IndexSettings; import org.opensearch.index.mapper.MappedFieldType; import org.opensearch.index.mapper.TextFieldMapper; import org.opensearch.index.query.MatchAllQueryBuilder; @@ -57,6 +61,7 @@ import org.opensearch.index.query.QueryBuilders; import org.opensearch.index.query.QueryShardContext; import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.index.remote.RemoteStoreEnums; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; @@ -119,6 +124,7 @@ public void testDoToQuery_whenNoSubqueries_thenBuildSuccessfully() { @SneakyThrows public void testDoToQuery_whenOneSubquery_thenBuildSuccessfully() { HybridQueryBuilder queryBuilder = new HybridQueryBuilder(); + queryBuilder.paginationDepth(10); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); @@ -130,6 +136,10 @@ public void testDoToQuery_whenOneSubquery_thenBuildSuccessfully() { when(mockKNNVectorField.getKnnMappingConfig().getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); when(mockQueryShardContext.fieldMapper(eq(VECTOR_FIELD_NAME))).thenReturn(mockKNNVectorField); + IndexMetadata indexMetadata = getIndexMetadata(); + Settings settings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, Integer.toString(3)).build(); + IndexSettings indexSettings = new IndexSettings(indexMetadata, settings); + when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); NeuralQueryBuilder neuralQueryBuilder = NeuralQueryBuilder.builder() .fieldName(VECTOR_FIELD_NAME) @@ -154,6 +164,7 @@ public void testDoToQuery_whenOneSubquery_thenBuildSuccessfully() { @SneakyThrows public void testDoToQuery_whenMultipleSubqueries_thenBuildSuccessfully() { HybridQueryBuilder queryBuilder = new HybridQueryBuilder(); + queryBuilder.paginationDepth(10); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); @@ -165,6 +176,10 @@ public void testDoToQuery_whenMultipleSubqueries_thenBuildSuccessfully() { when(mockKNNVectorField.getKnnMappingConfig().getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); when(mockQueryShardContext.fieldMapper(eq(VECTOR_FIELD_NAME))).thenReturn(mockKNNVectorField); + IndexMetadata indexMetadata = getIndexMetadata(); + Settings settings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, Integer.toString(3)).build(); + IndexSettings indexSettings = new IndexSettings(indexMetadata, settings); + when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); NeuralQueryBuilder neuralQueryBuilder = NeuralQueryBuilder.builder() .fieldName(VECTOR_FIELD_NAME) @@ -201,6 +216,81 @@ public void testDoToQuery_whenMultipleSubqueries_thenBuildSuccessfully() { assertEquals(TERM_QUERY_TEXT, termQuery.getTerm().text()); } + @SneakyThrows + public void testDoToQuery_whenPaginationDepthIsGreaterThan10000_thenBuildSuccessfully() { + HybridQueryBuilder queryBuilder = new HybridQueryBuilder(); + queryBuilder.paginationDepth(10001); + Index dummyIndex = new Index("dummy", "dummy"); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); + KNNMappingConfig mockKNNMappingConfig = mock(KNNMappingConfig.class); + KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.FAISS, SpaceType.L2, MethodComponentContext.EMPTY); + when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(mockKNNMappingConfig); + when(mockKNNMappingConfig.getKnnMethodContext()).thenReturn(Optional.of(knnMethodContext)); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + when(mockKNNVectorField.getKnnMappingConfig().getDimension()).thenReturn(4); + when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); + when(mockQueryShardContext.fieldMapper(eq(VECTOR_FIELD_NAME))).thenReturn(mockKNNVectorField); + IndexMetadata indexMetadata = getIndexMetadata(); + Settings settings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, Integer.toString(3)).build(); + IndexSettings indexSettings = new IndexSettings(indexMetadata, settings); + when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); + + NeuralQueryBuilder neuralQueryBuilder = NeuralQueryBuilder.builder() + .fieldName(VECTOR_FIELD_NAME) + .queryText(QUERY_TEXT) + .modelId(MODEL_ID) + .k(K) + .vectorSupplier(TEST_VECTOR_SUPPLIER) + .build(); + + queryBuilder.add(neuralQueryBuilder); + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> queryBuilder.doToQuery(mockQueryShardContext) + ); + assertThat( + exception.getMessage(), + containsString("pagination_depth should be less than or equal to index.max_result_window setting") + ); + } + + @SneakyThrows + public void testDoToQuery_whenPaginationDepthIsLessThanZero_thenBuildSuccessfully() { + HybridQueryBuilder queryBuilder = new HybridQueryBuilder(); + queryBuilder.paginationDepth(-1); + Index dummyIndex = new Index("dummy", "dummy"); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); + KNNMappingConfig mockKNNMappingConfig = mock(KNNMappingConfig.class); + KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.FAISS, SpaceType.L2, MethodComponentContext.EMPTY); + when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(mockKNNMappingConfig); + when(mockKNNMappingConfig.getKnnMethodContext()).thenReturn(Optional.of(knnMethodContext)); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + when(mockKNNVectorField.getKnnMappingConfig().getDimension()).thenReturn(4); + when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); + when(mockQueryShardContext.fieldMapper(eq(VECTOR_FIELD_NAME))).thenReturn(mockKNNVectorField); + IndexMetadata indexMetadata = getIndexMetadata(); + Settings settings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, Integer.toString(3)).build(); + IndexSettings indexSettings = new IndexSettings(indexMetadata, settings); + when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); + + NeuralQueryBuilder neuralQueryBuilder = NeuralQueryBuilder.builder() + .fieldName(VECTOR_FIELD_NAME) + .queryText(QUERY_TEXT) + .modelId(MODEL_ID) + .k(K) + .vectorSupplier(TEST_VECTOR_SUPPLIER) + .build(); + + queryBuilder.add(neuralQueryBuilder); + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> queryBuilder.doToQuery(mockQueryShardContext) + ); + assertThat(exception.getMessage(), containsString("pagination_depth should be greater than 0")); + } + @SneakyThrows public void testDoToQuery_whenTooManySubqueries_thenFail() { // create query with 6 sub-queries, which is more than current max allowed @@ -336,6 +426,7 @@ public void testFromXContent_whenMultipleSubQueries_thenBuildSuccessfully() { assertEquals(2, queryTwoSubQueries.queries().size()); assertTrue(queryTwoSubQueries.queries().get(0) instanceof NeuralQueryBuilder); assertTrue(queryTwoSubQueries.queries().get(1) instanceof TermQueryBuilder); + assertEquals(10, queryTwoSubQueries.paginationDepth().intValue()); // verify knn vector query NeuralQueryBuilder neuralQueryBuilder = (NeuralQueryBuilder) queryTwoSubQueries.queries().get(0); assertEquals(VECTOR_FIELD_NAME, neuralQueryBuilder.fieldName()); @@ -409,6 +500,7 @@ public void testFromXContent_whenIncorrectFormat_thenFail() { @SneakyThrows public void testToXContent_whenIncomingJsonIsCorrect_thenSuccessful() { + setUpClusterService(); HybridQueryBuilder queryBuilder = new HybridQueryBuilder(); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); @@ -537,6 +629,7 @@ public void testHashAndEquals_whenSameOrIdenticalObject_thenReturnEqual() { } public void testHashAndEquals_whenSubQueriesDifferent_thenReturnNotEqual() { + setUpClusterService(); String modelId = "testModelId"; String fieldName = "fieldTwo"; String queryText = "query text"; @@ -637,6 +730,7 @@ public void testHashAndEquals_whenSubQueriesDifferent_thenReturnNotEqual() { @SneakyThrows public void testRewrite_whenMultipleSubQueries_thenReturnBuilderForEachSubQuery() { + setUpClusterService(); HybridQueryBuilder queryBuilder = new HybridQueryBuilder(); NeuralQueryBuilder neuralQueryBuilder = NeuralQueryBuilder.builder() .fieldName(VECTOR_FIELD_NAME) @@ -744,6 +838,7 @@ public void testBoost_whenNonDefaultBoostSet_thenFail() { @SneakyThrows public void testBoost_whenDefaultBoostSet_thenBuildSuccessfully() { + setUpClusterService(); // create query with 6 sub-queries, which is more than current max allowed XContentBuilder xContentBuilderWithNonDefaultBoost = XContentFactory.jsonBuilder() .startObject() @@ -794,6 +889,10 @@ public void testBuild_whenValidParameters_thenCreateQuery() { MappedFieldType fieldType = mock(MappedFieldType.class); when(context.fieldMapper(fieldName)).thenReturn(fieldType); when(fieldType.typeName()).thenReturn("rank_features"); + IndexMetadata indexMetadata = getIndexMetadata(); + Settings settings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, Integer.toString(3)).build(); + IndexSettings indexSettings = new IndexSettings(indexMetadata, settings); + when(context.getIndexSettings()).thenReturn(indexSettings); // Create HybridQueryBuilder instance (no spy since it's final) NeuralSparseQueryBuilder neuralSparseQueryBuilder = new NeuralSparseQueryBuilder(); @@ -802,6 +901,7 @@ public void testBuild_whenValidParameters_thenCreateQuery() { .modelId(modelId) .queryTokensSupplier(() -> Map.of("token1", 1.0f, "token2", 0.5f)); HybridQueryBuilder builder = new HybridQueryBuilder().add(neuralSparseQueryBuilder); + builder.paginationDepth(10); // Build query Query query = builder.toQuery(context); @@ -813,6 +913,7 @@ public void testBuild_whenValidParameters_thenCreateQuery() { @SneakyThrows public void testDoEquals_whenSameParameters_thenEqual() { + setUpClusterService(); // Create neural queries NeuralQueryBuilder neuralQueryBuilder1 = NeuralQueryBuilder.builder() .fieldName("test") @@ -894,4 +995,25 @@ private void initKNNSettings() { when(clusterService.getClusterSettings()).thenReturn(new ClusterSettings(Settings.EMPTY, defaultClusterSettings)); KNNSettings.state().setClusterService(clusterService); } + + private static IndexMetadata getIndexMetadata() { + Map remoteCustomData = Map.of( + RemoteStoreEnums.PathType.NAME, + HASHED_PREFIX.name(), + RemoteStoreEnums.PathHashAlgorithm.NAME, + RemoteStoreEnums.PathHashAlgorithm.FNV_1A_BASE64.name(), + IndexMetadata.TRANSLOG_METADATA_KEY, + "false" + ); + Settings idxSettings = Settings.builder() + .put(IndexMetadata.SETTING_VERSION_CREATED, Version.CURRENT) + .put(IndexMetadata.SETTING_INDEX_UUID, UUIDs.randomBase64UUID()) + .build(); + IndexMetadata indexMetadata = new IndexMetadata.Builder("test").settings(idxSettings) + .numberOfShards(1) + .numberOfReplicas(0) + .putCustom(IndexMetadata.REMOTE_STORE_CUSTOM_KEY, remoteCustomData) + .build(); + return indexMetadata; + } } diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java index 610e08dd0..c3087a1e4 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java @@ -28,6 +28,7 @@ import org.junit.Before; import org.opensearch.client.ResponseException; import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.index.query.MatchQueryBuilder; import org.opensearch.index.query.NestedQueryBuilder; import org.opensearch.index.query.QueryBuilder; @@ -793,21 +794,130 @@ public void testConcurrentSearchWithMultipleSlices_whenMultipleShardsIndex_thenS } } - // TODO remove this test after following issue https://github.com/opensearch-project/neural-search/issues/280 gets resolved. @SneakyThrows - public void testHybridQuery_whenFromIsSetInSearchRequest_thenFail() { + public void testPaginationOnSingleShard_whenConcurrentSearchEnabled_thenSuccessful() { try { + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); - MatchQueryBuilder matchQueryBuilder = QueryBuilders.matchQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); - HybridQueryBuilder hybridQueryBuilderOnlyTerm = new HybridQueryBuilder(); - hybridQueryBuilderOnlyTerm.add(matchQueryBuilder); + testHybridQuery_whenFromAndPaginationDepthIsGreaterThanZero_thenSuccessful(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD, null, null, SEARCH_PIPELINE); + } + } - ResponseException exceptionNoNestedTypes = expectThrows( + @SneakyThrows + public void testPaginationOnSingleShard_whenConcurrentSearchDisabled_thenSuccessful() { + try { + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); + createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); + testHybridQuery_whenFromAndPaginationDepthIsGreaterThanZero_thenSuccessful(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD, null, null, SEARCH_PIPELINE); + } + } + + @SneakyThrows + public void testPaginationOnMultipleShard_whenConcurrentSearchEnabled_thenSuccessful() { + try { + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME); + createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); + testHybridQuery_whenFromAndPaginationDepthIsGreaterThanZero_thenSuccessful(TEST_MULTI_DOC_INDEX_NAME); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_NAME, null, null, SEARCH_PIPELINE); + } + } + + @SneakyThrows + public void testPaginationOnMultipleShard_whenConcurrentSearchDisabled_thenSuccessful() { + try { + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME); + createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); + testHybridQuery_whenFromAndPaginationDepthIsGreaterThanZero_thenSuccessful(TEST_MULTI_DOC_INDEX_NAME); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_NAME, null, null, SEARCH_PIPELINE); + } + } + + @SneakyThrows + public void testHybridQuery_whenFromAndPaginationDepthIsGreaterThanZero_thenSuccessful(String indexName) { + HybridQueryBuilder hybridQueryBuilderOnlyMatchAll = new HybridQueryBuilder(); + hybridQueryBuilderOnlyMatchAll.add(new MatchAllQueryBuilder()); + hybridQueryBuilderOnlyMatchAll.paginationDepth(10); + + Map searchResponseAsMap = search( + indexName, + hybridQueryBuilderOnlyMatchAll, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE), + null, + null, + null, + false, + null, + 2 + ); + + assertEquals(2, getHitCount(searchResponseAsMap)); + Map total = getTotalHits(searchResponseAsMap); + assertNotNull(total.get("value")); + assertEquals(4, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + } + + @SneakyThrows + public void testHybridQuery_whenFromIsGreaterThanZeroAndPaginationDepthIsNotSent_thenSuccessful() { + try { + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); + createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); + HybridQueryBuilder hybridQueryBuilderOnlyMatchAll = new HybridQueryBuilder(); + hybridQueryBuilderOnlyMatchAll.add(new MatchAllQueryBuilder()); + + Map searchResponseAsMap = search( + TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD, + hybridQueryBuilderOnlyMatchAll, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE), + null, + null, + null, + false, + null, + 2 + ); + + assertEquals(2, getHitCount(searchResponseAsMap)); + Map total = getTotalHits(searchResponseAsMap); + assertNotNull(total.get("value")); + assertEquals(4, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD, null, null, SEARCH_PIPELINE); + } + } + + @SneakyThrows + public void testHybridQuery_whenFromIsGreaterThanTotalResultCount_thenFail() { + try { + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); + createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); + HybridQueryBuilder hybridQueryBuilderOnlyMatchAll = new HybridQueryBuilder(); + hybridQueryBuilderOnlyMatchAll.add(new MatchAllQueryBuilder()); + + ResponseException responseException = assertThrows( ResponseException.class, () -> search( TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD, - hybridQueryBuilderOnlyTerm, + hybridQueryBuilderOnlyMatchAll, null, 10, Map.of("search_pipeline", SEARCH_PIPELINE), @@ -816,18 +926,50 @@ public void testHybridQuery_whenFromIsSetInSearchRequest_thenFail() { null, false, null, - 10 + 5 ) - ); org.hamcrest.MatcherAssert.assertThat( - exceptionNoNestedTypes.getMessage(), - allOf( - containsString("In the current OpenSearch version pagination is not supported with hybrid query"), - containsString("illegal_argument_exception") + responseException.getMessage(), + allOf(containsString("Reached end of search result, increase pagination_depth value to see more results")) + ); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD, null, null, SEARCH_PIPELINE); + } + } + + @SneakyThrows + public void testHybridQuery_whenPaginationDepthIsOutOfRange_thenFail() { + try { + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); + createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); + HybridQueryBuilder hybridQueryBuilderOnlyMatchAll = new HybridQueryBuilder(); + hybridQueryBuilderOnlyMatchAll.add(new MatchAllQueryBuilder()); + hybridQueryBuilderOnlyMatchAll.paginationDepth(100001); + + ResponseException responseException = assertThrows( + ResponseException.class, + () -> search( + TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD, + hybridQueryBuilderOnlyMatchAll, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE), + null, + null, + null, + false, + null, + 0 ) ); + + org.hamcrest.MatcherAssert.assertThat( + responseException.getMessage(), + allOf(containsString("pagination_depth should be less than or equal to index.max_result_window setting")) + ); } finally { wipeOfTestResources(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD, null, null, SEARCH_PIPELINE); } diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryTests.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryTests.java index 15f0621e8..26babdbce 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryTests.java @@ -72,16 +72,19 @@ public void testQueryBasics_whenMultipleDifferentQueries_thenSuccessful() { when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); HybridQuery query1 = new HybridQuery( - List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext)) + List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext)), + new HybridQueryContext(10) ); HybridQuery query2 = new HybridQuery( - List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext)) + List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext)), + new HybridQueryContext(10) ); HybridQuery query3 = new HybridQuery( List.of( QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext), QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_ANOTHER_QUERY_TEXT).toQuery(mockQueryShardContext) - ) + ), + new HybridQueryContext(10) ); QueryUtils.check(query1); QueryUtils.checkEqual(query1, query2); @@ -96,6 +99,7 @@ public void testQueryBasics_whenMultipleDifferentQueries_thenSuccessful() { countOfQueries++; } assertEquals(2, countOfQueries); + assertEquals(10, query3.getQueryContext().getPaginationDepth().intValue()); } @SneakyThrows @@ -103,6 +107,7 @@ public void testRewrite_whenRewriteQuery_thenSuccessful() { QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + String field1Value = "text1"; Directory directory = newDirectory(); final IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); @@ -120,14 +125,18 @@ public void testRewrite_whenRewriteQuery_thenSuccessful() { // Test with TermQuery HybridQuery hybridQueryWithTerm = new HybridQuery( - List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext)) + List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext)), + new HybridQueryContext(10) ); Query rewritten = hybridQueryWithTerm.rewrite(reader); // term query is the same after we rewrite it assertSame(hybridQueryWithTerm, rewritten); // Test empty query list - IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> new HybridQuery(List.of())); + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> new HybridQuery(List.of(), new HybridQueryContext(10)) + ); assertThat(exception.getMessage(), containsString("collection of queries must not be empty")); w.close(); @@ -160,7 +169,8 @@ public void testWithRandomDocuments_whenMultipleTermSubQueriesWithMatch_thenRetu IndexSearcher searcher = newSearcher(reader); HybridQuery query = new HybridQuery( - List.of(new TermQuery(new Term(TEXT_FIELD_NAME, field1Value)), new TermQuery(new Term(TEXT_FIELD_NAME, field2Value))) + List.of(new TermQuery(new Term(TEXT_FIELD_NAME, field1Value)), new TermQuery(new Term(TEXT_FIELD_NAME, field2Value))), + new HybridQueryContext(10) ); // executing search query, getting up to 3 docs in result TopDocs hybridQueryResult = searcher.search(query, 3); @@ -206,7 +216,7 @@ public void testWithRandomDocuments_whenOneTermSubQueryWithoutMatch_thenReturnSu DirectoryReader reader = DirectoryReader.open(w); IndexSearcher searcher = newSearcher(reader); - HybridQuery query = new HybridQuery(List.of(new TermQuery(new Term(TEXT_FIELD_NAME, QUERY_TEXT)))); + HybridQuery query = new HybridQuery(List.of(new TermQuery(new Term(TEXT_FIELD_NAME, QUERY_TEXT))), new HybridQueryContext(10)); // executing search query, getting up to 3 docs in result TopDocs hybridQueryResult = searcher.search(query, 3); @@ -242,7 +252,8 @@ public void testWithRandomDocuments_whenMultipleTermSubQueriesWithoutMatch_thenR IndexSearcher searcher = newSearcher(reader); HybridQuery query = new HybridQuery( - List.of(new TermQuery(new Term(TEXT_FIELD_NAME, QUERY_TEXT)), new TermQuery(new Term(TEXT_FIELD_NAME, QUERY_TEXT))) + List.of(new TermQuery(new Term(TEXT_FIELD_NAME, QUERY_TEXT)), new TermQuery(new Term(TEXT_FIELD_NAME, QUERY_TEXT))), + new HybridQueryContext(10) ); // executing search query, getting up to 3 docs in result TopDocs hybridQueryResult = searcher.search(query, 3); @@ -256,10 +267,25 @@ public void testWithRandomDocuments_whenMultipleTermSubQueriesWithoutMatch_thenR @SneakyThrows public void testWithRandomDocuments_whenNoSubQueries_thenFail() { - IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> new HybridQuery(List.of())); + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> new HybridQuery(List.of(), new HybridQueryContext(10)) + ); assertThat(exception.getMessage(), containsString("collection of queries must not be empty")); } + @SneakyThrows + public void testWithRandomDocuments_whenPaginationDepthIsZero_thenFail() { + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> new HybridQuery( + List.of(new TermQuery(new Term(TEXT_FIELD_NAME, QUERY_TEXT)), new TermQuery(new Term(TEXT_FIELD_NAME, QUERY_TEXT))), + new HybridQueryContext(0) + ) + ); + assertThat(exception.getMessage(), containsString("pagination_depth must not be zero")); + } + @SneakyThrows public void testToString_whenCallQueryToString_thenSuccessful() { QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); @@ -273,7 +299,8 @@ public void testToString_whenCallQueryToString_thenSuccessful() { new BoolQueryBuilder().should(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT)) .should(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_ANOTHER_QUERY_TEXT)) .toQuery(mockQueryShardContext) - ) + ), + new HybridQueryContext(10) ); String queryString = query.toString(TEXT_FIELD_NAME); @@ -293,7 +320,8 @@ public void testFilter_whenSubQueriesWithFilterPassed_thenSuccessful() { QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext), QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_ANOTHER_QUERY_TEXT).toQuery(mockQueryShardContext) ), - List.of(filter) + List.of(filter), + new HybridQueryContext(10) ); QueryUtils.check(hybridQuery); diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryWeightTests.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryWeightTests.java index 0e32b5e78..024c5e6e3 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryWeightTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryWeightTests.java @@ -61,7 +61,8 @@ public void testScorerIterator_whenExecuteQuery_thenScorerIteratorSuccessful() { IndexReader reader = DirectoryReader.open(w); HybridQuery hybridQueryWithTerm = new HybridQuery( - List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext)) + List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext)), + new HybridQueryContext(10) ); IndexSearcher searcher = newSearcher(reader); Weight weight = hybridQueryWithTerm.createWeight(searcher, ScoreMode.TOP_SCORES, 1.0f); @@ -117,7 +118,8 @@ public void testSubQueries_whenMultipleEqualSubQueries_thenSuccessful() { .rewrite(mockQueryShardContext) .toQuery(mockQueryShardContext), QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext) - ) + ), + new HybridQueryContext(10) ); IndexSearcher searcher = newSearcher(reader); Weight weight = hybridQueryWithTerm.createWeight(searcher, ScoreMode.TOP_SCORES, 1.0f); @@ -164,7 +166,8 @@ public void testExplain_whenCallExplain_thenSuccessful() { IndexReader reader = DirectoryReader.open(w); HybridQuery hybridQueryWithTerm = new HybridQuery( - List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext)) + List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext)), + new HybridQueryContext(10) ); IndexSearcher searcher = newSearcher(reader); Weight weight = searcher.createWeight(hybridQueryWithTerm, ScoreMode.COMPLETE, 1.0f); diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessorTests.java index f44e762f0..acbc2148c 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessorTests.java @@ -15,11 +15,13 @@ import org.opensearch.action.OriginalIndices; import org.opensearch.common.lucene.search.TopDocsAndMaxScore; import org.opensearch.core.index.shard.ShardId; +import org.opensearch.index.mapper.MapperService; import org.opensearch.index.mapper.TextFieldMapper; import org.opensearch.index.query.QueryBuilders; import org.opensearch.index.query.QueryShardContext; import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.neuralsearch.query.HybridQuery; +import org.opensearch.neuralsearch.query.HybridQueryContext; import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; import org.opensearch.search.DocValueFormat; import org.opensearch.search.SearchShardTarget; @@ -69,9 +71,12 @@ public void testCollectorManager_whenHybridQueryAndNotConcurrentSearch_thenSucce TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext))); + HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build(); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), hybridQueryContext); when(searchContext.query()).thenReturn(hybridQuery); + MapperService mapperService = mock(MapperService.class); + when(searchContext.mapperService()).thenReturn(mapperService); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); IndexReader indexReader = mock(IndexReader.class); when(indexSearcher.getIndexReader()).thenReturn(indexReader); @@ -129,9 +134,12 @@ public void testCollectorManager_whenHybridQueryAndConcurrentSearch_thenSuccessf TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext))); + HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build(); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), hybridQueryContext); when(searchContext.query()).thenReturn(hybridQuery); + MapperService mapperService = createMapperService(); + when(searchContext.mapperService()).thenReturn(mapperService); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); IndexReader indexReader = mock(IndexReader.class); when(indexSearcher.getIndexReader()).thenReturn(indexReader); diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java index 24ebebe5b..e0d95f24e 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java @@ -36,6 +36,7 @@ import org.apache.lucene.tests.analysis.MockAnalyzer; import org.opensearch.common.lucene.search.FilteredCollector; import org.opensearch.common.lucene.search.TopDocsAndMaxScore; +import org.opensearch.index.mapper.MapperService; import org.opensearch.index.mapper.TextFieldMapper; import org.opensearch.index.query.BoostingQueryBuilder; import org.opensearch.index.query.QueryBuilders; @@ -44,6 +45,7 @@ import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.ParsedQuery; import org.opensearch.neuralsearch.query.HybridQuery; +import org.opensearch.neuralsearch.query.HybridQueryContext; import org.opensearch.neuralsearch.query.HybridQueryWeight; import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; import org.opensearch.neuralsearch.search.collector.HybridTopScoreDocCollector; @@ -52,12 +54,14 @@ import org.opensearch.neuralsearch.search.query.exception.HybridSearchRescoreQueryException; import org.opensearch.search.DocValueFormat; import org.opensearch.search.internal.ContextIndexSearcher; +import org.opensearch.search.internal.ScrollContext; import org.opensearch.search.internal.SearchContext; import org.opensearch.search.query.QuerySearchResult; import org.opensearch.search.query.ReduceableSearchResult; import java.util.HashMap; import java.util.List; +import java.util.Locale; import java.util.Map; import static org.mockito.ArgumentMatchers.any; @@ -88,11 +92,14 @@ public class HybridCollectorManagerTests extends OpenSearchQueryTestCase { public void testNewCollector_whenNotConcurrentSearch_thenSuccessful() { SearchContext searchContext = mock(SearchContext.class); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + MapperService mapperService = createMapperService(); TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext))); + HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build(); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), hybridQueryContext); + when(searchContext.mapperService()).thenReturn(mapperService); when(searchContext.query()).thenReturn(hybridQuery); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); IndexReader indexReader = mock(IndexReader.class); @@ -122,8 +129,11 @@ public void testNewCollector_whenConcurrentSearch_thenSuccessful() { TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext))); + HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build(); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), hybridQueryContext); + MapperService mapperService = createMapperService(); + when(searchContext.mapperService()).thenReturn(mapperService); when(searchContext.query()).thenReturn(hybridQuery); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); IndexReader indexReader = mock(IndexReader.class); @@ -153,8 +163,11 @@ public void testPostFilter_whenNotConcurrentSearch_thenSuccessful() { TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext))); + HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build(); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), hybridQueryContext); + MapperService mapperService = createMapperService(); + when(searchContext.mapperService()).thenReturn(mapperService); QueryBuilder postFilterQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, "world"); ParsedQuery parsedQuery = new ParsedQuery(postFilterQuery.toQuery(mockQueryShardContext)); searchContext.parsedQuery(parsedQuery); @@ -197,7 +210,11 @@ public void testPostFilter_whenConcurrentSearch_thenSuccessful() { TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext))); + HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build(); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), hybridQueryContext); + + MapperService mapperService = createMapperService(); + when(searchContext.mapperService()).thenReturn(mapperService); QueryBuilder postFilterQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, "world"); Query pfQuery = postFilterQuery.toQuery(mockQueryShardContext); @@ -240,9 +257,14 @@ public void testReduce_whenMatchedDocs_thenSuccessful() { QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build(); + + MapperService mapperService = createMapperService(); + when(searchContext.mapperService()).thenReturn(mapperService); HybridQuery hybridQueryWithTerm = new HybridQuery( - List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1).toQuery(mockQueryShardContext)) + List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1).toQuery(mockQueryShardContext)), + hybridQueryContext ); when(searchContext.query()).thenReturn(hybridQueryWithTerm); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); @@ -343,9 +365,13 @@ public void testNewCollector_whenNotConcurrentSearchAndSortingIsApplied_thenSucc TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext))); + HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build(); + + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), hybridQueryContext); when(searchContext.query()).thenReturn(hybridQuery); + MapperService mapperService = createMapperService(); + when(searchContext.mapperService()).thenReturn(mapperService); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); IndexReader indexReader = mock(IndexReader.class); when(indexSearcher.getIndexReader()).thenReturn(indexReader); @@ -380,9 +406,13 @@ public void testNewCollector_whenNotConcurrentSearchAndSortingAndSearchAfterAreA TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext))); + HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build(); + + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), hybridQueryContext); when(searchContext.query()).thenReturn(hybridQuery); + MapperService mapperService = createMapperService(); + when(searchContext.mapperService()).thenReturn(mapperService); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); IndexReader indexReader = mock(IndexReader.class); when(indexSearcher.getIndexReader()).thenReturn(indexReader); @@ -410,8 +440,12 @@ public void testReduce_whenMatchedDocsAndSortingIsApplied_thenSuccessful() { QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build(); - HybridQuery hybridQueryWithMatchAll = new HybridQuery(List.of(QueryBuilders.matchAllQuery().toQuery(mockQueryShardContext))); + HybridQuery hybridQueryWithMatchAll = new HybridQuery( + List.of(QueryBuilders.matchAllQuery().toQuery(mockQueryShardContext)), + hybridQueryContext + ); when(searchContext.query()).thenReturn(hybridQueryWithMatchAll); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); IndexReader indexReader = mock(IndexReader.class); @@ -420,6 +454,9 @@ public void testReduce_whenMatchedDocsAndSortingIsApplied_thenSuccessful() { when(searchContext.searcher()).thenReturn(indexSearcher); when(searchContext.size()).thenReturn(1); + MapperService mapperService = createMapperService(); + when(searchContext.mapperService()).thenReturn(mapperService); + Map, CollectorManager> classCollectorManagerMap = new HashMap<>(); when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap); when(searchContext.shouldUseConcurrentSearch()).thenReturn(false); @@ -503,14 +540,18 @@ public void testReduceWithConcurrentSegmentSearch_whenMultipleCollectorsMatchedD QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build(); HybridQuery hybridQueryWithTerm = new HybridQuery( List.of( QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1).toQuery(mockQueryShardContext), QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY2).toQuery(mockQueryShardContext) - ) + ), + hybridQueryContext ); when(searchContext.query()).thenReturn(hybridQueryWithTerm); + MapperService mapperService = createMapperService(); + when(searchContext.mapperService()).thenReturn(mapperService); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); IndexReader indexReader = mock(IndexReader.class); when(indexReader.numDocs()).thenReturn(3); @@ -593,9 +634,15 @@ public void testReduceWithConcurrentSegmentSearch_whenMultipleCollectorsMatchedD QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build(); - HybridQuery hybridQueryWithTerm = new HybridQuery(List.of(QueryBuilders.matchAllQuery().toQuery(mockQueryShardContext))); + HybridQuery hybridQueryWithTerm = new HybridQuery( + List.of(QueryBuilders.matchAllQuery().toQuery(mockQueryShardContext)), + hybridQueryContext + ); when(searchContext.query()).thenReturn(hybridQueryWithTerm); + MapperService mapperService = createMapperService(); + when(searchContext.mapperService()).thenReturn(mapperService); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); IndexReader indexReader = mock(IndexReader.class); when(indexReader.numDocs()).thenReturn(2); @@ -718,14 +765,18 @@ public void testReduceAndRescore_whenMatchedDocsAndRescoreContextPresent_thenSuc QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build(); HybridQuery hybridQueryWithTerm = new HybridQuery( List.of( QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1).toQuery(mockQueryShardContext), QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY2).toQuery(mockQueryShardContext) - ) + ), + hybridQueryContext ); when(searchContext.query()).thenReturn(hybridQueryWithTerm); + MapperService mapperService = createMapperService(); + when(searchContext.mapperService()).thenReturn(mapperService); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); IndexReader indexReader = mock(IndexReader.class); when(indexReader.numDocs()).thenReturn(3); @@ -835,15 +886,19 @@ public void testRescoreWithConcurrentSegmentSearch_whenMatchedDocsAndRescore_the QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build(); HybridQuery hybridQueryWithTerm = new HybridQuery( List.of( QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1).toQuery(mockQueryShardContext), QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY2).toQuery(mockQueryShardContext), QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY3).toQuery(mockQueryShardContext) - ) + ), + hybridQueryContext ); when(searchContext.query()).thenReturn(hybridQueryWithTerm); + MapperService mapperService = createMapperService(); + when(searchContext.mapperService()).thenReturn(mapperService); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); IndexReader indexReader = mock(IndexReader.class); when(indexReader.numDocs()).thenReturn(2); @@ -979,14 +1034,18 @@ public void testReduceAndRescore_whenRescorerThrowsException_thenFail() { QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build(); HybridQuery hybridQueryWithTerm = new HybridQuery( List.of( QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1).toQuery(mockQueryShardContext), QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY2).toQuery(mockQueryShardContext) - ) + ), + hybridQueryContext ); when(searchContext.query()).thenReturn(hybridQueryWithTerm); + MapperService mapperService = createMapperService(); + when(searchContext.mapperService()).thenReturn(mapperService); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); IndexReader indexReader = mock(IndexReader.class); when(indexReader.numDocs()).thenReturn(3); @@ -1042,4 +1101,73 @@ public void testReduceAndRescore_whenRescorerThrowsException_thenFail() { reader.close(); directory.close(); } + + @SneakyThrows + public void testCreateCollectorManager_whenFromAreEqualToZeroAndPaginationDepthInRange_thenSuccessful() { + SearchContext searchContext = mock(SearchContext.class); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); + HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build(); + + // pagination_depth=10 + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), hybridQueryContext); + + when(searchContext.query()).thenReturn(hybridQuery); + MapperService mapperService = createMapperService(); + when(searchContext.mapperService()).thenReturn(mapperService); + ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); + IndexReader indexReader = mock(IndexReader.class); + when(indexSearcher.getIndexReader()).thenReturn(indexReader); + when(searchContext.searcher()).thenReturn(indexSearcher); + + Map, CollectorManager> classCollectorManagerMap = new HashMap<>(); + when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap); + when(searchContext.shouldUseConcurrentSearch()).thenReturn(false); + + CollectorManager hybridCollectorManager = HybridCollectorManager.createHybridCollectorManager(searchContext); + assertNotNull(hybridCollectorManager); + assertTrue(hybridCollectorManager instanceof HybridCollectorManager.HybridCollectorNonConcurrentManager); + + Collector collector = hybridCollectorManager.newCollector(); + assertNotNull(collector); + assertTrue(collector instanceof HybridTopScoreDocCollector); + + Collector secondCollector = hybridCollectorManager.newCollector(); + assertSame(collector, secondCollector); + } + + @SneakyThrows + public void testScrollWithHybridQuery_thenFail() { + SearchContext searchContext = mock(SearchContext.class); + ScrollContext scrollContext = new ScrollContext(); + when(searchContext.scrollContext()).thenReturn(scrollContext); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); + HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build(); + + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), hybridQueryContext); + + when(searchContext.query()).thenReturn(hybridQuery); + ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); + IndexReader indexReader = mock(IndexReader.class); + when(indexSearcher.getIndexReader()).thenReturn(indexReader); + when(searchContext.searcher()).thenReturn(indexSearcher); + + Map, CollectorManager> classCollectorManagerMap = new HashMap<>(); + when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap); + when(searchContext.shouldUseConcurrentSearch()).thenReturn(false); + + IllegalArgumentException illegalArgumentException = assertThrows( + IllegalArgumentException.class, + () -> HybridCollectorManager.createHybridCollectorManager(searchContext) + ); + assertEquals( + String.format(Locale.ROOT, "Scroll operation is not supported in hybrid query"), + illegalArgumentException.getMessage() + ); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java index a8cad5ec7..2aafa2ece 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java @@ -138,6 +138,10 @@ public void testQueryType_whenQueryIsHybrid_thenCallHybridDocCollector() { when(searchContext.indexShard()).thenReturn(indexShard); when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); when(searchContext.mapperService()).thenReturn(mapperService); + IndexMetadata indexMetadata = getIndexMetadata(); + Settings settings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, Integer.toString(1)).build(); + IndexSettings indexSettings = new IndexSettings(indexMetadata, settings); + when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); LinkedList collectors = new LinkedList<>(); boolean hasFilterCollector = randomBoolean(); @@ -150,6 +154,7 @@ public void testQueryType_whenQueryIsHybrid_thenCallHybridDocCollector() { TermQueryBuilder termSubQuery2 = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT2); queryBuilder.add(termSubQuery1); queryBuilder.add(termSubQuery2); + queryBuilder.paginationDepth(10); Query query = queryBuilder.toQuery(mockQueryShardContext); when(searchContext.query()).thenReturn(query); @@ -287,6 +292,10 @@ public void testQueryResult_whenOneSubQueryWithHits_thenHybridResultsAreSet() { when(searchContext.queryResult()).thenReturn(querySearchResult); when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); when(searchContext.mapperService()).thenReturn(mapperService); + IndexMetadata indexMetadata = getIndexMetadata(); + Settings settings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, Integer.toString(1)).build(); + IndexSettings indexSettings = new IndexSettings(indexMetadata, settings); + when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); LinkedList collectors = new LinkedList<>(); boolean hasFilterCollector = randomBoolean(); @@ -296,6 +305,7 @@ public void testQueryResult_whenOneSubQueryWithHits_thenHybridResultsAreSet() { TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1); queryBuilder.add(termSubQuery); + queryBuilder.paginationDepth(10); Query query = queryBuilder.toQuery(mockQueryShardContext); when(searchContext.query()).thenReturn(query); @@ -372,6 +382,10 @@ public void testQueryResult_whenMultipleTextSubQueriesWithSomeHits_thenHybridRes when(searchContext.indexShard()).thenReturn(indexShard); when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); when(searchContext.mapperService()).thenReturn(mapperService); + IndexMetadata indexMetadata = getIndexMetadata(); + Settings settings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, Integer.toString(1)).build(); + IndexSettings indexSettings = new IndexSettings(indexMetadata, settings); + when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); LinkedList collectors = new LinkedList<>(); boolean hasFilterCollector = randomBoolean(); @@ -382,6 +396,7 @@ public void testQueryResult_whenMultipleTextSubQueriesWithSomeHits_thenHybridRes queryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1)); queryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT2)); queryBuilder.add(QueryBuilders.matchAllQuery()); + queryBuilder.paginationDepth(10); Query query = queryBuilder.toQuery(mockQueryShardContext); when(searchContext.query()).thenReturn(query); @@ -473,6 +488,7 @@ public void testWrappedHybridQuery_whenHybridWrappedIntoBool_thenFail() { queryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1)); queryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT2)); + queryBuilder.paginationDepth(10); TermQueryBuilder termQuery3 = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1); BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery().should(queryBuilder).should(termQuery3); @@ -578,6 +594,7 @@ public void testWrappedHybridQuery_whenHybridWrappedIntoBoolAndIncorrectStructur queryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1)); queryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT2)); + queryBuilder.paginationDepth(10); BooleanQuery.Builder builder = new BooleanQuery.Builder(); builder.add(Queries.newNonNestedFilter(), BooleanClause.Occur.FILTER) @@ -694,6 +711,7 @@ public void testWrappedHybridQuery_whenHybridWrappedIntoBoolBecauseOfNested_then queryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1)); queryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT2)); + queryBuilder.paginationDepth(10); BooleanQuery.Builder builder = new BooleanQuery.Builder(); builder.add(queryBuilder.toQuery(mockQueryShardContext), BooleanClause.Occur.MUST) @@ -868,6 +886,10 @@ public void testAggregations_whenMetricAggregation_thenSuccessful() { when(searchContext.indexShard()).thenReturn(indexShard); when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); when(searchContext.mapperService()).thenReturn(mapperService); + IndexMetadata indexMetadata = getIndexMetadata(); + Settings settings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, Integer.toString(1)).build(); + IndexSettings indexSettings = new IndexSettings(indexMetadata, settings); + when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); LinkedList collectors = new LinkedList<>(); @@ -881,6 +903,7 @@ public void testAggregations_whenMetricAggregation_thenSuccessful() { TermQueryBuilder termSubQuery2 = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT2); queryBuilder.add(termSubQuery1); queryBuilder.add(termSubQuery2); + queryBuilder.paginationDepth(10); Query query = queryBuilder.toQuery(mockQueryShardContext); when(searchContext.query()).thenReturn(query); @@ -965,6 +988,10 @@ public void testAliasWithFilter_whenHybridWrappedIntoBoolBecauseOfIndexAlias_the when(searchContext.indexShard()).thenReturn(indexShard); when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); when(searchContext.mapperService()).thenReturn(mapperService); + IndexMetadata indexMetadata = getIndexMetadata(); + Settings settings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, Integer.toString(1)).build(); + IndexSettings indexSettings = new IndexSettings(indexMetadata, settings); + when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); LinkedList collectors = new LinkedList<>(); boolean hasFilterCollector = randomBoolean(); @@ -974,6 +1001,7 @@ public void testAliasWithFilter_whenHybridWrappedIntoBoolBecauseOfIndexAlias_the queryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1)); queryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT2)); + queryBuilder.paginationDepth(10); Query termFilter = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1).toQuery(mockQueryShardContext); BooleanQuery.Builder builder = new BooleanQuery.Builder(); diff --git a/src/test/java/org/opensearch/neuralsearch/util/HybridQueryUtilTests.java b/src/test/java/org/opensearch/neuralsearch/util/HybridQueryUtilTests.java index be9dbc2cc..ab882b388 100644 --- a/src/test/java/org/opensearch/neuralsearch/util/HybridQueryUtilTests.java +++ b/src/test/java/org/opensearch/neuralsearch/util/HybridQueryUtilTests.java @@ -6,20 +6,29 @@ import lombok.SneakyThrows; import org.apache.lucene.search.Query; +import org.opensearch.Version; +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.common.UUIDs; +import org.opensearch.common.settings.Settings; +import org.opensearch.index.IndexSettings; import org.opensearch.index.mapper.MapperService; import org.opensearch.index.mapper.TextFieldMapper; import org.opensearch.index.query.QueryBuilders; import org.opensearch.index.query.QueryShardContext; +import org.opensearch.index.remote.RemoteStoreEnums; import org.opensearch.neuralsearch.query.HybridQuery; import org.opensearch.neuralsearch.query.HybridQueryBuilder; +import org.opensearch.neuralsearch.query.HybridQueryContext; import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; import org.opensearch.search.internal.SearchContext; import java.util.List; +import java.util.Map; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import static org.opensearch.index.remote.RemoteStoreEnums.PathType.HASHED_PREFIX; public class HybridQueryUtilTests extends OpenSearchQueryTestCase { @@ -34,6 +43,7 @@ public void testIsHybridQueryCheck_whenQueryIsHybridQueryInstance_thenSuccess() QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build(); HybridQuery query = new HybridQuery( List.of( @@ -45,7 +55,8 @@ public void testIsHybridQueryCheck_whenQueryIsHybridQueryInstance_thenSuccess() .rewrite(mockQueryShardContext) .toQuery(mockQueryShardContext), QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext) - ) + ), + hybridQueryContext ); SearchContext searchContext = mock(SearchContext.class); @@ -58,13 +69,17 @@ public void testIsHybridQueryCheck_whenHybridWrappedIntoBoolAndNoNested_thenSucc MapperService mapperService = createMapperService(); TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) mapperService.fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + IndexMetadata indexMetadata = getIndexMetadata(); + Settings settings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, Integer.toString(1)).build(); + IndexSettings indexSettings = new IndexSettings(indexMetadata, settings); + when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); hybridQueryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT)); hybridQueryBuilder.add( QueryBuilders.rangeQuery(RANGE_FIELD).from(FROM_TEXT).to(TO_TEXT).rewrite(mockQueryShardContext).rewrite(mockQueryShardContext) ); - + hybridQueryBuilder.paginationDepth(10); Query booleanQuery = QueryBuilders.boolQuery() .should(hybridQueryBuilder) .should(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT)) @@ -97,4 +112,25 @@ public void testIsHybridQueryCheck_whenNoHybridQuery_thenSuccess() { assertFalse(HybridQueryUtil.isHybridQuery(booleanQuery, searchContext)); } + + private static IndexMetadata getIndexMetadata() { + Map remoteCustomData = Map.of( + RemoteStoreEnums.PathType.NAME, + HASHED_PREFIX.name(), + RemoteStoreEnums.PathHashAlgorithm.NAME, + RemoteStoreEnums.PathHashAlgorithm.FNV_1A_BASE64.name(), + IndexMetadata.TRANSLOG_METADATA_KEY, + "false" + ); + Settings idxSettings = Settings.builder() + .put(IndexMetadata.SETTING_VERSION_CREATED, Version.CURRENT) + .put(IndexMetadata.SETTING_INDEX_UUID, UUIDs.randomBase64UUID()) + .build(); + IndexMetadata indexMetadata = new IndexMetadata.Builder("test").settings(idxSettings) + .numberOfShards(1) + .numberOfReplicas(0) + .putCustom(IndexMetadata.REMOTE_STORE_CUSTOM_KEY, remoteCustomData) + .build(); + return indexMetadata; + } } diff --git a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java index f4d4a3c40..509527aeb 100644 --- a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java +++ b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java @@ -600,13 +600,11 @@ protected Map search( if (requestParams != null && !requestParams.isEmpty()) { requestParams.forEach(request::addParameter); } - logger.info("Sorting request " + builder.toString()); request.setJsonEntity(builder.toString()); Response response = client().performRequest(request); assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); String responseBody = EntityUtils.toString(response.getEntity()); - logger.info("Response " + responseBody); return XContentHelper.convertToMap(XContentType.JSON.xContent(), responseBody, false); }