From c677585b4deade825f7ea87e600b16f3ec5c8b33 Mon Sep 17 00:00:00 2001 From: laa Date: Thu, 2 Nov 2023 14:19:23 +0100 Subject: [PATCH] First version of VectoriaDB client and benchmark. --- vectoriadb-bench/build.gradle.kts | 20 +- .../vectoriadb/bench/Sift1MBench.java | 178 ++++++++++++++++++ .../vectoriadb/client/VectoriaDBClient.java | 120 ++++++++++-- 3 files changed, 299 insertions(+), 19 deletions(-) create mode 100644 vectoriadb-bench/src/main/java/jetbrains/vectoriadb/bench/Sift1MBench.java diff --git a/vectoriadb-bench/build.gradle.kts b/vectoriadb-bench/build.gradle.kts index 3bc767c17..3809099ea 100644 --- a/vectoriadb-bench/build.gradle.kts +++ b/vectoriadb-bench/build.gradle.kts @@ -1,5 +1,23 @@ dependencies { implementation(libs.commons.net) + implementation(project(":vectoriadb-java-client")) +} - implementation(project(":vectoriadb-index")) +tasks { + register("runSift1MBench") { + group = "application" + mainClass = "jetbrains.vectoriadb.bench.Sift1MBench" + classpath = sourceSets["main"].runtimeClasspath + jvmArgs = listOf( + "--add-modules", + "jdk.incubator.vector", + "-Djava.awt.headless=true", + "--enable-preview" + ) + systemProperties = mapOf( + "bench.path" to (project.findProperty("bench.path")), + "vectoriadb.host" to (project.findProperty("vectoriadb.host")), + "vectoriadb.port" to (project.findProperty("vectoriadb.port")) + ) + } } \ No newline at end of file diff --git a/vectoriadb-bench/src/main/java/jetbrains/vectoriadb/bench/Sift1MBench.java b/vectoriadb-bench/src/main/java/jetbrains/vectoriadb/bench/Sift1MBench.java new file mode 100644 index 000000000..b4b5906bd --- /dev/null +++ b/vectoriadb-bench/src/main/java/jetbrains/vectoriadb/bench/Sift1MBench.java @@ -0,0 +1,178 @@ +/* + * Copyright ${inceptionYear} - ${year} ${owner} + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package jetbrains.vectoriadb.bench; + +import jetbrains.vectoriadb.client.Distance; +import jetbrains.vectoriadb.client.IndexBuildStatusListener; +import jetbrains.vectoriadb.client.IndexState; +import jetbrains.vectoriadb.client.VectoriaDBClient; + +import java.nio.file.Path; +import java.util.List; +import java.util.Objects; +import java.util.concurrent.atomic.AtomicBoolean; + +public class Sift1MBench { + public static void main(String[] args) { + try { + var benchPathStr = System.getProperty("bench.path"); + var benchPath = Path.of(Objects.requireNonNullElse(benchPathStr, ".")); + var rootDir = benchPath.resolve("sift1m"); + var siftArchiveName = "sift.tar.gz"; + var vectorDimensions = 128; + + System.out.println("Working directory: " + rootDir.toAbsolutePath()); + + var siftArchivePath = BenchUtils.downloadBenchFile(rootDir, siftArchiveName); + BenchUtils.extractTarGzArchive(rootDir, siftArchivePath); + + var siftDataName = "sift_base.fvecs"; + var vectors = BenchUtils.readFVectors(rootDir.resolve(siftDataName), vectorDimensions); + + var indexName = "sift1m"; + System.out.printf("%d data vectors loaded with dimension %d, building index %s...%n", + vectors.length, vectorDimensions, indexName); + + var vectoriaDBHost = System.getProperty("vectoriadb.host", "localhost"); + Objects.requireNonNull(vectoriaDBHost, "Server host is not provided"); + + var vectoriaDBPort = Integer.parseInt(System.getProperty("vectoriadb.port", "9090")); + var client = new VectoriaDBClient(vectoriaDBHost, vectoriaDBPort); + + + var ts1 = System.currentTimeMillis(); + client.createIndex(indexName, Distance.L2); + var ts2 = System.currentTimeMillis(); + System.out.printf("Index %s created in %d ms%n", indexName, ts2 - ts1); + + ts1 = System.currentTimeMillis(); + client.uploadVectors(indexName, vectors); + ts2 = System.currentTimeMillis(); + System.out.printf("%d vectors uploaded in %d ms%n", vectors.length, ts2 - ts1); + + ts1 = System.currentTimeMillis(); + client.buildIndex(indexName); + + var stopPrintStatus = new AtomicBoolean(); + + client.buildStatusAsync((name, phases) -> { + printStatus(name, phases); + return !stopPrintStatus.get(); + }); + + while (true) { + var indexState = client.indexState(indexName); + if (indexState != IndexState.BUILDING && indexState != IndexState.BUILT && + indexState != IndexState.IN_BUILD_QUEUE) { + throw new IllegalStateException("Unexpected index state: " + indexState); + } + + if (indexState == IndexState.BUILT) { + break; + } + } + + ts2 = System.currentTimeMillis(); + System.out.printf("Index %s built in %d ms%n", indexName, ts2 - ts1); + + var queryFileName = "sift_query.fvecs"; + + System.out.println("Reading queries..."); + var queryFile = rootDir.resolve(queryFileName); + var queryVectors = BenchUtils.readFVectors(queryFile, vectorDimensions); + + System.out.println(queryVectors.length + " queries are read"); + System.out.println("Reading ground truth..."); + + var groundTruthFileName = "sift_groundtruth.ivecs"; + var groundTruthFile = rootDir.resolve(groundTruthFileName); + var groundTruth = BenchUtils.readIVectors(groundTruthFile, 100); + + System.out.println("Ground truth is read, searching..."); + System.out.println("Warming up ..."); + + for (int i = 0; i < 10; i++) { + for (float[] vector : queryVectors) { + client.findNearestNeighbours(indexName, vector, 1); + } + } + + System.out.println("Benchmark ..."); + for (int i = 0; i < 50; i++) { + ts1 = System.nanoTime(); + var errorsCount = 0; + for (var index = 0; index < queryVectors.length; index++) { + var vector = queryVectors[index]; + + var result = client.findNearestNeighbours(indexName, vector, 1); + if (groundTruth[index][0] != result[0]) { + errorsCount++; + } + } + ts2 = System.nanoTime(); + var errorPercentage = errorsCount * 100.0 / queryVectors.length; + + System.out.printf("Avg. query time : %d us, errors: %f%% %n", + (ts2 - ts1) / 1000 / queryVectors.length, errorPercentage); + } + } catch (Exception e) { + //noinspection CallToPrintStackTrace + e.printStackTrace(); + throw new RuntimeException(e); + } + } + + private static void printStatus(String indexName, List phases) { + if (indexName == null || phases.isEmpty()) { + return; + } + + StringBuilder builder = new StringBuilder(); + builder.append(indexName).append(" : "); + + int counter = 0; + for (var phase : phases) { + if (counter > 0) { + builder.append(" -> "); + } + + builder.append(phase.name()); + var parameters = phase.parameters(); + + if (parameters.length > 0) { + builder.append(" "); + } + + for (int j = 0; j < parameters.length; j += 2) { + builder.append("{"); + builder.append(parameters[j]); + builder.append(":"); + builder.append(parameters[j + 1]); + builder.append("}"); + + if (j < parameters.length - 2) { + builder.append(", "); + } + } + if (phase.progress() >= 0) { + builder.append(" [").append(String.format("%.2f", phase.progress())).append("%]"); + } + counter++; + } + + System.out.println(builder); + } +} diff --git a/vectoriadb-java-client/src/main/java/jetbrains/vectoriadb/client/VectoriaDBClient.java b/vectoriadb-java-client/src/main/java/jetbrains/vectoriadb/client/VectoriaDBClient.java index 8b3e5f95d..a28dcfb47 100644 --- a/vectoriadb-java-client/src/main/java/jetbrains/vectoriadb/client/VectoriaDBClient.java +++ b/vectoriadb-java-client/src/main/java/jetbrains/vectoriadb/client/VectoriaDBClient.java @@ -94,7 +94,7 @@ public List listIndexes() { return response.getIndexNamesList(); } - IndexState indexState(String indexName) { + public IndexState indexState(String indexName) { var builder = IndexManagerOuterClass.IndexNameRequest.newBuilder(); builder.setIndexName(indexName); @@ -115,6 +115,14 @@ IndexState indexState(String indexName) { } public void uploadVectors(final String indexName, final Iterator vectors) { + uploadVectors(indexName, vectors, VectoriaDBClient::uploadVectorsList); + } + + public void uploadVectors(final String indexName, final float[][] vectors) { + uploadVectors(indexName, vectors, VectoriaDBClient::uploadVectorsArray); + } + + private void uploadVectors(String indexName, T vectors, VectorsUploader vectorsUploader) { var error = new Throwable[1]; var finishedLatch = new CountDownLatch(1); var responseObserver = new StreamObserver() { @@ -138,22 +146,7 @@ public void onCompleted() { var requestObserver = indexManagerAsyncStub.uploadData(responseObserver); try { - while (vectors.hasNext()) { - var vector = vectors.next(); - var builder = IndexManagerOuterClass.UploadDataRequest.newBuilder(); - builder.setIndexName(indexName); - - for (var value : vector) { - builder.addVectorComponents(value); - } - - var request = builder.build(); - requestObserver.onNext(request); - - if (finishedLatch.getCount() == 0) { - break; - } - } + vectorsUploader.uploadVectors(indexName, vectors, requestObserver, finishedLatch); } catch (RuntimeException e) { requestObserver.onError(e); throw e; @@ -173,6 +166,47 @@ public void onCompleted() { } } + private static void uploadVectorsList(String indexName, Iterator vectors, + StreamObserver requestObserver, + CountDownLatch finishedLatch) { + while (vectors.hasNext()) { + var vector = vectors.next(); + var builder = IndexManagerOuterClass.UploadDataRequest.newBuilder(); + builder.setIndexName(indexName); + + for (var value : vector) { + builder.addVectorComponents(value); + } + + var request = builder.build(); + requestObserver.onNext(request); + + if (finishedLatch.getCount() == 0) { + break; + } + } + } + + private static void uploadVectorsArray(String indexName, float[][] vectors, + StreamObserver requestObserver, + CountDownLatch finishedLatch) { + for (var vector : vectors) { + var builder = IndexManagerOuterClass.UploadDataRequest.newBuilder(); + builder.setIndexName(indexName); + + for (var value : vector) { + builder.addVectorComponents(value); + } + + var request = builder.build(); + requestObserver.onNext(request); + + if (finishedLatch.getCount() == 0) { + break; + } + } + } + public void switchToSearchMode() { var builder = Empty.newBuilder(); var request = builder.build(); @@ -187,11 +221,15 @@ public void switchToBuildMode() { indexManagerBlockingStub.switchToBuildMode(request); } - public int[] findNearestNeighbours(final String indexName, int k) { + public int[] findNearestNeighbours(final String indexName, final float[] vector, int k) { var builder = IndexManagerOuterClass.FindNearestNeighboursRequest.newBuilder(); builder.setIndexName(indexName); builder.setK(k); + for (var vectorComponent : vector) { + builder.addVectorComponents(vectorComponent); + } + var request = builder.build(); var response = indexManagerBlockingStub.findNearestNeighbours(request); @@ -226,4 +264,50 @@ public void buildStatus(IndexBuildStatusListener buildStatusListener) { } } } + + public void buildStatusAsync(IndexBuildStatusListener buildStatusListener) { + var builder = Empty.newBuilder(); + var request = builder.build(); + + try (var cancellation = Context.current().withCancellation()) { + indexManagerAsyncStub.buildStatus(request, new StreamObserver<>() { + @Override + public void onNext(IndexManagerOuterClass.BuildStatusResponse value) { + var indexName = value.getIndexName(); + + var phasesResponse = value.getPhasesList(); + var phases = new ArrayList(phasesResponse.size()); + + for (var phase : phasesResponse) { + var phaseName = phase.getName(); + var progress = phase.getCompletionPercentage(); + var parameters = phase.getParametersList().toArray(new String[0]); + + phases.add(new IndexBuildStatusListener.Phase(phaseName, progress, parameters)); + } + + if (!buildStatusListener.onIndexBuildStatusUpdate(indexName, phases)) { + cancellation.cancel(new InterruptedException("Cancelled by build status listener")); + } + } + + @Override + public void onError(Throwable t) { + logger.error("Error while getting build status", t); + cancellation.cancel(t); + } + + @Override + public void onCompleted() { + //ignore + } + }); + } + } + + private interface VectorsUploader { + void uploadVectors(String indexName, T vectors, + StreamObserver requestObserver, + CountDownLatch finishedLatch); + } }