Skip to content

Commit

Permalink
Use request size when k is null to calculate the number of results to…
Browse files Browse the repository at this point in the history
… retrieve from each shard
  • Loading branch information
carlosdelest committed Dec 18, 2024
1 parent e0763c2 commit e03f240
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 87 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2008,6 +2008,7 @@ public Query createKnnQuery(
VectorData queryVector,
Integer k,
int numCands,
int requestSize,
Float numCandsFactor,
Query filter,
Float similarityThreshold,
Expand All @@ -2024,6 +2025,7 @@ public Query createKnnQuery(
queryVector.asFloatVector(),
k,
numCands,
requestSize,
numCandsFactor,
filter,
similarityThreshold,
Expand Down Expand Up @@ -2090,6 +2092,7 @@ private Query createKnnFloatQuery(
float[] queryVector,
Integer k,
int numCands,
int requestSize,
Float numCandsFactor,
Query filter,
Float similarityThreshold,
Expand Down Expand Up @@ -2127,7 +2130,7 @@ && isNotUnitVector(squaredMagnitude)) {
name(),
queryVector,
similarity.vectorSimilarityFunction(indexVersionCreated, ElementType.FLOAT),
k,
k == null ? requestSize : k,
knnQuery
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
import org.apache.lucene.search.DoubleValues;
import org.apache.lucene.search.DoubleValuesSource;
import org.apache.lucene.search.IndexSearcher;
import org.elasticsearch.search.profile.query.QueryProfiler;
import org.elasticsearch.search.vectors.QueryProfilerProvider;

import java.io.IOException;
import java.util.Arrays;
Expand All @@ -29,12 +27,11 @@
* DoubleValuesSource that is used to calculate scores according to a similarity function for a KnnFloatVectorField, using the
* original vector values stored in the index
*/
public class VectorSimilarityFloatValueSource extends DoubleValuesSource implements QueryProfilerProvider {
public class VectorSimilarityFloatValueSource extends DoubleValuesSource {

private final String field;
private final float[] target;
private final VectorSimilarityFunction vectorSimilarityFunction;
private long vectorOpsCount;

public VectorSimilarityFloatValueSource(String field, float[] target, VectorSimilarityFunction vectorSimilarityFunction) {
this.field = field;
Expand All @@ -52,7 +49,6 @@ public DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores) throws
return new DoubleValues() {
@Override
public double doubleValue() throws IOException {
vectorOpsCount++;
return vectorSimilarityFunction.compare(target, vectorValues.vectorValue(iterator.index()));
}

Expand All @@ -73,11 +69,6 @@ public DoubleValuesSource rewrite(IndexSearcher reader) throws IOException {
return this;
}

@Override
public void profile(QueryProfiler queryProfiler) {
queryProfiler.addVectorOpsCount(vectorOpsCount);
}

@Override
public int hashCode() {
return Objects.hash(field, Arrays.hashCode(target), vectorSimilarityFunction);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -528,8 +528,8 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException {
String parentPath = context.nestedLookup().getNestedParent(fieldName);
Float numCandidatesFactor = rescoreVectorBuilder() == null ? null : rescoreVectorBuilder.numCandidatesFactor();

BitSetProducer parentBitSet = null;
if (parentPath != null) {
final BitSetProducer parentBitSet;
final Query parentFilter;
NestedObjectMapper originalObjectMapper = context.nestedScope().getObjectMapper();
if (originalObjectMapper != null) {
Expand Down Expand Up @@ -558,17 +558,18 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException {
// Now join the filterQuery & parentFilter to provide the matching blocks of children
filterQuery = new ToChildBlockJoinQuery(filterQuery, parentBitSet);
}
return vectorFieldType.createKnnQuery(
queryVector,
k,
adjustedNumCands,
numCandidatesFactor,
filterQuery,
vectorSimilarity,
parentBitSet
);
}
return vectorFieldType.createKnnQuery(queryVector, k, adjustedNumCands, numCandidatesFactor, filterQuery, vectorSimilarity, null);

return vectorFieldType.createKnnQuery(
queryVector,
k,
adjustedNumCands,
requestSize,
numCandidatesFactor,
filterQuery,
vectorSimilarity,
parentBitSet
);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,15 @@ public class RescoreKnnVectorQuery extends Query implements QueryProfilerProvide
private final String fieldName;
private final float[] floatTarget;
private final VectorSimilarityFunction vectorSimilarityFunction;
private final Integer k;
private final int k;
private final Query innerQuery;

private QueryProfilerProvider vectorProfiling;
private long vectorOperations = 0;

public RescoreKnnVectorQuery(
String fieldName,
float[] floatTarget,
VectorSimilarityFunction vectorSimilarityFunction,
Integer k,
int k,
Query innerQuery
) {
this.fieldName = fieldName;
Expand All @@ -54,19 +53,12 @@ public RescoreKnnVectorQuery(
@Override
public Query rewrite(IndexSearcher searcher) throws IOException {
DoubleValuesSource valueSource = new VectorSimilarityFloatValueSource(fieldName, floatTarget, vectorSimilarityFunction);
// Vector similarity VectorSimilarityFloatValueSource keep track of the compared vectors - we need that in case we don't need
// to calculate top k and return directly the query to understand how many comparisons were done
vectorProfiling = (QueryProfilerProvider) valueSource;
FunctionScoreQuery functionScoreQuery = new FunctionScoreQuery(innerQuery, valueSource);
Query query = searcher.rewrite(functionScoreQuery);

if (k == null) {
// No need to calculate top k - let the request size limit the results.
return query;
}

// Retrieve top k documents from the rescored query
TopDocs topDocs = searcher.search(query, k);
vectorOperations = topDocs.totalHits.value();
ScoreDoc[] scoreDocs = topDocs.scoreDocs;
int[] docIds = new int[scoreDocs.length];
float[] scores = new float[scoreDocs.length];
Expand All @@ -82,7 +74,7 @@ public Query innerQuery() {
return innerQuery;
}

public Integer k() {
public int k() {
return k;
}

Expand All @@ -92,10 +84,7 @@ public void profile(QueryProfiler queryProfiler) {
queryProfilerProvider.profile(queryProfiler);
}

if (vectorProfiling == null) {
throw new IllegalStateException("Query should have been rewritten");
}
vectorProfiling.profile(queryProfiler);
queryProfiler.addVectorOpsCount(vectorOperations);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1674,7 +1674,7 @@ public void testByteVectorQueryBoundaries() throws IOException {

Exception e = expectThrows(
IllegalArgumentException.class,
() -> denseVectorFieldType.createKnnQuery(VectorData.fromFloats(new float[] { 128, 0, 0 }), 3, 3, null, null, null, null)
() -> denseVectorFieldType.createKnnQuery(VectorData.fromFloats(new float[] { 128, 0, 0 }), 3, 3, 3, null, null, null, null)
);
assertThat(
e.getMessage(),
Expand All @@ -1687,6 +1687,7 @@ public void testByteVectorQueryBoundaries() throws IOException {
VectorData.fromFloats(new float[] { 0.0f, 0f, -129.0f }),
3,
3,
3,
null,
null,
null,
Expand All @@ -1700,7 +1701,16 @@ public void testByteVectorQueryBoundaries() throws IOException {

e = expectThrows(
IllegalArgumentException.class,
() -> denseVectorFieldType.createKnnQuery(VectorData.fromFloats(new float[] { 0.0f, 0.5f, 0.0f }), 3, 3, null, null, null, null)
() -> denseVectorFieldType.createKnnQuery(
VectorData.fromFloats(new float[] { 0.0f, 0.5f, 0.0f }),
3,
3,
3,
null,
null,
null,
null
)
);
assertThat(
e.getMessage(),
Expand All @@ -1709,7 +1719,16 @@ public void testByteVectorQueryBoundaries() throws IOException {

e = expectThrows(
IllegalArgumentException.class,
() -> denseVectorFieldType.createKnnQuery(VectorData.fromFloats(new float[] { 0, 0.0f, -0.25f }), 3, 3, null, null, null, null)
() -> denseVectorFieldType.createKnnQuery(
VectorData.fromFloats(new float[] { 0, 0.0f, -0.25f }),
3,
3,
3,
null,
null,
null,
null
)
);
assertThat(
e.getMessage(),
Expand All @@ -1722,6 +1741,7 @@ public void testByteVectorQueryBoundaries() throws IOException {
VectorData.fromFloats(new float[] { Float.NaN, 0f, 0.0f }),
3,
3,
3,
null,
null,
null,
Expand All @@ -1736,6 +1756,7 @@ public void testByteVectorQueryBoundaries() throws IOException {
VectorData.fromFloats(new float[] { Float.POSITIVE_INFINITY, 0f, 0.0f }),
3,
3,
3,
null,
null,
null,
Expand All @@ -1753,6 +1774,7 @@ public void testByteVectorQueryBoundaries() throws IOException {
VectorData.fromFloats(new float[] { 0, Float.NEGATIVE_INFINITY, 0.0f }),
3,
3,
3,
null,
null,
null,
Expand Down Expand Up @@ -1787,6 +1809,7 @@ public void testFloatVectorQueryBoundaries() throws IOException {
VectorData.fromFloats(new float[] { Float.NaN, 0f, 0.0f }),
3,
3,
3,
null,
null,
null,
Expand All @@ -1801,6 +1824,7 @@ public void testFloatVectorQueryBoundaries() throws IOException {
VectorData.fromFloats(new float[] { Float.POSITIVE_INFINITY, 0f, 0.0f }),
3,
3,
3,
null,
null,
null,
Expand All @@ -1818,6 +1842,7 @@ public void testFloatVectorQueryBoundaries() throws IOException {
VectorData.fromFloats(new float[] { 0, Float.NEGATIVE_INFINITY, 0.0f }),
3,
3,
3,
null,
null,
null,
Expand Down
Loading

0 comments on commit e03f240

Please sign in to comment.