From 20039d0ad5ee99f8e21ae2042cf1b97ece80f6bd Mon Sep 17 00:00:00 2001 From: Tejas Shah Date: Fri, 9 Aug 2024 12:48:38 -0700 Subject: [PATCH] Introduces NativeEngineKNNQuery which executes ANN on rewrite (#1877) (#1943) Signed-off-by: Tejas Shah (cherry picked from commit df7627c6e580843beb2361f5c2ec3519efd52280) --- CHANGELOG.md | 1 + .../common/featureflags/KNNFeatureFlags.java | 46 +++++ .../org/opensearch/knn/index/KNNSettings.java | 20 +- .../knn/index/query/KNNQueryFactory.java | 9 +- .../opensearch/knn/index/query/KNNScorer.java | 7 + .../opensearch/knn/index/query/KNNWeight.java | 24 ++- .../query/nativelib/DocAndScoreQuery.java | 185 +++++++++++++++++ .../nativelib/NativeEngineKnnVectorQuery.java | 140 +++++++++++++ .../featureflags/KNNFeatureFlagsTests.java | 34 ++++ .../knn/index/query/KNNQueryBuilderTests.java | 14 ++ .../knn/index/query/KNNQueryFactoryTests.java | 34 ++++ .../knn/index/query/KNNWeightTests.java | 4 +- .../nativelib/DocAndScoreQueryTests.java | 99 +++++++++ .../NativeEngineKNNVectorQueryIT.java | 190 ++++++++++++++++++ .../NativeEngineKNNVectorQueryTests.java | 156 ++++++++++++++ 15 files changed, 950 insertions(+), 13 deletions(-) create mode 100644 src/main/java/org/opensearch/knn/common/featureflags/KNNFeatureFlags.java create mode 100644 src/main/java/org/opensearch/knn/index/query/nativelib/DocAndScoreQuery.java create mode 100644 src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java create mode 100644 src/test/java/org/opensearch/knn/common/featureflags/KNNFeatureFlagsTests.java create mode 100644 src/test/java/org/opensearch/knn/index/query/nativelib/DocAndScoreQueryTests.java create mode 100644 src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryIT.java create mode 100644 src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java diff --git a/CHANGELOG.md b/CHANGELOG.md index f9d715823..e10f3e065 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,3 +29,4 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * Refactor method structure and definitions [#1920](https://github.com/opensearch-project/k-NN/pull/1920) * Refactor KNNVectorFieldType from KNNVectorFieldMapper to a separate class for better readability. [#1931](https://github.com/opensearch-project/k-NN/pull/1931) * Generalize lib interface to return context objects [#1925](https://github.com/opensearch-project/k-NN/pull/1925) +* Move k search k-NN query to re-write phase of vector search query for Native Engines [#1877](https://github.com/opensearch-project/k-NN/pull/1877) \ No newline at end of file diff --git a/src/main/java/org/opensearch/knn/common/featureflags/KNNFeatureFlags.java b/src/main/java/org/opensearch/knn/common/featureflags/KNNFeatureFlags.java new file mode 100644 index 000000000..21160fc2d --- /dev/null +++ b/src/main/java/org/opensearch/knn/common/featureflags/KNNFeatureFlags.java @@ -0,0 +1,46 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.opensearch.knn.common.featureflags; + +import com.google.common.annotations.VisibleForTesting; +import lombok.experimental.UtilityClass; +import org.opensearch.common.settings.Setting; +import org.opensearch.knn.index.KNNSettings; + +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static org.opensearch.common.settings.Setting.Property.Dynamic; +import static org.opensearch.common.settings.Setting.Property.NodeScope; + +/** + * Class to manage KNN feature flags + */ +@UtilityClass +public class KNNFeatureFlags { + + // Feature flags + private static final String KNN_LAUNCH_QUERY_REWRITE_ENABLED = "knn.feature.query.rewrite.enabled"; + private static final boolean KNN_LAUNCH_QUERY_REWRITE_ENABLED_DEFAULT = true; + + @VisibleForTesting + public static final Setting KNN_LAUNCH_QUERY_REWRITE_ENABLED_SETTING = Setting.boolSetting( + KNN_LAUNCH_QUERY_REWRITE_ENABLED, + KNN_LAUNCH_QUERY_REWRITE_ENABLED_DEFAULT, + NodeScope, + Dynamic + ); + + public static List> getFeatureFlags() { + return Stream.of(KNN_LAUNCH_QUERY_REWRITE_ENABLED_SETTING).collect(Collectors.toUnmodifiableList()); + } + + public static boolean isKnnQueryRewriteEnabled() { + return Boolean.parseBoolean(KNNSettings.state().getSettingValue(KNN_LAUNCH_QUERY_REWRITE_ENABLED).toString()); + } +} diff --git a/src/main/java/org/opensearch/knn/index/KNNSettings.java b/src/main/java/org/opensearch/knn/index/KNNSettings.java index 3279e74bc..33c7ff410 100644 --- a/src/main/java/org/opensearch/knn/index/KNNSettings.java +++ b/src/main/java/org/opensearch/knn/index/KNNSettings.java @@ -9,15 +9,15 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.OpenSearchParseException; -import org.opensearch.cluster.metadata.IndexMetadata; -import org.opensearch.core.action.ActionListener; import org.opensearch.action.admin.cluster.settings.ClusterUpdateSettingsRequest; import org.opensearch.action.admin.cluster.settings.ClusterUpdateSettingsResponse; import org.opensearch.client.Client; +import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.unit.ByteSizeUnit; import org.opensearch.core.common.unit.ByteSizeValue; import org.opensearch.index.IndexModule; @@ -28,20 +28,22 @@ import org.opensearch.monitor.os.OsProbe; import java.security.InvalidParameterException; -import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.Stream; +import static java.util.stream.Collectors.toUnmodifiableMap; import static org.opensearch.common.settings.Setting.Property.Dynamic; import static org.opensearch.common.settings.Setting.Property.IndexScope; import static org.opensearch.common.settings.Setting.Property.NodeScope; import static org.opensearch.common.unit.MemorySizeValue.parseBytesSizeValueOrHeapRatio; import static org.opensearch.core.common.unit.ByteSizeValue.parseBytesSizeValue; +import static org.opensearch.knn.common.featureflags.KNNFeatureFlags.getFeatureFlags; /** * This class defines @@ -289,6 +291,9 @@ public class KNNSettings { } }; + private final static Map> FEATURE_FLAGS = getFeatureFlags().stream() + .collect(toUnmodifiableMap(Setting::getKey, Function.identity())); + private ClusterService clusterService; private Client client; @@ -326,7 +331,7 @@ private void setSettingsUpdateConsumers() { ); NativeMemoryCacheManager.getInstance().rebuildCache(builder.build()); - }, new ArrayList<>(dynamicCacheSettings.values())); + }, Stream.concat(dynamicCacheSettings.values().stream(), FEATURE_FLAGS.values().stream()).collect(Collectors.toUnmodifiableList())); } /** @@ -346,6 +351,10 @@ private Setting getSetting(String key) { return dynamicCacheSettings.get(key); } + if (FEATURE_FLAGS.containsKey(key)) { + return FEATURE_FLAGS.get(key); + } + if (KNN_CIRCUIT_BREAKER_TRIGGERED.equals(key)) { return KNN_CIRCUIT_BREAKER_TRIGGERED_SETTING; } @@ -390,7 +399,8 @@ public List> getSettings() { KNN_FAISS_AVX2_DISABLED_SETTING, KNN_VECTOR_STREAMING_MEMORY_LIMIT_PCT_SETTING ); - return Stream.concat(settings.stream(), dynamicCacheSettings.values().stream()).collect(Collectors.toList()); + return Stream.concat(settings.stream(), Stream.concat(getFeatureFlags().stream(), dynamicCacheSettings.values().stream())) + .collect(Collectors.toList()); } public static boolean isKNNPluginEnabled() { diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java index ee9a12a41..f3161b2db 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java @@ -16,12 +16,14 @@ import org.opensearch.index.query.QueryShardContext; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.query.nativelib.NativeEngineKnnVectorQuery; import java.util.Locale; import java.util.Map; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH; import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; +import static org.opensearch.knn.common.featureflags.KNNFeatureFlags.isKnnQueryRewriteEnabled; import static org.opensearch.knn.index.VectorDataType.SUPPORTED_VECTOR_DATA_TYPES; /** @@ -98,9 +100,10 @@ public static Query create(CreateQueryRequest createQueryRequest) { methodParameters ); + KNNQuery knnQuery = null; switch (vectorDataType) { case BINARY: - return KNNQuery.builder() + knnQuery = KNNQuery.builder() .field(fieldName) .byteQueryVector(byteVector) .indexName(indexName) @@ -110,8 +113,9 @@ public static Query create(CreateQueryRequest createQueryRequest) { .filterQuery(validatedFilterQuery) .vectorDataType(vectorDataType) .build(); + break; default: - return KNNQuery.builder() + knnQuery = KNNQuery.builder() .field(fieldName) .queryVector(vector) .indexName(indexName) @@ -122,6 +126,7 @@ public static Query create(CreateQueryRequest createQueryRequest) { .vectorDataType(vectorDataType) .build(); } + return isKnnQueryRewriteEnabled() ? new NativeEngineKnnVectorQuery(knnQuery) : knnQuery; } Integer requestEfSearch = null; diff --git a/src/main/java/org/opensearch/knn/index/query/KNNScorer.java b/src/main/java/org/opensearch/knn/index/query/KNNScorer.java index 02dc86e80..99962d307 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNScorer.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNScorer.java @@ -87,6 +87,13 @@ public float score() throws IOException { public int docID() { return docIdsIter.docID(); } + + @Override + public boolean equals(Object obj) { + if (!(obj instanceof Scorer)) return false; + return getWeight().equals(((Scorer) obj).getWeight()); + } }; + } } diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index f54d8328e..f88652525 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -108,6 +108,22 @@ public Explanation explain(LeafReaderContext context, int doc) { @Override public Scorer scorer(LeafReaderContext context) throws IOException { + final Map docIdToScoreMap = searchLeaf(context); + if (docIdToScoreMap.isEmpty()) { + return KNNScorer.emptyScorer(this); + } + + return convertSearchResponseToScorer(docIdToScoreMap); + } + + /** + * Executes k nearest neighbor search for a segment to get the top K results + * This is made public purely to be able to be reused in {@link org.opensearch.knn.index.query.nativelib.NativeEngineKnnVectorQuery} + * + * @param context LeafReaderContext + * @return A Map of docId to scores for top k results + */ + public Map searchLeaf(LeafReaderContext context) throws IOException { final BitSet filterBitSet = getFilteredDocsBitSet(context); int cardinality = filterBitSet.cardinality(); @@ -115,7 +131,7 @@ public Scorer scorer(LeafReaderContext context) throws IOException { // We should give this condition a deeper look that where it should be placed. For now I feel this is a good // place, if (filterWeight != null && cardinality == 0) { - return KNNScorer.emptyScorer(this); + return Collections.emptyMap(); } final Map docIdsToScoreMap = new HashMap<>(); @@ -129,7 +145,7 @@ public Scorer scorer(LeafReaderContext context) throws IOException { } else { Map annResults = doANNSearch(context, filterBitSet, cardinality); if (annResults == null) { - return null; + return Collections.emptyMap(); } if (canDoExactSearchAfterANNSearch(cardinality, annResults.size())) { log.debug( @@ -144,9 +160,9 @@ public Scorer scorer(LeafReaderContext context) throws IOException { docIdsToScoreMap.putAll(annResults); } if (docIdsToScoreMap.isEmpty()) { - return KNNScorer.emptyScorer(this); + return Collections.emptyMap(); } - return convertSearchResponseToScorer(docIdsToScoreMap); + return docIdsToScoreMap; } private BitSet getFilteredDocsBitSet(final LeafReaderContext ctx) throws IOException { diff --git a/src/main/java/org/opensearch/knn/index/query/nativelib/DocAndScoreQuery.java b/src/main/java/org/opensearch/knn/index/query/nativelib/DocAndScoreQuery.java new file mode 100644 index 000000000..f1a91d878 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/query/nativelib/DocAndScoreQuery.java @@ -0,0 +1,185 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query.nativelib; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.Weight; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Objects; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; + +/** + * This is the same as {@link org.apache.lucene.search.AbstractKnnVectorQuery.DocAndScoreQuery} + */ +final class DocAndScoreQuery extends Query { + + private final int k; + private final int[] docs; + private final float[] scores; + private final int[] segmentStarts; + private final Object contextIdentity; + + DocAndScoreQuery(int k, int[] docs, float[] scores, int[] segmentStarts, Object contextIdentity) { + this.k = k; + this.docs = docs; + this.scores = scores; + this.segmentStarts = segmentStarts; + this.contextIdentity = contextIdentity; + } + + @Override + public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) { + if (searcher.getIndexReader().getContext().id() != contextIdentity) { + throw new IllegalStateException("This DocAndScore query was created by a different reader"); + } + return new Weight(this) { + @Override + public Explanation explain(LeafReaderContext context, int doc) { + int found = Arrays.binarySearch(docs, doc + context.docBase); + if (found < 0) { + return Explanation.noMatch("not in top " + k); + } + return Explanation.match(scores[found] * boost, "within top " + k); + } + + @Override + public int count(LeafReaderContext context) { + return segmentStarts[context.ord + 1] - segmentStarts[context.ord]; + } + + @Override + public Scorer scorer(LeafReaderContext context) { + if (segmentStarts[context.ord] == segmentStarts[context.ord + 1]) { + return null; + } + return new Scorer(this) { + final int lower = segmentStarts[context.ord]; + final int upper = segmentStarts[context.ord + 1]; + int upTo = -1; + + @Override + public DocIdSetIterator iterator() { + return new DocIdSetIterator() { + @Override + public int docID() { + return docIdNoShadow(); + } + + @Override + public int nextDoc() { + if (upTo == -1) { + upTo = lower; + } else { + ++upTo; + } + return docIdNoShadow(); + } + + @Override + public int advance(int target) throws IOException { + return slowAdvance(target); + } + + @Override + public long cost() { + return upper - lower; + } + }; + } + + @Override + public float getMaxScore(int docId) { + docId += context.docBase; + float maxScore = 0; + for (int idx = Math.max(0, upTo); idx < upper && docs[idx] <= docId; idx++) { + maxScore = Math.max(maxScore, scores[idx]); + } + return maxScore * boost; + } + + @Override + public float score() { + return scores[upTo] * boost; + } + + @Override + public int advanceShallow(int docid) { + int start = Math.max(upTo, lower); + int docidIndex = Arrays.binarySearch(docs, start, upper, docid + context.docBase); + if (docidIndex < 0) { + docidIndex = -1 - docidIndex; + } + if (docidIndex >= upper) { + return NO_MORE_DOCS; + } + return docs[docidIndex]; + } + + /** + * move the implementation of docID() into a differently-named method so we can call it + * from DocIDSetIterator.docID() even though this class is anonymous + * + * @return the current docid + */ + private int docIdNoShadow() { + if (upTo == -1) { + return -1; + } + if (upTo >= upper) { + return NO_MORE_DOCS; + } + return docs[upTo] - context.docBase; + } + + @Override + public int docID() { + return docIdNoShadow(); + } + }; + } + + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return true; + } + }; + } + + @Override + public String toString(String field) { + return "DocAndScore[" + k + "][docs:" + Arrays.toString(docs) + ", scores:" + Arrays.toString(scores) + "]"; + } + + @Override + public void visit(QueryVisitor visitor) { + visitor.visitLeaf(this); + } + + @Override + public boolean equals(Object obj) { + if (!sameClassAs(obj)) { + return false; + } + return contextIdentity == ((DocAndScoreQuery) obj).contextIdentity + && Arrays.equals(docs, ((DocAndScoreQuery) obj).docs) + && Arrays.equals(scores, ((DocAndScoreQuery) obj).scores); + } + + @Override + public int hashCode() { + return Objects.hash(classHash(), contextIdentity, Arrays.hashCode(docs), Arrays.hashCode(scores)); + } +} diff --git a/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java b/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java new file mode 100644 index 000000000..6b9a40a9c --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java @@ -0,0 +1,140 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query.nativelib; + +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.MatchNoDocsQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.apache.lucene.util.Bits; +import org.opensearch.knn.index.query.KNNQuery; +import org.opensearch.knn.index.query.KNNWeight; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.Callable; + +/** + * {@link KNNQuery} executes approximate nearest neighbor search (ANN) on a segment level. + * {@link NativeEngineKnnVectorQuery} executes approximate nearest neighbor search but gives + * us the control to combine the top k results in each leaf and post process the results just + * for k-NN query if required. This is done by overriding rewrite method to execute ANN on each leaf + * {@link KNNQuery} does not give the ability to post process segment results. + */ +@Getter +@RequiredArgsConstructor +public class NativeEngineKnnVectorQuery extends Query { + + private final KNNQuery knnQuery; + + @Override + public Query rewrite(final IndexSearcher indexSearcher) throws IOException { + final IndexReader reader = indexSearcher.getIndexReader(); + final KNNWeight knnWeight = (KNNWeight) knnQuery.createWeight(indexSearcher, ScoreMode.COMPLETE, 1); + List leafReaderContexts = reader.leaves(); + + List> tasks = new ArrayList<>(leafReaderContexts.size()); + for (LeafReaderContext leafReaderContext : leafReaderContexts) { + tasks.add(() -> searchLeaf(leafReaderContext, knnWeight)); + } + TopDocs[] perLeafResults = indexSearcher.getTaskExecutor().invokeAll(tasks).toArray(TopDocs[]::new); + // TopDocs.merge requires perLeafResults to be sorted in descending order. + TopDocs topK = TopDocs.merge(knnQuery.getK(), perLeafResults); + if (topK.scoreDocs.length == 0) { + return new MatchNoDocsQuery(); + } + return createRewrittenQuery(reader, topK); + } + + private Query createRewrittenQuery(IndexReader reader, TopDocs topK) { + int len = topK.scoreDocs.length; + Arrays.sort(topK.scoreDocs, Comparator.comparingInt(a -> a.doc)); + int[] docs = new int[len]; + float[] scores = new float[len]; + for (int i = 0; i < len; i++) { + docs[i] = topK.scoreDocs[i].doc; + scores[i] = topK.scoreDocs[i].score; + } + int[] segmentStarts = findSegmentStarts(reader, docs); + return new DocAndScoreQuery(knnQuery.getK(), docs, scores, segmentStarts, reader.getContext().id()); + } + + private static int[] findSegmentStarts(IndexReader reader, int[] docs) { + int[] starts = new int[reader.leaves().size() + 1]; + starts[starts.length - 1] = docs.length; + if (starts.length == 2) { + return starts; + } + int resultIndex = 0; + for (int i = 1; i < starts.length - 1; i++) { + int upper = reader.leaves().get(i).docBase; + resultIndex = Arrays.binarySearch(docs, resultIndex, docs.length, upper); + if (resultIndex < 0) { + resultIndex = -1 - resultIndex; + } + starts[i] = resultIndex; + } + return starts; + } + + private TopDocs searchLeaf(LeafReaderContext ctx, KNNWeight queryWeight) throws IOException { + int totalHits = 0; + final Map leafDocScores = queryWeight.searchLeaf(ctx); + final List scoreDocs = new ArrayList<>(); + final Bits liveDocs = ctx.reader().getLiveDocs(); + + if (!leafDocScores.isEmpty()) { + final List> topScores = new ArrayList<>(leafDocScores.entrySet()); + topScores.sort(Map.Entry.comparingByValue().reversed()); + + for (Map.Entry entry : topScores) { + if (liveDocs == null || liveDocs.get(entry.getKey())) { + ScoreDoc scoreDoc = new ScoreDoc(entry.getKey() + ctx.docBase, entry.getValue()); + scoreDocs.add(scoreDoc); + totalHits++; + } + } + } + + return new TopDocs(new TotalHits(totalHits, TotalHits.Relation.EQUAL_TO), scoreDocs.toArray(ScoreDoc[]::new)); + } + + @Override + public String toString(String field) { + return this.getClass().getSimpleName() + "[" + field + "]..." + KNNQuery.class.getSimpleName() + "[" + knnQuery.toString() + "]"; + } + + @Override + public void visit(QueryVisitor visitor) { + visitor.visitLeaf(this); + } + + @Override + public boolean equals(Object obj) { + if (!sameClassAs(obj)) { + return false; + } + return knnQuery == ((NativeEngineKnnVectorQuery) obj).knnQuery; + } + + @Override + public int hashCode() { + return Objects.hash(classHash(), knnQuery.hashCode()); + } +} diff --git a/src/test/java/org/opensearch/knn/common/featureflags/KNNFeatureFlagsTests.java b/src/test/java/org/opensearch/knn/common/featureflags/KNNFeatureFlagsTests.java new file mode 100644 index 000000000..c3a8a1615 --- /dev/null +++ b/src/test/java/org/opensearch/knn/common/featureflags/KNNFeatureFlagsTests.java @@ -0,0 +1,34 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.common.featureflags; + +import org.mockito.Mock; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.KNNSettings; + +import static org.mockito.Mockito.when; +import static org.opensearch.knn.common.featureflags.KNNFeatureFlags.KNN_LAUNCH_QUERY_REWRITE_ENABLED_SETTING; +import static org.opensearch.knn.common.featureflags.KNNFeatureFlags.isKnnQueryRewriteEnabled; + +public class KNNFeatureFlagsTests extends KNNTestCase { + + @Mock + ClusterSettings clusterSettings; + + public void setUp() throws Exception { + super.setUp(); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + KNNSettings.state().setClusterService(clusterService); + } + + public void testIsFeatureEnabled() throws Exception { + when(clusterSettings.get(KNN_LAUNCH_QUERY_REWRITE_ENABLED_SETTING)).thenReturn(false); + assertFalse(isKnnQueryRewriteEnabled()); + when(clusterSettings.get(KNN_LAUNCH_QUERY_REWRITE_ENABLED_SETTING)).thenReturn(true); + assertTrue(isKnnQueryRewriteEnabled()); + } +} diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java index 0241a9afb..0b918bd9e 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java @@ -11,10 +11,12 @@ import org.apache.lucene.search.KnnFloatVectorQuery; import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.lucene.search.Query; +import org.junit.Before; import org.opensearch.Version; import org.opensearch.cluster.ClusterModule; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.settings.ClusterSettings; import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.core.common.io.stream.StreamInput; @@ -31,6 +33,7 @@ import org.opensearch.knn.index.util.KNNClusterUtil; import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.engine.MethodComponentContext; +import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.KNNEngine; @@ -51,6 +54,7 @@ import static org.mockito.Mockito.anyString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import static org.opensearch.knn.common.featureflags.KNNFeatureFlags.KNN_LAUNCH_QUERY_REWRITE_ENABLED_SETTING; import static org.opensearch.knn.index.KNNClusterTestUtils.mockClusterService; import static org.opensearch.knn.index.engine.KNNEngine.ENGINES_SUPPORTING_RADIAL_SEARCH; @@ -67,6 +71,16 @@ public class KNNQueryBuilderTests extends KNNTestCase { protected static final String TEXT_FIELD_NAME = "some_field"; protected static final String TEXT_VALUE = "some_value"; + @Before + @Override + public void setUp() throws Exception { + super.setUp(); + ClusterSettings clusterSettings = mock(ClusterSettings.class); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + when(clusterSettings.get(KNN_LAUNCH_QUERY_REWRITE_ENABLED_SETTING)).thenReturn(false); + KNNSettings.state().setClusterService(clusterService); + } + public void testInvalidK() { float[] queryVector = { 1.0f, 1.0f }; diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java index c74a79946..7bacc7d10 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java @@ -14,8 +14,11 @@ import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery; import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery; import org.apache.lucene.search.join.ToChildBlockJoinQuery; +import org.junit.Before; +import org.mockito.Mock; import org.mockito.MockedConstruction; import org.mockito.Mockito; +import org.opensearch.common.settings.ClusterSettings; import org.opensearch.index.mapper.MappedFieldType; import org.opensearch.index.mapper.MapperService; import org.opensearch.index.query.QueryBuilder; @@ -23,8 +26,10 @@ import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.index.search.NestedHelper; import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.query.nativelib.NativeEngineKnnVectorQuery; import java.util.Arrays; import java.util.List; @@ -36,6 +41,7 @@ import static org.mockito.Mockito.when; import static org.opensearch.knn.common.KNNConstants.DEFAULT_VECTOR_DATA_TYPE_FIELD; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH; +import static org.opensearch.knn.common.featureflags.KNNFeatureFlags.KNN_LAUNCH_QUERY_REWRITE_ENABLED_SETTING; public class KNNQueryFactoryTests extends KNNTestCase { private static final String FILTER_FILED_NAME = "foo"; @@ -50,8 +56,21 @@ public class KNNQueryFactoryTests extends KNNTestCase { private final int testK = 10; private final Map methodParameters = Map.of(METHOD_PARAMETER_EF_SEARCH, 100); + @Mock + ClusterSettings clusterSettings; + + @Before + @Override + public void setUp() throws Exception { + super.setUp(); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + when(clusterSettings.get(KNN_LAUNCH_QUERY_REWRITE_ENABLED_SETTING)).thenReturn(false); + KNNSettings.state().setClusterService(clusterService); + } + public void testCreateCustomKNNQuery() { for (KNNEngine knnEngine : KNNEngine.getEnginesThatCreateCustomSegmentFiles()) { + when(clusterSettings.get(KNN_LAUNCH_QUERY_REWRITE_ENABLED_SETTING)).thenReturn(false); Query query = KNNQueryFactory.create( knnEngine, testIndexName, @@ -61,6 +80,15 @@ public void testCreateCustomKNNQuery() { DEFAULT_VECTOR_DATA_TYPE_FIELD ); assertTrue(query instanceof KNNQuery); + assertEquals(testIndexName, ((KNNQuery) query).getIndexName()); + assertEquals(testFieldName, ((KNNQuery) query).getField()); + assertEquals(testQueryVector, ((KNNQuery) query).getQueryVector()); + assertEquals(testK, ((KNNQuery) query).getK()); + + when(clusterSettings.get(KNN_LAUNCH_QUERY_REWRITE_ENABLED_SETTING)).thenReturn(true); + query = KNNQueryFactory.create(knnEngine, testIndexName, testFieldName, testQueryVector, testK, DEFAULT_VECTOR_DATA_TYPE_FIELD); + assertTrue(query instanceof NativeEngineKnnVectorQuery); + query = ((NativeEngineKnnVectorQuery) query).getKnnQuery(); assertEquals(testIndexName, ((KNNQuery) query).getIndexName()); assertEquals(testFieldName, ((KNNQuery) query).getField()); @@ -392,6 +420,7 @@ public void testCreate_whenBinary_thenSuccess() { when(mockQueryShardContext.fieldMapper(any())).thenReturn(testMapper); BitSetProducer parentFilter = mock(BitSetProducer.class); when(mockQueryShardContext.getParentFilter()).thenReturn(parentFilter); + final KNNQueryFactory.CreateQueryRequest createQueryRequest = KNNQueryFactory.CreateQueryRequest.builder() .knnEngine(KNNEngine.FAISS) .indexName(testIndexName) @@ -407,5 +436,10 @@ public void testCreate_whenBinary_thenSuccess() { assertTrue(query instanceof KNNQuery); assertNotNull(((KNNQuery) query).getByteQueryVector()); assertNull(((KNNQuery) query).getQueryVector()); + + when(clusterSettings.get(KNN_LAUNCH_QUERY_REWRITE_ENABLED_SETTING)).thenReturn(true); + query = KNNQueryFactory.create(createQueryRequest); + assertTrue(query instanceof NativeEngineKnnVectorQuery); } + } diff --git a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java index c7077eace..c5abc964d 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java @@ -350,7 +350,7 @@ public void testShardWithoutFiles() { when(fieldInfos.fieldInfo(any())).thenReturn(fieldInfo); final Scorer knnScorer = knnWeight.scorer(leafReaderContext); - assertNull(knnScorer); + assertEquals(KNNScorer.emptyScorer(knnWeight), knnScorer); } @SneakyThrows @@ -394,7 +394,7 @@ public void testEmptyQueryResults() { when(fieldInfos.fieldInfo(any())).thenReturn(fieldInfo); final Scorer knnScorer = knnWeight.scorer(leafReaderContext); - assertNull(knnScorer); + assertEquals(KNNScorer.emptyScorer(knnWeight), knnScorer); } @SneakyThrows diff --git a/src/test/java/org/opensearch/knn/index/query/nativelib/DocAndScoreQueryTests.java b/src/test/java/org/opensearch/knn/index/query/nativelib/DocAndScoreQueryTests.java new file mode 100644 index 000000000..185cb5d47 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/query/nativelib/DocAndScoreQueryTests.java @@ -0,0 +1,99 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query.nativelib; + +import lombok.SneakyThrows; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexReaderContext; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.Weight; +import org.mockito.Mock; +import org.opensearch.test.OpenSearchTestCase; + +import static org.mockito.Mockito.when; +import static org.mockito.MockitoAnnotations.openMocks; + +public class DocAndScoreQueryTests extends OpenSearchTestCase { + + @Mock + private LeafReaderContext leaf1; + @Mock + private IndexSearcher indexSearcher; + @Mock + private IndexReader reader; + @Mock + private IndexReaderContext readerContext; + + private DocAndScoreQuery objectUnderTest; + + @Override + public void setUp() throws Exception { + super.setUp(); + openMocks(this); + + when(indexSearcher.getIndexReader()).thenReturn(reader); + when(reader.getContext()).thenReturn(readerContext); + when(readerContext.id()).thenReturn(1); + } + + // Note: cannot test with multi leaf as there LeafReaderContext is readonly with no getters for some fields to mock + public void testScorer() throws Exception { + // Given + int[] expectedDocs = { 0, 1, 2, 3, 4 }; + float[] expectedScores = { 0.1f, 1.2f, 2.3f, 5.1f, 3.4f }; + int[] findSegments = { 0, 2, 5 }; + objectUnderTest = new DocAndScoreQuery(4, expectedDocs, expectedScores, findSegments, 1); + + // When + Scorer scorer1 = objectUnderTest.createWeight(indexSearcher, ScoreMode.COMPLETE, 1).scorer(leaf1); + DocIdSetIterator iterator1 = scorer1.iterator(); + Scorer scorer2 = objectUnderTest.createWeight(indexSearcher, ScoreMode.COMPLETE, 1).scorer(leaf1); + DocIdSetIterator iterator2 = scorer2.iterator(); + + int[] actualDocs = new int[2]; + float[] actualScores = new float[2]; + int index = 0; + while (iterator1.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { + actualDocs[index] = iterator1.docID(); + actualScores[index] = scorer1.score(); + ++index; + } + + // Then + assertEquals(2, iterator1.cost()); + assertArrayEquals(new int[] { 0, 1 }, actualDocs); + assertArrayEquals(new float[] { 0.1f, 1.2f }, actualScores, 0.0001f); + + assertEquals(1.2f, scorer2.getMaxScore(1), 0.0001f); + assertEquals(iterator2.advance(1), 1); + } + + @SneakyThrows + public void testWeight() { + // Given + int[] expectedDocs = { 0, 1, 2, 3, 4 }; + float[] expectedScores = { 0.1f, 1.2f, 2.3f, 5.1f, 3.4f }; + int[] findSegments = { 0, 2, 5 }; + Explanation expectedExplanation = Explanation.match(1.2f, "within top 4"); + + // When + objectUnderTest = new DocAndScoreQuery(4, expectedDocs, expectedScores, findSegments, 1); + Weight weight = objectUnderTest.createWeight(indexSearcher, ScoreMode.COMPLETE, 1); + Explanation explanation = weight.explain(leaf1, 1); + + // Then + assertEquals(objectUnderTest, weight.getQuery()); + assertTrue(weight.isCacheable(leaf1)); + assertEquals(2, weight.count(leaf1)); + assertEquals(expectedExplanation, explanation); + assertEquals(Explanation.noMatch("not in top 4"), weight.explain(leaf1, 9)); + } +} diff --git a/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryIT.java b/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryIT.java new file mode 100644 index 000000000..1d84fcb48 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryIT.java @@ -0,0 +1,190 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query.nativelib; + +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; +import com.google.common.primitives.Floats; +import lombok.AllArgsConstructor; +import lombok.SneakyThrows; +import org.apache.http.util.EntityUtils; +import org.junit.BeforeClass; +import org.opensearch.client.Response; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.knn.KNNRestTestCase; +import org.opensearch.knn.KNNResult; +import org.opensearch.knn.TestUtils; +import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.FaissHNSWFlatE2EIT; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.query.KNNQueryBuilder; +import org.opensearch.knn.plugin.script.KNNScoringUtil; + +import java.io.IOException; +import java.net.URL; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.TreeMap; +import java.util.concurrent.ThreadLocalRandom; + +import static com.carrotsearch.randomizedtesting.RandomizedTest.$; +import static com.carrotsearch.randomizedtesting.RandomizedTest.$$; +import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; +import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE; +import static org.opensearch.knn.common.KNNConstants.NAME; +import static org.opensearch.knn.common.KNNConstants.PARAMETERS; + +@AllArgsConstructor +public class NativeEngineKNNVectorQueryIT extends KNNRestTestCase { + + private String description; + private int k; + private Map methodParameters; + private boolean deleteRandomDocs; + + static TestUtils.TestData testData; + + @BeforeClass + public static void setUpClass() throws IOException { + if (FaissHNSWFlatE2EIT.class.getClassLoader() == null) { + throw new IllegalStateException("ClassLoader of FaissIT Class is null"); + } + URL testIndexVectors = FaissHNSWFlatE2EIT.class.getClassLoader().getResource("data/test_vectors_1000x128.json"); + URL testQueries = FaissHNSWFlatE2EIT.class.getClassLoader().getResource("data/test_queries_100x128.csv"); + assert testIndexVectors != null; + assert testQueries != null; + testData = new TestUtils.TestData(testIndexVectors.getPath(), testQueries.getPath()); + } + + @ParametersFactory(argumentFormatting = "description:%1$s; k:%2$s; efSearch:%3$s, deleteDocs:%4$s") + public static Collection parameters() { + return Arrays.asList( + $$( + $("test without deletedocs", 10, Map.of(METHOD_PARAMETER_EF_SEARCH, 300), false), + $("test with deletedocs", 10, Map.of(METHOD_PARAMETER_EF_SEARCH, 300), true) + ) + ); + } + + @SneakyThrows + public void testResultComparisonSanity() { + String indexName = "test-index-1"; + String fieldName = "test-field-1"; + + SpaceType spaceType = SpaceType.L2; + + Integer dimension = testData.indexData.vectors[0].length; + + // Create an index + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(fieldName) + .field("type", "knn_vector") + .field("dimension", dimension) + .startObject(KNNConstants.KNN_METHOD) + .field(NAME, METHOD_HNSW) + .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .field(KNN_ENGINE, KNNEngine.FAISS.getName()) + .startObject(PARAMETERS) + .field(KNNConstants.METHOD_PARAMETER_M, 16) + .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, 32) + .field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, 32) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + + Map mappingMap = xContentBuilderToMap(builder); + String mapping = builder.toString(); + + createKnnIndex(indexName, mapping); + assertEquals(new TreeMap<>(mappingMap), new TreeMap<>(getIndexMappingAsMap(indexName))); + + // Index the test data + for (int i = 0; i < testData.indexData.docs.length; i++) { + addKnnDoc( + indexName, + Integer.toString(testData.indexData.docs[i]), + fieldName, + Floats.asList(testData.indexData.vectors[i]).toArray() + ); + } + + // Assert we have the right number of documents in the index + refreshAllNonSystemIndices(); + assertEquals(testData.indexData.docs.length, getDocCount(indexName)); + + // Delete few Docs + if (deleteRandomDocs) { + final Set docIdsToBeDeleted = new HashSet<>(); + while (docIdsToBeDeleted.size() < 10) { + docIdsToBeDeleted.add(randomInt(testData.indexData.docs.length - 1)); + } + + for (Integer id : docIdsToBeDeleted) { + deleteKnnDoc(indexName, Integer.toString(testData.indexData.docs[id])); + } + refreshAllNonSystemIndices(); + forceMergeKnnIndex(indexName, 3); + + assertEquals(testData.indexData.docs.length - 10, getDocCount(indexName)); + } + + int queryIndex = ThreadLocalRandom.current().nextInt(testData.queries.length); + // Test search queries + final KNNQueryBuilder queryBuilder = KNNQueryBuilder.builder() + .fieldName(fieldName) + .vector(testData.queries[queryIndex]) + .k(k) + .methodParameters(methodParameters) + .build(); + Response nativeEngineResponse = searchKNNIndex(indexName, queryBuilder, k); + String responseBody = EntityUtils.toString(nativeEngineResponse.getEntity()); + List nativeEngineKnnResults = parseSearchResponse(responseBody, fieldName); + assertEquals(k, nativeEngineKnnResults.size()); + + List actualScores = parseSearchResponseScore(responseBody, fieldName); + for (int j = 0; j < k; j++) { + float[] primitiveArray = nativeEngineKnnResults.get(j).getVector(); + assertEquals( + KNNEngine.FAISS.score(KNNScoringUtil.l2Squared(testData.queries[queryIndex], primitiveArray), spaceType), + actualScores.get(j), + 0.0001 + ); + } + + updateClusterSettings("knn.feature.query.rewrite.enabled", false); + Response launchControlDisabledResponse = searchKNNIndex(indexName, queryBuilder, k); + String launchControlDisabledResponseString = EntityUtils.toString(launchControlDisabledResponse.getEntity()); + List knnResults = parseSearchResponse(launchControlDisabledResponseString, fieldName); + assertEquals(k, knnResults.size()); + + assertEquals(nativeEngineKnnResults, knnResults); + + // Delete index + deleteKNNIndex(indexName); + + // Search every 5 seconds 14 times to confirm graph gets evicted + int intervals = 14; + for (int i = 0; i < intervals; i++) { + if (getTotalGraphsInCache() == 0) { + return; + } + Thread.sleep(5 * 1000); + } + + fail("Graphs are not getting evicted"); + } +} diff --git a/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java b/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java new file mode 100644 index 000000000..1e4b11a12 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java @@ -0,0 +1,156 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query.nativelib; + +import lombok.SneakyThrows; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexReaderContext; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.MatchNoDocsQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.TaskExecutor; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.util.Bits; +import org.mockito.ArgumentMatchers; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.opensearch.knn.index.query.KNNQuery; +import org.opensearch.knn.index.query.KNNWeight; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.Callable; + +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.mockito.MockitoAnnotations.openMocks; + +public class NativeEngineKNNVectorQueryTests extends OpenSearchTestCase { + + @Mock + private IndexSearcher searcher; + @Mock + private IndexReader reader; + @Mock + private KNNQuery knnQuery; + @Mock + private KNNWeight knnWeight; + @Mock + private TaskExecutor taskExecutor; + @Mock + private IndexReaderContext indexReaderContext; + @Mock + private LeafReaderContext leaf1; + @Mock + private LeafReaderContext leaf2; + @Mock + private LeafReader leafReader1; + @Mock + private LeafReader leafReader2; + + @InjectMocks + private NativeEngineKnnVectorQuery objectUnderTest; + + @Override + public void setUp() throws Exception { + super.setUp(); + openMocks(this); + + when(leaf1.reader()).thenReturn(leafReader1); + when(leaf2.reader()).thenReturn(leafReader2); + + when(searcher.getIndexReader()).thenReturn(reader); + when(knnQuery.createWeight(searcher, ScoreMode.COMPLETE, 1)).thenReturn(knnWeight); + + when(searcher.getTaskExecutor()).thenReturn(taskExecutor); + when(taskExecutor.invokeAll(ArgumentMatchers.>anyList())).thenAnswer(invocationOnMock -> { + List> callables = invocationOnMock.getArgument(0); + List topDocs = new ArrayList<>(); + for (Callable callable : callables) { + topDocs.add(callable.call()); + } + return topDocs; + }); + + when(reader.getContext()).thenReturn(indexReaderContext); + } + + @SneakyThrows + public void testMultiLeaf() { + // Given + List leaves = List.of(leaf1, leaf2); + when(reader.leaves()).thenReturn(leaves); + + when(knnWeight.searchLeaf(leaf1)).thenReturn(Map.of(0, 1.2f, 1, 5.1f, 2, 2.2f)); + when(knnWeight.searchLeaf(leaf2)).thenReturn(Map.of(4, 3.4f, 3, 5.1f)); + + // Making sure there is deleted docs in one of the segments + Bits liveDocs = mock(Bits.class); + when(leafReader1.getLiveDocs()).thenReturn(liveDocs); + when(leafReader2.getLiveDocs()).thenReturn(null); + + when(liveDocs.get(anyInt())).thenReturn(true); + when(liveDocs.get(2)).thenReturn(false); + when(liveDocs.get(1)).thenReturn(false); + + // k=4 to make sure we get topk results even if docs are deleted/less in one of the leaves + when(knnQuery.getK()).thenReturn(4); + + when(indexReaderContext.id()).thenReturn(1); + int[] expectedDocs = { 0, 3, 4 }; + float[] expectedScores = { 1.2f, 5.1f, 3.4f }; + int[] findSegments = { 0, 1, 3 }; + DocAndScoreQuery expected = new DocAndScoreQuery(4, expectedDocs, expectedScores, findSegments, 1); + + // When + Query actual = objectUnderTest.rewrite(searcher); + + // Then + assertEquals(expected, actual); + } + + @SneakyThrows + public void testSingleLeaf() { + // Given + List leaves = List.of(leaf1); + when(reader.leaves()).thenReturn(leaves); + when(knnWeight.searchLeaf(leaf1)).thenReturn(Map.of(0, 1.2f, 1, 5.1f, 2, 2.2f)); + when(knnQuery.getK()).thenReturn(4); + + when(indexReaderContext.id()).thenReturn(1); + int[] expectedDocs = { 0, 1, 2 }; + float[] expectedScores = { 1.2f, 5.1f, 2.2f }; + int[] findSegments = { 0, 3 }; + DocAndScoreQuery expected = new DocAndScoreQuery(4, expectedDocs, expectedScores, findSegments, 1); + + // When + Query actual = objectUnderTest.rewrite(searcher); + + // Then + assertEquals(expected, actual); + } + + @SneakyThrows + public void testNoMatch() { + // Given + List leaves = List.of(leaf1); + when(reader.leaves()).thenReturn(leaves); + when(knnWeight.searchLeaf(leaf1)).thenReturn(Collections.emptyMap()); + when(knnQuery.getK()).thenReturn(4); + // When + Query actual = objectUnderTest.rewrite(searcher); + + // Then + assertEquals(new MatchNoDocsQuery(), actual); + } +}