Skip to content

Commit

Permalink
Add timeout support to AbstractVectorSimilarityQuery (apache#13285)
Browse files Browse the repository at this point in the history
Co-authored-by: Kaival Parikh <[email protected]>
  • Loading branch information
kaivalnp and Kaival Parikh authored Aug 6, 2024
1 parent 43c8011 commit e0e5d81
Show file tree
Hide file tree
Showing 5 changed files with 235 additions and 54 deletions.
3 changes: 3 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,9 @@ Improvements

* GITHUB#13625: Remove BitSet#nextSetBit code duplication. (Greg Miller)

* GITHUB#13285: Early terminate graph searches of AbstractVectorSimilarityQuery to follow timeout set from
IndexSearcher#setTimeout(QueryTimeout). (Kaival Parikh)

Optimizations
---------------------

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import java.util.Objects;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.QueryTimeout;
import org.apache.lucene.search.knn.KnnCollectorManager;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.Bits;
Expand Down Expand Up @@ -58,10 +60,19 @@ abstract class AbstractVectorSimilarityQuery extends Query {
this.filter = filter;
}

protected KnnCollectorManager getKnnCollectorManager() {
return (visitedLimit, context) ->
new VectorSimilarityCollector(traversalSimilarity, resultSimilarity, visitedLimit);
}

abstract VectorScorer createVectorScorer(LeafReaderContext context) throws IOException;

protected abstract TopDocs approximateSearch(
LeafReaderContext context, Bits acceptDocs, int visitLimit) throws IOException;
LeafReaderContext context,
Bits acceptDocs,
int visitLimit,
KnnCollectorManager knnCollectorManager)
throws IOException;

@Override
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost)
Expand All @@ -72,6 +83,10 @@ public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float bo
? null
: searcher.createWeight(searcher.rewrite(filter), ScoreMode.COMPLETE_NO_SCORES, 1);

final QueryTimeout queryTimeout = searcher.getTimeout();
final TimeLimitingKnnCollectorManager timeLimitingKnnCollectorManager =
new TimeLimitingKnnCollectorManager(getKnnCollectorManager(), queryTimeout);

