diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQuery.java b/src/main/java/org/opensearch/knn/index/query/KNNQuery.java index 1a03f4b99..3a4201bff 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQuery.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQuery.java @@ -10,6 +10,7 @@ import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.Setter; +import lombok.extern.log4j.Log4j2; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.FieldExistsQuery; @@ -19,6 +20,7 @@ import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Weight; import org.apache.lucene.search.join.BitSetProducer; +import org.opensearch.common.StopWatch; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.query.rescore.RescoreContext; @@ -32,6 +34,7 @@ * Custom KNN query. Query is used for KNNEngine's that create their own custom segment files. These files need to be * loaded and queried in a custom manner throughout the query path. */ +@Log4j2 @Getter @Builder @AllArgsConstructor @@ -45,7 +48,6 @@ public class KNNQuery extends Query { private final String indexName; private final VectorDataType vectorDataType; private final RescoreContext rescoreContext; - @Setter private Query filterQuery; @Getter @@ -53,6 +55,10 @@ public class KNNQuery extends Query { private Float radius; private Context context; + // Note: ideally query should not have to deal with shard level information. Adding it for logging purposes only + // TODO: ThreadContext does not work with logger, remove this from here once its figured out + private int shardId; + public KNNQuery( final String field, final float[] queryVector, @@ -168,7 +174,22 @@ public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float bo if (!KNNSettings.isKNNPluginEnabled()) { throw new IllegalStateException("KNN plugin is disabled. To enable update knn.plugin.enabled to true"); } + StopWatch stopWatch = null; + if (log.isDebugEnabled()) { + stopWatch = new StopWatch().start(); + } + final Weight filterWeight = getFilterWeight(searcher); + if (log.isDebugEnabled() && stopWatch != null) { + stopWatch.stop(); + log.debug( + "Creating filter weight, Shard: [{}], field: [{}] took in nanos: [{}]", + shardId, + field, + stopWatch.totalTime().nanos() + ); + } + if (filterWeight != null) { return new KNNWeight(this, boost, filterWeight); } diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java index 8e6c97f05..498a1e602 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java @@ -52,9 +52,11 @@ public static Query create(CreateQueryRequest createQueryRequest) { final KNNEngine knnEngine = createQueryRequest.getKnnEngine(); final boolean expandNested = createQueryRequest.getExpandNested().orElse(false); BitSetProducer parentFilter = null; + int shardId = -1; if (createQueryRequest.getContext().isPresent()) { QueryShardContext context = createQueryRequest.getContext().get(); parentFilter = context.getParentFilter(); + shardId = context.getShardId(); } if (parentFilter == null && expandNested) { @@ -93,6 +95,7 @@ public static Query create(CreateQueryRequest createQueryRequest) { .filterQuery(validatedFilterQuery) .vectorDataType(vectorDataType) .rescoreContext(rescoreContext) + .shardId(shardId) .build(); break; default: @@ -106,6 +109,7 @@ public static Query create(CreateQueryRequest createQueryRequest) { .filterQuery(validatedFilterQuery) .vectorDataType(vectorDataType) .rescoreContext(rescoreContext) + .shardId(shardId) .build(); } diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index d5912d758..d07d57bda 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -20,6 +20,8 @@ import org.apache.lucene.util.BitSetIterator; import org.apache.lucene.util.Bits; import org.apache.lucene.util.FixedBitSet; +import org.opensearch.common.Nullable; +import org.opensearch.common.StopWatch; import org.opensearch.common.lucene.Lucene; import org.opensearch.knn.common.FieldInfoExtractor; import org.opensearch.knn.common.KNNConstants; @@ -129,7 +131,13 @@ public Scorer scorer(LeafReaderContext context) throws IOException { * @return A Map of docId to scores for top k results */ public PerLeafResult searchLeaf(LeafReaderContext context, int k) throws IOException { + final SegmentReader reader = Lucene.segmentReader(context.reader()); + final String segmentName = reader.getSegmentName(); + + StopWatch stopWatch = startStopWatch(); final BitSet filterBitSet = getFilteredDocsBitSet(context); + stopStopWatchAndLog(stopWatch, "FilterBitSet creation", segmentName); + final int maxDoc = context.reader().maxDoc(); int cardinality = filterBitSet.cardinality(); // We don't need to go to JNI layer if no documents are found which satisfy the filters @@ -153,7 +161,10 @@ public PerLeafResult searchLeaf(LeafReaderContext context, int k) throws IOExcep * so that it will not do a bitset look up in bottom search layer. */ final BitSet annFilter = (filterWeight != null && cardinality == maxDoc) ? null : filterBitSet; - final Map docIdsToScoreMap = doANNSearch(context, annFilter, cardinality, k); + + StopWatch annStopWatch = startStopWatch(); + final Map docIdsToScoreMap = doANNSearch(reader, context, annFilter, cardinality, k); + stopStopWatchAndLog(annStopWatch, "ANN search", segmentName); // See whether we have to perform exact search based on approx search results // This is required if there are no native engine files or if approximate search returned @@ -166,6 +177,14 @@ public PerLeafResult searchLeaf(LeafReaderContext context, int k) throws IOExcep return new PerLeafResult(filterWeight == null ? null : filterBitSet, docIdsToScoreMap); } + private void stopStopWatchAndLog(@Nullable final StopWatch stopWatch, final String prefixMessage, String segmentName) { + if (log.isDebugEnabled() && stopWatch != null) { + stopWatch.stop(); + final String logMessage = prefixMessage + " shard: [{}], segment: [{}], field: [{}], time in nanos:[{}] "; + log.debug(logMessage, knnQuery.getShardId(), segmentName, knnQuery.getField(), stopWatch.totalTime().nanos()); + } + } + private BitSet getFilteredDocsBitSet(final LeafReaderContext ctx) throws IOException { if (this.filterWeight == null) { return new FixedBitSet(0); @@ -222,7 +241,7 @@ private Map doExactSearch( final LeafReaderContext context, final DocIdSetIterator acceptedDocs, final long numberOfAcceptedDocs, - int k + final int k ) throws IOException { final ExactSearcherContextBuilder exactSearcherContextBuilder = ExactSearcher.ExactSearcherContext.builder() .isParentHits(true) @@ -237,13 +256,12 @@ private Map doExactSearch( } private Map doANNSearch( + final SegmentReader reader, final LeafReaderContext context, final BitSet filterIdsBitSet, final int cardinality, final int k ) throws IOException { - final SegmentReader reader = Lucene.segmentReader(context.reader()); - FieldInfo fieldInfo = FieldInfoExtractor.getFieldInfo(reader, knnQuery.getField()); if (fieldInfo == null) { @@ -402,7 +420,11 @@ public Map exactSearch( final LeafReaderContext leafReaderContext, final ExactSearcher.ExactSearcherContext exactSearcherContext ) throws IOException { - return exactSearcher.searchLeaf(leafReaderContext, exactSearcherContext); + StopWatch stopWatch = startStopWatch(); + Map exactSearchResults = exactSearcher.searchLeaf(leafReaderContext, exactSearcherContext); + final SegmentReader reader = Lucene.segmentReader(leafReaderContext.reader()); + stopStopWatchAndLog(stopWatch, "Exact search", reader.getSegmentName()); + return exactSearchResults; } @Override @@ -523,4 +545,11 @@ private boolean isMissingNativeEngineFiles(LeafReaderContext context) { ); return engineFiles.isEmpty(); } + + private StopWatch startStopWatch() { + if (log.isDebugEnabled()) { + return new StopWatch().start(); + } + return null; + } }