Skip to content

Commit

Permalink
WIP: add static metrics to SAI for simple ANN queries
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeljmarshall committed Jan 22, 2025
1 parent 715cdab commit 5740890
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 5 deletions.
22 changes: 22 additions & 0 deletions src/java/org/apache/cassandra/index/sai/QueryContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ public class QueryContext
private final LongAdder annNodesVisited = new LongAdder();
private float annRerankFloor = 0.0f; // only called from single-threaded setup code

private long addRowMaterializationDurationNanos = 0;
private long annSearchDurationNanos = 0;

private final LongAdder shadowedPrimaryKeyCount = new LongAdder();

// Determines the order of using indexes for filtering and sorting.
Expand Down Expand Up @@ -205,6 +208,15 @@ public long annNodesVisited()
{
return annNodesVisited.longValue();
}
public long rowMaterializationDuration()
{
return addRowMaterializationDurationNanos;
}

public long annSearchDuration()
{
return annSearchDurationNanos;
}

public FilterSortOrder filterSortOrder()
{
Expand Down Expand Up @@ -244,6 +256,16 @@ public void updateAnnRerankFloor(float observedFloor)
annRerankFloor = max(annRerankFloor, observedFloor);
}

public void addRowMaterializationDuration(long nanos)
{
addRowMaterializationDurationNanos += nanos;
}

public void addAnnSearchDuration(long nanos)
{
annSearchDurationNanos += nanos;
}

