Skip to content

Commit

Permalink
Adds debug logs for KNNQuery and KNNWeight (#2466)
Browse files Browse the repository at this point in the history
* Adds debug logs for KNNQuery and KNNWeight

Signed-off-by: Tejas Shah <[email protected]>

* Adds check to see if log is enabled to start and stop StopWatch

Signed-off-by: Tejas Shah <[email protected]>

* Addressing comments on the PR

Signed-off-by: Tejas Shah <[email protected]>

* Adds shard and segment info in the logs

Signed-off-by: Tejas Shah <[email protected]>

* Removes unnecessary segment name param from exact search

Signed-off-by: Tejas Shah <[email protected]>

* Fixes the build

Signed-off-by: Tejas Shah <[email protected]>

---------

Signed-off-by: Tejas Shah <[email protected]>
(cherry picked from commit f322e27)
  • Loading branch information
shatejas authored and github-actions[bot] committed Jan 30, 2025
1 parent 135f7eb commit bc88c29
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 6 deletions.
23 changes: 22 additions & 1 deletion src/main/java/org/opensearch/knn/index/query/KNNQuery.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.FieldExistsQuery;
Expand All @@ -19,6 +20,7 @@
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Weight;
import org.apache.lucene.search.join.BitSetProducer;
import org.opensearch.common.StopWatch;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.query.rescore.RescoreContext;
Expand All @@ -32,6 +34,7 @@
* Custom KNN query. Query is used for KNNEngine's that create their own custom segment files. These files need to be
* loaded and queried in a custom manner throughout the query path.
*/
@Log4j2
@Getter
@Builder
@AllArgsConstructor
Expand All @@ -45,14 +48,17 @@ public class KNNQuery extends Query {
private final String indexName;
private final VectorDataType vectorDataType;
private final RescoreContext rescoreContext;

@Setter
private Query filterQuery;
@Getter
private BitSetProducer parentsFilter;
private Float radius;
private Context context;

// Note: ideally query should not have to deal with shard level information. Adding it for logging purposes only
// TODO: ThreadContext does not work with logger, remove this from here once its figured out
private int shardId;

public KNNQuery(
final String field,
final float[] queryVector,
Expand Down Expand Up @@ -168,7 +174,22 @@ public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float bo
if (!KNNSettings.isKNNPluginEnabled()) {
throw new IllegalStateException("KNN plugin is disabled. To enable update knn.plugin.enabled to true");
}
StopWatch stopWatch = null;
if (log.isDebugEnabled()) {
stopWatch = new StopWatch().start();
}

final Weight filterWeight = getFilterWeight(searcher);
if (log.isDebugEnabled() && stopWatch != null) {
stopWatch.stop();
log.debug(
"Creating filter weight, Shard: [{}], field: [{}] took in nanos: [{}]",
shardId,
field,
stopWatch.totalTime().nanos()
);
}

if (filterWeight != null) {
return new KNNWeight(this, boost, filterWeight);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,11 @@ public static Query create(CreateQueryRequest createQueryRequest) {
final KNNEngine knnEngine = createQueryRequest.getKnnEngine();
final boolean expandNested = createQueryRequest.getExpandNested().orElse(false);
BitSetProducer parentFilter = null;
int shardId = -1;
if (createQueryRequest.getContext().isPresent()) {
QueryShardContext context = createQueryRequest.getContext().get();
parentFilter = context.getParentFilter();
shardId = context.getShardId();
}

if (parentFilter == null && expandNested) {
Expand Down Expand Up @@ -93,6 +95,7 @@ public static Query create(CreateQueryRequest createQueryRequest) {
.filterQuery(validatedFilterQuery)
.vectorDataType(vectorDataType)
.rescoreContext(rescoreContext)
.shardId(shardId)
.build();
break;
default:
Expand All @@ -106,6 +109,7 @@ public static Query create(CreateQueryRequest createQueryRequest) {
.filterQuery(validatedFilterQuery)
.vectorDataType(vectorDataType)
.rescoreContext(rescoreContext)
.shardId(shardId)
.build();
}

Expand Down
39 changes: 34 additions & 5 deletions src/main/java/org/opensearch/knn/index/query/KNNWeight.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.FixedBitSet;
import org.opensearch.common.Nullable;
import org.opensearch.common.StopWatch;
import org.opensearch.common.lucene.Lucene;
import org.opensearch.knn.common.FieldInfoExtractor;
import org.opensearch.knn.common.KNNConstants;
Expand Down Expand Up @@ -128,7 +130,13 @@ public Scorer scorer(LeafReaderContext context) throws IOException {
* @return A Map of docId to scores for top k results
*/
public PerLeafResult searchLeaf(LeafReaderContext context, int k) throws IOException {
final SegmentReader reader = Lucene.segmentReader(context.reader());
final String segmentName = reader.getSegmentName();

StopWatch stopWatch = startStopWatch();
final BitSet filterBitSet = getFilteredDocsBitSet(context);
stopStopWatchAndLog(stopWatch, "FilterBitSet creation", segmentName);

final int maxDoc = context.reader().maxDoc();
int cardinality = filterBitSet.cardinality();
// We don't need to go to JNI layer if no documents are found which satisfy the filters
Expand All @@ -152,7 +160,10 @@ public PerLeafResult searchLeaf(LeafReaderContext context, int k) throws IOExcep
* so that it will not do a bitset look up in bottom search layer.
*/
final BitSet annFilter = (filterWeight != null && cardinality == maxDoc) ? null : filterBitSet;
final Map<Integer, Float> docIdsToScoreMap = doANNSearch(context, annFilter, cardinality, k);

StopWatch annStopWatch = startStopWatch();
final Map<Integer, Float> docIdsToScoreMap = doANNSearch(reader, context, annFilter, cardinality, k);
stopStopWatchAndLog(annStopWatch, "ANN search", segmentName);

// See whether we have to perform exact search based on approx search results
// This is required if there are no native engine files or if approximate search returned
Expand All @@ -165,6 +176,14 @@ public PerLeafResult searchLeaf(LeafReaderContext context, int k) throws IOExcep
return new PerLeafResult(filterWeight == null ? null : filterBitSet, docIdsToScoreMap);
}

private void stopStopWatchAndLog(@Nullable final StopWatch stopWatch, final String prefixMessage, String segmentName) {
if (log.isDebugEnabled() && stopWatch != null) {
stopWatch.stop();
final String logMessage = prefixMessage + " shard: [{}], segment: [{}], field: [{}], time in nanos:[{}] ";
log.debug(logMessage, knnQuery.getShardId(), segmentName, knnQuery.getField(), stopWatch.totalTime().nanos());
}
}

private BitSet getFilteredDocsBitSet(final LeafReaderContext ctx) throws IOException {
if (this.filterWeight == null) {
return new FixedBitSet(0);
Expand Down Expand Up @@ -221,7 +240,7 @@ private Map<Integer, Float> doExactSearch(
final LeafReaderContext context,
final DocIdSetIterator acceptedDocs,
final long numberOfAcceptedDocs,
int k
final int k
) throws IOException {
final ExactSearcherContextBuilder exactSearcherContextBuilder = ExactSearcher.ExactSearcherContext.builder()
.isParentHits(true)
Expand All @@ -236,13 +255,12 @@ private Map<Integer, Float> doExactSearch(
}

private Map<Integer, Float> doANNSearch(
final SegmentReader reader,
final LeafReaderContext context,
final BitSet filterIdsBitSet,
final int cardinality,
final int k
) throws IOException {
final SegmentReader reader = Lucene.segmentReader(context.reader());

FieldInfo fieldInfo = FieldInfoExtractor.getFieldInfo(reader, knnQuery.getField());

if (fieldInfo == null) {
Expand Down Expand Up @@ -401,7 +419,11 @@ public Map<Integer, Float> exactSearch(
final LeafReaderContext leafReaderContext,
final ExactSearcher.ExactSearcherContext exactSearcherContext
) throws IOException {
return exactSearcher.searchLeaf(leafReaderContext, exactSearcherContext);
StopWatch stopWatch = startStopWatch();
Map<Integer, Float> exactSearchResults = exactSearcher.searchLeaf(leafReaderContext, exactSearcherContext);
final SegmentReader reader = Lucene.segmentReader(leafReaderContext.reader());
stopStopWatchAndLog(stopWatch, "Exact search", reader.getSegmentName());
return exactSearchResults;
}

@Override
Expand Down Expand Up @@ -508,4 +530,11 @@ private boolean isMissingNativeEngineFiles(LeafReaderContext context) {
);
return engineFiles.isEmpty();
}

private StopWatch startStopWatch() {
if (log.isDebugEnabled()) {
return new StopWatch().start();
}
return null;
}
}

0 comments on commit bc88c29

Please sign in to comment.