Skip to content

Commit

Permalink
Adds ef_search support for Lucene kNN queries
Browse files Browse the repository at this point in the history
Signed-off-by: Tejas Shah <[email protected]>
  • Loading branch information
shatejas committed Jun 13, 2024
1 parent 3ed8dfa commit 8a33cd0
Show file tree
Hide file tree
Showing 7 changed files with 133 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import java.util.Locale;
import java.util.Map;

import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH;
import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;
import static org.opensearch.knn.index.VectorDataType.SUPPORTED_VECTOR_DATA_TYPES;

Expand Down Expand Up @@ -101,12 +102,17 @@ public static Query create(CreateQueryRequest createQueryRequest) {
.build();
}

Integer requestEfSearch = null;
if (methodParameters != null && methodParameters.containsKey(METHOD_PARAMETER_EF_SEARCH)) {
requestEfSearch = (Integer) methodParameters.get(METHOD_PARAMETER_EF_SEARCH);
}
int luceneK = requestEfSearch == null ? k : Math.max(k, requestEfSearch);
log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k));
switch (vectorDataType) {
case BYTE:
return getKnnByteVectorQuery(fieldName, byteVector, k, filterQuery, parentFilter);
return getKnnByteVectorQuery(fieldName, byteVector, luceneK, filterQuery, parentFilter);
case FLOAT:
return getKnnFloatVectorQuery(fieldName, vector, k, filterQuery, parentFilter);
return getKnnFloatVectorQuery(fieldName, vector, luceneK, filterQuery, parentFilter);
default:
throw new IllegalArgumentException(
String.format(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import java.util.Map;
import java.util.function.Function;

import static org.opensearch.knn.index.IndexUtil.isClusterOnOrAfterMinRequiredVersion;
import static org.opensearch.knn.index.query.KNNQueryBuilder.METHOD_PARAMS_FIELD;
import static org.opensearch.knn.index.query.KNNQueryBuilder.NAME;

Expand Down
2 changes: 1 addition & 1 deletion src/main/java/org/opensearch/knn/index/util/Lucene.java
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ public class Lucene extends JVMLibrary {
* @param distanceTransform Map of space type to distance transformation function
*/
Lucene(Map<String, KNNMethod> methods, String version, Map<SpaceType, Function<Float, Float>> distanceTransform) {
super(methods, Map.of(METHOD_HNSW, EngineSpecificMethodContext.EMPTY), version);
super(methods, Map.of(METHOD_HNSW, new DefaultHnswContext()), version);
this.distanceTransform = distanceTransform;
}

Expand Down
28 changes: 27 additions & 1 deletion src/test/java/org/opensearch/knn/index/LuceneEngineIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -499,13 +499,18 @@ private void baseQueryTest(SpaceType spaceType) throws Exception {
}

validateQueries(spaceType, FIELD_NAME);
validateQueries(spaceType, FIELD_NAME, Map.of("ef_search", 100));
}

private void validateQueries(SpaceType spaceType, String fieldName) throws Exception {
validateQueries(spaceType, fieldName, null);
}

private void validateQueries(SpaceType spaceType, String fieldName, Map<String, ?> methodParameters) throws Exception {

int k = LuceneEngineIT.TEST_INDEX_VECTORS.length;
for (float[] queryVector : TEST_QUERY_VECTORS) {
Response response = searchKNNIndex(INDEX_NAME, new KNNQueryBuilder(fieldName, queryVector, k), k);
Response response = searchKNNIndex(INDEX_NAME, buildLuceneKSearchQuery(k, queryVector, methodParameters), k);
String responseBody = EntityUtils.toString(response.getEntity());
List<KNNResult> knnResults = parseSearchResponse(responseBody, fieldName);
assertEquals(k, knnResults.size());
Expand All @@ -520,6 +525,27 @@ private void validateQueries(SpaceType spaceType, String fieldName) throws Excep
}
}

@SneakyThrows
private XContentBuilder buildLuceneKSearchQuery(int k, float[] vector, Map<String, ?> methodParams) {
XContentBuilder builder = XContentFactory.jsonBuilder()
.startObject()
.startObject("query")
.startObject("knn")
.startObject(FIELD_NAME)
.field("vector", vector)
.field("k", k);
if (methodParams != null) {
builder.startObject("method_parameters");
for (Map.Entry<String, ?> entry : methodParams.entrySet()) {
builder.field(entry.getKey(), entry.getValue());
}
builder.endObject();
}

builder.endObject().endObject().endObject().endObject();
return builder;
}

private List<float[]> queryResults(final float[] searchVector, final int k) throws Exception {
final String responseBody = EntityUtils.toString(
searchKNNIndex(INDEX_NAME, new KNNQueryBuilder(FIELD_NAME, searchVector, k), k).getEntity()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -825,14 +825,15 @@ public void testDoToQuery_WhenknnQueryWithFilterAndFaissEngine_thenSuccess() {
assertEquals(HNSW_METHOD_PARAMS, ((KNNQuery) query).getMethodParameters());
}

/** This test should be uncommented once we have nprobs. Considering engine instance is static its not possible to test this right now
public void testDoToQuery_ThrowsIllegalArgumentExceptionForUnknownMethodParameter() {
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class);
when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField);
when(mockKNNVectorField.getDimension()).thenReturn(4);
when(mockKNNVectorField.getKnnMethodContext()).thenReturn(
new KNNMethodContext(KNNEngine.NMSLIB, SpaceType.COSINESIMIL, new MethodComponentContext("hnsw", Map.of()))
new KNNMethodContext(KNNEngine.LUCENE, SpaceType.COSINESIMIL, new MethodComponentContext("hnsw", Map.of()))
);
float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f };
Expand All @@ -844,7 +845,7 @@ public void testDoToQuery_ThrowsIllegalArgumentExceptionForUnknownMethodParamete
.build();
expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext));
}
}**/

public void testDoToQuery_whenknnQueryWithFilterAndNmsLibEngine_thenException() {
float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f };
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.knn.index.query;

import org.apache.lucene.index.Term;
import org.apache.lucene.search.KnnByteVectorQuery;
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TermQuery;
Expand Down Expand Up @@ -85,6 +86,98 @@ public void testCreateLuceneDefaultQuery() {
}
}

public void testLuceneFloatVectorQuery() {
Query actualQuery1 = KNNQueryFactory.create(
BaseQueryFactory.CreateQueryRequest.builder()
.knnEngine(KNNEngine.LUCENE)
.vector(testQueryVector)
.k(testK)
.indexName(testIndexName)
.fieldName(testFieldName)
.methodParameters(methodParameters)
.vectorDataType(VectorDataType.FLOAT)
.build()
);

// efsearch > k
Query expectedQuery1 = new KnnFloatVectorQuery(testFieldName, testQueryVector, 100, null);
assertEquals(expectedQuery1, actualQuery1);

// efsearch < k
actualQuery1 = KNNQueryFactory.create(
BaseQueryFactory.CreateQueryRequest.builder()
.knnEngine(KNNEngine.LUCENE)
.vector(testQueryVector)
.k(testK)
.indexName(testIndexName)
.fieldName(testFieldName)
.methodParameters(Map.of("ef_search", 1))
.vectorDataType(VectorDataType.FLOAT)
.build()
);
expectedQuery1 = new KnnFloatVectorQuery(testFieldName, testQueryVector, testK, null);
assertEquals(expectedQuery1, actualQuery1);

actualQuery1 = KNNQueryFactory.create(
BaseQueryFactory.CreateQueryRequest.builder()
.knnEngine(KNNEngine.LUCENE)
.vector(testQueryVector)
.k(testK)
.indexName(testIndexName)
.fieldName(testFieldName)
.vectorDataType(VectorDataType.FLOAT)
.build()
);
expectedQuery1 = new KnnFloatVectorQuery(testFieldName, testQueryVector, testK, null);
assertEquals(expectedQuery1, actualQuery1);
}

public void testLuceneByteVectorQuery() {
Query actualQuery1 = KNNQueryFactory.create(
BaseQueryFactory.CreateQueryRequest.builder()
.knnEngine(KNNEngine.LUCENE)
.byteVector(testByteQueryVector)
.k(testK)
.indexName(testIndexName)
.fieldName(testFieldName)
.methodParameters(methodParameters)
.vectorDataType(VectorDataType.BYTE)
.build()
);

// efsearch > k
Query expectedQuery1 = new KnnByteVectorQuery(testFieldName, testByteQueryVector, 100, null);
assertEquals(expectedQuery1, actualQuery1);

// efsearch < k
actualQuery1 = KNNQueryFactory.create(
BaseQueryFactory.CreateQueryRequest.builder()
.knnEngine(KNNEngine.LUCENE)
.byteVector(testByteQueryVector)
.k(testK)
.indexName(testIndexName)
.fieldName(testFieldName)
.methodParameters(Map.of("ef_search", 1))
.vectorDataType(VectorDataType.BYTE)
.build()
);
expectedQuery1 = new KnnByteVectorQuery(testFieldName, testByteQueryVector, testK, null);
assertEquals(expectedQuery1, actualQuery1);

actualQuery1 = KNNQueryFactory.create(
BaseQueryFactory.CreateQueryRequest.builder()
.knnEngine(KNNEngine.LUCENE)
.byteVector(testByteQueryVector)
.k(testK)
.indexName(testIndexName)
.fieldName(testFieldName)
.vectorDataType(VectorDataType.BYTE)
.build()
);
expectedQuery1 = new KnnByteVectorQuery(testFieldName, testByteQueryVector, testK, null);
assertEquals(expectedQuery1, actualQuery1);
}

public void testCreateLuceneQueryWithFilter() {
List<KNNEngine> luceneDefaultQueryEngineList = Arrays.stream(KNNEngine.values())
.filter(knnEngine -> !KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine))
Expand Down
13 changes: 1 addition & 12 deletions src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java
Original file line number Diff line number Diff line change
Expand Up @@ -193,18 +193,7 @@ protected Response searchKNNIndex(String index, KNNQueryBuilder knnQueryBuilder,
XContentBuilder builder = XContentFactory.jsonBuilder().startObject().startObject("query");
knnQueryBuilder.doXContent(builder, ToXContent.EMPTY_PARAMS);
builder.endObject().endObject();

Request request = new Request("POST", "/" + index + "/_search");

request.addParameter("size", Integer.toString(resultSize));
request.addParameter("explain", Boolean.toString(true));
request.addParameter("search_type", "query_then_fetch");
request.setJsonEntity(builder.toString());

Response response = client().performRequest(request);
assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));

return response;
return searchKNNIndex(index, builder, resultSize);
}

/**
Expand Down

0 comments on commit 8a33cd0

Please sign in to comment.