diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/MutablePQVectors.java b/jvector-base/src/main/java/io/github/jbellis/jvector/pq/MutablePQVectors.java index d0fee5b26..b9e5f165d 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/MutablePQVectors.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/pq/MutablePQVectors.java @@ -41,15 +41,18 @@ public MutablePQVectors(ProductQuantization pq, int maximumVectorCount) { long totalSize = (long) maximumVectorCount * compressedDimension; this.vectorsPerChunk = totalSize <= MAX_CHUNK_SIZE ? maximumVectorCount : MAX_CHUNK_SIZE / compressedDimension; - int numChunks = maximumVectorCount / vectorsPerChunk; - ByteSequence[] chunks = new ByteSequence[numChunks]; - int chunkSize = vectorsPerChunk * compressedDimension; - for (int i = 0; i < numChunks - 1; i++) - chunks[i] = vectorTypeSupport.createByteSequence(chunkSize); + int fullSizeChunks = maximumVectorCount / vectorsPerChunk; + int totalChunks = maximumVectorCount % vectorsPerChunk == 0 ? fullSizeChunks : fullSizeChunks + 1; + ByteSequence[] chunks = new ByteSequence[totalChunks]; + int chunkBytes = vectorsPerChunk * compressedDimension; + for (int i = 0; i < fullSizeChunks; i++) + chunks[i] = vectorTypeSupport.createByteSequence(chunkBytes); // Last chunk might be smaller - int remainingVectors = maximumVectorCount - (vectorsPerChunk * (numChunks - 1)); - chunks[numChunks - 1] = vectorTypeSupport.createByteSequence(remainingVectors * compressedDimension); + if (totalChunks > fullSizeChunks) { + int remainingVectors = maximumVectorCount % vectorsPerChunk; + chunks[fullSizeChunks] = vectorTypeSupport.createByteSequence(remainingVectors * compressedDimension); + } this.compressedDataChunks = chunks; } diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/pq/TestProductQuantization.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/pq/TestProductQuantization.java index 23e238481..da4774c30 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/pq/TestProductQuantization.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/pq/TestProductQuantization.java @@ -250,4 +250,14 @@ public void testSaveVersion0() throws Exception { var contents2 = Files.readAllBytes(fileOut.toPath()); assertArrayEquals(contents1, contents2); } + + @Test + public void testMutablePQVectors() { + // test that MPVQ gets the math right in an allocation edge case + var R = getRandom(); + VectorFloat[] vectors = generate(2 * DEFAULT_CLUSTERS, 2, 1_000); + var ravv = new ListRandomAccessVectorValues(List.of(vectors), vectors[0].length()); + var pq = ProductQuantization.compute(ravv, 1, DEFAULT_CLUSTERS, false); + var pqv = new MutablePQVectors(pq, Integer.MAX_VALUE); + } }