diff --git a/CHANGELOG.md b/CHANGELOG.md index b75fa82d1..b6099524e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -44,6 +44,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * Fixing it to retrieve space_type from index setting when both method and top level don't have the value. [#2374](https://github.com/opensearch-project/k-NN/pull/2374) * Fixing the bug where setting rescore as false for on_disk knn_vector query is a no-op (#2399)[https://github.com/opensearch-project/k-NN/pull/2399] * Fixing the bug to prevent index.knn setting from being modified or removed on restore snapshot (#2445)[https://github.com/opensearch-project/k-NN/pull/2445] +* Fix Faiss byte vector efficient filter bug (#2448)[https://github.com/opensearch-project/k-NN/pull/2448] ### Infrastructure * Updated C++ version in JNI from c++11 to c++17 [#2259](https://github.com/opensearch-project/k-NN/pull/2259) * Upgrade bytebuddy and objenesis version to match OpenSearch core and, update github ci runner for macos [#2279](https://github.com/opensearch-project/k-NN/pull/2279) diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index 37b5cc9ad..d5912d758 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -45,6 +45,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.concurrent.ExecutionException; import java.util.stream.Collectors; @@ -438,9 +439,23 @@ private boolean isFilteredExactSearchPreferred(final int filterIdsCount) { * TODO we can have a different MAX_DISTANCE_COMPUTATIONS for binary index as computation cost for binary index * is cheaper than computation cost for non binary vector */ - return KNNConstants.MAX_DISTANCE_COMPUTATIONS >= filterIdsCount * (knnQuery.getVectorDataType() == VectorDataType.FLOAT - ? knnQuery.getQueryVector().length - : knnQuery.getByteQueryVector().length); + return KNNConstants.MAX_DISTANCE_COMPUTATIONS >= filterIdsCount * getQueryVectorLength(); + } + + /** + * Returns the length of query vector based on the query vector data type + * @return length of query vector + */ + private int getQueryVectorLength() { + if (knnQuery.getVectorDataType() == VectorDataType.FLOAT || knnQuery.getVectorDataType() == VectorDataType.BYTE) { + return knnQuery.getQueryVector().length; + } + if (knnQuery.getVectorDataType() == VectorDataType.BINARY) { + return knnQuery.getByteQueryVector().length; + } + throw new IllegalArgumentException( + String.format(Locale.ROOT, "[%s] datatype is not supported for k-NN query vector", knnQuery.getVectorDataType().getValue()) + ); } /** diff --git a/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java b/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java index 07b2be40d..b99559158 100644 --- a/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java +++ b/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java @@ -5,6 +5,7 @@ package org.opensearch.knn.index; +import com.google.common.collect.ImmutableMap; import lombok.SneakyThrows; import org.apache.http.util.EntityUtils; import org.junit.After; @@ -17,6 +18,7 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; import org.opensearch.knn.KNNRestTestCase; import org.opensearch.knn.KNNResult; import org.opensearch.knn.common.KNNConstants; @@ -25,11 +27,13 @@ import org.opensearch.script.Script; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.stream.Collectors; import static org.opensearch.knn.common.KNNConstants.DIMENSION; import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ; @@ -62,6 +66,7 @@ public class VectorDataTypeIT extends KNNRestTestCase { private static final String KNN_VECTOR_TYPE = "knn_vector"; private static final int EF_CONSTRUCTION = 128; private static final int M = 16; + private static final String COLOR_FIELD_NAME = "color"; private static final QueryBuilder MATCH_ALL_QUERY_BUILDER = new MatchAllQueryBuilder(); @After @@ -666,6 +671,99 @@ public void testIVFByteVector_whenIndexedAndQueried_thenSucceed() { deleteModel(modelId); } + @SneakyThrows + public void testQueryWithFilterFaissByteVector_withDifferentCombination_thenSuccess() { + setupKNNFaissByteIndexForFilterQuery(); + final Byte[] searchVector = { 6, 6, 4 }; + // K > filteredResults + int kGreaterThanFilterResult = 5; + List expectedDocIds = Arrays.asList("1", "3"); + final Response response = searchKNNIndex( + INDEX_NAME, + new KNNQueryBuilder( + FIELD_NAME, + convertByteToFloatArray(searchVector), + kGreaterThanFilterResult, + QueryBuilders.termQuery(COLOR_FIELD_NAME, "red") + ), + kGreaterThanFilterResult + ); + final String responseBody = EntityUtils.toString(response.getEntity()); + final List knnResults = parseSearchResponse(responseBody, FIELD_NAME); + + assertEquals(expectedDocIds.size(), knnResults.size()); + assertTrue(knnResults.stream().map(KNNResult::getDocId).collect(Collectors.toList()).containsAll(expectedDocIds)); + + // K Limits Filter results + int kLimitsFilterResult = 1; + List expectedDocIdsKLimitsFilterResult = List.of("1"); + final Response responseKLimitsFilterResult = searchKNNIndex( + INDEX_NAME, + new KNNQueryBuilder( + FIELD_NAME, + convertByteToFloatArray(searchVector), + kLimitsFilterResult, + QueryBuilders.termQuery(COLOR_FIELD_NAME, "red") + ), + kLimitsFilterResult + ); + final String responseBodyKLimitsFilterResult = EntityUtils.toString(responseKLimitsFilterResult.getEntity()); + final List knnResultsKLimitsFilterResult = parseSearchResponse(responseBodyKLimitsFilterResult, FIELD_NAME); + + assertEquals(expectedDocIdsKLimitsFilterResult.size(), knnResultsKLimitsFilterResult.size()); + assertTrue( + knnResultsKLimitsFilterResult.stream() + .map(KNNResult::getDocId) + .collect(Collectors.toList()) + .containsAll(expectedDocIdsKLimitsFilterResult) + ); + + // Empty filter docIds + int k = 10; + final Response emptyFilterResponse = searchKNNIndex( + INDEX_NAME, + new KNNQueryBuilder( + FIELD_NAME, + convertByteToFloatArray(searchVector), + kLimitsFilterResult, + QueryBuilders.termQuery(COLOR_FIELD_NAME, "color_not_present") + ), + k + ); + final String responseBodyForEmptyDocIds = EntityUtils.toString(emptyFilterResponse.getEntity()); + final List emptyKNNFilteredResultsFromResponse = parseSearchResponse(responseBodyForEmptyDocIds, FIELD_NAME); + + assertEquals(0, emptyKNNFilteredResultsFromResponse.size()); + } + + protected void setupKNNFaissByteIndexForFilterQuery() throws Exception { + // Create Mappings + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject(PROPERTIES_FIELD) + .startObject(FIELD_NAME) + .field("type", "knn_vector") + .field("dimension", 3) + .field(VECTOR_DATA_TYPE_FIELD, VectorDataType.BYTE.getValue()) + .startObject(KNN_METHOD) + .field(NAME, METHOD_HNSW) + .field(METHOD_PARAMETER_SPACE_TYPE, SpaceType.L2) + .field(KNN_ENGINE, KNNEngine.FAISS.getName()) + .endObject() + .endObject() + .endObject() + .endObject(); + final String mapping = builder.toString(); + + createKnnIndex(INDEX_NAME, getKNNDefaultIndexSettings(), mapping); + + addKnnDocWithAttributes(INDEX_NAME, "1", FIELD_NAME, new Byte[] { 6, 7, 3 }, ImmutableMap.of(COLOR_FIELD_NAME, "red")); + addKnnDocWithAttributes(INDEX_NAME, "2", FIELD_NAME, new Byte[] { 3, 2, 4 }, ImmutableMap.of(COLOR_FIELD_NAME, "green")); + addKnnDocWithAttributes(INDEX_NAME, "3", FIELD_NAME, new Byte[] { 4, 5, 7 }, ImmutableMap.of(COLOR_FIELD_NAME, "red")); + + refreshIndex(INDEX_NAME); + } + @SneakyThrows private void ingestL2ByteTestData() { Byte[] b1 = { 6, 6 }; diff --git a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java index 572ff4c5e..381b368c0 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java @@ -1910,11 +1910,11 @@ protected void addKnnDocWithAttributes(String docId, float[] vector, Map void addKnnDocWithAttributes( String indexName, String docId, String vectorFieldName, - float[] vector, + T vector, Map fieldValues ) throws IOException { Request request = new Request("POST", "/" + indexName + "/_doc/" + docId + "?refresh=true");