From 5d5926712d9ca7d8b3240cd18e0817e71eb0b622 Mon Sep 17 00:00:00 2001 From: laa Date: Thu, 2 Nov 2023 08:32:13 +0100 Subject: [PATCH] Initial version of VectoriadDB client and addition of benchmarks project. --- settings.gradle.kts | 9 +- vectoriadb-bench/build.gradle.kts | 5 + .../vectoriadb/bench/BenchUtils.java | 188 +++++++++++++++ vectoriadb-java-client/build.gradle.kts | 9 + .../jetbrains/vectoriadb/client/Distance.java | 7 + .../client/IndexBuildStatusListener.java | 10 + .../vectoriadb/client/IndexMetadata.java | 4 + .../vectoriadb/client/IndexState.java | 12 + .../vectoriadb/client/VectoriaDBClient.java | 214 ++++++++++++++++++ 9 files changed, 451 insertions(+), 7 deletions(-) create mode 100644 vectoriadb-bench/build.gradle.kts create mode 100644 vectoriadb-bench/src/main/java/jetbrains/vectoriadb/bench/BenchUtils.java create mode 100644 vectoriadb-java-client/src/main/java/jetbrains/vectoriadb/client/Distance.java create mode 100644 vectoriadb-java-client/src/main/java/jetbrains/vectoriadb/client/IndexBuildStatusListener.java create mode 100644 vectoriadb-java-client/src/main/java/jetbrains/vectoriadb/client/IndexMetadata.java create mode 100644 vectoriadb-java-client/src/main/java/jetbrains/vectoriadb/client/IndexState.java create mode 100644 vectoriadb-java-client/src/main/java/jetbrains/vectoriadb/client/VectoriaDBClient.java diff --git a/settings.gradle.kts b/settings.gradle.kts index c4ca2e5a3..9b5cbe73c 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -8,7 +8,7 @@ pluginManagement { id("me.champeau.jmh") version ("0.7.1") id("com.google.protobuf") version ("0.9.4") id("org.springframework.boot") version ("3.1.5") - id("com.bmuschko.docker-remote-api") version("9.3.6") + id("com.bmuschko.docker-remote-api") version ("9.3.6") } repositories { maven(url = "https://cache-redirector.jetbrains.com/plugins.gradle.org/m2") @@ -176,11 +176,6 @@ include("vectoriadb-index") project(":vectoriadb-index").name = "vectoriadb-index" include("vectoriadb-server") -project(":vectoriadb-server").name = "vectoriadb-server" - include("vectoriadb-interface") -project(":vectoriadb-interface").name = "vectoriadb-interface" - include("vectoriadb-java-client") -project(":vectoriadb-java-client").name = "vectoriadb-java-client" -include("vectoriadb-docker") +include("vectoriadb-bench") diff --git a/vectoriadb-bench/build.gradle.kts b/vectoriadb-bench/build.gradle.kts new file mode 100644 index 000000000..3bc767c17 --- /dev/null +++ b/vectoriadb-bench/build.gradle.kts @@ -0,0 +1,5 @@ +dependencies { + implementation(libs.commons.net) + + implementation(project(":vectoriadb-index")) +} \ No newline at end of file diff --git a/vectoriadb-bench/src/main/java/jetbrains/vectoriadb/bench/BenchUtils.java b/vectoriadb-bench/src/main/java/jetbrains/vectoriadb/bench/BenchUtils.java new file mode 100644 index 000000000..9cd5b4e83 --- /dev/null +++ b/vectoriadb-bench/src/main/java/jetbrains/vectoriadb/bench/BenchUtils.java @@ -0,0 +1,188 @@ +package jetbrains.vectoriadb.bench; + +import org.apache.commons.compress.archivers.tar.TarArchiveInputStream; +import org.apache.commons.compress.compressors.gzip.GzipCompressorInputStream; +import org.apache.commons.compress.utils.IOUtils; +import org.apache.commons.net.ftp.FTP; +import org.apache.commons.net.ftp.FTPClient; + +import java.io.EOFException; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.channels.FileChannel; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.StandardCopyOption; + +public class BenchUtils { + public static void extractTarGzArchive(Path rootDir, Path archivePath) throws IOException { + System.out.println("Extracting " + archivePath.getFileName() + " into " + rootDir); + + try (var fis = Files.newInputStream(archivePath)) { + try (var giz = new GzipCompressorInputStream(fis)) { + try (var tar = new TarArchiveInputStream(giz)) { + var entry = tar.getNextTarEntry(); + + while (entry != null) { + var name = entry.getName(); + if (name.endsWith(".fvecs") || name.endsWith(".ivecs")) { + System.out.printf("Extracting %s%n", name); + var file = rootDir.resolve(name); + if (!Files.exists(file.getParent())) { + Files.createDirectories(file.getParent()); + } + + try (var fos = Files.newOutputStream(file)) { + IOUtils.copy(tar, fos); + } + } + entry = tar.getNextTarEntry(); + } + } + } + } + + System.out.printf("%s extracted%n", archivePath.getFileName()); + } + + public static void extractGzArchive(Path targetPath, Path archivePath) throws IOException { + System.out.println("Extracting " + archivePath.getFileName() + " into " + targetPath.getFileName()); + + try (var fis = Files.newInputStream(archivePath)) { + try (var giz = new GzipCompressorInputStream(fis)) { + Files.copy(giz, targetPath, StandardCopyOption.REPLACE_EXISTING); + } + } + + System.out.printf("%s extracted%n", archivePath.getFileName()); + } + + + public static Path downloadBenchFile(Path rootDir, String benchArchiveName) throws IOException { + var benchArchivePath = rootDir.resolve(benchArchiveName); + + if (Files.exists(benchArchivePath)) { + System.out.println(benchArchiveName + " already exists in " + rootDir); + } else { + System.out.println("Downloading " + benchArchiveName + + " from ftp.irisa.fr into " + rootDir); + + var ftpClient = new FTPClient(); + ftpClient.connect("ftp.irisa.fr"); + ftpClient.enterLocalPassiveMode(); + var loggedIdn = ftpClient.login("anonymous", "anonymous"); + ftpClient.setFileType(FTP.BINARY_FILE_TYPE); + if (!loggedIdn) { + throw new IllegalStateException("Failed to login to ftp.irisa.fr"); + } + + System.out.println("Logged in to ftp.irisa.fr"); + try (var fos = Files.newOutputStream(benchArchivePath)) { + ftpClient.retrieveFile("/local/texmex/corpus/" + benchArchiveName, fos); + } finally { + ftpClient.logout(); + ftpClient.disconnect(); + } + + System.out.println(benchArchiveName + " downloaded"); + } + + return benchArchivePath; + } + + public static float[][] readFVectors(Path path, int vectorDimensions) throws IOException { + try (var channel = FileChannel.open(path)) { + + var vectorBuffer = ByteBuffer.allocate(Float.BYTES * vectorDimensions + Integer.BYTES); + vectorBuffer.order(ByteOrder.LITTLE_ENDIAN); + + var vectorsCount = + (int) (channel.size() / (Float.BYTES * vectorDimensions + Integer.BYTES)); + var vectors = new float[vectorsCount][]; + for (var i = 0; i < vectorsCount; i++) { + vectorBuffer.rewind(); + readFully(channel, vectorBuffer); + vectorBuffer.rewind(); + + if (vectorBuffer.getInt() != vectorDimensions) { + throw new IllegalStateException("Vector dimensions mismatch"); + } + + var vector = new float[vectorDimensions]; + for (var j = 0; j < vector.length; j++) { + vector[j] = vectorBuffer.getFloat(); + } + vectors[i] = vector; + } + return vectors; + } + } + + public static float[][] readFBVectors(Path path, int vectorDimensions, int size) throws IOException { + try (var channel = FileChannel.open(path)) { + var vectorBuffer = ByteBuffer.allocate(vectorDimensions + Integer.BYTES); + vectorBuffer.order(ByteOrder.LITTLE_ENDIAN); + + var vectorsCount = + Math.min(size, (int) (channel.size() / (vectorDimensions + Integer.BYTES))); + var vectors = new float[vectorsCount][]; + + for (var i = 0; i < vectorsCount; i++) { + vectorBuffer.rewind(); + readFully(channel, vectorBuffer); + vectorBuffer.rewind(); + + if (vectorBuffer.getInt() != vectorDimensions) { + throw new IllegalStateException("Vector dimensions mismatch"); + } + + var vector = new float[vectorDimensions]; + for (var j = 0; j < vector.length; j++) { + vector[j] = vectorBuffer.get(); + } + vectors[i] = vector; + } + + return vectors; + } + } + + @SuppressWarnings("SameParameterValue") + public static int[][] readIVectors(Path siftSmallBase, int vectorDimensions) throws IOException { + try (var channel = FileChannel.open(siftSmallBase)) { + var vectorBuffer = ByteBuffer.allocate(Integer.BYTES * vectorDimensions + Integer.BYTES); + vectorBuffer.order(ByteOrder.LITTLE_ENDIAN); + + var vectorsCount = + (int) (channel.size() / ((long) Integer.BYTES * vectorDimensions + Integer.BYTES)); + var vectors = new int[vectorsCount][]; + for (var i = 0; i < vectorsCount; i++) { + vectorBuffer.rewind(); + readFully(channel, vectorBuffer); + vectorBuffer.rewind(); + + if (vectorBuffer.getInt() != vectorDimensions) { + throw new IllegalStateException("Vector dimensions mismatch"); + } + + var vector = new int[vectorDimensions]; + for (var j = 0; j < vector.length; j++) { + vector[j] = vectorBuffer.getInt(); + } + + vectors[i] = vector; + } + return vectors; + } + } + + private static void readFully(FileChannel siftSmallBaseChannel, ByteBuffer vectorBuffer) throws IOException { + while (vectorBuffer.remaining() > 0) { + var r = siftSmallBaseChannel.read(vectorBuffer); + if (r < 0) { + throw new EOFException(); + } + } + } +} diff --git a/vectoriadb-java-client/build.gradle.kts b/vectoriadb-java-client/build.gradle.kts index e69de29bb..7846a5452 100644 --- a/vectoriadb-java-client/build.gradle.kts +++ b/vectoriadb-java-client/build.gradle.kts @@ -0,0 +1,9 @@ +dependencies { + implementation(libs.grpc.java) + implementation(libs.grpc.protobuf) + implementation(libs.grpc.netty.shaded) + implementation(libs.grpc.stub) + implementation(libs.commons.net) + + implementation(project(":vectoriadb-interface")) +} \ No newline at end of file diff --git a/vectoriadb-java-client/src/main/java/jetbrains/vectoriadb/client/Distance.java b/vectoriadb-java-client/src/main/java/jetbrains/vectoriadb/client/Distance.java new file mode 100644 index 000000000..2d8e83a36 --- /dev/null +++ b/vectoriadb-java-client/src/main/java/jetbrains/vectoriadb/client/Distance.java @@ -0,0 +1,7 @@ +package jetbrains.vectoriadb.client; + +public enum Distance { + L2, + DOT, + COSINE +} diff --git a/vectoriadb-java-client/src/main/java/jetbrains/vectoriadb/client/IndexBuildStatusListener.java b/vectoriadb-java-client/src/main/java/jetbrains/vectoriadb/client/IndexBuildStatusListener.java new file mode 100644 index 000000000..9825f44e9 --- /dev/null +++ b/vectoriadb-java-client/src/main/java/jetbrains/vectoriadb/client/IndexBuildStatusListener.java @@ -0,0 +1,10 @@ +package jetbrains.vectoriadb.client; + +import java.util.List; + +public interface IndexBuildStatusListener { + boolean onIndexBuildStatusUpdate(String indexName, List phases); + + record Phase(String name, double progress, String... parameters) { + } +} diff --git a/vectoriadb-java-client/src/main/java/jetbrains/vectoriadb/client/IndexMetadata.java b/vectoriadb-java-client/src/main/java/jetbrains/vectoriadb/client/IndexMetadata.java new file mode 100644 index 000000000..b2d9b4054 --- /dev/null +++ b/vectoriadb-java-client/src/main/java/jetbrains/vectoriadb/client/IndexMetadata.java @@ -0,0 +1,4 @@ +package jetbrains.vectoriadb.client; + +public record IndexMetadata(int maximumConnectionsPerVertex, int maximumCandidatesReturned, int compressionRatio, float distanceMultiplier) { +} diff --git a/vectoriadb-java-client/src/main/java/jetbrains/vectoriadb/client/IndexState.java b/vectoriadb-java-client/src/main/java/jetbrains/vectoriadb/client/IndexState.java new file mode 100644 index 000000000..e2f53d935 --- /dev/null +++ b/vectoriadb-java-client/src/main/java/jetbrains/vectoriadb/client/IndexState.java @@ -0,0 +1,12 @@ +package jetbrains.vectoriadb.client; + +public enum IndexState { + CREATING, + CREATED, + UPLOADING, + UPLOADED, + IN_BUILD_QUEUE, + BUILDING, + BUILT, + BROKEN +} diff --git a/vectoriadb-java-client/src/main/java/jetbrains/vectoriadb/client/VectoriaDBClient.java b/vectoriadb-java-client/src/main/java/jetbrains/vectoriadb/client/VectoriaDBClient.java new file mode 100644 index 000000000..1c4755bbc --- /dev/null +++ b/vectoriadb-java-client/src/main/java/jetbrains/vectoriadb/client/VectoriaDBClient.java @@ -0,0 +1,214 @@ +package jetbrains.vectoriadb.client; + +import com.google.protobuf.Empty; +import io.grpc.Context; +import io.grpc.ManagedChannelBuilder; +import io.grpc.stub.StreamObserver; +import jetbrains.vectoriadb.service.base.IndexManagerGrpc; +import jetbrains.vectoriadb.service.base.IndexManagerOuterClass; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.concurrent.CountDownLatch; + + +public final class VectoriaDBClient { + private static final Logger logger = LoggerFactory.getLogger(VectoriaDBClient.class); + private final IndexManagerGrpc.IndexManagerBlockingStub indexManagerBlockingStub; + private final IndexManagerGrpc.IndexManagerStub indexManagerAsyncStub; + + public VectoriaDBClient(String host) { + this(host, 9090); + } + + public VectoriaDBClient(String host, int port) { + var channel = ManagedChannelBuilder.forAddress(host, port).build(); + this.indexManagerBlockingStub = IndexManagerGrpc.newBlockingStub(channel); + this.indexManagerAsyncStub = IndexManagerGrpc.newStub(channel); + } + + public IndexMetadata createIndex(final String indexName, final Distance distance) { + var builder = IndexManagerOuterClass.CreateIndexRequest.newBuilder(); + builder.setIndexName(indexName); + + switch (distance) { + case L2: + builder.setDistance(IndexManagerOuterClass.Distance.L2); + break; + case DOT: + builder.setDistance(IndexManagerOuterClass.Distance.DOT); + break; + case COSINE: + builder.setDistance(IndexManagerOuterClass.Distance.COSINE); + break; + } + + var request = builder.build(); + var response = indexManagerBlockingStub.createIndex(request); + + return new IndexMetadata(response.getMaximumConnectionsPerVertex(), response.getMaximumCandidatesReturned(), + response.getCompressionRatio(), response.getDistanceMultiplier()); + } + + public void buildIndex(final String indexName) { + var builder = IndexManagerOuterClass.IndexNameRequest.newBuilder(); + builder.setIndexName(indexName); + + var request = builder.build(); + //noinspection ResultOfMethodCallIgnored + indexManagerBlockingStub.buildIndex(request); + } + + public void dropIndex(final String indexName) { + var builder = IndexManagerOuterClass.IndexNameRequest.newBuilder(); + builder.setIndexName(indexName); + + var request = builder.build(); + //noinspection ResultOfMethodCallIgnored + indexManagerBlockingStub.dropIndex(request); + } + + public List listIndexes() { + var builder = Empty.newBuilder(); + var request = builder.build(); + + var response = indexManagerBlockingStub.indexList(request); + return response.getIndexNamesList(); + } + + IndexState indexState(String indexName) { + var builder = IndexManagerOuterClass.IndexNameRequest.newBuilder(); + builder.setIndexName(indexName); + + var request = builder.build(); + var response = indexManagerBlockingStub.indexState(request); + + return switch (response.getState()) { + case CREATING -> IndexState.CREATING; + case CREATED -> IndexState.CREATED; + case UPLOADING -> IndexState.UPLOADING; + case UPLOADED -> IndexState.UPLOADED; + case IN_BUILD_QUEUE -> IndexState.IN_BUILD_QUEUE; + case BUILDING -> IndexState.BUILDING; + case BUILT -> IndexState.BUILT; + case BROKEN -> IndexState.BROKEN; + default -> throw new IllegalStateException("Unexpected value: " + response.getState()); + }; + } + + public void uploadVectors(final String indexName, final Iterator vectors) { + var error = new Throwable[1]; + var finishedLatch = new CountDownLatch(1); + var responseObserver = new StreamObserver() { + @Override + public void onNext(Empty value) { + //ignore + } + + @Override + public void onError(Throwable t) { + logger.error("Error while uploading vectors", t); + error[0] = t; + finishedLatch.countDown(); + } + + @Override + public void onCompleted() { + finishedLatch.countDown(); + } + }; + + var requestObserver = indexManagerAsyncStub.uploadData(responseObserver); + try { + while (vectors.hasNext()) { + var vector = vectors.next(); + var builder = IndexManagerOuterClass.UploadDataRequest.newBuilder(); + builder.setIndexName(indexName); + + for (var value : vector) { + builder.addVectorComponents(value); + } + + var request = builder.build(); + requestObserver.onNext(request); + + if (finishedLatch.getCount() == 0) { + break; + } + } + } catch (RuntimeException e) { + requestObserver.onError(e); + throw e; + } + + responseObserver.onCompleted(); + try { + finishedLatch.await(); + } catch (InterruptedException e) { + logger.error("Error while uploading vectors", e); + Thread.currentThread().interrupt(); + } + + if (error[0] != null) { + logger.error("Error while uploading vectors", error[0]); + throw new RuntimeException(error[0]); + } + } + + public void switchToSearchMode() { + var builder = Empty.newBuilder(); + var request = builder.build(); + //noinspection ResultOfMethodCallIgnored + indexManagerBlockingStub.switchToSearchMode(request); + } + + public void switchToBuildMode() { + var builder = Empty.newBuilder(); + var request = builder.build(); + //noinspection ResultOfMethodCallIgnored + indexManagerBlockingStub.switchToBuildMode(request); + } + + public int[] findNearestNeighbours(final String indexName, int k) { + var builder = IndexManagerOuterClass.FindNearestNeighboursRequest.newBuilder(); + builder.setIndexName(indexName); + builder.setK(k); + + var request = builder.build(); + var response = indexManagerBlockingStub.findNearestNeighbours(request); + + return response.getIdsList().stream().mapToInt(Integer::intValue).toArray(); + } + + public void buildStatus(IndexBuildStatusListener buildStatusListener) { + var builder = Empty.newBuilder(); + var request = builder.build(); + + try (var cancellation = Context.current().withCancellation()) { + var response = indexManagerBlockingStub.buildStatus(request); + while (response.hasNext()) { + var status = response.next(); + var indexName = status.getIndexName(); + + var phasesResponse = status.getPhasesList(); + var phases = new ArrayList(phasesResponse.size()); + + for (var phase : phasesResponse) { + var phaseName = phase.getName(); + var progress = phase.getCompletionPercentage(); + var parameters = phase.getParametersList().toArray(new String[0]); + + phases.add(new IndexBuildStatusListener.Phase(phaseName, progress, parameters)); + } + + if (!buildStatusListener.onIndexBuildStatusUpdate(indexName, phases)) { + cancellation.cancel(new InterruptedException("Cancelled by build status listener")); + break; + } + } + } + } +}