Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add GraphIndexBuilder.rescore() for use by C* CompactionGraph #375

Merged
merged 7 commits into from
Dec 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -162,16 +162,23 @@ public GraphIndexBuilder(BuildScoreProvider scoreProvider,
ForkJoinPool simdExecutor,
ForkJoinPool parallelExecutor)
{
this.scoreProvider = scoreProvider;
this.dimension = dimension;
this.neighborOverflow = neighborOverflow;
this.alpha = alpha;
if (M <= 0) {
throw new IllegalArgumentException("maxConn must be positive");
}
if (beamWidth <= 0) {
throw new IllegalArgumentException("beamWidth must be positive");
}
if (neighborOverflow < 1.0f) {
throw new IllegalArgumentException("neighborOverflow must be >= 1.0");
}
if (alpha <= 0) {
throw new IllegalArgumentException("alpha must be positive");
}

this.scoreProvider = scoreProvider;
this.dimension = dimension;
this.neighborOverflow = neighborOverflow;
this.alpha = alpha;
this.beamWidth = beamWidth;
this.simdExecutor = simdExecutor;
this.parallelExecutor = parallelExecutor;
Expand All @@ -185,6 +192,42 @@ public GraphIndexBuilder(BuildScoreProvider scoreProvider,
this.concurrentScratch = ExplicitThreadLocal.withInitial(() -> new NodeArray(Math.max(beamWidth, M + 1)));
}

public static GraphIndexBuilder rescore(GraphIndexBuilder other, BuildScoreProvider newProvider) {
var newBuilder = new GraphIndexBuilder(newProvider,
other.dimension,
other.graph.maxDegree(),
other.beamWidth,
other.neighborOverflow,
other.alpha,
other.simdExecutor,
other.parallelExecutor);

// Copy each node and its neighbors from the old graph to the new one
for (int i = 0; i < other.graph.getIdUpperBound(); i++) {
if (!other.graph.containsNode(i)) {
continue;
}

var neighbors = other.graph.getNeighbors(i);
var sf = newProvider.searchProviderFor(i).scoreFunction();
var newNeighbors = new NodeArray(neighbors.size());

// Copy edges, compute new scores
for (var it = neighbors.iterator(); it.hasNext(); ) {
int neighbor = it.nextInt();
// since we're using a different score provider, use insertSorted instead of addInOrder
newNeighbors.insertSorted(neighbor, sf.similarityTo(neighbor));
}

newBuilder.graph.addNode(i, newNeighbors);
}

// Set the entry node
newBuilder.graph.updateEntryNode(other.graph.entry());

return newBuilder;
}

public OnHeapGraphIndex build(RandomAccessVectorValues ravv) {
var vv = ravv.threadLocalSupplier();
int size = ravv.size();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package io.github.jbellis.jvector.graph.similarity;

import io.github.jbellis.jvector.graph.RandomAccessVectorValues;
import io.github.jbellis.jvector.pq.BQVectors;
import io.github.jbellis.jvector.pq.PQVectors;
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
import io.github.jbellis.jvector.vector.VectorUtil;
Expand Down Expand Up @@ -178,4 +179,44 @@ public VectorFloat<?> approximateCentroid() {
};
}

static BuildScoreProvider bqBuildScoreProvider(BQVectors bqv) {
return new BuildScoreProvider() {
@Override
public boolean isExact() {
return false;
}

@Override
public VectorFloat<?> approximateCentroid() {
// centroid = zeros is actually a decent approximation
return vts.createFloatVector(bqv.getCompressor().getOriginalDimension());
}

@Override
public SearchScoreProvider searchProviderFor(VectorFloat<?> vector) {
return new SearchScoreProvider(bqv.scoreFunctionFor(vector, null));
}

@Override
public SearchScoreProvider searchProviderFor(int node1) {
var encoded1 = bqv.get(node1);
return new SearchScoreProvider(new ScoreFunction() {
@Override
public boolean isExact() {
return false;
}

@Override
public float similarityTo(int node2) {
return bqv.similarityBetween(encoded1, bqv.get(node2));
}
});
}

@Override
public SearchScoreProvider diversityProviderFor(int node1) {
return searchProviderFor(node1);
}
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,18 @@
import java.util.Arrays;
import java.util.Objects;

public class BQVectors implements CompressedVectors {
private final BinaryQuantization bq;
private final long[][] compressedVectors;
public abstract class BQVectors implements CompressedVectors {
protected final BinaryQuantization bq;
protected long[][] compressedVectors;
protected int vectorCount;

public BQVectors(BinaryQuantization bq, long[][] compressedVectors) {
protected BQVectors(BinaryQuantization bq) {
this.bq = bq;
this.compressedVectors = compressedVectors;
}

@Override
public int count() {
return compressedVectors.length;
return vectorCount;
}

@Override
Expand Down Expand Up @@ -73,7 +73,7 @@ public static BQVectors load(RandomAccessReader in, long offset) throws IOExcept
}
var compressedVectors = new long[size][];
if (size == 0) {
return new BQVectors(bq, compressedVectors);
return new ImmutableBQVectors(bq, compressedVectors);
}
int compressedLength = in.readInt();
if (compressedLength < 0) {
Expand All @@ -88,7 +88,7 @@ public static BQVectors load(RandomAccessReader in, long offset) throws IOExcept
compressedVectors[i] = vector;
}

return new BQVectors(bq, compressedVectors);
return new ImmutableBQVectors(bq, compressedVectors);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ public static BinaryQuantization compute(RandomAccessVectorValues ravv, ForkJoin

@Override
public CompressedVectors createCompressedVectors(Object[] compressedVectors) {
return new BQVectors(this, (long[][]) compressedVectors);
return new ImmutableBQVectors(this, (long[][]) compressedVectors);
}

@Override
Expand All @@ -74,7 +74,7 @@ public CompressedVectors encodeAll(RandomAccessVectorValues ravv, ForkJoinPool s
})
.toArray(long[][]::new))
.join();
return new BQVectors(this, cv);
return new ImmutableBQVectors(this, cv);
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package io.github.jbellis.jvector.pq;

public class ImmutableBQVectors extends BQVectors {
public ImmutableBQVectors(BinaryQuantization bq, long[][] compressedVectors) {
super(bq);
this.compressedVectors = compressedVectors;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Copyright DataStax, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.github.jbellis.jvector.pq;

import io.github.jbellis.jvector.vector.types.ByteSequence;

public class ImmutablePQVectors extends PQVectors {
/**
* Construct an immutable PQVectors instance with the given ProductQuantization and compressed data chunks.
* @param pq the ProductQuantization to use
* @param compressedDataChunks the compressed data chunks
* @param vectorCount the number of vectors
* @param vectorsPerChunk the number of vectors per chunk
*/
public ImmutablePQVectors(ProductQuantization pq, ByteSequence<?>[] compressedDataChunks, int vectorCount, int vectorsPerChunk) {
super(pq);
this.compressedDataChunks = compressedDataChunks;
this.vectorCount = vectorCount;
this.vectorsPerChunk = vectorsPerChunk;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package io.github.jbellis.jvector.pq;

public class MutableBQVectors extends BQVectors implements MutableCompressedVectors<long[]> {
/**
* Construct a mutable BQVectors instance with the given BinaryQuantization and maximum number of vectors
* that will be stored in this instance.
* @param bq the BinaryQuantization to use
* @param maximumVectorCount the maximum number of vectors that will be stored in this instance
*/
public MutableBQVectors(BinaryQuantization bq, int maximumVectorCount) {
super(bq);
this.compressedVectors = new long[maximumVectorCount][];
this.vectorCount = 0;
}

@Override
public void encodeAndSet(int ordinal, long[] vector) {
compressedVectors[ordinal] = vector;
vectorCount = Math.max(vectorCount, ordinal + 1);
}

@Override
public void setZero(int ordinal) {
compressedVectors[ordinal] = new long[bq.compressedVectorSize()];
vectorCount = Math.max(vectorCount, ordinal + 1);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package io.github.jbellis.jvector.pq;

public interface MutableCompressedVectors<T> extends CompressedVectors {
/**
* Encode the given vector and set it at the given ordinal. Done without unnecessary copying.
*
* It's the caller's responsibility to ensure there are no "holes" in the ordinals that are
* neither encoded nor set to zero.
*
* @param ordinal the ordinal to set
* @param vector the vector to encode and set
*/
void encodeAndSet(int ordinal, T vector);

/**
* Set the vector at the given ordinal to zero.
*
* It's the caller's responsibility to ensure there are no "holes" in the ordinals that are
* neither encoded nor set to zero.
*
* @param ordinal the ordinal to set
*/
void setZero(int ordinal);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Copyright DataStax, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.github.jbellis.jvector.pq;

import io.github.jbellis.jvector.vector.VectorizationProvider;
import io.github.jbellis.jvector.vector.types.ByteSequence;
import io.github.jbellis.jvector.vector.types.VectorFloat;
import io.github.jbellis.jvector.vector.types.VectorTypeSupport;

import static java.lang.Math.max;

public class MutablePQVectors extends PQVectors implements MutableCompressedVectors<VectorFloat<?>> {
private static final VectorTypeSupport vectorTypeSupport = VectorizationProvider.getInstance().getVectorTypeSupport();

/**
* Construct a mutable PQVectors instance with the given ProductQuantization and maximum number of vectors that will be
* stored in this instance. The vectors are split into chunks to avoid exceeding the maximum array size.
* @param pq the ProductQuantization to use
* @param maximumVectorCount the maximum number of vectors that will be stored in this instance
*/
public MutablePQVectors(ProductQuantization pq, int maximumVectorCount) {
super(pq);
this.vectorCount = 0;

// Calculate if we need to split into multiple chunks
int compressedDimension = pq.compressedVectorSize();
long totalSize = (long) maximumVectorCount * compressedDimension;
this.vectorsPerChunk = totalSize <= MAX_CHUNK_SIZE ? maximumVectorCount : MAX_CHUNK_SIZE / compressedDimension;

int fullSizeChunks = maximumVectorCount / vectorsPerChunk;
int totalChunks = maximumVectorCount % vectorsPerChunk == 0 ? fullSizeChunks : fullSizeChunks + 1;
ByteSequence<?>[] chunks = new ByteSequence<?>[totalChunks];
int chunkBytes = vectorsPerChunk * compressedDimension;
for (int i = 0; i < fullSizeChunks; i++)
chunks[i] = vectorTypeSupport.createByteSequence(chunkBytes);

// Last chunk might be smaller
if (totalChunks > fullSizeChunks) {
int remainingVectors = maximumVectorCount % vectorsPerChunk;
chunks[fullSizeChunks] = vectorTypeSupport.createByteSequence(remainingVectors * compressedDimension);
}

this.compressedDataChunks = chunks;
}

@Override
public void encodeAndSet(int ordinal, VectorFloat<?> vector) {
vectorCount = max(vectorCount, ordinal + 1);
pq.encodeTo(vector, get(ordinal));
}

@Override
public void setZero(int ordinal) {
vectorCount = max(vectorCount, ordinal + 1);
get(ordinal).zero();
}
}
Loading
Loading