diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java index 8846f6977..01d271cdd 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java @@ -12,7 +12,6 @@ import java.util.List; import java.util.Objects; -import org.apache.lucene.index.IndexReader; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.IndexSearcher; @@ -77,12 +76,12 @@ public String toString(String field) { /** * Re-writes queries into primitive queries. Callers are expected to call rewrite multiple times if necessary, * until the rewritten query is the same as the original query. - * @param reader + * @param indexSearcher * @return * @throws IOException */ @Override - public Query rewrite(IndexReader reader) throws IOException { + public Query rewrite(IndexSearcher indexSearcher) throws IOException { if (subQueries.isEmpty()) { return new MatchNoDocsQuery("empty HybridQuery"); } @@ -90,7 +89,7 @@ public Query rewrite(IndexReader reader) throws IOException { boolean actuallyRewritten = false; List rewrittenSubQueries = new ArrayList<>(); for (Query subQuery : subQueries) { - Query rewrittenSub = subQuery.rewrite(reader); + Query rewrittenSub = subQuery.rewrite(indexSearcher); /* we keep rewrite sub-query unless it's not equal to itself, it may take multiple levels of recursive calls queries need to be rewritten from high-level clauses into lower-level clauses because low-level clauses perform better. For hybrid query we need to track progress of re-write for all sub-queries */ @@ -102,7 +101,7 @@ public Query rewrite(IndexReader reader) throws IOException { return new HybridQuery(rewrittenSubQueries); } - return super.rewrite(reader); + return super.rewrite(indexSearcher); } /** diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java index aa4242c2e..46c087894 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java @@ -27,7 +27,6 @@ import org.opensearch.index.query.QueryRewriteContext; import org.opensearch.index.query.QueryShardContext; import org.opensearch.index.query.QueryShardException; -import org.opensearch.index.query.Rewriteable; import org.opensearch.index.query.QueryBuilderVisitor; import lombok.Getter; @@ -290,7 +289,7 @@ private void writeQueries(StreamOutput out, List queries private Collection toQueries(Collection queryBuilders, QueryShardContext context) throws QueryShardException { List queries = queryBuilders.stream().map(qb -> { try { - return Rewriteable.rewrite(qb, context).toQuery(context); + return qb.rewrite(context).toQuery(context); } catch (IOException e) { throw new RuntimeException(e); } diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java index 5abfd0b5e..179acb64f 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java @@ -6,6 +6,7 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -15,13 +16,17 @@ import org.apache.lucene.search.DisiPriorityQueue; import org.apache.lucene.search.DisiWrapper; -import org.apache.lucene.search.DisjunctionDISIApproximation; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.Query; +import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.TwoPhaseIterator; import org.apache.lucene.search.Weight; import lombok.Getter; +import org.apache.lucene.util.PriorityQueue; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; /** * Class abstracts functionality of Scorer for hybrid query. When iterating over documents in increasing @@ -40,12 +45,60 @@ public final class HybridQueryScorer extends Scorer { private final Map> queryToIndex; + private final DocIdSetIterator approximation; + HybridScorePropagator disjunctionBlockPropagator; + private final TwoPhase twoPhase; + public HybridQueryScorer(Weight weight, List subScorers) throws IOException { + this(weight, subScorers, ScoreMode.TOP_SCORES); + } + + public HybridQueryScorer(Weight weight, List subScorers, ScoreMode scoreMode) throws IOException { super(weight); + // max this.subScorers = Collections.unmodifiableList(subScorers); + // custom subScores = new float[subScorers.size()]; this.queryToIndex = mapQueryToIndex(); + // base this.subScorersPQ = initializeSubScorersPQ(); + // base + boolean needsScores = scoreMode != ScoreMode.COMPLETE_NO_SCORES; + this.approximation = new HybridDisjunctionDISIApproximation(this.subScorersPQ); + // max + if (scoreMode == ScoreMode.TOP_SCORES) { + this.disjunctionBlockPropagator = new HybridScorePropagator(subScorers); + } else { + this.disjunctionBlockPropagator = null; + } + // base + boolean hasApproximation = false; + float sumMatchCost = 0; + long sumApproxCost = 0; + // Compute matchCost as the average over the matchCost of the subScorers. + // This is weighted by the cost, which is an expected number of matching documents. + for (DisiWrapper w : subScorersPQ) { + long costWeight = (w.cost <= 1) ? 1 : w.cost; + sumApproxCost += costWeight; + if (w.twoPhaseView != null) { + hasApproximation = true; + sumMatchCost += w.matchCost * costWeight; + } + } + if (!hasApproximation) { // no sub scorer supports approximations + twoPhase = null; + } else { + final float matchCost = sumMatchCost / sumApproxCost; + twoPhase = new TwoPhase(approximation, matchCost, subScorersPQ, needsScores); + } + } + + @Override + public int advanceShallow(int target) throws IOException { + if (disjunctionBlockPropagator != null) { + return disjunctionBlockPropagator.advanceShallow(target); + } + return super.advanceShallow(target); } /** @@ -55,7 +108,7 @@ public HybridQueryScorer(Weight weight, List subScorers) throws IOExcept */ @Override public float score() throws IOException { - DisiWrapper topList = subScorersPQ.topList(); + /*DisiWrapper topList = subScorersPQ.topList(); float totalScore = 0.0f; for (DisiWrapper disiWrapper = topList; disiWrapper != null; disiWrapper = disiWrapper.next) { // check if this doc has match in the subQuery. If not, add score as 0.0 and continue @@ -64,16 +117,58 @@ public float score() throws IOException { } totalScore += disiWrapper.scorer.score(); } + return totalScore;*/ + return score(getSubMatches()); + } + + private float score(DisiWrapper topList) throws IOException { + /*float scoreMax = 0; + double otherScoreSum = 0; + for (DisiWrapper w = topList; w != null; w = w.next) { + float subScore = w.scorer.score(); + if (subScore >= scoreMax) { + otherScoreSum += scoreMax; + scoreMax = subScore; + } else { + otherScoreSum += subScore; + } + } + return (float) (scoreMax + otherScoreSum);*/ + float totalScore = 0.0f; + for (DisiWrapper disiWrapper = topList; disiWrapper != null; disiWrapper = disiWrapper.next) { + // check if this doc has match in the subQuery. If not, add score as 0.0 and continue + if (disiWrapper.scorer.docID() == NO_MORE_DOCS) { + continue; + } + totalScore += disiWrapper.scorer.score(); + } return totalScore; } + DisiWrapper getSubMatches() throws IOException { + if (twoPhase == null) { + return subScorersPQ.topList(); + } else { + return twoPhase.getSubMatches(); + } + } + /** * Return a DocIdSetIterator over matching documents. * @return DocIdSetIterator object */ @Override public DocIdSetIterator iterator() { - return new DisjunctionDISIApproximation(this.subScorersPQ); + if (twoPhase != null) { + return TwoPhaseIterator.asDocIdSetIterator(twoPhase); + } else { + return approximation; + } + } + + @Override + public TwoPhaseIterator twoPhaseIterator() { + return twoPhase; } /** @@ -93,12 +188,28 @@ public float getMaxScore(int upTo) throws IOException { }).max(Float::compare).orElse(0.0f); } + @Override + public void setMinCompetitiveScore(float minScore) throws IOException { + if (disjunctionBlockPropagator != null) { + disjunctionBlockPropagator.setMinCompetitiveScore(minScore); + } + + for (Scorer scorer : subScorers) { + if (Objects.nonNull(scorer)) { + scorer.setMinCompetitiveScore(minScore); + } + } + } + /** * Returns the doc ID that is currently being scored. * @return document id */ @Override public int docID() { + if (subScorersPQ.size() == 0) { + return NO_MORE_DOCS; + } return subScorersPQ.top().doc; } @@ -169,4 +280,151 @@ private DisiPriorityQueue initializeSubScorersPQ() { } return subScorersPQ; } + + @Override + public Collection getChildren() throws IOException { + ArrayList children = new ArrayList<>(); + for (DisiWrapper scorer = getSubMatches(); scorer != null; scorer = scorer.next) { + children.add(new ChildScorable(scorer.scorer, "SHOULD")); + } + return children; + } + + static class TwoPhase extends TwoPhaseIterator { + private final float matchCost; + // list of verified matches on the current doc + DisiWrapper verifiedMatches; + // priority queue of approximations on the current doc that have not been verified yet + final PriorityQueue unverifiedMatches; + DisiPriorityQueue subScorers; + boolean needsScores; + + private TwoPhase(DocIdSetIterator approximation, float matchCost, DisiPriorityQueue subScorers, boolean needsScores) { + super(approximation); + this.matchCost = matchCost; + this.subScorers = subScorers; + unverifiedMatches = new PriorityQueue<>(subScorers.size()) { + @Override + protected boolean lessThan(DisiWrapper a, DisiWrapper b) { + return a.matchCost < b.matchCost; + } + }; + this.needsScores = needsScores; + } + + DisiWrapper getSubMatches() throws IOException { + // iteration order does not matter + for (DisiWrapper w : unverifiedMatches) { + if (w.twoPhaseView.matches()) { + w.next = verifiedMatches; + verifiedMatches = w; + } + } + unverifiedMatches.clear(); + return verifiedMatches; + } + + @Override + public boolean matches() throws IOException { + verifiedMatches = null; + unverifiedMatches.clear(); + + for (DisiWrapper w = subScorers.topList(); w != null;) { + DisiWrapper next = w.next; + + if (w.twoPhaseView == null) { + // implicitly verified, move it to verifiedMatches + w.next = verifiedMatches; + verifiedMatches = w; + + if (!needsScores) { + // we can stop here + return true; + } + } else { + unverifiedMatches.add(w); + } + w = next; + } + + if (verifiedMatches != null) { + return true; + } + + // verify subs that have an two-phase iterator + // least-costly ones first + while (unverifiedMatches.size() > 0) { + DisiWrapper w = unverifiedMatches.pop(); + if (w.twoPhaseView.matches()) { + w.next = null; + verifiedMatches = w; + return true; + } + } + + return false; + } + + @Override + public float matchCost() { + return matchCost; + } + } + + static class HybridDisjunctionDISIApproximation extends DocIdSetIterator { + + final DisiPriorityQueue subIterators; + final long cost; + + public HybridDisjunctionDISIApproximation(DisiPriorityQueue subIterators) { + this.subIterators = subIterators; + long cost = 0; + for (DisiWrapper w : subIterators) { + cost += w.cost; + } + this.cost = cost; + } + + @Override + public long cost() { + return cost; + } + + @Override + public int docID() { + if (subIterators.size() == 0) { + return NO_MORE_DOCS; + } + return subIterators.top().doc; + } + + @Override + public int nextDoc() throws IOException { + if (subIterators.size() == 0) { + return NO_MORE_DOCS; + } + DisiWrapper top = subIterators.top(); + final int doc = top.doc; + do { + top.doc = top.approximation.nextDoc(); + top = subIterators.updateTop(); + } while (top.doc == doc); + + return top.doc; + } + + @Override + public int advance(int target) throws IOException { + if (subIterators.size() == 0) { + return NO_MORE_DOCS; + } + DisiWrapper top = subIterators.top(); + do { + top.doc = top.approximation.advance(target); + top = subIterators.updateTop(); + } while (top.doc < target); + + return top.doc; + } + } } diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryWeight.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryWeight.java index 69ee5015f..144167a9a 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryWeight.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryWeight.java @@ -5,6 +5,7 @@ package org.opensearch.neuralsearch.query; import java.io.IOException; +import java.util.ArrayList; import java.util.List; import java.util.Objects; import java.util.stream.Collectors; @@ -16,6 +17,7 @@ import org.apache.lucene.search.MatchesUtils; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.ScorerSupplier; import org.apache.lucene.search.Weight; /** @@ -23,18 +25,18 @@ */ public final class HybridQueryWeight extends Weight { - private final HybridQuery queries; // The Weights for our subqueries, in 1-1 correspondence private final List weights; private final ScoreMode scoreMode; + static final int BOOLEAN_REWRITE_TERM_COUNT_THRESHOLD = 16; + /** * Construct the Weight for this Query searched by searcher. Recursively construct subquery weights. */ public HybridQueryWeight(HybridQuery hybridQuery, IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { super(hybridQuery); - this.queries = hybridQuery; weights = hybridQuery.getSubQueries().stream().map(q -> { try { return searcher.createWeight(q, scoreMode, boost); @@ -65,6 +67,65 @@ public Matches matches(LeafReaderContext context, int doc) throws IOException { return MatchesUtils.fromSubMatches(mis); } + @Override + public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException { + // critical section + // return super.scorerSupplier(context); + List scorerSuppliers = new ArrayList<>(); + for (Weight w : weights) { + ScorerSupplier ss = w.scorerSupplier(context); + scorerSuppliers.add(ss); + } + + if (scorerSuppliers.isEmpty()) { + return null; + } else { + final Weight thisWeight = this; + return new ScorerSupplier() { + private long cost = -1; + + @Override + public Scorer get(long leadCost) throws IOException { + List tScorers = new ArrayList<>(); + for (int i = 0; i < scorerSuppliers.size(); i++) { + ScorerSupplier ss = scorerSuppliers.get(i); + if (Objects.nonNull(ss)) { + tScorers.add(ss.get(leadCost)); + } else { + tScorers.add(null); + } + } + return new HybridQueryScorer(thisWeight, tScorers, scoreMode); + } + + @Override + public long cost() { + if (cost == -1) { + long cost = 0; + for (ScorerSupplier ss : scorerSuppliers) { + if (Objects.nonNull(ss)) { + cost += ss.cost(); + } + } + this.cost = cost; + } + return cost; + } + + @Override + public void setTopLevelScoringClause() throws IOException { + for (ScorerSupplier ss : scorerSuppliers) { + // sub scorers need to be able to skip too as calls to setMinCompetitiveScore get + // propagated + if (Objects.nonNull(ss)) { + ss.setTopLevelScoringClause(); + } + } + } + }; + } + } + /** * Create the scorer used to score our associated Query * @@ -75,7 +136,7 @@ public Matches matches(LeafReaderContext context, int doc) throws IOException { */ @Override public Scorer scorer(LeafReaderContext context) throws IOException { - List scorers = weights.stream().map(w -> { + /*List scorers = weights.stream().map(w -> { try { return w.scorer(context); } catch (IOException e) { @@ -87,7 +148,14 @@ public Scorer scorer(LeafReaderContext context) throws IOException { if (scorers.stream().allMatch(Objects::isNull)) { return null; } - return new HybridQueryScorer(this, scorers); + return new HybridQueryScorer(this, scorers);*/ + // critical section + ScorerSupplier supplier = scorerSupplier(context); + if (supplier == null) { + return null; + } + supplier.setTopLevelScoringClause(); + return supplier.get(Long.MAX_VALUE); } /** @@ -98,6 +166,11 @@ public Scorer scorer(LeafReaderContext context) throws IOException { */ @Override public boolean isCacheable(LeafReaderContext ctx) { + if (weights.size() > BOOLEAN_REWRITE_TERM_COUNT_THRESHOLD) { + // Disallow caching large queries to not encourage users + // to build large queries + return false; + } return weights.stream().allMatch(w -> w.isCacheable(ctx)); } diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridScorePropagator.java b/src/main/java/org/opensearch/neuralsearch/query/HybridScorePropagator.java new file mode 100644 index 000000000..92e1bbf7e --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridScorePropagator.java @@ -0,0 +1,91 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.query; + +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Scorer; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collection; +import java.util.Comparator; +import java.util.Objects; + +public class HybridScorePropagator { + + private static final Comparator MAX_SCORE_COMPARATOR = Comparator.comparing((Scorer s) -> { + try { + return s.getMaxScore(DocIdSetIterator.NO_MORE_DOCS); + } catch (IOException e) { + throw new RuntimeException(e); + } + }).thenComparing(s -> s.iterator().cost()); + + private final Scorer[] scorers; + private final float[] maxScores; + private int leadIndex = 0; + + HybridScorePropagator(Collection scorers) throws IOException { + this.scorers = scorers.stream().filter(Objects::nonNull).toArray(Scorer[]::new); + for (Scorer scorer : this.scorers) { + scorer.advanceShallow(0); + } + Arrays.sort(this.scorers, MAX_SCORE_COMPARATOR); + + maxScores = new float[this.scorers.length]; + for (int i = 0; i < this.scorers.length; ++i) { + maxScores[i] = this.scorers[i].getMaxScore(DocIdSetIterator.NO_MORE_DOCS); + } + } + + /** See {@link Scorer#advanceShallow(int)}. */ + int advanceShallow(int target) throws IOException { + // For scorers that are below the lead index, just propagate. + for (int i = 0; i < leadIndex; ++i) { + Scorer s = scorers[i]; + if (s.docID() < target) { + s.advanceShallow(target); + } + } + + // For scorers above the lead index, we take the minimum + // boundary. + Scorer leadScorer = scorers[leadIndex]; + int upTo = leadScorer.advanceShallow(Math.max(leadScorer.docID(), target)); + + for (int i = leadIndex + 1; i < scorers.length; ++i) { + Scorer scorer = scorers[i]; + if (scorer.docID() <= target) { + upTo = Math.min(scorer.advanceShallow(target), upTo); + } + } + + // If the maximum scoring clauses are beyond `target`, then we use their + // docID as a boundary. It helps not consider them when computing the + // maximum score and get a lower score upper bound. + for (int i = scorers.length - 1; i > leadIndex; --i) { + Scorer scorer = scorers[i]; + if (scorer.docID() > target) { + upTo = Math.min(upTo, scorer.docID() - 1); + } else { + break; + } + } + + return upTo; + } + + /** + * Set the minimum competitive score to filter out clauses that score less than this threshold. + * + * @see Scorer#setMinCompetitiveScore + */ + void setMinCompetitiveScore(float minScore) throws IOException { + // Update the lead index if necessary + while (leadIndex < maxScores.length - 1 && minScore > maxScores[leadIndex]) { + leadIndex++; + } + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollector.java b/src/main/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollector.java index 8b7a12d29..9190bfeac 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollector.java +++ b/src/main/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollector.java @@ -19,7 +19,6 @@ import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.TopDocs; -import org.apache.lucene.search.TopScoreDocCollector; import org.apache.lucene.search.TotalHits; import org.apache.lucene.util.PriorityQueue; import org.opensearch.neuralsearch.query.HybridQueryScorer; @@ -47,16 +46,35 @@ public HybridTopScoreDocCollector(int numHits, HitsThresholdChecker hitsThreshol } @Override - public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException { + public LeafCollector getLeafCollector(LeafReaderContext context) { docBase = context.docBase; - return new TopScoreDocCollector.ScorerLeafCollector() { + return new LeafCollector() { HybridQueryScorer compoundQueryScorer; @Override public void setScorer(Scorable scorer) throws IOException { - super.setScorer(scorer); - compoundQueryScorer = (HybridQueryScorer) scorer; + if (scorer instanceof HybridQueryScorer) { + compoundQueryScorer = (HybridQueryScorer) scorer; + } else { + compoundQueryScorer = getHybridQueryScorer(scorer); + } + } + + private HybridQueryScorer getHybridQueryScorer(final Scorable scorer) throws IOException { + if (scorer == null) { + return null; + } + if (scorer instanceof HybridQueryScorer) { + return (HybridQueryScorer) scorer; + } + for (Scorable.ChildScorable childScorable : scorer.getChildren()) { + HybridQueryScorer hybridQueryScorer = getHybridQueryScorer(childScorable.child); + if (hybridQueryScorer != null) { + return hybridQueryScorer; + } + } + return null; } @Override diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java new file mode 100644 index 000000000..285216d41 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java @@ -0,0 +1,248 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.search.query; + +import lombok.RequiredArgsConstructor; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.search.Collector; +import org.apache.lucene.search.CollectorManager; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.apache.lucene.search.Weight; +import org.opensearch.common.lucene.search.TopDocsAndMaxScore; +import org.opensearch.neuralsearch.search.HitsThresholdChecker; +import org.opensearch.neuralsearch.search.HybridTopScoreDocCollector; +import org.opensearch.search.DocValueFormat; +import org.opensearch.search.internal.ContextIndexSearcher; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.query.MultiCollectorWrapper; +import org.opensearch.search.query.QuerySearchResult; +import org.opensearch.search.query.ReduceableSearchResult; +import org.opensearch.search.sort.SortAndFormats; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashSet; +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createDelimiterElementForHybridSearchResults; +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults; + +@RequiredArgsConstructor +public abstract class HybridCollectorManager implements CollectorManager { + + private final int numHits; + private final HitsThresholdChecker hitsThresholdChecker; + private final boolean isSingleShard; + private final int trackTotalHitsUpTo; + private final SortAndFormats sortAndFormats; + private final Optional filteringWeightOptional; + + public static CollectorManager createHybridCollectorManager(final SearchContext searchContext) throws IOException { + final IndexReader reader = searchContext.searcher().getIndexReader(); + final int totalNumDocs = Math.max(0, reader.numDocs()); + boolean isSingleShard = searchContext.numberOfShards() == 1; + int numDocs = Math.min(searchContext.from() + searchContext.size(), totalNumDocs); + int trackTotalHitsUpTo = searchContext.trackTotalHitsUpTo(); + + Weight filterWeight = null; + // check for post filter + if (Objects.nonNull(searchContext.parsedPostFilter())) { + Query filterQuery = searchContext.parsedPostFilter().query(); + ContextIndexSearcher searcher = searchContext.searcher(); + filterWeight = searcher.createWeight(searcher.rewrite(filterQuery), ScoreMode.COMPLETE_NO_SCORES, 1f); + } + + return searchContext.shouldUseConcurrentSearch() + ? new HybridCollectorConcurrentSearchManager( + numDocs, + new HitsThresholdChecker(Math.max(numDocs, searchContext.trackTotalHitsUpTo())), + isSingleShard, + trackTotalHitsUpTo, + searchContext.sort(), + filterWeight + ) + : new HybridCollectorNonConcurrentManager( + numDocs, + new HitsThresholdChecker(Math.max(numDocs, searchContext.trackTotalHitsUpTo())), + isSingleShard, + trackTotalHitsUpTo, + searchContext.sort(), + filterWeight + ); + } + + @Override + abstract public Collector newCollector(); + + Collector getCollector() { + Collector hybridcollector = new HybridTopScoreDocCollector(numHits, hitsThresholdChecker); + return hybridcollector; + } + + @Override + public ReduceableSearchResult reduce(Collection collectors) { + final List hybridTopScoreDocCollectors = new ArrayList<>(); + + for (final Collector collector : collectors) { + if (collector instanceof MultiCollectorWrapper) { + for (final Collector sub : (((MultiCollectorWrapper) collector).getCollectors())) { + if (sub instanceof HybridTopScoreDocCollector) { + hybridTopScoreDocCollectors.add((HybridTopScoreDocCollector) sub); + } + } + } else if (collector instanceof HybridTopScoreDocCollector) { + hybridTopScoreDocCollectors.add((HybridTopScoreDocCollector) collector); + } + } + + if (!hybridTopScoreDocCollectors.isEmpty()) { + HybridTopScoreDocCollector hybridTopScoreDocCollector = hybridTopScoreDocCollectors.stream() + .findFirst() + .orElseThrow(() -> new IllegalStateException("cannot collect results of hybrid search query")); + List topDocs = hybridTopScoreDocCollector.topDocs(); + TopDocs newTopDocs = getNewTopDocs(getTotalHits(this.trackTotalHitsUpTo, topDocs, isSingleShard), topDocs); + float maxScore = getMaxScore(topDocs); + TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(newTopDocs, maxScore); + return (QuerySearchResult result) -> { result.topDocs(topDocsAndMaxScore, getSortValueFormats(sortAndFormats)); }; + } + throw new IllegalStateException("cannot collect results of hybrid search query, there are no proper score collectors"); + } + + private TopDocs getNewTopDocs(final TotalHits totalHits, final List topDocs) { + ScoreDoc[] scoreDocs = new ScoreDoc[0]; + if (Objects.nonNull(topDocs)) { + // for a single shard case we need to do score processing at coordinator level. + // this is workaround for current core behaviour, for single shard fetch phase is executed + // right after query phase and processors are called after actual fetch is done + // find any valid doc Id, or set it to -1 if there is not a single match + int delimiterDocId = topDocs.stream() + .filter(Objects::nonNull) + .filter(topDoc -> Objects.nonNull(topDoc.scoreDocs)) + .map(topDoc -> topDoc.scoreDocs) + .filter(scoreDoc -> scoreDoc.length > 0) + .map(scoreDoc -> scoreDoc[0].doc) + .findFirst() + .orElse(-1); + if (delimiterDocId == -1) { + return new TopDocs(totalHits, scoreDocs); + } + // format scores using following template: + // doc_id | magic_number_1 + // doc_id | magic_number_2 + // ... + // doc_id | magic_number_2 + // ... + // doc_id | magic_number_2 + // ... + // doc_id | magic_number_1 + List result = new ArrayList<>(); + result.add(createStartStopElementForHybridSearchResults(delimiterDocId)); + for (TopDocs topDoc : topDocs) { + if (Objects.isNull(topDoc) || Objects.isNull(topDoc.scoreDocs)) { + result.add(createDelimiterElementForHybridSearchResults(delimiterDocId)); + continue; + } + result.add(createDelimiterElementForHybridSearchResults(delimiterDocId)); + result.addAll(Arrays.asList(topDoc.scoreDocs)); + } + result.add(createStartStopElementForHybridSearchResults(delimiterDocId)); + scoreDocs = result.stream().map(doc -> new ScoreDoc(doc.doc, doc.score, doc.shardIndex)).toArray(ScoreDoc[]::new); + } + return new TopDocs(totalHits, scoreDocs); + } + + private TotalHits getTotalHits(int trackTotalHitsUpTo, final List topDocs, final boolean isSingleShard) { + final TotalHits.Relation relation = trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_DISABLED + ? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO + : TotalHits.Relation.EQUAL_TO; + if (topDocs == null || topDocs.isEmpty()) { + return new TotalHits(0, relation); + } + + List scoreDocs = topDocs.stream() + .map(topdDoc -> topdDoc.scoreDocs) + .filter(Objects::nonNull) + .collect(Collectors.toList()); + Set uniqueDocIds = new HashSet<>(); + for (ScoreDoc[] scoreDocsArray : scoreDocs) { + uniqueDocIds.addAll(Arrays.stream(scoreDocsArray).map(scoreDoc -> scoreDoc.doc).collect(Collectors.toList())); + } + long maxTotalHits = uniqueDocIds.size(); + + return new TotalHits(maxTotalHits, relation); + } + + private float getMaxScore(final List topDocs) { + if (topDocs.isEmpty()) { + return 0.0f; + } else { + return topDocs.stream() + .map(docs -> docs.scoreDocs.length == 0 ? new ScoreDoc(-1, 0.0f) : docs.scoreDocs[0]) + .map(scoreDoc -> scoreDoc.score) + .max(Float::compare) + .get(); + } + } + + private DocValueFormat[] getSortValueFormats(final SortAndFormats sortAndFormats) { + return sortAndFormats == null ? null : sortAndFormats.formats; + } + + static class HybridCollectorNonConcurrentManager extends HybridCollectorManager { + Collector maxScoreCollector; + + public HybridCollectorNonConcurrentManager( + int numHits, + HitsThresholdChecker hitsThresholdChecker, + boolean isSingleShard, + int trackTotalHitsUpTo, + SortAndFormats sortAndFormats, + Weight filteringWeight + ) { + super(numHits, hitsThresholdChecker, isSingleShard, trackTotalHitsUpTo, sortAndFormats, Optional.ofNullable(filteringWeight)); + } + + @Override + public Collector newCollector() { + if (Objects.isNull(maxScoreCollector)) { + maxScoreCollector = getCollector(); + return maxScoreCollector; + } else { + Collector toReturnCollector = maxScoreCollector; + maxScoreCollector = null; + return toReturnCollector; + } + } + } + + static class HybridCollectorConcurrentSearchManager extends HybridCollectorManager { + + public HybridCollectorConcurrentSearchManager( + int numHits, + HitsThresholdChecker hitsThresholdChecker, + boolean isSingleShard, + int trackTotalHitsUpTo, + SortAndFormats sortAndFormats, + Weight filteringWeight + ) { + super(numHits, hitsThresholdChecker, isSingleShard, trackTotalHitsUpTo, sortAndFormats, Optional.ofNullable(filteringWeight)); + } + + @Override + public Collector newCollector() { + return getCollector(); + } + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java index bf05fdc9d..ac2b75661 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java @@ -4,46 +4,33 @@ */ package org.opensearch.neuralsearch.search.query; -import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createDelimiterElementForHybridSearchResults; -import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults; -import static org.opensearch.search.query.TopDocsCollectorContext.createTopDocsCollectorContext; - import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; +import java.util.Collection; import java.util.LinkedList; import java.util.List; -import java.util.Objects; +import java.util.Map; -import org.apache.lucene.index.IndexReader; +import lombok.AllArgsConstructor; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; +import org.apache.lucene.search.Collector; +import org.apache.lucene.search.CollectorManager; import org.apache.lucene.search.FieldExistsQuery; import org.apache.lucene.search.Query; -import org.apache.lucene.search.ScoreDoc; -import org.apache.lucene.search.TopDocs; -import org.apache.lucene.search.TotalHitCountCollector; -import org.apache.lucene.search.TotalHits; -import org.opensearch.common.lucene.search.TopDocsAndMaxScore; import org.opensearch.common.settings.Settings; import org.opensearch.index.mapper.MapperService; import org.opensearch.index.mapper.SeqNoFieldMapper; import org.opensearch.index.search.NestedHelper; import org.opensearch.neuralsearch.query.HybridQuery; -import org.opensearch.neuralsearch.search.HitsThresholdChecker; -import org.opensearch.neuralsearch.search.HybridTopScoreDocCollector; -import org.opensearch.search.DocValueFormat; +import org.opensearch.search.aggregations.AggregationProcessor; import org.opensearch.search.internal.ContextIndexSearcher; import org.opensearch.search.internal.SearchContext; import org.opensearch.search.query.QueryCollectorContext; import org.opensearch.search.query.QueryPhase; +import org.opensearch.search.query.QueryPhaseExecutionException; import org.opensearch.search.query.QueryPhaseSearcherWrapper; import org.opensearch.search.query.QuerySearchResult; -import org.opensearch.search.query.TopDocsCollectorContext; -import org.opensearch.search.rescore.RescoreContext; -import org.opensearch.search.sort.SortAndFormats; - -import com.google.common.annotations.VisibleForTesting; +import org.opensearch.search.query.ReduceableSearchResult; import lombok.extern.log4j.Log4j2; @@ -66,12 +53,13 @@ public boolean searchWith( final boolean hasFilterCollector, final boolean hasTimeout ) throws IOException { - if (isHybridQuery(query, searchContext)) { + if (!isHybridQuery(query, searchContext)) { + validateQuery(searchContext, query); + return super.searchWith(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout); + } else { Query hybridQuery = extractHybridQuery(searchContext, query); - return searchWithCollector(searchContext, searcher, hybridQuery, collectors, hasFilterCollector, hasTimeout); + return super.searchWith(searchContext, searcher, hybridQuery, collectors, hasFilterCollector, hasTimeout); } - validateQuery(searchContext, query); - return super.searchWith(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout); } private boolean isHybridQuery(final Query query, final SearchContext searchContext) { @@ -103,7 +91,7 @@ private boolean isHybridQuery(final Query query, final SearchContext searchConte // we have already checked if query in instance of Boolean in higher level else if condition return ((BooleanQuery) query).clauses() .stream() - .filter(clause -> clause.getQuery() instanceof HybridQuery == false) + .filter(clause -> !(clause.getQuery() instanceof HybridQuery)) .allMatch(clause -> { return clause.getOccur() == BooleanClause.Occur.FILTER && clause.getQuery() instanceof FieldExistsQuery @@ -180,152 +168,68 @@ private void validateNestedBooleanQuery(final Query query, final int level) { } } - @VisibleForTesting - protected boolean searchWithCollector( - final SearchContext searchContext, - final ContextIndexSearcher searcher, - final Query query, - final LinkedList collectors, - final boolean hasFilterCollector, - final boolean hasTimeout - ) throws IOException { - log.debug("searching with custom doc collector, shard {}", searchContext.shardTarget().getShardId()); + private int getMaxDepthLimit(final SearchContext searchContext) { + Settings indexSettings = searchContext.getQueryShardContext().getIndexSettings().getSettings(); + return MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(indexSettings).intValue(); + } - final TopDocsCollectorContext topDocsFactory = createTopDocsCollectorContext(searchContext, hasFilterCollector); - collectors.addFirst(topDocsFactory); - if (searchContext.size() == 0) { - final TotalHitCountCollector collector = new TotalHitCountCollector(); - searcher.search(query, collector); - return false; - } - final IndexReader reader = searchContext.searcher().getIndexReader(); - int totalNumDocs = Math.max(0, reader.numDocs()); - int numDocs = Math.min(searchContext.from() + searchContext.size(), totalNumDocs); - final boolean shouldRescore = !searchContext.rescore().isEmpty(); - if (shouldRescore) { - for (RescoreContext rescoreContext : searchContext.rescore()) { - numDocs = Math.max(numDocs, rescoreContext.getWindowSize()); - } - } + @Override + public AggregationProcessor aggregationProcessor(SearchContext searchContext) { + AggregationProcessor coreAggProcessor = super.aggregationProcessor(searchContext); + return new HybridAggregationProcessor(coreAggProcessor); + } - final QuerySearchResult queryResult = searchContext.queryResult(); + @AllArgsConstructor + public class HybridAggregationProcessor implements AggregationProcessor { - final HybridTopScoreDocCollector collector = new HybridTopScoreDocCollector( - numDocs, - new HitsThresholdChecker(Math.max(numDocs, searchContext.trackTotalHitsUpTo())) - ); + private final AggregationProcessor delegateAggsProcessor; - searcher.search(query, collector); + @Override + public void preProcess(SearchContext context) { + delegateAggsProcessor.preProcess(context); - if (searchContext.terminateAfter() != SearchContext.DEFAULT_TERMINATE_AFTER && queryResult.terminatedEarly() == null) { - queryResult.terminatedEarly(false); + if (isHybridQuery(context.query(), context)) { + // adding collector manager for hybrid query + CollectorManager collectorManager; + try { + collectorManager = HybridCollectorManager.createHybridCollectorManager(context); + } catch (IOException e) { + throw new RuntimeException(e); + } + Map, CollectorManager> collectorManagersByManagerClass = context + .queryCollectorManagers(); + collectorManagersByManagerClass.put(HybridCollectorManager.class, collectorManager); + } } - setTopDocsInQueryResult(queryResult, collector, searchContext); - - return shouldRescore; - } - - private void setTopDocsInQueryResult( - final QuerySearchResult queryResult, - final HybridTopScoreDocCollector collector, - final SearchContext searchContext - ) { - final List topDocs = collector.topDocs(); - final float maxScore = getMaxScore(topDocs); - final boolean isSingleShard = searchContext.numberOfShards() == 1; - final TopDocs newTopDocs = getNewTopDocs(getTotalHits(searchContext, topDocs, isSingleShard), topDocs); - final TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(newTopDocs, maxScore); - queryResult.topDocs(topDocsAndMaxScore, getSortValueFormats(searchContext.sort())); - } - - private TopDocs getNewTopDocs(final TotalHits totalHits, final List topDocs) { - ScoreDoc[] scoreDocs = new ScoreDoc[0]; - if (Objects.nonNull(topDocs)) { - // for a single shard case we need to do score processing at coordinator level. - // this is workaround for current core behaviour, for single shard fetch phase is executed - // right after query phase and processors are called after actual fetch is done - // find any valid doc Id, or set it to -1 if there is not a single match - int delimiterDocId = topDocs.stream() - .filter(Objects::nonNull) - .filter(topDoc -> Objects.nonNull(topDoc.scoreDocs)) - .map(topDoc -> topDoc.scoreDocs) - .filter(scoreDoc -> scoreDoc.length > 0) - .map(scoreDoc -> scoreDoc[0].doc) - .findFirst() - .orElse(-1); - if (delimiterDocId == -1) { - return new TopDocs(totalHits, scoreDocs); - } - // format scores using following template: - // doc_id | magic_number_1 - // doc_id | magic_number_2 - // ... - // doc_id | magic_number_2 - // ... - // doc_id | magic_number_2 - // ... - // doc_id | magic_number_1 - List result = new ArrayList<>(); - result.add(createStartStopElementForHybridSearchResults(delimiterDocId)); - for (TopDocs topDoc : topDocs) { - if (Objects.isNull(topDoc) || Objects.isNull(topDoc.scoreDocs)) { - result.add(createDelimiterElementForHybridSearchResults(delimiterDocId)); - continue; + @Override + public void postProcess(SearchContext context) { + if (isHybridQuery(context.query(), context)) { + if (!context.shouldUseConcurrentSearch()) { + reduceCollectorResults(context); } - result.add(createDelimiterElementForHybridSearchResults(delimiterDocId)); - result.addAll(Arrays.asList(topDoc.scoreDocs)); + updateQueryResult(context.queryResult(), context); } - result.add(createStartStopElementForHybridSearchResults(delimiterDocId)); - scoreDocs = result.stream().map(doc -> new ScoreDoc(doc.doc, doc.score, doc.shardIndex)).toArray(ScoreDoc[]::new); - } - return new TopDocs(totalHits, scoreDocs); - } - private TotalHits getTotalHits(final SearchContext searchContext, final List topDocs, final boolean isSingleShard) { - int trackTotalHitsUpTo = searchContext.trackTotalHitsUpTo(); - final TotalHits.Relation relation = trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_DISABLED - ? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO - : TotalHits.Relation.EQUAL_TO; - if (topDocs == null || topDocs.isEmpty()) { - return new TotalHits(0, relation); + delegateAggsProcessor.postProcess(context); } - long maxTotalHits = topDocs.get(0).totalHits.value; - int totalSize = 0; - for (TopDocs topDoc : topDocs) { - maxTotalHits = Math.max(maxTotalHits, topDoc.totalHits.value); - if (isSingleShard) { - totalSize += topDoc.totalHits.value + 1; + + private void reduceCollectorResults(SearchContext context) { + CollectorManager collectorManager = context.queryCollectorManagers() + .get(HybridCollectorManager.class); + try { + final Collection collectors = List.of(collectorManager.newCollector()); + collectorManager.reduce(collectors).reduce(context.queryResult()); + } catch (IOException e) { + throw new QueryPhaseExecutionException(context.shardTarget(), "failed to execute hybrid query aggregation processor", e); } } - // add 1 qty per each sub-query and + 2 for start and stop delimiters - totalSize += 2; - if (isSingleShard) { - // for single shard we need to update total size as this is how many docs are fetched in Fetch phase - searchContext.size(totalSize); - } - - return new TotalHits(maxTotalHits, relation); - } - private float getMaxScore(final List topDocs) { - if (topDocs.isEmpty()) { - return 0.0f; - } else { - return topDocs.stream() - .map(docs -> docs.scoreDocs.length == 0 ? new ScoreDoc(-1, 0.0f) : docs.scoreDocs[0]) - .map(scoreDoc -> scoreDoc.score) - .max(Float::compare) - .get(); + private void updateQueryResult(final QuerySearchResult queryResult, final SearchContext searchContext) { + boolean isSingleShard = searchContext.numberOfShards() == 1; + if (isSingleShard) { + searchContext.size(queryResult.queryResult().topDocs().topDocs.scoreDocs.length); + } } } - - private DocValueFormat[] getSortValueFormats(final SortAndFormats sortAndFormats) { - return sortAndFormats == null ? null : sortAndFormats.formats; - } - - private int getMaxDepthLimit(final SearchContext searchContext) { - Settings indexSettings = searchContext.getQueryShardContext().getIndexSettings().getSettings(); - return MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(indexSettings).intValue(); - } } diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java index e609eec05..8ef9dc6ee 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java @@ -159,7 +159,7 @@ public void testQueryType_whenQueryIsHybrid_thenCallHybridDocCollector() { releaseResources(directory, w, reader); - verify(hybridQueryPhaseSearcher, atLeastOnce()).searchWithCollector(any(), any(), any(), any(), anyBoolean(), anyBoolean()); + verify(hybridQueryPhaseSearcher, atLeastOnce()).searchWith(any(), any(), any(), any(), anyBoolean(), anyBoolean()); } @SneakyThrows @@ -226,7 +226,7 @@ public void testQueryType_whenQueryIsNotHybrid_thenDoNotCallHybridDocCollector() releaseResources(directory, w, reader); - verify(hybridQueryPhaseSearcher, never()).searchWithCollector(any(), any(), any(), any(), anyBoolean(), anyBoolean()); + verify(hybridQueryPhaseSearcher, never()).searchWith(any(), any(), any(), any(), anyBoolean(), anyBoolean()); } @SneakyThrows