diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java index dab2e08c8..ee03a73c4 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java @@ -15,6 +15,7 @@ import org.opensearch.index.query.QueryShardContext; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.query.lucene.LuceneEngineKnnVectorQuery; import org.opensearch.knn.index.query.nativelib.NativeEngineKnnVectorQuery; import org.opensearch.knn.index.query.rescore.RescoreContext; @@ -106,9 +107,9 @@ public static Query create(CreateQueryRequest createQueryRequest) { log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k)); switch (vectorDataType) { case BYTE: - return getKnnByteVectorQuery(fieldName, byteVector, luceneK, filterQuery, parentFilter); + return new LuceneEngineKnnVectorQuery(getKnnByteVectorQuery(fieldName, byteVector, luceneK, filterQuery, parentFilter)); case FLOAT: - return getKnnFloatVectorQuery(fieldName, vector, luceneK, filterQuery, parentFilter); + return new LuceneEngineKnnVectorQuery(getKnnFloatVectorQuery(fieldName, vector, luceneK, filterQuery, parentFilter)); default: throw new IllegalArgumentException( String.format( diff --git a/src/main/java/org/opensearch/knn/index/query/lucene/LuceneEngineKnnVectorQuery.java b/src/main/java/org/opensearch/knn/index/query/lucene/LuceneEngineKnnVectorQuery.java new file mode 100644 index 000000000..04972eab4 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/query/lucene/LuceneEngineKnnVectorQuery.java @@ -0,0 +1,67 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query.lucene; + +import lombok.AllArgsConstructor; +import lombok.extern.log4j.Log4j2; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Weight; + +import java.io.IOException; + +/** + * LuceneEngineKnnVectorQuery is a wrapper around a vector queries for the Lucene engine. + * This enables us to defer rewrites until weight creation to optimize repeated execution + * of Lucene based k-NN queries. + */ +@AllArgsConstructor +@Log4j2 +public class LuceneEngineKnnVectorQuery extends Query { + private final Query luceneQuery; + + /* + Prevents repeated rewrites of the query for the Lucene engine. + */ + @Override + public Query rewrite(IndexSearcher indexSearcher) { + return this; + } + + /* + Rewrites the query just before weight creation. + */ + @Override + public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { + Query rewrittenQuery = luceneQuery.rewrite(searcher); + return rewrittenQuery.createWeight(searcher, scoreMode, boost); + } + + @Override + public String toString(String s) { + return luceneQuery.toString(); + } + + @Override + public void visit(QueryVisitor queryVisitor) { + queryVisitor.visitLeaf(this); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + LuceneEngineKnnVectorQuery otherQuery = (LuceneEngineKnnVectorQuery) o; + return luceneQuery.equals(otherQuery.luceneQuery); + } + + @Override + public int hashCode() { + return luceneQuery.hashCode(); + } +} diff --git a/src/test/java/org/opensearch/knn/index/query/lucene/LuceneEngineKnnVectorQueryTests.java b/src/test/java/org/opensearch/knn/index/query/lucene/LuceneEngineKnnVectorQueryTests.java new file mode 100644 index 000000000..10be94890 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/query/lucene/LuceneEngineKnnVectorQueryTests.java @@ -0,0 +1,95 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query.lucene; + +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Weight; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.Spy; +import org.opensearch.test.OpenSearchTestCase; + +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.MockitoAnnotations.openMocks; + +public class LuceneEngineKnnVectorQueryTests extends OpenSearchTestCase { + + @Mock + IndexSearcher indexSearcher; + + @Mock + Query luceneQuery; + + @Mock + Weight weight; + + @Mock + QueryVisitor queryVisitor; + + @Spy + @InjectMocks + LuceneEngineKnnVectorQuery objectUnderTest; + + @Override + public void setUp() throws Exception { + super.setUp(); + openMocks(this); + when(luceneQuery.rewrite(any(IndexSearcher.class))).thenReturn(luceneQuery); + when(luceneQuery.createWeight(any(IndexSearcher.class), any(ScoreMode.class), anyFloat())).thenReturn(weight); + } + + public void testRewrite() { + objectUnderTest.rewrite(indexSearcher); + objectUnderTest.rewrite(indexSearcher); + objectUnderTest.rewrite(indexSearcher); + verifyNoInteractions(luceneQuery); + verify(objectUnderTest, times(3)).rewrite(indexSearcher); + } + + public void testCreateWeight() throws Exception { + objectUnderTest.rewrite(indexSearcher); + objectUnderTest.rewrite(indexSearcher); + objectUnderTest.rewrite(indexSearcher); + verifyNoInteractions(luceneQuery); + Weight actualWeight = objectUnderTest.createWeight(indexSearcher, ScoreMode.TOP_DOCS, 1.0f); + verify(luceneQuery, times(1)).rewrite(indexSearcher); + verify(objectUnderTest, times(3)).rewrite(indexSearcher); + assertEquals(weight, actualWeight); + } + + public void testVisit() { + objectUnderTest.visit(queryVisitor); + verify(queryVisitor).visitLeaf(objectUnderTest); + } + + public void testEquals() { + LuceneEngineKnnVectorQuery mainQuery = new LuceneEngineKnnVectorQuery(luceneQuery); + LuceneEngineKnnVectorQuery otherQuery = new LuceneEngineKnnVectorQuery(luceneQuery); + assertEquals(mainQuery, otherQuery); + assertEquals(mainQuery, mainQuery); + assertNotEquals(mainQuery, null); + assertNotEquals(mainQuery, new Object()); + LuceneEngineKnnVectorQuery otherQuery2 = new LuceneEngineKnnVectorQuery(null); + assertNotEquals(mainQuery, otherQuery2); + } + + public void testHashCode() { + LuceneEngineKnnVectorQuery mainQuery = new LuceneEngineKnnVectorQuery(luceneQuery); + assertEquals(mainQuery.hashCode(), luceneQuery.hashCode()); + } + + public void testToString() { + LuceneEngineKnnVectorQuery mainQuery = new LuceneEngineKnnVectorQuery(luceneQuery); + assertEquals(mainQuery.toString(), luceneQuery.toString()); + } +}