@Override
public Explanation explain(LeafReaderContext context, int doc) throws IOException {
if (filterWeight != null) {
Expand Down Expand Up @@ -103,16 +118,14 @@ public Explanation explain(LeafReaderContext context, int doc) throws IOExceptio
public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException {
LeafReader leafReader = context.reader();
Bits liveDocs = leafReader.getLiveDocs();
final Scorer vectorSimilarityScorer;

// If there is no filter
if (filterWeight == null) {
// Return exhaustive results
TopDocs results = approximateSearch(context, liveDocs, Integer.MAX_VALUE);
if (results.scoreDocs.length == 0) {
return null;
}
vectorSimilarityScorer =
VectorSimilarityScorer.fromScoreDocs(this, boost, results.scoreDocs);
TopDocs results =
approximateSearch(
context, liveDocs, Integer.MAX_VALUE, timeLimitingKnnCollectorManager);
return VectorSimilarityScorerSupplier.fromScoreDocs(boost, results.scoreDocs);
} else {
Scorer scorer = filterWeight.scorer(context);
if (scorer == null) {
Expand Down Expand Up @@ -143,27 +156,23 @@ protected boolean match(int doc) {
}

// Perform an approximate search
TopDocs results = approximateSearch(context, acceptDocs, cardinality);
TopDocs results =
approximateSearch(context, acceptDocs, cardinality, timeLimitingKnnCollectorManager);

// If the limit was exhausted
if (results.totalHits.relation == TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO) {
// Return a lazy-loading iterator
vectorSimilarityScorer =
VectorSimilarityScorer.fromAcceptDocs(
this,
boost,
createVectorScorer(context),
new BitSetIterator(acceptDocs, cardinality),
resultSimilarity);
} else if (results.scoreDocs.length == 0) {
return null;
} else {
if (results.totalHits.relation == TotalHits.Relation.EQUAL_TO
// Return partial results only when timeout is met
|| (queryTimeout != null && queryTimeout.shouldExit())) {
// Return an iterator over the collected results
vectorSimilarityScorer =
VectorSimilarityScorer.fromScoreDocs(this, boost, results.scoreDocs);
return VectorSimilarityScorerSupplier.fromScoreDocs(boost, results.scoreDocs);
} else {
// Return a lazy-loading iterator
return VectorSimilarityScorerSupplier.fromAcceptDocs(
boost,
createVectorScorer(context),
new BitSetIterator(acceptDocs, cardinality),
resultSimilarity);
}
}
return new DefaultScorerSupplier(vectorSimilarityScorer);
}

@Override
Expand Down Expand Up @@ -197,16 +206,20 @@ public int hashCode() {
return Objects.hash(field, traversalSimilarity, resultSimilarity, filter);
}

private static class VectorSimilarityScorer extends Scorer {
private static class VectorSimilarityScorerSupplier extends ScorerSupplier {
final DocIdSetIterator iterator;
final float[] cachedScore;

VectorSimilarityScorer(DocIdSetIterator iterator, float[] cachedScore) {
VectorSimilarityScorerSupplier(DocIdSetIterator iterator, float[] cachedScore) {
this.iterator = iterator;
this.cachedScore = cachedScore;
}

static VectorSimilarityScorer fromScoreDocs(Weight weight, float boost, ScoreDoc[] scoreDocs) {
static VectorSimilarityScorerSupplier fromScoreDocs(float boost, ScoreDoc[] scoreDocs) {
if (scoreDocs.length == 0) {
return null;
}

// Sort in ascending order of docid
Arrays.sort(scoreDocs, Comparator.comparingInt(scoreDoc -> scoreDoc.doc));

Expand Down Expand Up @@ -252,18 +265,15 @@ public long cost() {
}
};

return new VectorSimilarityScorer(iterator, cachedScore);
return new VectorSimilarityScorerSupplier(iterator, cachedScore);
}

static VectorSimilarityScorer fromAcceptDocs(
Weight weight,
float boost,
VectorScorer scorer,
DocIdSetIterator acceptDocs,
float threshold) {
static VectorSimilarityScorerSupplier fromAcceptDocs(
float boost, VectorScorer scorer, DocIdSetIterator acceptDocs, float threshold) {
if (scorer == null) {
return null;
}

float[] cachedScore = new float[1];
DocIdSetIterator vectorIterator = scorer.iterator();
DocIdSetIterator conjunction =
Expand All @@ -281,27 +291,37 @@ protected boolean match(int doc) throws IOException {
}
};

return new VectorSimilarityScorer(iterator, cachedScore);
return new VectorSimilarityScorerSupplier(iterator, cachedScore);
}

@Override
public int docID() {
return iterator.docID();
}
public Scorer get(long leadCost) {
return new Scorer() {
@Override
public int docID() {
return iterator.docID();
}

@Override
public DocIdSetIterator iterator() {
return iterator;
}
@Override
public DocIdSetIterator iterator() {
return iterator;
}

@Override
public float getMaxScore(int upTo) {
return Float.POSITIVE_INFINITY;
@Override
public float getMaxScore(int upTo) {
return Float.POSITIVE_INFINITY;
}

@Override
public float score() {
return cachedScore[0];
}
};
}

@Override
public float score() {
return cachedScore[0];
public long cost() {
return iterator.cost();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.apache.lucene.document.KnnByteVectorField;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.knn.KnnCollectorManager;
import org.apache.lucene.util.Bits;

/**
Expand Down Expand Up @@ -106,10 +107,13 @@ VectorScorer createVectorScorer(LeafReaderContext context) throws IOException {

@Override
@SuppressWarnings("resource")
protected TopDocs approximateSearch(LeafReaderContext context, Bits acceptDocs, int visitLimit)
protected TopDocs approximateSearch(
LeafReaderContext context,
Bits acceptDocs,
int visitLimit,
KnnCollectorManager knnCollectorManager)
throws IOException {
KnnCollector collector =
new VectorSimilarityCollector(traversalSimilarity, resultSimilarity, visitLimit);
KnnCollector collector = knnCollectorManager.newCollector(visitLimit, context);
context.reader().searchNearestVectors(field, target, collector, acceptDocs);
return collector.topDocs();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.apache.lucene.document.KnnFloatVectorField;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.knn.KnnCollectorManager;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.VectorUtil;

Expand Down Expand Up @@ -108,10 +109,13 @@ VectorScorer createVectorScorer(LeafReaderContext context) throws IOException {

@Override
@SuppressWarnings("resource")
protected TopDocs approximateSearch(LeafReaderContext context, Bits acceptDocs, int visitLimit)
protected TopDocs approximateSearch(
LeafReaderContext context,
Bits acceptDocs,
int visitLimit,
KnnCollectorManager knnCollectorManager)
throws IOException {
KnnCollector collector =
new VectorSimilarityCollector(traversalSimilarity, resultSimilarity, visitLimit);
KnnCollector collector = knnCollectorManager.newCollector(visitLimit, context);
context.reader().searchNearestVectors(field, target, collector, acceptDocs);
return collector.topDocs();
}
Expand Down
Loading

0 comments on commit e0e5d81

Please sign in to comment.