Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.19] Adds debug logs for KNNQuery and KNNWeight #2471

Merged
merged 1 commit into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion src/main/java/org/opensearch/knn/index/query/KNNQuery.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.FieldExistsQuery;
Expand All @@ -19,6 +20,7 @@
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Weight;
import org.apache.lucene.search.join.BitSetProducer;
import org.opensearch.common.StopWatch;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.query.rescore.RescoreContext;
Expand All @@ -32,6 +34,7 @@
* Custom KNN query. Query is used for KNNEngine's that create their own custom segment files. These files need to be
* loaded and queried in a custom manner throughout the query path.
*/
@Log4j2
@Getter
@Builder
@AllArgsConstructor
Expand All @@ -45,14 +48,17 @@ public class KNNQuery extends Query {
private final String indexName;
private final VectorDataType vectorDataType;
private final RescoreContext rescoreContext;

@Setter
private Query filterQuery;
@Getter
private BitSetProducer parentsFilter;
private Float radius;
private Context context;

// Note: ideally query should not have to deal with shard level information. Adding it for logging purposes only
// TODO: ThreadContext does not work with logger, remove this from here once its figured out
private int shardId;

public KNNQuery(
final String field,
final float[] queryVector,
Expand Down Expand Up @@ -168,7 +174,22 @@ public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float bo
if (!KNNSettings.isKNNPluginEnabled()) {
throw new IllegalStateException("KNN plugin is disabled. To enable update knn.plugin.enabled to true");
}
StopWatch stopWatch = null;
if (log.isDebugEnabled()) {
stopWatch = new StopWatch().start();
}

final Weight filterWeight = getFilterWeight(searcher);
if (log.isDebugEnabled() && stopWatch != null) {
stopWatch.stop();
log.debug(
"Creating filter weight, Shard: [{}], field: [{}] took in nanos: [{}]",
shardId,
field,
stopWatch.totalTime().nanos()
);
}

if (filterWeight != null) {
return new KNNWeight(this, boost, filterWeight);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,11 @@ public static Query create(CreateQueryRequest createQueryRequest) {
final KNNEngine knnEngine = createQueryRequest.getKnnEngine();
final boolean expandNested = createQueryRequest.getExpandNested().orElse(false);
BitSetProducer parentFilter = null;
int shardId = -1;
if (createQueryRequest.getContext().isPresent()) {
QueryShardContext context = createQueryRequest.getContext().get();
parentFilter = context.getParentFilter();
shardId = context.getShardId();
}

if (parentFilter == null && expandNested) {
Expand Down Expand Up @@ -93,6 +95,7 @@ public static Query create(CreateQueryRequest createQueryRequest) {
.filterQuery(validatedFilterQuery)
.vectorDataType(vectorDataType)
.rescoreContext(rescoreContext)
.shardId(shardId)
.build();
break;
default:
Expand All @@ -106,6 +109,7 @@ public static Query create(CreateQueryRequest createQueryRequest) {
.filterQuery(validatedFilterQuery)
.vectorDataType(vectorDataType)
.rescoreContext(rescoreContext)
.shardId(shardId)
.build();
}

Expand Down
39 changes: 34 additions & 5 deletions src/main/java/org/opensearch/knn/index/query/KNNWeight.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.FixedBitSet;
import org.opensearch.common.Nullable;
import org.opensearch.common.StopWatch;
import org.opensearch.common.lucene.Lucene;
import org.opensearch.knn.common.FieldInfoExtractor;
import org.opensearch.knn.common.KNNConstants;
Expand Down Expand Up @@ -129,7 +131,13 @@ public Scorer scorer(LeafReaderContext context) throws IOException {
* @return A Map of docId to scores for top k results
*/
public PerLeafResult searchLeaf(LeafReaderContext context, int k) throws IOException {
final SegmentReader reader = Lucene.segmentReader(context.reader());
final String segmentName = reader.getSegmentName();

StopWatch stopWatch = startStopWatch();
final BitSet filterBitSet = getFilteredDocsBitSet(context);
stopStopWatchAndLog(stopWatch, "FilterBitSet creation", segmentName);

final int maxDoc = context.reader().maxDoc();
int cardinality = filterBitSet.cardinality();
// We don't need to go to JNI layer if no documents are found which satisfy the filters
Expand All @@ -153,7 +161,10 @@ public PerLeafResult searchLeaf(LeafReaderContext context, int k) throws IOExcep
* so that it will not do a bitset look up in bottom search layer.
*/
final BitSet annFilter = (filterWeight != null && cardinality == maxDoc) ? null : filterBitSet;
final Map<Integer, Float> docIdsToScoreMap = doANNSearch(context, annFilter, cardinality, k);

StopWatch annStopWatch = startStopWatch();
final Map<Integer, Float> docIdsToScoreMap = doANNSearch(reader, context, annFilter, cardinality, k);
stopStopWatchAndLog(annStopWatch, "ANN search", segmentName);

// See whether we have to perform exact search based on approx search results
// This is required if there are no native engine files or if approximate search returned
Expand All @@ -166,6 +177,14 @@ public PerLeafResult searchLeaf(LeafReaderContext context, int k) throws IOExcep
return new PerLeafResult(filterWeight == null ? null : filterBitSet, docIdsToScoreMap);
}

private void stopStopWatchAndLog(@Nullable final StopWatch stopWatch, final String prefixMessage, String segmentName) {
if (log.isDebugEnabled() && stopWatch != null) {
stopWatch.stop();
final String logMessage = prefixMessage + " shard: [{}], segment: [{}], field: [{}], time in nanos:[{}] ";
log.debug(logMessage, knnQuery.getShardId(), segmentName, knnQuery.getField(), stopWatch.totalTime().nanos());
}
}

private BitSet getFilteredDocsBitSet(final LeafReaderContext ctx) throws IOException {
if (this.filterWeight == null) {
return new FixedBitSet(0);
Expand Down Expand Up @@ -222,7 +241,7 @@ private Map<Integer, Float> doExactSearch(
final LeafReaderContext context,
final DocIdSetIterator acceptedDocs,
final long numberOfAcceptedDocs,
int k
final int k
) throws IOException {
final ExactSearcherContextBuilder exactSearcherContextBuilder = ExactSearcher.ExactSearcherContext.builder()
.isParentHits(true)
Expand All @@ -237,13 +256,12 @@ private Map<Integer, Float> doExactSearch(
}

private Map<Integer, Float> doANNSearch(
final SegmentReader reader,
final LeafReaderContext context,
final BitSet filterIdsBitSet,
final int cardinality,
final int k
) throws IOException {
final SegmentReader reader = Lucene.segmentReader(context.reader());

FieldInfo fieldInfo = FieldInfoExtractor.getFieldInfo(reader, knnQuery.getField());

if (fieldInfo == null) {
Expand Down Expand Up @@ -402,7 +420,11 @@ public Map<Integer, Float> exactSearch(
final LeafReaderContext leafReaderContext,
final ExactSearcher.ExactSearcherContext exactSearcherContext
) throws IOException {
return exactSearcher.searchLeaf(leafReaderContext, exactSearcherContext);
StopWatch stopWatch = startStopWatch();
Map<Integer, Float> exactSearchResults = exactSearcher.searchLeaf(leafReaderContext, exactSearcherContext);
final SegmentReader reader = Lucene.segmentReader(leafReaderContext.reader());
stopStopWatchAndLog(stopWatch, "Exact search", reader.getSegmentName());
return exactSearchResults;
}

@Override
Expand Down Expand Up @@ -523,4 +545,11 @@ private boolean isMissingNativeEngineFiles(LeafReaderContext context) {
);
return engineFiles.isEmpty();
}

private StopWatch startStopWatch() {
if (log.isDebugEnabled()) {
return new StopWatch().start();
}
return null;
}
}
Loading