From 25973c07bf5f204c024ea5f5eaae728ef73de76a Mon Sep 17 00:00:00 2001 From: laa Date: Fri, 1 Dec 2023 04:16:46 +0100 Subject: [PATCH] Ids of vectors are associated and stored inside of vertex record on disk. Presence of vector id is mandatory. Id is represented is byte array of size of 16 bytes. --- .../vectoriadb/bench/Sift1MBench.java | 11 +- .../jetbrains/vectoriadb/index/DataStore.java | 14 +- .../vectoriadb/index/IndexBuilder.java | 99 ++++++++---- .../vectoriadb/index/IndexReader.java | 23 ++- .../vectoriadb/index/VectorReader.java | 2 + .../vectoriadb/index/bench/BenchUtils.java | 20 ++- .../index/bench/PrepareBigANNBench.java | 4 +- .../index/bench/PrepareRandomVectorBench.java | 9 ++ .../index/bench/RunBigANNBench.java | 16 +- .../index/bench/RunRandomVectorBench.java | 5 +- .../vectoriadb/index/diskcache/DiskCache.java | 149 +++++++++++++----- .../vectoriadb/index/DiskANNTest.java | 57 +++++-- .../vectoriadb/index/L2PQKMeansTest.java | 11 ++ .../src/main/proto/IndexManager.proto | 7 +- .../vectoriadb/client/VectoriaDBClient.java | 82 +++++++--- .../server/IndexManagerServiceImpl.java | 15 +- .../vectoriadb/server/IndexManagerTest.java | 42 +++-- 17 files changed, 430 insertions(+), 136 deletions(-) diff --git a/vectoriadb-bench/src/main/java/jetbrains/vectoriadb/bench/Sift1MBench.java b/vectoriadb-bench/src/main/java/jetbrains/vectoriadb/bench/Sift1MBench.java index 15c3c25b5..7fb28d7d5 100644 --- a/vectoriadb-bench/src/main/java/jetbrains/vectoriadb/bench/Sift1MBench.java +++ b/vectoriadb-bench/src/main/java/jetbrains/vectoriadb/bench/Sift1MBench.java @@ -46,7 +46,12 @@ public static void main(String[] args) { var siftDir = rootDir.resolve("sift"); var siftDataName = "sift_base.fvecs"; + var vectors = BenchUtils.readFVectors(siftDir.resolve(siftDataName), vectorDimensions); + var ids = new int[vectors.length]; + for (int i = 0; i < ids.length; i++) { + ids[i] = i; + } var indexName = "sift1m"; System.out.printf("%d data vectors loaded with dimension %d, building index %s...%n", @@ -73,7 +78,7 @@ public static void main(String[] args) { ts1 = System.currentTimeMillis(); - client.uploadVectors(indexName, vectors, (current, count) -> { + client.uploadVectors(indexName, vectors, ids, (current, count) -> { if (current >= 0 && current < Integer.MAX_VALUE) { if (current % 1_000 == 0) { System.out.printf("%d vectors uploaded out of %d%n", current, count); @@ -133,7 +138,7 @@ public static void main(String[] args) { System.out.printf("Iteration %d out of 5 %n", (i + 1)); for (int j = 0; j < queryVectors.length; j++) { - var vector = queryVectors[j]; + var vector = queryVectors[j]; client.findNearestNeighbours(indexName, vector, 1); if ((j + 1) % 1_000 == 0) { @@ -149,7 +154,7 @@ public static void main(String[] args) { for (var index = 0; index < queryVectors.length; index++) { var vector = queryVectors[index]; - var result = client.findNearestNeighbours(indexName, vector, 1); + var result = client.findIntNearestNeighbours(indexName, vector, 1); if (groundTruth[index][0] != result[0]) { errorsCount++; } diff --git a/vectoriadb-index/src/main/java/jetbrains/vectoriadb/index/DataStore.java b/vectoriadb-index/src/main/java/jetbrains/vectoriadb/index/DataStore.java index 83766ccf9..db635aef3 100644 --- a/vectoriadb-index/src/main/java/jetbrains/vectoriadb/index/DataStore.java +++ b/vectoriadb-index/src/main/java/jetbrains/vectoriadb/index/DataStore.java @@ -37,8 +37,9 @@ private DataStore(int dimensions, DistanceFunction distanceFunction, FileChannel this.channel = channel; this.distanceFunction = distanceFunction; - var vectorSize = dimensions * Float.BYTES; - var bufferSize = Math.min(64 * 1024 * 1024 / vectorSize, 1) * vectorSize; + //record contains vector and its associated id + var recordSize = dimensions * Float.BYTES + IndexBuilder.VECTOR_ID_SIZE; + var bufferSize = Math.min(64 * 1024 * 1024 / recordSize, 1) * recordSize; this.buffer = ByteBuffer.allocate(bufferSize).order(ByteOrder.nativeOrder()); this.preprocessingResult = new float[dimensions]; @@ -54,7 +55,12 @@ public static DataStore create(final String name, final int dimensions, return new DataStore(dimensions, distanceFunction, channel); } - public void add(final float[] vector) throws IOException { + public void add(final float[] vector, @NotNull byte[] id) throws IOException { + if (id.length != IndexBuilder.VECTOR_ID_SIZE) { + throw new IllegalArgumentException("Vector id size should be equal to " + IndexBuilder.VECTOR_ID_SIZE + + ". Vector id size : " + id.length); + } + var vectorToStore = distanceFunction.preProcess(vector, preprocessingResult); if (buffer.remaining() == 0) { @@ -70,6 +76,8 @@ public void add(final float[] vector) throws IOException { for (var component : vectorToStore) { buffer.putFloat(component); } + + buffer.put(id); } public static Path dataLocation(@NotNull final String name, final Path dataDirectoryPath) { diff --git a/vectoriadb-index/src/main/java/jetbrains/vectoriadb/index/IndexBuilder.java b/vectoriadb-index/src/main/java/jetbrains/vectoriadb/index/IndexBuilder.java index d4b701edb..006a70752 100644 --- a/vectoriadb-index/src/main/java/jetbrains/vectoriadb/index/IndexBuilder.java +++ b/vectoriadb-index/src/main/java/jetbrains/vectoriadb/index/IndexBuilder.java @@ -59,6 +59,11 @@ import java.util.concurrent.atomic.AtomicLongArray; public final class IndexBuilder { + /** + * Maximum size of id of related vector in bytes. + */ + public static final int VECTOR_ID_SIZE = 16; + public static final int DEFAULT_MAX_CONNECTIONS_PER_VERTEX = 128; public static final int DEFAULT_MAX_AMOUNT_OF_CANDIDATES = 128; public static final float DEFAULT_DISTANCE_MULTIPLIER = 2.0f; @@ -95,14 +100,15 @@ public static void buildIndex(String name, int vectorsDimension, int compression var verticesCountPerPage = pageStructure.verticesCountPerPage(); var vertexRecordSize = pageStructure.vertexRecordSize(); var recordVectorsOffset = pageStructure.recordVectorsOffset(); + var diskRecordVectorIdOffset = pageStructure.recordIdOffset(); var recordEdgesOffset = pageStructure.recordEdgesOffset(); var recordEdgesCountOffset = pageStructure.recordEdgesCountOffset(); try (var vectorReader = new MmapVectorReader(vectorsDimension, dataStoreFilePath)) { try (var arena = Arena.ofShared()) { - var size = vectorReader.size(); - if (size == 0) { + var vectorsCount = vectorReader.size(); + if (vectorsCount == 0) { logger.info("Vector index " + name + ". There are no vectors to index. Stopping index build."); return; } @@ -141,7 +147,7 @@ public static void buildIndex(String name, int vectorsDimension, int compression var partitions = (int) Math.max(1, - 3 * calculateGraphPartitionSize(2L * verticesCount, + 3 * calculateGraphPartitionSizeInRAM(2L * verticesCount, maxConnectionsPerVertex, vectorsDimension) / memoryConsumption); var totalPartitionsSize = 0; @@ -173,7 +179,7 @@ public static void buildIndex(String name, int vectorsDimension, int compression } } - checkRequestedFreeSpace(indexDirectoryPath, size, totalPartitionsSize, maxConnectionsPerVertex, + checkRequestedFreeSpace(indexDirectoryPath, vectorsCount, totalPartitionsSize, maxConnectionsPerVertex, pageSize, verticesCountPerPage); var avgPartitionSize = totalPartitionsSize / partitions; @@ -186,11 +192,11 @@ public static void buildIndex(String name, int vectorsDimension, int compression } var endPartition = System.nanoTime(); - var maxPartitionSizeBytes = calculateGraphPartitionSize(maxPartitionSize, maxConnectionsPerVertex, + var maxPartitionSizeBytes = calculateGraphPartitionSizeInRAM(maxPartitionSize, maxConnectionsPerVertex, vectorsDimension); long maxPartitionSizeKBytes = maxPartitionSizeBytes / 1024; long minPartitionSizeKBytes = - calculateGraphPartitionSize(minPartitionSize, maxConnectionsPerVertex, vectorsDimension) / 1024; + calculateGraphPartitionSizeInRAM(minPartitionSize, maxConnectionsPerVertex, vectorsDimension) / 1024; //noinspection IntegerDivisionInFloatingPointContext logger.info("Splitting vectors into {} partitions has been finished. Max. partition size {} vertexes " + @@ -245,7 +251,7 @@ public static void buildIndex(String name, int vectorsDimension, int compression Files.delete(graphFilePath); } - var diskCache = initFile(graphFilePath, size, verticesCountPerPage, + var diskCache = initFile(graphFilePath, vectorsCount, verticesCountPerPage, pageSize, arena); var graphs = new MMapedGraph[partitions]; @@ -261,9 +267,12 @@ public static void buildIndex(String name, int vectorsDimension, int compression var partition = dmPartitions[i]; var partitionSize = (int) partition.byteSize() / Integer.BYTES; - var graph = new MMapedGraph(partitionSize, i, name, indexDirectoryPath, maxConnectionsPerVertex, - vectorsDimension, distanceFunction, maxAmountOfCandidates, pageSize, - vertexRecordSize, recordVectorsOffset, diskCache); + var graph = new MMapedGraph(partitionSize, i, name, indexDirectoryPath, + maxConnectionsPerVertex, vectorsDimension, distanceFunction, + maxAmountOfCandidates, pageSize, + vertexRecordSize, recordVectorsOffset, + diskRecordVectorIdOffset, + diskCache); progressTracker.pushPhase("building search graph for partition " + i, "partition size", String.valueOf(partitionSize)); try { @@ -271,10 +280,14 @@ public static void buildIndex(String name, int vectorsDimension, int compression var vectorIndex = partition.getAtIndex(ValueLayout.JAVA_INT, j); var vector = vectorReader.read(vectorIndex); - graph.addVector(vectorIndex, vector); + var vectorId = vectorReader.id(vectorIndex); + + graph.addVector(vectorIndex, vector, vectorId); - var currentDistance = distanceFunction.computeDistance(vector, 0, centroid, + var currentDistance = distanceFunction.computeDistance(vector, + 0, centroid, 0, vectorsDimension); + if (currentDistance < medoidMinDistance) { medoidMinDistance = currentDistance; medoidMinIndex = vectorIndex; @@ -663,7 +676,7 @@ private static void permuteIndexes(MemorySegment indexes, UniformRandomProvider private static MemorySegment initFile(Path path, int globalVertexCount, int verticesPerPage, int pageSize, Arena arena) throws IOException { - var fileLength = calculateRequestedFileLength(globalVertexCount, pageSize, verticesPerPage); + var fileLength = calculateSearchGraphFileSize(globalVertexCount, pageSize, verticesPerPage); MemorySegment diskCache; try (var rwFile = new RandomAccessFile(path.toFile(), "rw")) { rwFile.setLength(fileLength); @@ -676,17 +689,24 @@ private static MemorySegment initFile(Path path, int globalVertexCount, int vert } - private static void checkRequestedFreeSpace(Path dbPath, int size, int totalPartitionsSize, + private static void checkRequestedFreeSpace(Path dbPath, int vectorsCount, int totalPartitionsSize, int maxConnectionsPerVertex, int pageSize, int verticesPerPage) throws IOException { var fileStore = Files.getFileStore(dbPath); var usableSpace = fileStore.getUsableSpace(); - var requiredGraphSpace = calculateRequestedFileLength(size, pageSize, verticesPerPage); - //space needed for mmap files to store edges and global indexes of all partitions. + var requiredGraphSpace = calculateSearchGraphFileSize(vectorsCount, pageSize, verticesPerPage); + + //During index build data are stored in several files. + //All vectors are stored into the final file according to their dedicated positions along with their ids. + //Edges are stored in separate files for each partition, backed by mmap and then merged into the final file + //during the merge step. + //That is done to merge search graphs of several partitions into one in memory restricted environment. + //Space needed for MMAP files to store edges and global indexes. + var requiredPartitionsSpace = - (long) totalPartitionsSize * (maxConnectionsPerVertex + 1) * Integer.BYTES + - (long) totalPartitionsSize * Integer.BYTES; + (long) totalPartitionsSize * ((long) (maxConnectionsPerVertex + 1) * Integer.BYTES + + Integer.BYTES); var requiredSpace = requiredGraphSpace + requiredPartitionsSpace; if (requiredSpace > usableSpace * 0.9) { @@ -696,7 +716,7 @@ private static void checkRequestedFreeSpace(Path dbPath, int size, int totalPart } } - private static long calculateRequestedFileLength(long verticesCount, int pageSize, int verticesPerPage) { + private static long calculateSearchGraphFileSize(long verticesCount, int pageSize, int verticesPerPage) { var pagesToWrite = pagesToWrite(verticesCount, verticesPerPage); return (long) pagesToWrite * pageSize; } @@ -705,12 +725,13 @@ private static int pagesToWrite(long verticesCount, int verticesPerPage) { return (int) (verticesCount + verticesPerPage - 1) / verticesPerPage; } - private static long calculateGraphPartitionSize(long partitionSize, int maxConnectionsPerVertex, int vectorDim) { + private static long calculateGraphPartitionSizeInRAM(long partitionSize, int maxConnectionsPerVertex, int vectorDim) { //1. edges - //2. global indexes - //3. vertex records - return partitionSize * (maxConnectionsPerVertex + 1) * Integer.BYTES + - partitionSize * Integer.BYTES + partitionSize * vectorDim * Float.BYTES; + //2. vector position + //3. vector + //4. vector id + return partitionSize * ((long) (maxConnectionsPerVertex + 1) * Integer.BYTES + + Integer.BYTES + (long) vectorDim * Float.BYTES + VECTOR_ID_SIZE); } private static final class MMapedGraph implements AutoCloseable { @@ -718,6 +739,9 @@ private static final class MMapedGraph implements AutoCloseable { private final MemorySegment edges; private final MemorySegment vectors; private final MemorySegment globalIndexes; + + private final MemorySegment ids; + @Nullable private AtomicLongArray edgeVersions; private final Arena edgesArena; @@ -734,21 +758,25 @@ private static final class MMapedGraph implements AutoCloseable { private final int pageSize; private final int vertexRecordSize; private final int diskRecordVectorsOffset; + private final int diskRecordVectorIdOffset; private final MemorySegment diskCache; private MMapedGraph(int capacity, int id, String name, Path path, int maxConnectionsPerVertex, int vectorDimensions, DistanceFunction distanceFunction, int maxAmountOfCandidates, int pageSize, int vertexRecordSize, int diskRecordVectorsOffset, + int diskRecordVectorIdOffset, MemorySegment diskCache) throws IOException { this(capacity, false, id, name, path, maxConnectionsPerVertex, vectorDimensions, - distanceFunction, maxAmountOfCandidates, pageSize, vertexRecordSize, diskRecordVectorsOffset, diskCache); + distanceFunction, maxAmountOfCandidates, pageSize, vertexRecordSize, diskRecordVectorsOffset, + diskRecordVectorIdOffset, diskCache); } private MMapedGraph(int capacity, boolean skipVectors, int id, String name, Path path, int maxConnectionsPerVertex, int vectorDimensions, DistanceFunction distanceFunction, int maxAmountOfCandidates, int pageSize, int vertexRecordSize, int diskRecordVectorsOffset, + int diskRecordVectorIdOffset, MemorySegment diskCache) throws IOException { this.edgeVersions = new AtomicLongArray(capacity); this.name = name; @@ -761,6 +789,7 @@ private MMapedGraph(int capacity, boolean skipVectors, int id, String name, Path this.pageSize = pageSize; this.vertexRecordSize = vertexRecordSize; this.diskRecordVectorsOffset = diskRecordVectorsOffset; + this.diskRecordVectorIdOffset = diskRecordVectorIdOffset; this.diskCache = diskCache; this.edgesArena = Arena.ofShared(); @@ -793,8 +822,11 @@ private MMapedGraph(int capacity, boolean skipVectors, int id, String name, Path var vectorsLayout = MemoryLayout.sequenceLayout( (long) capacity * this.vectorDimensions, ValueLayout.JAVA_FLOAT); this.vectors = vectorsArena.allocate(vectorsLayout); + this.ids = vectorsArena.allocateArray(ValueLayout.JAVA_BYTE, + (long) capacity * VECTOR_ID_SIZE); } else { vectors = null; + this.ids = null; } } @@ -855,12 +887,15 @@ private int calculateMedoid() { } - private void addVector(int globalIndex, MemorySegment vector) { + private void addVector(int globalIndex, MemorySegment vector, MemorySegment id) { var index = (long) size * vectorDimensions; MemorySegment.copy(vector, 0, vectors, index * Float.BYTES, (long) vectorDimensions * Float.BYTES); + MemorySegment.copy(id, 0, ids, + (long) size * VECTOR_ID_SIZE, + VECTOR_ID_SIZE); globalIndexes.setAtIndex(ValueLayout.JAVA_INT, size, globalIndex); size++; @@ -1354,6 +1389,9 @@ private void saveVectorsToDisk() { var recordOffset = localPageOffset * vertexRecordSize + Long.BYTES + pageOffset; + diskCache.asSlice(recordOffset + diskRecordVectorIdOffset).copyFrom( + ids.asSlice(i * VECTOR_ID_SIZE, VECTOR_ID_SIZE)); + for (long j = 0; j < vectorDimensions; j++, vectorsIndex++) { var vectorItem = vectors.get(ValueLayout.JAVA_FLOAT, vectorsIndex * Float.BYTES); @@ -1481,8 +1519,9 @@ private static final class MmapVectorReader implements VectorReader { public MmapVectorReader(final int vectorDimensions, Path path) throws IOException { this.vectorDimensions = vectorDimensions; - this.recordSize = Float.BYTES * vectorDimensions; + //record size = vector size + vector id + this.recordSize = Float.BYTES * vectorDimensions + IndexBuilder.VECTOR_ID_SIZE; arena = Arena.ofShared(); @@ -1502,6 +1541,12 @@ public MemorySegment read(int index) { return segment.asSlice((long) index * recordSize, (long) Float.BYTES * vectorDimensions); } + @Override + public MemorySegment id(int index) { + return segment.asSlice((long) index * recordSize + (long) Float.BYTES * vectorDimensions, + IndexBuilder.VECTOR_ID_SIZE); + } + @Override public void close() { arena.close(); diff --git a/vectoriadb-index/src/main/java/jetbrains/vectoriadb/index/IndexReader.java b/vectoriadb-index/src/main/java/jetbrains/vectoriadb/index/IndexReader.java index aeaa4ba0d..34a9fc165 100644 --- a/vectoriadb-index/src/main/java/jetbrains/vectoriadb/index/IndexReader.java +++ b/vectoriadb-index/src/main/java/jetbrains/vectoriadb/index/IndexReader.java @@ -27,6 +27,7 @@ import java.io.BufferedInputStream; import java.io.DataInputStream; import java.io.IOException; +import java.lang.foreign.MemorySegment; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.StandardOpenOption; @@ -112,7 +113,7 @@ public IndexReader(String name, int vectorDim, int maxConnectionsPerVertex, int logger.info("Vector index {} has been initialized.", name); } - public void nearest(float[] vector, int[] result, int resultSize) { + public byte[][] nearest(float[] vector, int resultSize) { if (closed) { throw new IllegalStateException("Index is closed"); } @@ -242,7 +243,25 @@ public void nearest(float[] vector, int[] result, int resultSize) { diskCache.unlock(id, vertexToPreload[i], graphFilePath); } - nearestCandidates.vertexIndices(result, resultSize); + var vertexIndexes = new int[resultSize]; + nearestCandidates.vertexIndices(vertexIndexes, resultSize); + + var result = new byte[resultSize][]; + for (int i = 0; i < resultSize; i++) { + var vertexIndex = vertexIndexes[i]; + var inMemoryPageIndex = diskCache.readLock(id, vertexIndex, graphFilePath); + try { + var vectorIdOffset = diskCache.vectorIdOffset(inMemoryPageIndex, vertexIndex); + var vectorId = new byte[IndexBuilder.VECTOR_ID_SIZE]; + MemorySegment.copy(diskCache.pages, vectorIdOffset, MemorySegment.ofArray(vectorId), 0, + IndexBuilder.VECTOR_ID_SIZE); + result[i] = vectorId; + } finally { + diskCache.unlock(id, vertexIndex, graphFilePath); + } + } + + return result; } private void preloadVertices(BoundedGreedyVertexPriorityQueue nearestCandidates, int[] vertexToPreload) { diff --git a/vectoriadb-index/src/main/java/jetbrains/vectoriadb/index/VectorReader.java b/vectoriadb-index/src/main/java/jetbrains/vectoriadb/index/VectorReader.java index c19ddf168..5bdccfb4a 100644 --- a/vectoriadb-index/src/main/java/jetbrains/vectoriadb/index/VectorReader.java +++ b/vectoriadb-index/src/main/java/jetbrains/vectoriadb/index/VectorReader.java @@ -21,4 +21,6 @@ public interface VectorReader extends AutoCloseable { int size(); MemorySegment read(int index); + + MemorySegment id(int index); } diff --git a/vectoriadb-index/src/main/java/jetbrains/vectoriadb/index/bench/BenchUtils.java b/vectoriadb-index/src/main/java/jetbrains/vectoriadb/index/bench/BenchUtils.java index 9dd33b69f..93a88f933 100644 --- a/vectoriadb-index/src/main/java/jetbrains/vectoriadb/index/bench/BenchUtils.java +++ b/vectoriadb-index/src/main/java/jetbrains/vectoriadb/index/bench/BenchUtils.java @@ -60,8 +60,15 @@ static void runSiftBenchmarks( var indexName = "test_index"; try (var dataBuilder = DataStore.create(indexName, 128, L2DistanceFunction.INSTANCE, dbDir)) { - for (var vector : vectors) { - dataBuilder.add(vector); + for (int i = 0; i < vectors.length; i++) { + var buffer = ByteBuffer.allocate(IndexBuilder.VECTOR_ID_SIZE); + buffer.order(ByteOrder.LITTLE_ENDIAN); + + buffer.putInt(i); + buffer.rewind(); + + var vector = vectors[i]; + dataBuilder.add(vector, buffer.array()); } } @@ -91,11 +98,9 @@ static void runSiftBenchmarks( //give GC chance to collect garbage Thread.sleep(60 * 1000); - - var result = new int[1]; for (int i = 0; i < 10; i++) { for (float[] vector : queryVectors) { - indexReader.nearest(vector, result, 1); + indexReader.nearest(vector, 1); } } @@ -110,8 +115,9 @@ static void runSiftBenchmarks( var errorsCount = 0; for (var index = 0; index < queryVectors.length; index++) { var vector = queryVectors[index]; - indexReader.nearest(vector, result, 1); - if (groundTruth[index][0] != result[0]) { + var rawId = indexReader.nearest(vector, 1); + + if (groundTruth[index][0] != ByteBuffer.wrap(rawId[0]).order(ByteOrder.LITTLE_ENDIAN).getInt()) { errorsCount++; } } diff --git a/vectoriadb-index/src/main/java/jetbrains/vectoriadb/index/bench/PrepareBigANNBench.java b/vectoriadb-index/src/main/java/jetbrains/vectoriadb/index/bench/PrepareBigANNBench.java index 67fa138bc..a49e0b420 100644 --- a/vectoriadb-index/src/main/java/jetbrains/vectoriadb/index/bench/PrepareBigANNBench.java +++ b/vectoriadb-index/src/main/java/jetbrains/vectoriadb/index/bench/PrepareBigANNBench.java @@ -97,7 +97,9 @@ public static void main(String[] args) { vector[j] = buffer.get(); } - dataBuilder.add(vector); + var id = ByteBuffer.allocate(Integer.BYTES). + order(ByteOrder.LITTLE_ENDIAN).putInt((int) i).array(); + dataBuilder.add(vector, id); } } } diff --git a/vectoriadb-index/src/main/java/jetbrains/vectoriadb/index/bench/PrepareRandomVectorBench.java b/vectoriadb-index/src/main/java/jetbrains/vectoriadb/index/bench/PrepareRandomVectorBench.java index fce74d80d..90269fd18 100644 --- a/vectoriadb-index/src/main/java/jetbrains/vectoriadb/index/bench/PrepareRandomVectorBench.java +++ b/vectoriadb-index/src/main/java/jetbrains/vectoriadb/index/bench/PrepareRandomVectorBench.java @@ -209,6 +209,15 @@ public MemorySegment read(int index) { return segment.asSlice((long) index * recordSize, (long) Float.BYTES * vectorDimensions); } + @Override + public MemorySegment id(int index) { + var buffer = ByteBuffer.allocate(IndexBuilder.VECTOR_ID_SIZE); + buffer.order(ByteOrder.LITTLE_ENDIAN); + buffer.putInt(index); + + return MemorySegment.ofBuffer(buffer); + } + @Override public void close() { arena.close(); diff --git a/vectoriadb-index/src/main/java/jetbrains/vectoriadb/index/bench/RunBigANNBench.java b/vectoriadb-index/src/main/java/jetbrains/vectoriadb/index/bench/RunBigANNBench.java index 60bc216e9..8b835b69b 100644 --- a/vectoriadb-index/src/main/java/jetbrains/vectoriadb/index/bench/RunBigANNBench.java +++ b/vectoriadb-index/src/main/java/jetbrains/vectoriadb/index/bench/RunBigANNBench.java @@ -20,6 +20,8 @@ import jetbrains.vectoriadb.index.IndexReader; import jetbrains.vectoriadb.index.diskcache.DiskCache; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; import java.nio.file.Files; import java.nio.file.Path; import java.util.Objects; @@ -80,10 +82,9 @@ public static void main(String[] args) throws Exception { System.out.println("Warming up ..."); - var result = new int[1]; for (int i = 0; i < 50; i++) { for (float[] vector : m1QueryVectors) { - indexReader.nearest(vector, result, 1); + indexReader.nearest(vector, 1); } } } @@ -101,11 +102,18 @@ public static void main(String[] args) throws Exception { bigAnnDbDir, Distance.DOT, diskCache)) { System.out.println("Running BigANN bench..."); - var result = new int[recallCount]; + var start = System.nanoTime(); for (int i = 0; i < bigAnnQueryVectors.length; i++) { float[] vector = bigAnnQueryVectors[i]; - indexReader.nearest(vector, result, recallCount); + var rawIds = indexReader.nearest(vector, recallCount); + + var result = new int[recallCount]; + for (int j = 0; j < rawIds.length; j++) { + var rawId = rawIds[j]; + result[j] = ByteBuffer.wrap(rawId).order(ByteOrder.LITTLE_ENDIAN).getInt(); + } + totalRecall += recall(result, bigAnnGroundTruth[i], recallCount); } var end = System.nanoTime(); diff --git a/vectoriadb-index/src/main/java/jetbrains/vectoriadb/index/bench/RunRandomVectorBench.java b/vectoriadb-index/src/main/java/jetbrains/vectoriadb/index/bench/RunRandomVectorBench.java index 47e0f3c41..171823cd0 100644 --- a/vectoriadb-index/src/main/java/jetbrains/vectoriadb/index/bench/RunRandomVectorBench.java +++ b/vectoriadb-index/src/main/java/jetbrains/vectoriadb/index/bench/RunRandomVectorBench.java @@ -57,7 +57,6 @@ public static void main(String[] args) throws Exception { System.out.println("Running queries..."); var errors = 0; - var result = new int[1]; try (var diskCache = new DiskCache(400 * 1024 * 1024, vectorDimensions, IndexBuilder.DEFAULT_MAX_CONNECTIONS_PER_VERTEX)) { @@ -72,8 +71,8 @@ public static void main(String[] args) throws Exception { MemorySegment.copy(queryVectorSegment, ValueLayout.JAVA_FLOAT, 0, queryVector, 0, vectorDimensions); - indexReader.nearest(queryVector, result, 1); - if (result[0] != groundTruth[index]) { + var rawIds = indexReader.nearest(queryVector, 1); + if (ByteBuffer.wrap(rawIds[0]).order(ByteOrder.LITTLE_ENDIAN).getInt() != groundTruth[index]) { errors++; } } diff --git a/vectoriadb-index/src/main/java/jetbrains/vectoriadb/index/diskcache/DiskCache.java b/vectoriadb-index/src/main/java/jetbrains/vectoriadb/index/diskcache/DiskCache.java index ca48eac71..22460ebe9 100644 --- a/vectoriadb-index/src/main/java/jetbrains/vectoriadb/index/diskcache/DiskCache.java +++ b/vectoriadb-index/src/main/java/jetbrains/vectoriadb/index/diskcache/DiskCache.java @@ -19,6 +19,7 @@ import com.sun.nio.file.ExtendedOpenOption; import it.unimi.dsi.fastutil.Hash; import it.unimi.dsi.fastutil.objects.ObjectOpenHashSet; +import jetbrains.vectoriadb.index.IndexBuilder; import jetbrains.vectoriadb.index.util.collections.BlockingLongArrayQueue; import jetbrains.vectoriadb.index.util.collections.NonBlockingHashMapLongLong; import org.jctools.maps.NonBlockingHashMapLong; @@ -133,25 +134,45 @@ public final class DiskCache extends BLCHeader.DrainStatusRef implements AutoClo private static final Logger logger = LoggerFactory.getLogger(DiskCache.class); public static final int DISK_BLOCK_SIZE = 4 * 1024; private static final int NCPU = Runtime.getRuntime().availableProcessors(); - /** The initial capacity of the write buffer. */ + /** + * The initial capacity of the write buffer. + */ private static final int ADD_BUFFER_MIN = 4; - /** The maximum capacity of the write buffer. */ + /** + * The maximum capacity of the write buffer. + */ private static final int ADD_BUFFER_MAX = 128 * Integer.highestOneBit(NCPU - 1) << 1; - /** The number of attempts to insert into the write buffer before yielding. */ + /** + * The number of attempts to insert into the write buffer before yielding. + */ private static final int ADD_BUFFER_RETRIES = 100; - /** The initial percent of the maximum weighted capacity dedicated to the main space. */ + /** + * The initial percent of the maximum weighted capacity dedicated to the main space. + */ private static final double PERCENT_MAIN = 0.99d; - /** The percent of the maximum weighted capacity dedicated to the main's protected space. */ + /** + * The percent of the maximum weighted capacity dedicated to the main's protected space. + */ private static final double PERCENT_MAIN_PROTECTED = 0.80d; - /** The difference in hit rates that restarts the climber. */ + /** + * The difference in hit rates that restarts the climber. + */ private static final double HILL_CLIMBER_RESTART_THRESHOLD = 0.05d; - /** The percent of the total size to adapt the window by. */ + /** + * The percent of the total size to adapt the window by. + */ private static final double HILL_CLIMBER_STEP_PERCENT = 0.0625d; - /** The rate to decrease the step size to adapt by. */ + /** + * The rate to decrease the step size to adapt by. + */ private static final double HILL_CLIMBER_STEP_DECAY_RATE = 0.98d; - /** The minimum popularity for allowing randomized admission. */ + /** + * The minimum popularity for allowing randomized admission. + */ private static final int ADMIT_HASHDOS_THRESHOLD = 6; - /** The maximum number of entries that can be transferred between queues. */ + /** + * The maximum number of entries that can be transferred between queues. + */ private static final int QUEUE_TRANSFER_THRESHOLD = 1_000; private static final long WARN_AFTER_LOCK_WAIT_NANOS = TimeUnit.SECONDS.toNanos(30); @@ -216,6 +237,8 @@ public final class DiskCache extends BLCHeader.DrainStatusRef implements AutoClo private long fileOffsetTracker; private final long vectorRecordOffset; + + private final long vectorIdRecordOffset; private final long edgesCountOffset; private final long edgesOffset; @@ -261,6 +284,7 @@ public DiskCache(long cacheSize, int vectorDim, int maxConnectionsPerVertex) { this.edgesOffset = pagesStructure.pageStructure.recordEdgesOffset; this.verticesCountPerPage = pagesStructure.pageStructure.verticesCountPerPage; this.vectorRecordOffset = pagesStructure.pageStructure.recordVectorsOffset; + this.vectorIdRecordOffset = pagesStructure.pageStructure.recordIdOffset; int cachePagesCount = pagesStructure.cachePagesCount; @@ -323,6 +347,12 @@ public long vectorOffset(long inMemoryPageIndex, long vertexIndex) { return inMemoryPageIndex * pageSize + recordOffset + vectorRecordOffset; } + public long vectorIdOffset(long inMemoryPageIndex, long vertexIndex) { + //we add offset of the page version + var recordOffset = (vertexIndex % verticesCountPerPage) * vertexRecordSize + Long.BYTES; + return inMemoryPageIndex * pageSize + recordOffset + vectorIdRecordOffset; + } + public int fetchEdges(long indexId, long vertexIndex, int[] edges, Path filePath) { var recordOffset = (vertexIndex % verticesCountPerPage) * vertexRecordSize + Long.BYTES; var edgeCountOffset = recordOffset + edgesCountOffset; @@ -624,9 +654,9 @@ private Future schedulePagePreLoading(long indexId, long pageIndex, long p * Adds a node to the policy and the data store. If an existing node is found, then its value is * updated if allowed. * - * @param key key with which the specified value is to be associated + * @param key key with which the specified value is to be associated * @param inMemoryPageIndex index of the page inside the cache memory. - * @param pageVersion version of page during the caching of the page. + * @param pageVersion version of page during the caching of the page. */ private void add(long key, long inMemoryPageIndex, long pageVersion) { Node node = null; @@ -734,7 +764,9 @@ private void scheduleAfterWrite() { } - /** Acquires the eviction lock. */ + /** + * Acquires the eviction lock. + */ private void lock() { long remainingNanos = WARN_AFTER_LOCK_WAIT_NANOS; long end = System.nanoTime() + remainingNanos; @@ -881,7 +913,9 @@ private void onAddTask(Node node) { } } - /** Adapts the eviction policy to towards the optimal recency / frequency configuration. */ + /** + * Adapts the eviction policy to towards the optimal recency / frequency configuration. + */ @GuardedBy("evictionLock") private void climb() { determineAdjustment(); @@ -895,7 +929,9 @@ private void climb() { } } - /** Decreases the size of the admission window and increases the main's protected region. */ + /** + * Decreases the size of the admission window and increases the main's protected region. + */ @GuardedBy("evictionLock") private void decreaseWindow() { if (windowMaximum <= 1) { @@ -927,7 +963,9 @@ private void decreaseWindow() { } - /** Transfers the nodes from the protected to the probation region if it exceeds the maximum. */ + /** + * Transfers the nodes from the protected to the probation region if it exceeds the maximum. + */ @GuardedBy("evictionLock") private void demoteFromMainProtected() { long mainProtectedMaximum = this.mainProtectedMaximum; @@ -1004,7 +1042,9 @@ private void increaseWindow() { } - /** Calculates the amount to adapt the window by and sets {@link #adjustment} accordingly. */ + /** + * Calculates the amount to adapt the window by and sets {@link #adjustment} accordingly. + */ @GuardedBy("evictionLock") private void determineAdjustment() { int requestCount = hitsInSample + missesInSample; @@ -1028,7 +1068,9 @@ private void determineAdjustment() { } - /** Evicts entries if the cache exceeds the maximum. */ + /** + * Evicts entries if the cache exceeds the maximum. + */ @GuardedBy("evictionLock") private void evictEntries() { lockedNodes.clear(); @@ -1181,7 +1223,7 @@ private void evictFromMain(@Nullable Node candidate, @NotNull ObjectOpenHashSet< * are admitted. * * @param candidateKey the key for the entry being proposed for long term retention - * @param victimKey the key for the entry chosen by the eviction policy for replacement + * @param victimKey the key for the entry chosen by the eviction policy for replacement * @return if the candidate should be admitted and the victim ejected */ @GuardedBy("evictionLock") @@ -1285,7 +1327,9 @@ private void addPageToFreePagesQueue(Node node) { } } - /** Logs if the node cannot be found in the map but is still alive. */ + /** + * Logs if the node cannot be found in the map but is still alive. + */ private void logIfAlive(Node node) { if (node.isAlive()) { String message = brokenEqualityMessage(node.getKey()); @@ -1293,7 +1337,9 @@ private void logIfAlive(Node node) { } } - /** Returns the formatted broken equality error message. */ + /** + * Returns the formatted broken equality error message. + */ private static String brokenEqualityMessage(long key) { return String.format(US, "An invalid state was detected, occurring when the key's equals or " + "hashCode was modified while residing in the cache. This violation of the Map " @@ -1361,7 +1407,9 @@ private Node lockForEviction(Node node, @NotNull ObjectOpenHashSet lockedN } - /** Drains the write buffer. */ + /** + * Drains the write buffer. + */ @GuardedBy("evictionLock") private void drainWriteBuffer() { for (int i = 0; i <= ADD_BUFFER_MAX; i++) { @@ -1377,14 +1425,18 @@ private void drainWriteBuffer() { setDrainStatusOpaque(PROCESSING_TO_REQUIRED); } - /** Drains the read buffer. */ + /** + * Drains the read buffer. + */ @GuardedBy("evictionLock") private void drainReadBuffer() { readBuffer.drainTo(this::onAccess); } - /** Updates the node's location in the page replacement policy. */ + /** + * Updates the node's location in the page replacement policy. + */ @GuardedBy("evictionLock") private void onAccess(Node node) { var key = node.getKey(); @@ -1402,7 +1454,9 @@ private void onAccess(Node node) { } - /** Promote the node from probation to protected on an access. */ + /** + * Promote the node from probation to protected on an access. + */ @GuardedBy("evictionLock") private void reorderProbation(Node node) { if (!accessOrderProbationDeque.contains(node)) { @@ -1439,7 +1493,9 @@ public boolean assertAllUnlocked() { return true; } - /** Updates the node's location in the policy's deque. */ + /** + * Updates the node's location in the policy's deque. + */ private static void reorder(LinkedDeque deque, Node node) { // An entry may be scheduled for reordering despite having been removed. This can occur when the // entry was concurrently read while a writer was removing it. If the entry is no longer linked @@ -1450,7 +1506,9 @@ private static void reorder(LinkedDeque deque, Node node) { } - /** A reusable task that performs the maintenance work; used to avoid wrapping by ForkJoinPool. */ + /** + * A reusable task that performs the maintenance work; used to avoid wrapping by ForkJoinPool. + */ private static final class PerformCleanupTask extends ForkJoinTask implements Runnable { private final WeakReference reference; @@ -1536,6 +1594,7 @@ private static PagesStructure createPagesStructure(long totalSize, int vectorDim @NotNull public static PageStructure createPageStructure(int vectorDim, int maxConnectionsPerVertex) { var vertexLayout = MemoryLayout.structLayout( + MemoryLayout.sequenceLayout(IndexBuilder.VECTOR_ID_SIZE, ValueLayout.JAVA_BYTE).withName("id"), MemoryLayout.sequenceLayout(vectorDim, ValueLayout.JAVA_FLOAT).withName("vector"), MemoryLayout.sequenceLayout(maxConnectionsPerVertex, ValueLayout.JAVA_INT).withName("edges"), ValueLayout.JAVA_INT.withName("edgesCount") @@ -1563,13 +1622,15 @@ public static PageStructure createPageStructure(int vectorDim, int maxConnection MemoryLayout.paddingLayout(paddingSpace)); assert pageSize == pageLayout.byteSize(); + var idOffset = (int) vertexLayout.byteOffset(MemoryLayout.PathElement.groupElement("id")); var recordVectorOffset = (int) vertexLayout.byteOffset(MemoryLayout.PathElement.groupElement("vector")); var recordEdgesOffset = (int) vertexLayout.byteOffset(MemoryLayout.PathElement.groupElement("edges")); var recordEdgesCountOffset = (int) vertexLayout.byteOffset(MemoryLayout.PathElement.groupElement("edgesCount")); return new PageStructure(pageSize, verticesCountPerPage, vertexRecordSize, - recordVectorOffset, recordEdgesOffset, recordEdgesCountOffset, pageLayout); + recordVectorOffset, recordEdgesOffset, recordEdgesCountOffset, idOffset, + pageLayout); } private record PagesStructure(int cachePagesCount, int preLoadersCount, int allocatedPagesCount, @@ -1578,12 +1639,16 @@ private record PagesStructure(int cachePagesCount, int preLoadersCount, int allo public record PageStructure(int pageSize, int verticesCountPerPage, int vertexRecordSize, int recordVectorsOffset, - int recordEdgesOffset, int recordEdgesCountOffset, MemoryLayout pageLayout) { + int recordEdgesOffset, int recordEdgesCountOffset, + int recordIdOffset, + MemoryLayout pageLayout) { } } -/** The namespace for field padding through inheritance. */ +/** + * The namespace for field padding through inheritance. + */ final class BLCHeader { @SuppressWarnings("unused") @@ -1605,20 +1670,32 @@ static class PadDrainStatus { byte p112, p113, p114, p115, p116, p117, p118, p119; } - /** Enforces a memory layout to avoid false sharing by padding the drain status. */ + /** + * Enforces a memory layout to avoid false sharing by padding the drain status. + */ abstract static class DrainStatusRef extends BLCHeader.PadDrainStatus { static final VarHandle DRAIN_STATUS; - /** A drain is not taking place. */ + /** + * A drain is not taking place. + */ static final int IDLE = 0; - /** A drain is required due to a pending write modification. */ + /** + * A drain is required due to a pending write modification. + */ static final int REQUIRED = 1; - /** A drain is in progress and will transition to idle. */ + /** + * A drain is in progress and will transition to idle. + */ static final int PROCESSING_TO_IDLE = 2; - /** A drain is in progress and will transition to required. */ + /** + * A drain is in progress and will transition to required. + */ static final int PROCESSING_TO_REQUIRED = 3; - /** The draining status of the buffers. */ + /** + * The draining status of the buffers. + */ volatile int drainStatus = IDLE; /** diff --git a/vectoriadb-index/src/test/java/jetbrains/vectoriadb/index/DiskANNTest.java b/vectoriadb-index/src/test/java/jetbrains/vectoriadb/index/DiskANNTest.java index ce477b5be..43d830b32 100644 --- a/vectoriadb-index/src/test/java/jetbrains/vectoriadb/index/DiskANNTest.java +++ b/vectoriadb-index/src/test/java/jetbrains/vectoriadb/index/DiskANNTest.java @@ -25,6 +25,8 @@ import org.junit.Test; import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; import java.nio.file.Files; import java.nio.file.Path; import java.util.Arrays; @@ -58,8 +60,11 @@ public void testFindLoadedVectorsL2Distance() throws Exception { var indexName = "test_index"; var ts1 = System.nanoTime(); try (var dataBuilder = DataStore.create(indexName, vectorDimensions, L2DistanceFunction.INSTANCE, dbDir)) { - for (var vector : vectors) { - dataBuilder.add(vector); + for (int n = 0; n < vectorsCount; n++) { + var vector = vectors[n]; + var id = ByteBuffer.allocate(IndexBuilder.VECTOR_ID_SIZE). + order(ByteOrder.LITTLE_ENDIAN).putInt(n).array(); + dataBuilder.add(vector, id); } } @@ -76,8 +81,13 @@ public void testFindLoadedVectorsL2Distance() throws Exception { ts1 = System.nanoTime(); for (var j = 0; j < vectorsCount; j++) { var vector = queries[j]; + var rawIds = indexReader.nearest(vector, recallCount); var result = new int[recallCount]; - indexReader.nearest(vector, result, recallCount); + + for (int n = 0; n < recallCount; n++) { + result[n] = ByteBuffer.wrap(rawIds[n]).order(ByteOrder.LITTLE_ENDIAN).getInt(); + } + totalRecall += recall(result, groundTruth[j]); if ((j + 1) % 1_000 == 0) { @@ -124,8 +134,11 @@ public void testFindLoadedVectorsDotDistance() throws Exception { var indexName = "test_index"; try (var dataStore = DataStore.create(indexName, vectorDimensions, DotDistanceFunction.INSTANCE, dbDir)) { - for (var vector : vectors) { - dataStore.add(vector); + for (var n = 0; n < vectorsCount; n++) { + var id = ByteBuffer.allocate(IndexBuilder.VECTOR_ID_SIZE). + order(ByteOrder.LITTLE_ENDIAN).putInt(n).array(); + var vector = vectors[n]; + dataStore.add(vector, id); } } @@ -142,8 +155,13 @@ public void testFindLoadedVectorsDotDistance() throws Exception { ts1 = System.nanoTime(); for (var j = 0; j < vectorsCount; j++) { var vector = queryVectors[j]; + + var rawIds = indexReader.nearest(vector, recallCount); var result = new int[recallCount]; - indexReader.nearest(vector, result, recallCount); + + for (int n = 0; n < recallCount; n++) { + result[n] = ByteBuffer.wrap(rawIds[n]).order(ByteOrder.LITTLE_ENDIAN).getInt(); + } totalRecall += recall(result, groundTruth[j]); if ((j + 1) % 1_000 == 0) { @@ -192,8 +210,11 @@ public void testFindLoadedVectorsCosineDistance() throws Exception { var indexName = "test_index"; try (var dataBuilder = DataStore.create(indexName, vectorDimensions, CosineDistanceFunction.INSTANCE, dbDir)) { - for (var vector : vectors) { - dataBuilder.add(vector); + for (int n = 0; n < vectorsCount; n++) { + var vector = vectors[n]; + var id = ByteBuffer.allocate(IndexBuilder.VECTOR_ID_SIZE). + order(ByteOrder.LITTLE_ENDIAN).putInt(n).array(); + dataBuilder.add(vector, id); } } @@ -212,9 +233,13 @@ public void testFindLoadedVectorsCosineDistance() throws Exception { ts1 = System.nanoTime(); for (var j = 0; j < vectorsCount; j++) { var vector = queryVectors[j]; + + var rawIds = indexReader.nearest(vector, recallCount); var result = new int[recallCount]; - indexReader.nearest(vector, result, recallCount); + for (var n = 0; n < recallCount; n++) { + result[n] = ByteBuffer.wrap(rawIds[n]).order(ByteOrder.LITTLE_ENDIAN).getInt(); + } totalRecall += recall(result, groundTruth[j]); if ((j + 1) % 1_000 == 0) { @@ -350,8 +375,11 @@ private void runSiftBenchmarks( var indexName = "test_index"; try (var dataBuilder = DataStore.create(indexName, 128, L2DistanceFunction.INSTANCE, dbDir)) { - for (var vector : vectors) { - dataBuilder.add(vector); + for (int n = 0; n < vectors.length; n++) { + var vector = vectors[n]; + var id = ByteBuffer.allocate(IndexBuilder.VECTOR_ID_SIZE). + order(ByteOrder.LITTLE_ENDIAN).putInt(n).array(); + dataBuilder.add(vector, id); } } @@ -369,11 +397,10 @@ private void runSiftBenchmarks( ts1 = System.nanoTime(); for (var index = 0; index < queryVectors.length; index++) { var vector = queryVectors[index]; - var result = new int[1]; - indexReader.nearest(vector, result, 1); - Assert.assertEquals("j = " + index, 1, result.length); - if (groundTruth[index][0] != result[0]) { + var rawIds = indexReader.nearest(vector, 1); + Assert.assertEquals("j = " + index, 1, rawIds.length); + if (groundTruth[index][0] != ByteBuffer.wrap(rawIds[0]).order(ByteOrder.LITTLE_ENDIAN).getInt()) { errorsCount++; } } diff --git a/vectoriadb-index/src/test/java/jetbrains/vectoriadb/index/L2PQKMeansTest.java b/vectoriadb-index/src/test/java/jetbrains/vectoriadb/index/L2PQKMeansTest.java index 41ff43cb8..133f98d94 100644 --- a/vectoriadb-index/src/test/java/jetbrains/vectoriadb/index/L2PQKMeansTest.java +++ b/vectoriadb-index/src/test/java/jetbrains/vectoriadb/index/L2PQKMeansTest.java @@ -24,6 +24,8 @@ import java.lang.foreign.Arena; import java.lang.foreign.MemorySegment; import java.lang.foreign.ValueLayout; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; import java.util.ArrayList; import java.util.Random; @@ -185,6 +187,15 @@ public MemorySegment read(int index) { return vectorSegment; } + @Override + public MemorySegment id(int index) { + var buffer = ByteBuffer.allocate(IndexBuilder.VECTOR_ID_SIZE); + buffer.order(ByteOrder.LITTLE_ENDIAN); + buffer.putInt(index); + + return MemorySegment.ofBuffer(buffer); + } + @Override public void close() { } diff --git a/vectoriadb-interface/src/main/proto/IndexManager.proto b/vectoriadb-interface/src/main/proto/IndexManager.proto index 92a50b439..8ce0adcd3 100644 --- a/vectoriadb-interface/src/main/proto/IndexManager.proto +++ b/vectoriadb-interface/src/main/proto/IndexManager.proto @@ -95,12 +95,17 @@ message FindNearestNeighboursRequest { } message FindNearestNeighboursResponse { - repeated int32 ids = 1; + repeated VectorId ids = 1; +} + +message VectorId { + bytes id = 1; } message UploadVectorsRequest { string index_name = 1; repeated float vector_components = 2; + VectorId id = 3; } message IndexListResponse { diff --git a/vectoriadb-java-client/src/main/java/jetbrains/vectoriadb/client/VectoriaDBClient.java b/vectoriadb-java-client/src/main/java/jetbrains/vectoriadb/client/VectoriaDBClient.java index f5becabda..6a9c3bd2d 100644 --- a/vectoriadb-java-client/src/main/java/jetbrains/vectoriadb/client/VectoriaDBClient.java +++ b/vectoriadb-java-client/src/main/java/jetbrains/vectoriadb/client/VectoriaDBClient.java @@ -15,6 +15,7 @@ */ package jetbrains.vectoriadb.client; +import com.google.protobuf.ByteString; import com.google.protobuf.Empty; import io.grpc.Context; import io.grpc.ManagedChannelBuilder; @@ -27,6 +28,8 @@ import org.slf4j.LoggerFactory; import javax.annotation.Nullable; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; import java.util.ArrayList; import java.util.Iterator; import java.util.List; @@ -36,8 +39,10 @@ import java.util.concurrent.locks.ReentrantLock; import java.util.function.BiConsumer; - +@SuppressWarnings({"UnusedReturnValue", "unused"}) public final class VectoriaDBClient { + public static final int VECTOR_ID_SIZE = 16; + private static final Logger logger = LoggerFactory.getLogger(VectoriaDBClient.class); private final IndexManagerGrpc.IndexManagerBlockingStub indexManagerBlockingStub; private final IndexManagerGrpc.IndexManagerStub indexManagerAsyncStub; @@ -123,26 +128,40 @@ public IndexState retrieveIndexState(String indexName) { }; } - public void uploadVectors(final String indexName, final Iterator vectors, + public void uploadVectors(final String indexName, final Iterator vectors, final Iterator ids, @Nullable BiConsumer progressIndicator) { - uploadVectors(indexName, vectors, VectoriaDBClient::uploadVectorsList, progressIndicator); + uploadVectors(indexName, vectors, ids, VectoriaDBClient::uploadVectorsList, progressIndicator); } - public void uploadVectors(final String indexName, final Iterator vectors) { - uploadVectors(indexName, vectors, VectoriaDBClient::uploadVectorsList, null); + public void uploadVectors(final String indexName, final Iterator vectors, final Iterator ids) { + uploadVectors(indexName, vectors, ids, VectoriaDBClient::uploadVectorsList, null); } - public void uploadVectors(final String indexName, final float[][] vectors, + public void uploadVectors(final String indexName, final float[][] vectors, final byte[][] ids, @Nullable BiConsumer progressIndicator) { - uploadVectors(indexName, vectors, VectoriaDBClient::uploadVectorsArray, progressIndicator); + uploadVectors(indexName, vectors, ids, VectoriaDBClient::uploadVectorsArray, progressIndicator); } - public void uploadVectors(final String indexName, final float[][] vectors) { - uploadVectors(indexName, vectors, VectoriaDBClient::uploadVectorsArray, null); + public void uploadVectors(final String indexName, final float[][] vectors, byte[][] ids) { + uploadVectors(indexName, vectors, ids, VectoriaDBClient::uploadVectorsArray, null); } - private void uploadVectors(String indexName, T vectors, VectorsUploader vectorsUploader, - @Nullable BiConsumer progressIndicator) { + public void uploadVectors(final String indexName, final float[][] vectors, final int[] ids, + @Nullable BiConsumer progressIndicator) { + var rawIds = new byte[ids.length][VECTOR_ID_SIZE]; + for (int i = 0; i < ids.length; i++) { + //little endian because that is most used integer presentation format in CPU architecture + var buffer = ByteBuffer.allocate(VECTOR_ID_SIZE).order(ByteOrder.LITTLE_ENDIAN); + buffer.putInt(i); + + rawIds[i] = buffer.array(); + } + + uploadVectors(indexName, vectors, rawIds, VectoriaDBClient::uploadVectorsArray, progressIndicator); + } + + private void uploadVectors(String indexName, T vectors, U ids, VectorsUploader vectorsUploader, + @Nullable BiConsumer progressIndicator) { var error = new Throwable[1]; var finishedLatch = new CountDownLatch(1); var onReadyHandler = new OnReadyHandler(); @@ -172,7 +191,7 @@ public void onCompleted() { var requestObserver = indexManagerAsyncStub.uploadVectors(responseObserver); try { - vectorsUploader.uploadVectors(indexName, vectors, requestObserver, finishedLatch, onReadyHandler, + vectorsUploader.uploadVectors(indexName, vectors, ids, requestObserver, finishedLatch, onReadyHandler, progressIndicator); } catch (RuntimeException e) { requestObserver.onError(e); @@ -197,7 +216,7 @@ public void onCompleted() { } } - private static void uploadVectorsList(String indexName, Iterator vectors, + private static void uploadVectorsList(String indexName, Iterator vectors, Iterator ids, StreamObserver requestObserver, CountDownLatch finishedLatch, OnReadyHandler onReadyHandler, @@ -206,6 +225,8 @@ private static void uploadVectorsList(String indexName, Iterator vector while (vectors.hasNext()) { onReadyHandler.callWhenReady(() -> { var vector = vectors.next(); + var id = ids.next(); + var builder = IndexManagerOuterClass.UploadVectorsRequest.newBuilder(); builder.setIndexName(indexName); @@ -213,6 +234,8 @@ private static void uploadVectorsList(String indexName, Iterator vector builder.addVectorComponents(value); } + builder.setId(IndexManagerOuterClass.VectorId.newBuilder().setId(ByteString.copyFrom(id)).build()); + var request = builder.build(); requestObserver.onNext(request); if (progressIndicator != null) { @@ -227,21 +250,27 @@ private static void uploadVectorsList(String indexName, Iterator vector } } - private static void uploadVectorsArray(String indexName, float[][] vectors, + private static void uploadVectorsArray(String indexName, float[][] vectors, byte[][] ids, StreamObserver requestObserver, CountDownLatch finishedLatch, OnReadyHandler onReadyHandler, @Nullable BiConsumer progressIndicator) { var counter = new int[1]; - for (var vector : vectors) { + for (int i = 0; i < vectors.length; i++) { + var vector = vectors[i]; + var id = ids[i]; + onReadyHandler.callWhenReady(() -> { var builder = IndexManagerOuterClass.UploadVectorsRequest.newBuilder(); + builder.setIndexName(indexName); for (var value : vector) { builder.addVectorComponents(value); } + builder.setId(IndexManagerOuterClass.VectorId.newBuilder().setId(ByteString.copyFrom(id)).build()); + var request = builder.build(); requestObserver.onNext(request); if (progressIndicator != null) { @@ -270,7 +299,22 @@ public void switchToBuildMode() { indexManagerBlockingStub.switchToBuildMode(request); } - public int[] findNearestNeighbours(final String indexName, final float[] vector, int k) { + public int[] findIntNearestNeighbours(final String indexName, final float[] vector, int k) { + var rawIds = findNearestNeighbours(indexName, vector, k); + var result = new int[rawIds.length]; + + var buffer = ByteBuffer.allocate(VECTOR_ID_SIZE).order(ByteOrder.LITTLE_ENDIAN); + for (int i = 0; i < rawIds.length; i++) { + buffer.clear(); + buffer.put(rawIds[i]); + buffer.flip(); + result[i] = buffer.getInt(); + } + + return result; + } + + public byte[][] findNearestNeighbours(final String indexName, final float[] vector, int k) { var builder = IndexManagerOuterClass.FindNearestNeighboursRequest.newBuilder(); builder.setIndexName(indexName); builder.setK(k); @@ -282,7 +326,7 @@ public int[] findNearestNeighbours(final String indexName, final float[] vector, var request = builder.build(); var response = indexManagerBlockingStub.findNearestNeighbours(request); - return response.getIdsList().stream().mapToInt(Integer::intValue).toArray(); + return response.getIdsList().stream().map(vectorId -> vectorId.getId().toByteArray()).toArray(byte[][]::new); } public void buildStatus(IndexBuildStatusListener buildStatusListener) { @@ -354,8 +398,8 @@ public void onCompleted() { } } - private interface VectorsUploader { - void uploadVectors(String indexName, T vectors, + private interface VectorsUploader { + void uploadVectors(String indexName, T vectors, U ids, StreamObserver requestObserver, CountDownLatch finishedLatch, OnReadyHandler onReadyHandler, @Nullable BiConsumer progressIndicator); diff --git a/vectoriadb-server/src/main/java/jetbrains/vectoriadb/server/IndexManagerServiceImpl.java b/vectoriadb-server/src/main/java/jetbrains/vectoriadb/server/IndexManagerServiceImpl.java index 7972e3652..29010e982 100644 --- a/vectoriadb-server/src/main/java/jetbrains/vectoriadb/server/IndexManagerServiceImpl.java +++ b/vectoriadb-server/src/main/java/jetbrains/vectoriadb/server/IndexManagerServiceImpl.java @@ -15,6 +15,7 @@ */ package jetbrains.vectoriadb.server; +import com.google.protobuf.ByteString; import com.google.protobuf.Empty; import io.grpc.Context; import io.grpc.Status; @@ -864,18 +865,20 @@ public void findNearestNeighbours(IndexManagerOuterClass.FindNearestNeighboursRe var neighboursCount = request.getK(); var queryVector = request.getVectorComponentsList(); - var result = new int[neighboursCount]; var vector = new float[dimensions]; for (int i = 0; i < dimensions; i++) { vector[i] = queryVector.get(i); } - indexReader.nearest(vector, result, neighboursCount); + var ids = indexReader.nearest(vector, neighboursCount); - for (var vectorIndex : result) { - responseBuilder.addIds(vectorIndex); + for (var id : ids) { + var vectorId = IndexManagerOuterClass.VectorId.newBuilder(); + vectorId.setId(ByteString.copyFrom(id)); + + responseBuilder.addIds(vectorId); } } catch (Exception e) { logger.error("Failed to find nearest neighbours", e); @@ -1184,8 +1187,8 @@ public void onNext(IndexManagerOuterClass.UploadVectorsRequest value) { vector[i] = value.getVectorComponents(i); } try { - store.add(vector); - } catch (IOException e) { + store.add(vector, value.getId().toByteArray()); + } catch (Exception e) { var msg = "Failed to add vector to index " + indexName; logger.error(msg, e); diff --git a/vectoriadb-server/src/test/java/jetbrains/vectoriadb/server/IndexManagerTest.java b/vectoriadb-server/src/test/java/jetbrains/vectoriadb/server/IndexManagerTest.java index bd004f00a..ffb2ad6fb 100644 --- a/vectoriadb-server/src/test/java/jetbrains/vectoriadb/server/IndexManagerTest.java +++ b/vectoriadb-server/src/test/java/jetbrains/vectoriadb/server/IndexManagerTest.java @@ -15,6 +15,7 @@ */ package jetbrains.vectoriadb.server; +import com.google.protobuf.ByteString; import com.google.protobuf.Empty; import io.grpc.internal.testing.StreamRecorder; import jetbrains.vectoriadb.index.CosineDistanceFunction; @@ -33,6 +34,8 @@ import org.springframework.mock.env.MockEnvironment; import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; import java.nio.file.Files; import java.nio.file.Path; import java.util.ArrayList; @@ -334,9 +337,15 @@ public void testShutDownAndReload() throws Exception { var rng = RandomSource.XO_RO_SHI_RO_128_PP.create(); var vectors = new float[10][64]; generateUniqueVectorSet(vectors, rng); + var ids = new byte[10][]; + + for (int i = 0; i < ids.length; i++) { + ids[i] = new byte[16]; + ByteBuffer.wrap(ids[i]).order(ByteOrder.LITTLE_ENDIAN).putInt(i); + } createIndex(uploadedIndex, indexManagerService, IndexManagerOuterClass.Distance.L2); - uploadVectors(uploadedIndex, vectors, indexManagerService); + uploadVectors(uploadedIndex, vectors, ids, indexManagerService); generateIndex(builtIndex, L2DistanceFunction.INSTANCE, 64, 10, indexManagerService); @@ -400,18 +409,22 @@ private static void executeInServiceContext(boolean preDeleteDirectories, boolea } } - private static void uploadVectors(String indexName, float[][] vectors, + private static void uploadVectors(String indexName, float[][] vectors, byte[][] ids, IndexManagerServiceImpl indexManagerService) throws Exception { var vectorsUploadRecorder = StreamRecorder.create(); var request = indexManagerService.uploadVectors(vectorsUploadRecorder); try { - for (var vector : vectors) { + for (var i = 0; i < vectors.length; i++) { + var vector = vectors[i]; + var id = ids[i]; + var builder = IndexManagerOuterClass.UploadVectorsRequest.newBuilder(); builder.setIndexName(indexName); for (var component : vector) { builder.addVectorComponents(component); } + builder.setId(IndexManagerOuterClass.VectorId.newBuilder().setId(ByteString.copyFrom(id)).build()); request.onNext(builder.build()); if (vectorsUploadRecorder.getError() != null) { @@ -518,8 +531,8 @@ private static void switchToSearchMode(IndexManagerServiceImpl indexManagerServi checkCompleteness(switchToSearchModeRecorder); } - private static int[] findNearestNeighbours(IndexManagerServiceImpl indexManagerService, - float[] queryVector, String indexName, int k) throws Exception { + private static byte[][] findNearestNeighbours(IndexManagerServiceImpl indexManagerService, + float[] queryVector, String indexName, int k) throws Exception { var findNearestVectorsRecorder = StreamRecorder.create(); var builder = IndexManagerOuterClass.FindNearestNeighboursRequest.newBuilder(); builder.setIndexName(indexName); @@ -535,10 +548,10 @@ private static int[] findNearestNeighbours(IndexManagerServiceImpl indexManagerS var response = findNearestVectorsRecorder.getValues().get(0); var nearestVectors = response.getIdsList(); - var result = new int[nearestVectors.size()]; + var result = new byte[nearestVectors.size()][]; for (int i = 0; i < nearestVectors.size(); i++) { - result[i] = nearestVectors.get(i); + result[i] = nearestVectors.get(i).getId().toByteArray(); } return result; @@ -602,9 +615,14 @@ private static float[][] generateIndex(String indexName, DistanceFunction distan generateUniqueVectorSet(vectors, rng); + var ids = new byte[vectorsCount][]; + for (int i = 0; i < ids.length; i++) { + ids[i] = new byte[16]; + ByteBuffer.wrap(ids[i]).order(ByteOrder.LITTLE_ENDIAN).putInt(i); + } createIndex(indexName, indexManagerService, distance); - uploadVectors(indexName, vectors, indexManagerService); + uploadVectors(indexName, vectors, ids, indexManagerService); var ts1 = System.nanoTime(); buildIndex(indexName, indexManagerService); @@ -630,7 +648,13 @@ private static void searchNeighbours(String indexName, int vectorsCount, int vec for (var j = 0; j < vectorsCount; j++) { var vector = queries[j]; - var result = findNearestNeighbours(indexManagerService, vector, indexName, recallCount); + var rawIds = findNearestNeighbours(indexManagerService, vector, indexName, recallCount); + + var result = new int[rawIds.length]; + for (int i = 0; i < rawIds.length; i++) { + result[i] = ByteBuffer.wrap(rawIds[i]).order(ByteOrder.LITTLE_ENDIAN).getInt(); + } + totalRecall += recall(result, groundTruth[j]); if ((j + 1) % 1_000 == 0) {