Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Faiss byte vector efficient filter bug #2448

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Fixing it to retrieve space_type from index setting when both method and top level don't have the value. [#2374](https://github.com/opensearch-project/k-NN/pull/2374)
* Fixing the bug where setting rescore as false for on_disk knn_vector query is a no-op (#2399)[https://github.com/opensearch-project/k-NN/pull/2399]
* Fixing the bug to prevent index.knn setting from being modified or removed on restore snapshot (#2445)[https://github.com/opensearch-project/k-NN/pull/2445]
* Fix Faiss byte vector efficient filter bug (#2448)[https://github.com/opensearch-project/k-NN/pull/2448]
### Infrastructure
* Updated C++ version in JNI from c++11 to c++17 [#2259](https://github.com/opensearch-project/k-NN/pull/2259)
* Upgrade bytebuddy and objenesis version to match OpenSearch core and, update github ci runner for macos [#2279](https://github.com/opensearch-project/k-NN/pull/2279)
Expand Down
21 changes: 18 additions & 3 deletions src/main/java/org/opensearch/knn/index/query/KNNWeight.java
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -438,9 +439,23 @@ private boolean isFilteredExactSearchPreferred(final int filterIdsCount) {
* TODO we can have a different MAX_DISTANCE_COMPUTATIONS for binary index as computation cost for binary index
* is cheaper than computation cost for non binary vector
*/
return KNNConstants.MAX_DISTANCE_COMPUTATIONS >= filterIdsCount * (knnQuery.getVectorDataType() == VectorDataType.FLOAT
? knnQuery.getQueryVector().length
: knnQuery.getByteQueryVector().length);
return KNNConstants.MAX_DISTANCE_COMPUTATIONS >= filterIdsCount * getQueryVectorLength();
}

/**
* Returns the length of query vector based on the query vector data type
* @return length of query vector
*/
private int getQueryVectorLength() {
if (knnQuery.getVectorDataType() == VectorDataType.FLOAT || knnQuery.getVectorDataType() == VectorDataType.BYTE) {
return knnQuery.getQueryVector().length;
}
if (knnQuery.getVectorDataType() == VectorDataType.BINARY) {
return knnQuery.getByteQueryVector().length;
}
throw new IllegalArgumentException(
String.format(Locale.ROOT, "[%s] datatype is not supported for k-NN query vector", knnQuery.getVectorDataType().getValue())
);
}

/**
Expand Down
98 changes: 98 additions & 0 deletions src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.knn.index;

import com.google.common.collect.ImmutableMap;
import lombok.SneakyThrows;
import org.apache.http.util.EntityUtils;
import org.junit.After;
Expand All @@ -17,6 +18,7 @@
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.index.query.MatchAllQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.knn.KNNRestTestCase;
import org.opensearch.knn.KNNResult;
import org.opensearch.knn.common.KNNConstants;
Expand All @@ -25,11 +27,13 @@
import org.opensearch.script.Script;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.stream.Collectors;

import static org.opensearch.knn.common.KNNConstants.DIMENSION;
import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ;
Expand Down Expand Up @@ -62,6 +66,7 @@ public class VectorDataTypeIT extends KNNRestTestCase {
private static final String KNN_VECTOR_TYPE = "knn_vector";
private static final int EF_CONSTRUCTION = 128;
private static final int M = 16;
private static final String COLOR_FIELD_NAME = "color";
private static final QueryBuilder MATCH_ALL_QUERY_BUILDER = new MatchAllQueryBuilder();

@After
Expand Down Expand Up @@ -666,6 +671,99 @@ public void testIVFByteVector_whenIndexedAndQueried_thenSucceed() {
deleteModel(modelId);
}

@SneakyThrows
public void testQueryWithFilterFaissByteVector_withDifferentCombination_thenSuccess() {
setupKNNFaissByteIndexForFilterQuery();
final Byte[] searchVector = { 6, 6, 4 };
// K > filteredResults
int kGreaterThanFilterResult = 5;
List<String> expectedDocIds = Arrays.asList("1", "3");
final Response response = searchKNNIndex(
INDEX_NAME,
new KNNQueryBuilder(
FIELD_NAME,
convertByteToFloatArray(searchVector),
kGreaterThanFilterResult,
QueryBuilders.termQuery(COLOR_FIELD_NAME, "red")
),
kGreaterThanFilterResult
);
final String responseBody = EntityUtils.toString(response.getEntity());
final List<KNNResult> knnResults = parseSearchResponse(responseBody, FIELD_NAME);

assertEquals(expectedDocIds.size(), knnResults.size());
assertTrue(knnResults.stream().map(KNNResult::getDocId).collect(Collectors.toList()).containsAll(expectedDocIds));

// K Limits Filter results
int kLimitsFilterResult = 1;
List<String> expectedDocIdsKLimitsFilterResult = List.of("1");
final Response responseKLimitsFilterResult = searchKNNIndex(
INDEX_NAME,
new KNNQueryBuilder(
FIELD_NAME,
convertByteToFloatArray(searchVector),
kLimitsFilterResult,
QueryBuilders.termQuery(COLOR_FIELD_NAME, "red")
),
kLimitsFilterResult
);
final String responseBodyKLimitsFilterResult = EntityUtils.toString(responseKLimitsFilterResult.getEntity());
final List<KNNResult> knnResultsKLimitsFilterResult = parseSearchResponse(responseBodyKLimitsFilterResult, FIELD_NAME);

assertEquals(expectedDocIdsKLimitsFilterResult.size(), knnResultsKLimitsFilterResult.size());
assertTrue(
knnResultsKLimitsFilterResult.stream()
.map(KNNResult::getDocId)
.collect(Collectors.toList())
.containsAll(expectedDocIdsKLimitsFilterResult)
);

// Empty filter docIds
int k = 10;
final Response emptyFilterResponse = searchKNNIndex(
INDEX_NAME,
new KNNQueryBuilder(
FIELD_NAME,
convertByteToFloatArray(searchVector),
kLimitsFilterResult,
QueryBuilders.termQuery(COLOR_FIELD_NAME, "color_not_present")
),
k
);
final String responseBodyForEmptyDocIds = EntityUtils.toString(emptyFilterResponse.getEntity());
final List<KNNResult> emptyKNNFilteredResultsFromResponse = parseSearchResponse(responseBodyForEmptyDocIds, FIELD_NAME);

assertEquals(0, emptyKNNFilteredResultsFromResponse.size());
}

protected void setupKNNFaissByteIndexForFilterQuery() throws Exception {
// Create Mappings
XContentBuilder builder = XContentFactory.jsonBuilder()
.startObject()
.startObject(PROPERTIES_FIELD)
.startObject(FIELD_NAME)
.field("type", "knn_vector")
.field("dimension", 3)
.field(VECTOR_DATA_TYPE_FIELD, VectorDataType.BYTE.getValue())
.startObject(KNN_METHOD)
.field(NAME, METHOD_HNSW)
.field(METHOD_PARAMETER_SPACE_TYPE, SpaceType.L2)
.field(KNN_ENGINE, KNNEngine.FAISS.getName())
.endObject()
.endObject()
.endObject()
.endObject();
final String mapping = builder.toString();

createKnnIndex(INDEX_NAME, getKNNDefaultIndexSettings(), mapping);

addKnnDocWithAttributes(INDEX_NAME, "1", FIELD_NAME, new Byte[] { 6, 7, 3 }, ImmutableMap.of(COLOR_FIELD_NAME, "red"));
addKnnDocWithAttributes(INDEX_NAME, "2", FIELD_NAME, new Byte[] { 3, 2, 4 }, ImmutableMap.of(COLOR_FIELD_NAME, "green"));
addKnnDocWithAttributes(INDEX_NAME, "3", FIELD_NAME, new Byte[] { 4, 5, 7 }, ImmutableMap.of(COLOR_FIELD_NAME, "red"));

refreshIndex(INDEX_NAME);
}

@SneakyThrows
private void ingestL2ByteTestData() {
Byte[] b1 = { 6, 6 };
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1910,11 +1910,11 @@ protected void addKnnDocWithAttributes(String docId, float[] vector, Map<String,
assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));
}

protected void addKnnDocWithAttributes(
protected <T> void addKnnDocWithAttributes(
String indexName,
String docId,
String vectorFieldName,
float[] vector,
T vector,
Map<String, String> fieldValues
) throws IOException {
Request request = new Request("POST", "/" + indexName + "/_doc/" + docId + "?refresh=true");
Expand Down
Loading