Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds ef_search support for Lucene kNN queries #1748

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import java.util.Locale;
import java.util.Map;

import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH;
import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;
import static org.opensearch.knn.index.VectorDataType.SUPPORTED_VECTOR_DATA_TYPES;

Expand Down Expand Up @@ -101,12 +102,17 @@ public static Query create(CreateQueryRequest createQueryRequest) {
.build();
}

Integer requestEfSearch = null;
if (methodParameters != null && methodParameters.containsKey(METHOD_PARAMETER_EF_SEARCH)) {
requestEfSearch = (Integer) methodParameters.get(METHOD_PARAMETER_EF_SEARCH);
}
int luceneK = requestEfSearch == null ? k : Math.max(k, requestEfSearch);
log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k));
switch (vectorDataType) {
case BYTE:
return getKnnByteVectorQuery(fieldName, byteVector, k, filterQuery, parentFilter);
return getKnnByteVectorQuery(fieldName, byteVector, luceneK, filterQuery, parentFilter);
case FLOAT:
return getKnnFloatVectorQuery(fieldName, vector, k, filterQuery, parentFilter);
return getKnnFloatVectorQuery(fieldName, vector, luceneK, filterQuery, parentFilter);
default:
throw new IllegalArgumentException(
String.format(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import java.util.Map;
import java.util.function.Function;

import static org.opensearch.knn.index.IndexUtil.isClusterOnOrAfterMinRequiredVersion;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we removing this ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was a leftover from this PR https://github.com/opensearch-project/k-NN/pull/1742/files. The import is not needed. Sorry for the confusion

import static org.opensearch.knn.index.query.KNNQueryBuilder.METHOD_PARAMS_FIELD;
import static org.opensearch.knn.index.query.KNNQueryBuilder.NAME;

Expand Down
2 changes: 1 addition & 1 deletion src/main/java/org/opensearch/knn/index/util/Lucene.java
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ public class Lucene extends JVMLibrary {
* @param distanceTransform Map of space type to distance transformation function
*/
Lucene(Map<String, KNNMethod> methods, String version, Map<SpaceType, Function<Float, Float>> distanceTransform) {
super(methods, Map.of(METHOD_HNSW, EngineSpecificMethodContext.EMPTY), version);
super(methods, Map.of(METHOD_HNSW, new DefaultHnswContext()), version);
this.distanceTransform = distanceTransform;
}

Expand Down
28 changes: 27 additions & 1 deletion src/test/java/org/opensearch/knn/index/LuceneEngineIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -499,13 +499,18 @@ private void baseQueryTest(SpaceType spaceType) throws Exception {
}

validateQueries(spaceType, FIELD_NAME);
validateQueries(spaceType, FIELD_NAME, Map.of("ef_search", 100));
}

private void validateQueries(SpaceType spaceType, String fieldName) throws Exception {
validateQueries(spaceType, fieldName, null);
}

private void validateQueries(SpaceType spaceType, String fieldName, Map<String, ?> methodParameters) throws Exception {

int k = LuceneEngineIT.TEST_INDEX_VECTORS.length;
for (float[] queryVector : TEST_QUERY_VECTORS) {
Response response = searchKNNIndex(INDEX_NAME, new KNNQueryBuilder(fieldName, queryVector, k), k);
Response response = searchKNNIndex(INDEX_NAME, buildLuceneKSearchQuery(fieldName, k, queryVector, methodParameters), k);
String responseBody = EntityUtils.toString(response.getEntity());
List<KNNResult> knnResults = parseSearchResponse(responseBody, fieldName);
assertEquals(k, knnResults.size());
Expand All @@ -520,6 +525,27 @@ private void validateQueries(SpaceType spaceType, String fieldName) throws Excep
}
}

@SneakyThrows
private XContentBuilder buildLuceneKSearchQuery(String fieldName, int k, float[] vector, Map<String, ?> methodParams) {
XContentBuilder builder = XContentFactory.jsonBuilder()
.startObject()
.startObject("query")
.startObject("knn")
.startObject(fieldName)
.field("vector", vector)
.field("k", k);
if (methodParams != null) {
builder.startObject("method_parameters");
for (Map.Entry<String, ?> entry : methodParams.entrySet()) {
builder.field(entry.getKey(), entry.getValue());
}
builder.endObject();
}

builder.endObject().endObject().endObject().endObject();
return builder;
}

private List<float[]> queryResults(final float[] searchVector, final int k) throws Exception {
final String responseBody = EntityUtils.toString(
searchKNNIndex(INDEX_NAME, new KNNQueryBuilder(FIELD_NAME, searchVector, k), k).getEntity()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -825,14 +825,15 @@ public void testDoToQuery_WhenknnQueryWithFilterAndFaissEngine_thenSuccess() {
assertEquals(HNSW_METHOD_PARAMS, ((KNNQuery) query).getMethodParameters());
}

/** This test should be uncommented once we have nprobs. Considering engine instance is static its not possible to test this right now
public void testDoToQuery_ThrowsIllegalArgumentExceptionForUnknownMethodParameter() {

QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class);
when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField);
when(mockKNNVectorField.getDimension()).thenReturn(4);
when(mockKNNVectorField.getKnnMethodContext()).thenReturn(
new KNNMethodContext(KNNEngine.NMSLIB, SpaceType.COSINESIMIL, new MethodComponentContext("hnsw", Map.of()))
new KNNMethodContext(KNNEngine.LUCENE, SpaceType.COSINESIMIL, new MethodComponentContext("hnsw", Map.of()))
);

float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f };
Expand All @@ -844,7 +845,7 @@ public void testDoToQuery_ThrowsIllegalArgumentExceptionForUnknownMethodParamete
.build();

expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext));
}
}**/

public void testDoToQuery_whenknnQueryWithFilterAndNmsLibEngine_thenException() {
float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f };
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.knn.index.query;

import org.apache.lucene.index.Term;
import org.apache.lucene.search.KnnByteVectorQuery;
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TermQuery;
Expand Down Expand Up @@ -85,6 +86,98 @@ public void testCreateLuceneDefaultQuery() {
}
}

public void testLuceneFloatVectorQuery() {
Query actualQuery1 = KNNQueryFactory.create(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a function to create the query as it is duplicated multiple times

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It has a different request everytime. The create request itself is using a builder so wrapping it up with a function is moot

BaseQueryFactory.CreateQueryRequest.builder()
.knnEngine(KNNEngine.LUCENE)
.vector(testQueryVector)
.k(testK)
.indexName(testIndexName)
.fieldName(testFieldName)
.methodParameters(methodParameters)
.vectorDataType(VectorDataType.FLOAT)
.build()
);

// efsearch > k
Query expectedQuery1 = new KnnFloatVectorQuery(testFieldName, testQueryVector, 100, null);
assertEquals(expectedQuery1, actualQuery1);

// efsearch < k
actualQuery1 = KNNQueryFactory.create(
BaseQueryFactory.CreateQueryRequest.builder()
.knnEngine(KNNEngine.LUCENE)
.vector(testQueryVector)
.k(testK)
.indexName(testIndexName)
.fieldName(testFieldName)
.methodParameters(Map.of("ef_search", 1))
.vectorDataType(VectorDataType.FLOAT)
.build()
);
expectedQuery1 = new KnnFloatVectorQuery(testFieldName, testQueryVector, testK, null);
assertEquals(expectedQuery1, actualQuery1);

actualQuery1 = KNNQueryFactory.create(
BaseQueryFactory.CreateQueryRequest.builder()
.knnEngine(KNNEngine.LUCENE)
.vector(testQueryVector)
.k(testK)
.indexName(testIndexName)
.fieldName(testFieldName)
.vectorDataType(VectorDataType.FLOAT)
.build()
);
expectedQuery1 = new KnnFloatVectorQuery(testFieldName, testQueryVector, testK, null);
assertEquals(expectedQuery1, actualQuery1);
}

public void testLuceneByteVectorQuery() {
Query actualQuery1 = KNNQueryFactory.create(
BaseQueryFactory.CreateQueryRequest.builder()
.knnEngine(KNNEngine.LUCENE)
.byteVector(testByteQueryVector)
.k(testK)
.indexName(testIndexName)
.fieldName(testFieldName)
.methodParameters(methodParameters)
.vectorDataType(VectorDataType.BYTE)
.build()
);

// efsearch > k
Query expectedQuery1 = new KnnByteVectorQuery(testFieldName, testByteQueryVector, 100, null);
assertEquals(expectedQuery1, actualQuery1);

// efsearch < k
actualQuery1 = KNNQueryFactory.create(
BaseQueryFactory.CreateQueryRequest.builder()
.knnEngine(KNNEngine.LUCENE)
.byteVector(testByteQueryVector)
.k(testK)
.indexName(testIndexName)
.fieldName(testFieldName)
.methodParameters(Map.of("ef_search", 1))
.vectorDataType(VectorDataType.BYTE)
.build()
);
expectedQuery1 = new KnnByteVectorQuery(testFieldName, testByteQueryVector, testK, null);
assertEquals(expectedQuery1, actualQuery1);

actualQuery1 = KNNQueryFactory.create(
BaseQueryFactory.CreateQueryRequest.builder()
.knnEngine(KNNEngine.LUCENE)
.byteVector(testByteQueryVector)
.k(testK)
.indexName(testIndexName)
.fieldName(testFieldName)
.vectorDataType(VectorDataType.BYTE)
.build()
);
expectedQuery1 = new KnnByteVectorQuery(testFieldName, testByteQueryVector, testK, null);
assertEquals(expectedQuery1, actualQuery1);
}

public void testCreateLuceneQueryWithFilter() {
List<KNNEngine> luceneDefaultQueryEngineList = Arrays.stream(KNNEngine.values())
.filter(knnEngine -> !KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine))
Expand Down
13 changes: 1 addition & 12 deletions src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java
Original file line number Diff line number Diff line change
Expand Up @@ -193,18 +193,7 @@ protected Response searchKNNIndex(String index, KNNQueryBuilder knnQueryBuilder,
XContentBuilder builder = XContentFactory.jsonBuilder().startObject().startObject("query");
knnQueryBuilder.doXContent(builder, ToXContent.EMPTY_PARAMS);
builder.endObject().endObject();

Request request = new Request("POST", "/" + index + "/_search");

request.addParameter("size", Integer.toString(resultSize));
request.addParameter("explain", Boolean.toString(true));
request.addParameter("search_type", "query_then_fetch");
request.setJsonEntity(builder.toString());

Response response = client().performRequest(request);
assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));

return response;
return searchKNNIndex(index, builder, resultSize);
}

/**
Expand Down
Loading