Skip to content

Commit

Permalink
Initial version, ported from test branch
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Mar 1, 2024
1 parent b97dbe8 commit 57c5c4f
Show file tree
Hide file tree
Showing 9 changed files with 769 additions and 179 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import java.util.List;
import java.util.Objects;

import org.apache.lucene.index.IndexReader;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.IndexSearcher;
Expand Down Expand Up @@ -77,20 +76,20 @@ public String toString(String field) {
/**
* Re-writes queries into primitive queries. Callers are expected to call rewrite multiple times if necessary,
* until the rewritten query is the same as the original query.
* @param reader
* @param indexSearcher
* @return
* @throws IOException
*/
@Override
public Query rewrite(IndexReader reader) throws IOException {
public Query rewrite(IndexSearcher indexSearcher) throws IOException {
if (subQueries.isEmpty()) {
return new MatchNoDocsQuery("empty HybridQuery");
}

boolean actuallyRewritten = false;
List<Query> rewrittenSubQueries = new ArrayList<>();
for (Query subQuery : subQueries) {
Query rewrittenSub = subQuery.rewrite(reader);
Query rewrittenSub = subQuery.rewrite(indexSearcher);
/* we keep rewrite sub-query unless it's not equal to itself, it may take multiple levels of recursive calls
queries need to be rewritten from high-level clauses into lower-level clauses because low-level clauses
perform better. For hybrid query we need to track progress of re-write for all sub-queries */
Expand All @@ -102,7 +101,7 @@ public Query rewrite(IndexReader reader) throws IOException {
return new HybridQuery(rewrittenSubQueries);
}

return super.rewrite(reader);
return super.rewrite(indexSearcher);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import org.opensearch.index.query.QueryRewriteContext;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.index.query.QueryShardException;
import org.opensearch.index.query.Rewriteable;
import org.opensearch.index.query.QueryBuilderVisitor;

import lombok.Getter;
Expand Down Expand Up @@ -290,7 +289,7 @@ private void writeQueries(StreamOutput out, List<? extends QueryBuilder> queries
private Collection<Query> toQueries(Collection<QueryBuilder> queryBuilders, QueryShardContext context) throws QueryShardException {
List<Query> queries = queryBuilders.stream().map(qb -> {
try {
return Rewriteable.rewrite(qb, context).toQuery(context);
return qb.rewrite(context).toQuery(context);
} catch (IOException e) {
throw new RuntimeException(e);
}
Expand Down
264 changes: 261 additions & 3 deletions src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
Expand All @@ -15,13 +16,17 @@

import org.apache.lucene.search.DisiPriorityQueue;
import org.apache.lucene.search.DisiWrapper;
import org.apache.lucene.search.DisjunctionDISIApproximation;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.TwoPhaseIterator;
import org.apache.lucene.search.Weight;

import lombok.Getter;
import org.apache.lucene.util.PriorityQueue;

import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;

/**
* Class abstracts functionality of Scorer for hybrid query. When iterating over documents in increasing
Expand All @@ -40,12 +45,60 @@ public final class HybridQueryScorer extends Scorer {

private final Map<Query, List<Integer>> queryToIndex;

private final DocIdSetIterator approximation;
HybridScorePropagator disjunctionBlockPropagator;
private final TwoPhase twoPhase;

public HybridQueryScorer(Weight weight, List<Scorer> subScorers) throws IOException {
this(weight, subScorers, ScoreMode.TOP_SCORES);
}

public HybridQueryScorer(Weight weight, List<Scorer> subScorers, ScoreMode scoreMode) throws IOException {
super(weight);
// max
this.subScorers = Collections.unmodifiableList(subScorers);
// custom
subScores = new float[subScorers.size()];
this.queryToIndex = mapQueryToIndex();
// base
this.subScorersPQ = initializeSubScorersPQ();
// base
boolean needsScores = scoreMode != ScoreMode.COMPLETE_NO_SCORES;
this.approximation = new HybridDisjunctionDISIApproximation(this.subScorersPQ);
// max
if (scoreMode == ScoreMode.TOP_SCORES) {
this.disjunctionBlockPropagator = new HybridScorePropagator(subScorers);
} else {
this.disjunctionBlockPropagator = null;
}
// base
boolean hasApproximation = false;
float sumMatchCost = 0;
long sumApproxCost = 0;
// Compute matchCost as the average over the matchCost of the subScorers.
// This is weighted by the cost, which is an expected number of matching documents.
for (DisiWrapper w : subScorersPQ) {
long costWeight = (w.cost <= 1) ? 1 : w.cost;
sumApproxCost += costWeight;
if (w.twoPhaseView != null) {
hasApproximation = true;
sumMatchCost += w.matchCost * costWeight;
}
}
if (!hasApproximation) { // no sub scorer supports approximations
twoPhase = null;
} else {
final float matchCost = sumMatchCost / sumApproxCost;
twoPhase = new TwoPhase(approximation, matchCost, subScorersPQ, needsScores);
}
}

@Override
public int advanceShallow(int target) throws IOException {
if (disjunctionBlockPropagator != null) {
return disjunctionBlockPropagator.advanceShallow(target);
}
return super.advanceShallow(target);
}

/**
Expand All @@ -55,7 +108,7 @@ public HybridQueryScorer(Weight weight, List<Scorer> subScorers) throws IOExcept
*/
@Override
public float score() throws IOException {
DisiWrapper topList = subScorersPQ.topList();
/*DisiWrapper topList = subScorersPQ.topList();
float totalScore = 0.0f;
for (DisiWrapper disiWrapper = topList; disiWrapper != null; disiWrapper = disiWrapper.next) {
// check if this doc has match in the subQuery. If not, add score as 0.0 and continue
Expand All @@ -64,16 +117,58 @@ public float score() throws IOException {
}
totalScore += disiWrapper.scorer.score();
}
return totalScore;*/
return score(getSubMatches());
}

private float score(DisiWrapper topList) throws IOException {
/*float scoreMax = 0;
double otherScoreSum = 0;
for (DisiWrapper w = topList; w != null; w = w.next) {
float subScore = w.scorer.score();
if (subScore >= scoreMax) {
otherScoreSum += scoreMax;
scoreMax = subScore;
} else {
otherScoreSum += subScore;
}
}
return (float) (scoreMax + otherScoreSum);*/
float totalScore = 0.0f;
for (DisiWrapper disiWrapper = topList; disiWrapper != null; disiWrapper = disiWrapper.next) {
// check if this doc has match in the subQuery. If not, add score as 0.0 and continue
if (disiWrapper.scorer.docID() == NO_MORE_DOCS) {
continue;
}
totalScore += disiWrapper.scorer.score();
}
return totalScore;
}

DisiWrapper getSubMatches() throws IOException {
if (twoPhase == null) {
return subScorersPQ.topList();
} else {
return twoPhase.getSubMatches();
}
}

/**
* Return a DocIdSetIterator over matching documents.
* @return DocIdSetIterator object
*/
@Override
public DocIdSetIterator iterator() {
return new DisjunctionDISIApproximation(this.subScorersPQ);
if (twoPhase != null) {
return TwoPhaseIterator.asDocIdSetIterator(twoPhase);
} else {
return approximation;
}
}

@Override
public TwoPhaseIterator twoPhaseIterator() {
return twoPhase;
}

/**
Expand All @@ -93,12 +188,28 @@ public float getMaxScore(int upTo) throws IOException {
}).max(Float::compare).orElse(0.0f);
}

@Override
public void setMinCompetitiveScore(float minScore) throws IOException {
if (disjunctionBlockPropagator != null) {
disjunctionBlockPropagator.setMinCompetitiveScore(minScore);
}

for (Scorer scorer : subScorers) {
if (Objects.nonNull(scorer)) {
scorer.setMinCompetitiveScore(minScore);
}
}
}

/**
* Returns the doc ID that is currently being scored.
* @return document id
*/
@Override
public int docID() {
if (subScorersPQ.size() == 0) {
return NO_MORE_DOCS;
}
return subScorersPQ.top().doc;
}

Expand Down Expand Up @@ -169,4 +280,151 @@ private DisiPriorityQueue initializeSubScorersPQ() {
}
return subScorersPQ;
}

@Override
public Collection<ChildScorable> getChildren() throws IOException {
ArrayList<ChildScorable> children = new ArrayList<>();
for (DisiWrapper scorer = getSubMatches(); scorer != null; scorer = scorer.next) {
children.add(new ChildScorable(scorer.scorer, "SHOULD"));
}
return children;
}

static class TwoPhase extends TwoPhaseIterator {
private final float matchCost;
// list of verified matches on the current doc
DisiWrapper verifiedMatches;
// priority queue of approximations on the current doc that have not been verified yet
final PriorityQueue<DisiWrapper> unverifiedMatches;
DisiPriorityQueue subScorers;
boolean needsScores;

private TwoPhase(DocIdSetIterator approximation, float matchCost, DisiPriorityQueue subScorers, boolean needsScores) {
super(approximation);
this.matchCost = matchCost;
this.subScorers = subScorers;
unverifiedMatches = new PriorityQueue<>(subScorers.size()) {
@Override
protected boolean lessThan(DisiWrapper a, DisiWrapper b) {
return a.matchCost < b.matchCost;
}
};
this.needsScores = needsScores;
}

DisiWrapper getSubMatches() throws IOException {
// iteration order does not matter
for (DisiWrapper w : unverifiedMatches) {
if (w.twoPhaseView.matches()) {
w.next = verifiedMatches;
verifiedMatches = w;
}
}
unverifiedMatches.clear();
return verifiedMatches;
}

@Override
public boolean matches() throws IOException {
verifiedMatches = null;
unverifiedMatches.clear();

for (DisiWrapper w = subScorers.topList(); w != null;) {
DisiWrapper next = w.next;

if (w.twoPhaseView == null) {
// implicitly verified, move it to verifiedMatches
w.next = verifiedMatches;
verifiedMatches = w;

if (!needsScores) {
// we can stop here
return true;
}
} else {
unverifiedMatches.add(w);
}
w = next;
}

if (verifiedMatches != null) {
return true;
}

// verify subs that have an two-phase iterator
// least-costly ones first
while (unverifiedMatches.size() > 0) {
DisiWrapper w = unverifiedMatches.pop();
if (w.twoPhaseView.matches()) {
w.next = null;
verifiedMatches = w;
return true;
}
}

return false;
}

@Override
public float matchCost() {
return matchCost;
}
}

static class HybridDisjunctionDISIApproximation extends DocIdSetIterator {

final DisiPriorityQueue subIterators;
final long cost;

public HybridDisjunctionDISIApproximation(DisiPriorityQueue subIterators) {
this.subIterators = subIterators;
long cost = 0;
for (DisiWrapper w : subIterators) {
cost += w.cost;
}
this.cost = cost;
}

@Override
public long cost() {
return cost;
}

@Override
public int docID() {
if (subIterators.size() == 0) {
return NO_MORE_DOCS;
}
return subIterators.top().doc;
}

@Override
public int nextDoc() throws IOException {
if (subIterators.size() == 0) {
return NO_MORE_DOCS;
}
DisiWrapper top = subIterators.top();
final int doc = top.doc;
do {
top.doc = top.approximation.nextDoc();
top = subIterators.updateTop();
} while (top.doc == doc);

return top.doc;
}

@Override
public int advance(int target) throws IOException {
if (subIterators.size() == 0) {
return NO_MORE_DOCS;
}
DisiWrapper top = subIterators.top();
do {
top.doc = top.approximation.advance(target);
top = subIterators.updateTop();
} while (top.doc < target);

return top.doc;
}
}
}
Loading

0 comments on commit 57c5c4f

Please sign in to comment.