/**
* Determines the order of filtering and sorting operations.
* Currently used only by vector search.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,15 @@

import java.util.Arrays;
import java.util.Iterator;
import java.util.concurrent.TimeUnit;
import java.util.function.IntConsumer;

import io.github.jbellis.jvector.graph.GraphSearcher;
import io.github.jbellis.jvector.graph.SearchResult;
import io.micrometer.core.instrument.DistributionSummary;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Timer;
import org.apache.cassandra.index.sai.QueryContext;
import org.apache.cassandra.tracing.Tracing;
import org.apache.cassandra.utils.AbstractIterator;

Expand All @@ -41,16 +46,22 @@ public class AutoResumingNodeScoreIterator extends AbstractIterator<SearchResult
private final int rerankK;
private final boolean inMemory;
private final String source;
private final QueryContext context;
private final IntConsumer nodesVisitedConsumer;
private Iterator<SearchResult.NodeScore> nodeScores;
private int cumulativeNodesVisited;
private int resumes = 0;

private final static Timer annResumeSearchNanos = Metrics.timer("sai_ann_search_nanos", "phase", "resume");
private final static DistributionSummary annResumeCount = Metrics.summary("sai_ann_per_index_resume_count");

/**
* Create a new {@link AutoResumingNodeScoreIterator} that iterates over the provided {@link SearchResult}.
* If the {@link SearchResult} is consumed, it retrieves the next {@link SearchResult} until the search returns
* no more results.
* @param searcher the {@link GraphSearcher} to use to resume search.
* @param result the first {@link SearchResult} to iterate over
* @param queryContext the query context to use for the search
* @param nodesVisitedConsumer a consumer that accepts the total number of nodes visited
* @param limit the limit to pass to the {@link GraphSearcher} when resuming search
* @param rerankK the rerankK to pass to the {@link GraphSearcher} when resuming search
Expand All @@ -60,6 +71,7 @@ public class AutoResumingNodeScoreIterator extends AbstractIterator<SearchResult
public AutoResumingNodeScoreIterator(GraphSearcher searcher,
GraphSearcherAccessManager accessManager,
SearchResult result,
QueryContext queryContext,
IntConsumer nodesVisitedConsumer,
int limit,
int rerankK,
Expand All @@ -70,6 +82,7 @@ public AutoResumingNodeScoreIterator(GraphSearcher searcher,
this.accessManager = accessManager;
this.nodeScores = Arrays.stream(result.getNodes()).iterator();
this.cumulativeNodesVisited = result.getVisitedCount();
this.context = queryContext;
this.nodesVisitedConsumer = nodesVisitedConsumer;
this.limit = max(1, limit / 2); // we shouldn't need as many results on resume
this.rerankK = rerankK;
Expand All @@ -83,7 +96,13 @@ protected SearchResult.NodeScore computeNext()
if (nodeScores.hasNext())
return nodeScores.next();

var start = System.nanoTime();
var nextResult = searcher.resume(limit, rerankK);
var duration = System.nanoTime() - start;
annResumeSearchNanos.record(duration, TimeUnit.NANOSECONDS);
context.addAnnSearchDuration(duration);

resumes++;
maybeLogTrace(nextResult);
cumulativeNodesVisited += nextResult.getVisitedCount();
// If the next result is empty, we are done searching.
Expand All @@ -102,6 +121,7 @@ private void maybeLogTrace(SearchResult result)
public void close()
{
nodesVisitedConsumer.accept(cumulativeNodesVisited);
annResumeCount.record(resumes);
accessManager.release();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

package org.apache.cassandra.index.sai.disk.vector;

import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Timer;
import org.apache.cassandra.index.sai.utils.RowIdWithMeta;
import org.apache.cassandra.index.sai.utils.RowIdWithScore;
import org.apache.cassandra.io.util.FileUtils;
Expand Down Expand Up @@ -81,6 +83,8 @@ public static int compare(RowWithApproximateScore l, RowWithApproximateScore r)
private final int limit;
private int rerankedCount;

private final static Timer rerankTimer = Metrics.timer("sai_brute_force_rerank_nanos");

/**
* @param approximateScoreQueue A priority queue of rows and their ordinal ordered by their approximate similarity scores
* @param reranker A function that takes a graph ordinal and returns the exact similarity score
Expand All @@ -105,13 +109,15 @@ public BruteForceRowIdIterator(SortingIterator<RowWithApproximateScore> approxim
protected RowIdWithScore computeNext() {
int consumed = rerankedCount - exactScoreQueue.size();
if (consumed >= limit) {
var timer = Timer.start();
// Refill the exactScoreQueue until it reaches topK exact scores, or the approximate score queue is empty
while (approximateScoreQueue.hasNext() && exactScoreQueue.size() < topK) {
RowWithApproximateScore rowOrdinalScore = approximateScoreQueue.next();
float score = reranker.similarityTo(rowOrdinalScore.ordinal);
exactScoreQueue.add(new RowIdWithScore(rowOrdinalScore.rowId, score));
}
rerankedCount = exactScoreQueue.size();
timer.stop(rerankTimer);
}
RowIdWithScore top = exactScoreQueue.pop();
return top == null ? endOfData() : top;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.io.IOException;
import java.util.Arrays;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.function.IntConsumer;
import javax.annotation.Nullable;

Expand All @@ -41,6 +42,9 @@
import io.github.jbellis.jvector.util.ExplicitThreadLocal;
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
import io.github.jbellis.jvector.vector.types.VectorFloat;
import io.micrometer.core.instrument.DistributionSummary;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Timer;
import org.apache.cassandra.index.sai.IndexContext;
import org.apache.cassandra.index.sai.QueryContext;
import org.apache.cassandra.index.sai.SSTableContext;
Expand Down Expand Up @@ -82,6 +86,11 @@ public class CassandraDiskAnn

private final ExplicitThreadLocal<GraphSearcherAccessManager> searchers;

private final static DistributionSummary annRerankFloor = Metrics.summary("sai_ann_rerank_floor");
private final static DistributionSummary annRerankK = Metrics.summary("sai_ann_rerank_k");
// Note this tag is coordinated with the resume search logic
private final static Timer annInitialSearch = Metrics.timer("sai_ann_search", "phase", "initial");

public CassandraDiskAnn(SSTableContext sstableContext, SegmentMetadata.ComponentMetadataMap componentMetadatas, PerIndexFiles indexFiles, IndexContext context, OrdinalsMapFactory omFactory) throws IOException
{
this.source = sstableContext.sstable().getId();
Expand Down Expand Up @@ -254,9 +263,20 @@ else if (compressedVectors == null)
var rr = view.rerankerFor(queryVector, similarityFunction);
ssp = new SearchScoreProvider(asf, rr);
}

var start = System.nanoTime();
var result = searcher.search(ssp, limit, rerankK, threshold, context.getAnnRerankFloor(), ordinalsMap.ignoringDeleted(acceptBits));
var duration = System.nanoTime() - start;
annInitialSearch.record(duration, TimeUnit.NANOSECONDS);

if (V3OnDiskFormat.ENABLE_RERANK_FLOOR)
context.updateAnnRerankFloor(result.getWorstApproximateScoreInTopK());

// Record temporary metrics.
annRerankFloor.record(context.getAnnRerankFloor());
annRerankK.record(rerankK);
context.addAnnSearchDuration(duration);

Tracing.trace("DiskANN search for {}/{} visited {} nodes, reranked {} to return {} results from {}",
limit, rerankK, result.getVisitedCount(), result.getRerankedCount(), result.getNodes().length, source);
if (threshold > 0)
Expand All @@ -269,7 +289,7 @@ else if (compressedVectors == null)
}
else
{
var nodeScores = new AutoResumingNodeScoreIterator(searcher, graphAccessManager, result, nodesVisitedConsumer, limit, rerankK, false, source.toString());
var nodeScores = new AutoResumingNodeScoreIterator(searcher, graphAccessManager, result, context, nodesVisitedConsumer, limit, rerankK, false, source.toString());
return new NodeScoreToRowIdWithScoreIterator(nodeScores, ordinalsMap.getRowIdsView());
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ public CloseableIterator<SearchResult.NodeScore> search(QueryContext context, Ve
graphAccessManager.release();
return CloseableIterator.wrap(Arrays.stream(result.getNodes()).iterator());
}
return new AutoResumingNodeScoreIterator(searcher, graphAccessManager, result, context::addAnnNodesVisited, limit, rerankK, true, source);
return new AutoResumingNodeScoreIterator(searcher, graphAccessManager, result, context, context::addAnnNodesVisited, limit, rerankK, true, source);
}
catch (Throwable t)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
package org.apache.cassandra.index.sai.metrics;

import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.LongAdder;

import com.codahale.metrics.Counter;
import com.codahale.metrics.Histogram;
import com.codahale.metrics.Timer;
import io.micrometer.core.instrument.DistributionSummary;
import org.apache.cassandra.index.sai.QueryContext;
import org.apache.cassandra.schema.TableMetadata;
import org.apache.cassandra.tracing.Tracing;
Expand Down Expand Up @@ -108,7 +108,11 @@ public class PerQueryMetrics extends AbstractMetrics
private final Histogram postingsSkips;
private final Histogram postingsDecodes;

private final LongAdder annNodesVisited = new LongAdder();
// Assumes there is a static mircometer MeterRegistry that will subsequently be scraped
private final DistributionSummary annNodesVisited = io.micrometer.core.instrument.Metrics.summary("sai_ann_nodes_visited_per_query");
private final io.micrometer.core.instrument.Timer rowMaterializationDuration;
private final io.micrometer.core.instrument.Timer annSearchDuration;


public PerQueryMetrics(TableMetadata table)
{
Expand All @@ -131,6 +135,9 @@ public PerQueryMetrics(TableMetadata table)
rowsFiltered = Metrics.histogram(createMetricName("RowsFiltered"), false);

shadowedKeysScannedHistogram = Metrics.histogram(createMetricName("ShadowedKeysScannedHistogram"), false);

rowMaterializationDuration = io.micrometer.core.instrument.Metrics.timer("sai_row_materialization_duration");
annSearchDuration = io.micrometer.core.instrument.Metrics.timer("sai_ann_search_duration");
}

private void recordStringIndexCacheMetrics(QueryContext events)
Expand All @@ -149,7 +156,9 @@ private void recordNumericIndexCacheMetrics(QueryContext events)

private void recordAnnIndexMetrics(QueryContext queryContext)
{
annNodesVisited.add(queryContext.annNodesVisited());
annNodesVisited.record(queryContext.annNodesVisited());
rowMaterializationDuration.record(queryContext.rowMaterializationDuration(), TimeUnit.NANOSECONDS);
annSearchDuration.record(queryContext.annSearchDuration(), TimeUnit.NANOSECONDS);
}

public void record(QueryContext queryContext)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,8 @@ public static class ScoreOrderedResultRetriever extends AbstractIterator<Unfilte
private final int softLimit;
private int returnedRowCount = 0;

private long rowMaterializationNanos = 0;

private ScoreOrderedResultRetriever(CloseableIterator<PrimaryKeyWithSortKey> scoredPrimaryKeyIterator,
FilterTree filterTree,
QueryController controller,
Expand Down Expand Up @@ -585,6 +587,8 @@ public UnfilteredRowIterator readAndValidatePartition(PrimaryKey key, List<Prima
if (processedKeys.contains(key))
return null;

long start = System.nanoTime();

try (UnfilteredRowIterator partition = controller.getPartition(key, view, executionController))
{
queryContext.addPartitionsRead(1);
Expand Down Expand Up @@ -619,6 +623,8 @@ public UnfilteredRowIterator readAndValidatePartition(PrimaryKey key, List<Prima
}
}
return isRowValid ? new PrimaryKeyIterator(partition, staticRow, row) : null;
} finally {
rowMaterializationNanos += System.nanoTime() - start;
}
}

Expand Down Expand Up @@ -660,6 +666,7 @@ public void close()
{
FileUtils.closeQuietly(scoredPrimaryKeyIterator);
controller.finish();
queryContext.addRowMaterializationDuration(rowMaterializationNanos);
}
}

Expand Down

0 comments on commit 5740890

Please sign in to comment.