Skip to content

Commit

Permalink
Fix Faiss byte vector efficient filter bug
Browse files Browse the repository at this point in the history
Signed-off-by: Naveen Tatikonda <[email protected]>
  • Loading branch information
naveentatikonda committed Jan 27, 2025
1 parent 67cc345 commit cf68ce8
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Fixing it to retrieve space_type from index setting when both method and top level don't have the value. [#2374](https://github.com/opensearch-project/k-NN/pull/2374)
* Fixing the bug where setting rescore as false for on_disk knn_vector query is a no-op (#2399)[https://github.com/opensearch-project/k-NN/pull/2399]
* Fixing the bug to prevent index.knn setting from being modified or removed on restore snapshot (#2445)[https://github.com/opensearch-project/k-NN/pull/2445]
* Fix Faiss byte vector efficient filter bug (#2448)[https://github.com/opensearch-project/k-NN/pull/2448]
### Infrastructure
* Updated C++ version in JNI from c++11 to c++17 [#2259](https://github.com/opensearch-project/k-NN/pull/2259)
* Upgrade bytebuddy and objenesis version to match OpenSearch core and, update github ci runner for macos [#2279](https://github.com/opensearch-project/k-NN/pull/2279)
Expand Down
21 changes: 18 additions & 3 deletions src/main/java/org/opensearch/knn/index/query/KNNWeight.java
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -438,9 +439,23 @@ private boolean isFilteredExactSearchPreferred(final int filterIdsCount) {
* TODO we can have a different MAX_DISTANCE_COMPUTATIONS for binary index as computation cost for binary index
* is cheaper than computation cost for non binary vector
*/
return KNNConstants.MAX_DISTANCE_COMPUTATIONS >= filterIdsCount * (knnQuery.getVectorDataType() == VectorDataType.FLOAT
? knnQuery.getQueryVector().length
: knnQuery.getByteQueryVector().length);
return KNNConstants.MAX_DISTANCE_COMPUTATIONS >= filterIdsCount * getQueryVectorLength();
}

/**
* Returns the length of query vector based on the query vector data type
* @return length of query vector
*/
private int getQueryVectorLength() {
if (knnQuery.getVectorDataType() == VectorDataType.FLOAT || knnQuery.getVectorDataType() == VectorDataType.BYTE) {
return knnQuery.getQueryVector().length;
}
if (knnQuery.getVectorDataType() == VectorDataType.BINARY) {
return knnQuery.getByteQueryVector().length;
}
throw new IllegalArgumentException(
String.format(Locale.ROOT, "[%s] datatype is not supported for k-NN query vector", knnQuery.getVectorDataType().getValue())
);
}

/**
Expand Down
98 changes: 98 additions & 0 deletions src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.knn.index;

import com.google.common.collect.ImmutableMap;
import lombok.SneakyThrows;
import org.apache.http.util.EntityUtils;
import org.junit.After;
Expand All @@ -17,6 +18,7 @@
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.index.query.MatchAllQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.knn.KNNRestTestCase;
import org.opensearch.knn.KNNResult;
import org.opensearch.knn.common.KNNConstants;
Expand All @@ -25,11 +27,13 @@
import org.opensearch.script.Script;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.stream.Collectors;

import static org.opensearch.knn.common.KNNConstants.DIMENSION;
import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ;
Expand Down Expand Up @@ -62,6 +66,7 @@ public class VectorDataTypeIT extends KNNRestTestCase {
private static final String KNN_VECTOR_TYPE = "knn_vector";
private static final int EF_CONSTRUCTION = 128;
private static final int M = 16;
private static final String COLOR_FIELD_NAME = "color";
private static final QueryBuilder MATCH_ALL_QUERY_BUILDER = new MatchAllQueryBuilder();

@After
Expand Down Expand Up @@ -666,6 +671,99 @@ public void testIVFByteVector_whenIndexedAndQueried_thenSucceed() {
deleteModel(modelId);
}

@SneakyThrows
public void testQueryWithFilterFaissByteVector_withDifferentCombination_thenSuccess() {
setupKNNFaissByteIndexForFilterQuery();
final Byte[] searchVector = { 6, 6, 4 };
// K > filteredResults
int kGreaterThanFilterResult = 5;
List<String> expectedDocIds = Arrays.asList("1", "3");
final Response response = searchKNNIndex(
INDEX_NAME,
new KNNQueryBuilder(
FIELD_NAME,
convertByteToFloatArray(searchVector),
kGreaterThanFilterResult,
QueryBuilders.termQuery(COLOR_FIELD_NAME, "red")
),
kGreaterThanFilterResult
);
final String responseBody = EntityUtils.toString(response.getEntity());
final List<KNNResult> knnResults = parseSearchResponse(responseBody, FIELD_NAME);

assertEquals(expectedDocIds.size(), knnResults.size());
assertTrue(knnResults.stream().map(KNNResult::getDocId).collect(Collectors.toList()).containsAll(expectedDocIds));

// K Limits Filter results
int kLimitsFilterResult = 1;
List<String> expectedDocIdsKLimitsFilterResult = List.of("1");
final Response responseKLimitsFilterResult = searchKNNIndex(
INDEX_NAME,
new KNNQueryBuilder(
FIELD_NAME,
convertByteToFloatArray(searchVector),
kLimitsFilterResult,
QueryBuilders.termQuery(COLOR_FIELD_NAME, "red")
),
kLimitsFilterResult
);
final String responseBodyKLimitsFilterResult = EntityUtils.toString(responseKLimitsFilterResult.getEntity());
final List<KNNResult> knnResultsKLimitsFilterResult = parseSearchResponse(responseBodyKLimitsFilterResult, FIELD_NAME);

assertEquals(expectedDocIdsKLimitsFilterResult.size(), knnResultsKLimitsFilterResult.size());
assertTrue(
knnResultsKLimitsFilterResult.stream()
.map(KNNResult::getDocId)
.collect(Collectors.toList())
.containsAll(expectedDocIdsKLimitsFilterResult)
);

// Empty filter docIds
int k = 10;
final Response emptyFilterResponse = searchKNNIndex(
INDEX_NAME,
new KNNQueryBuilder(
FIELD_NAME,
convertByteToFloatArray(searchVector),
kLimitsFilterResult,
QueryBuilders.termQuery(COLOR_FIELD_NAME, "color_not_present")
),
k
);
final String responseBodyForEmptyDocIds = EntityUtils.toString(emptyFilterResponse.getEntity());
final List<KNNResult> emptyKNNFilteredResultsFromResponse = parseSearchResponse(responseBodyForEmptyDocIds, FIELD_NAME);

assertEquals(0, emptyKNNFilteredResultsFromResponse.size());
}

protected void setupKNNFaissByteIndexForFilterQuery() throws Exception {
// Create Mappings
XContentBuilder builder = XContentFactory.jsonBuilder()
.startObject()
.startObject(PROPERTIES_FIELD)
.startObject(FIELD_NAME)
.field("type", "knn_vector")
.field("dimension", 3)
.field(VECTOR_DATA_TYPE_FIELD, VectorDataType.BYTE.getValue())
.startObject(KNN_METHOD)
.field(NAME, METHOD_HNSW)
.field(METHOD_PARAMETER_SPACE_TYPE, SpaceType.L2)
.field(KNN_ENGINE, KNNEngine.FAISS.getName())
.endObject()
.endObject()
.endObject()
.endObject();
final String mapping = builder.toString();

createKnnIndex(INDEX_NAME, getKNNDefaultIndexSettings(), mapping);

addKnnDocWithAttributes(INDEX_NAME, "1", FIELD_NAME, new Byte[] { 6, 7, 3 }, ImmutableMap.of(COLOR_FIELD_NAME, "red"));
addKnnDocWithAttributes(INDEX_NAME, "2", FIELD_NAME, new Byte[] { 3, 2, 4 }, ImmutableMap.of(COLOR_FIELD_NAME, "green"));
addKnnDocWithAttributes(INDEX_NAME, "3", FIELD_NAME, new Byte[] { 4, 5, 7 }, ImmutableMap.of(COLOR_FIELD_NAME, "red"));

refreshIndex(INDEX_NAME);
}

@SneakyThrows
private void ingestL2ByteTestData() {
Byte[] b1 = { 6, 6 };
Expand Down
4 changes: 2 additions & 2 deletions src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java
Original file line number Diff line number Diff line change
Expand Up @@ -1910,11 +1910,11 @@ protected void addKnnDocWithAttributes(String docId, float[] vector, Map<String,
assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));
}

protected void addKnnDocWithAttributes(
protected <T> void addKnnDocWithAttributes(
String indexName,
String docId,
String vectorFieldName,
float[] vector,
T vector,
Map<String, String> fieldValues
) throws IOException {
Request request = new Request("POST", "/" + indexName + "/_doc/" + docId + "?refresh=true");
Expand Down

0 comments on commit cf68ce8

Please sign in to comment.