Skip to content

Commit

Permalink
First version of VectoriaDB client and benchmark.
Browse files Browse the repository at this point in the history
  • Loading branch information
andrii0lomakin committed Nov 2, 2023
1 parent ca2546d commit c677585
Show file tree
Hide file tree
Showing 3 changed files with 299 additions and 19 deletions.
20 changes: 19 additions & 1 deletion vectoriadb-bench/build.gradle.kts
Original file line number Diff line number Diff line change
@@ -1,5 +1,23 @@
dependencies {
implementation(libs.commons.net)
implementation(project(":vectoriadb-java-client"))
}

implementation(project(":vectoriadb-index"))
tasks {
register<JavaExec>("runSift1MBench") {
group = "application"
mainClass = "jetbrains.vectoriadb.bench.Sift1MBench"
classpath = sourceSets["main"].runtimeClasspath
jvmArgs = listOf(
"--add-modules",
"jdk.incubator.vector",
"-Djava.awt.headless=true",
"--enable-preview"
)
systemProperties = mapOf(
"bench.path" to (project.findProperty("bench.path")),
"vectoriadb.host" to (project.findProperty("vectoriadb.host")),
"vectoriadb.port" to (project.findProperty("vectoriadb.port"))
)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
/*
* Copyright ${inceptionYear} - ${year} ${owner}
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package jetbrains.vectoriadb.bench;

import jetbrains.vectoriadb.client.Distance;
import jetbrains.vectoriadb.client.IndexBuildStatusListener;
import jetbrains.vectoriadb.client.IndexState;
import jetbrains.vectoriadb.client.VectoriaDBClient;

import java.nio.file.Path;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicBoolean;

public class Sift1MBench {
public static void main(String[] args) {
try {
var benchPathStr = System.getProperty("bench.path");
var benchPath = Path.of(Objects.requireNonNullElse(benchPathStr, "."));
var rootDir = benchPath.resolve("sift1m");
var siftArchiveName = "sift.tar.gz";
var vectorDimensions = 128;

System.out.println("Working directory: " + rootDir.toAbsolutePath());

var siftArchivePath = BenchUtils.downloadBenchFile(rootDir, siftArchiveName);
BenchUtils.extractTarGzArchive(rootDir, siftArchivePath);

var siftDataName = "sift_base.fvecs";
var vectors = BenchUtils.readFVectors(rootDir.resolve(siftDataName), vectorDimensions);

var indexName = "sift1m";
System.out.printf("%d data vectors loaded with dimension %d, building index %s...%n",
vectors.length, vectorDimensions, indexName);

var vectoriaDBHost = System.getProperty("vectoriadb.host", "localhost");
Objects.requireNonNull(vectoriaDBHost, "Server host is not provided");

var vectoriaDBPort = Integer.parseInt(System.getProperty("vectoriadb.port", "9090"));
var client = new VectoriaDBClient(vectoriaDBHost, vectoriaDBPort);


var ts1 = System.currentTimeMillis();
client.createIndex(indexName, Distance.L2);
var ts2 = System.currentTimeMillis();
System.out.printf("Index %s created in %d ms%n", indexName, ts2 - ts1);

ts1 = System.currentTimeMillis();
client.uploadVectors(indexName, vectors);
ts2 = System.currentTimeMillis();
System.out.printf("%d vectors uploaded in %d ms%n", vectors.length, ts2 - ts1);

ts1 = System.currentTimeMillis();
client.buildIndex(indexName);

var stopPrintStatus = new AtomicBoolean();

client.buildStatusAsync((name, phases) -> {
printStatus(name, phases);
return !stopPrintStatus.get();
});

while (true) {
var indexState = client.indexState(indexName);
if (indexState != IndexState.BUILDING && indexState != IndexState.BUILT &&
indexState != IndexState.IN_BUILD_QUEUE) {
throw new IllegalStateException("Unexpected index state: " + indexState);
}

if (indexState == IndexState.BUILT) {
break;
}
}

ts2 = System.currentTimeMillis();
System.out.printf("Index %s built in %d ms%n", indexName, ts2 - ts1);

var queryFileName = "sift_query.fvecs";

System.out.println("Reading queries...");
var queryFile = rootDir.resolve(queryFileName);
var queryVectors = BenchUtils.readFVectors(queryFile, vectorDimensions);

System.out.println(queryVectors.length + " queries are read");
System.out.println("Reading ground truth...");

var groundTruthFileName = "sift_groundtruth.ivecs";
var groundTruthFile = rootDir.resolve(groundTruthFileName);
var groundTruth = BenchUtils.readIVectors(groundTruthFile, 100);

System.out.println("Ground truth is read, searching...");
System.out.println("Warming up ...");

for (int i = 0; i < 10; i++) {
for (float[] vector : queryVectors) {
client.findNearestNeighbours(indexName, vector, 1);
}
}

System.out.println("Benchmark ...");
for (int i = 0; i < 50; i++) {
ts1 = System.nanoTime();
var errorsCount = 0;
for (var index = 0; index < queryVectors.length; index++) {
var vector = queryVectors[index];

var result = client.findNearestNeighbours(indexName, vector, 1);
if (groundTruth[index][0] != result[0]) {
errorsCount++;
}
}
ts2 = System.nanoTime();
var errorPercentage = errorsCount * 100.0 / queryVectors.length;

System.out.printf("Avg. query time : %d us, errors: %f%% %n",
(ts2 - ts1) / 1000 / queryVectors.length, errorPercentage);
}
} catch (Exception e) {
//noinspection CallToPrintStackTrace
e.printStackTrace();
throw new RuntimeException(e);
}
}

private static void printStatus(String indexName, List<IndexBuildStatusListener.Phase> phases) {
if (indexName == null || phases.isEmpty()) {
return;
}

StringBuilder builder = new StringBuilder();
builder.append(indexName).append(" : ");

int counter = 0;
for (var phase : phases) {
if (counter > 0) {
builder.append(" -> ");
}

builder.append(phase.name());
var parameters = phase.parameters();

if (parameters.length > 0) {
builder.append(" ");
}

for (int j = 0; j < parameters.length; j += 2) {
builder.append("{");
builder.append(parameters[j]);
builder.append(":");
builder.append(parameters[j + 1]);
builder.append("}");

if (j < parameters.length - 2) {
builder.append(", ");
}
}
if (phase.progress() >= 0) {
builder.append(" [").append(String.format("%.2f", phase.progress())).append("%]");
}
counter++;
}

System.out.println(builder);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ public List<String> listIndexes() {
return response.getIndexNamesList();
}

IndexState indexState(String indexName) {
public IndexState indexState(String indexName) {
var builder = IndexManagerOuterClass.IndexNameRequest.newBuilder();
builder.setIndexName(indexName);

Expand All @@ -115,6 +115,14 @@ IndexState indexState(String indexName) {
}

public void uploadVectors(final String indexName, final Iterator<float[]> vectors) {
uploadVectors(indexName, vectors, VectoriaDBClient::uploadVectorsList);
}

public void uploadVectors(final String indexName, final float[][] vectors) {
uploadVectors(indexName, vectors, VectoriaDBClient::uploadVectorsArray);
}

private <T> void uploadVectors(String indexName, T vectors, VectorsUploader<T> vectorsUploader) {
var error = new Throwable[1];
var finishedLatch = new CountDownLatch(1);
var responseObserver = new StreamObserver<Empty>() {
Expand All @@ -138,22 +146,7 @@ public void onCompleted() {

var requestObserver = indexManagerAsyncStub.uploadData(responseObserver);
try {
while (vectors.hasNext()) {
var vector = vectors.next();
var builder = IndexManagerOuterClass.UploadDataRequest.newBuilder();
builder.setIndexName(indexName);

for (var value : vector) {
builder.addVectorComponents(value);
}

var request = builder.build();
requestObserver.onNext(request);

if (finishedLatch.getCount() == 0) {
break;
}
}
vectorsUploader.uploadVectors(indexName, vectors, requestObserver, finishedLatch);
} catch (RuntimeException e) {
requestObserver.onError(e);
throw e;
Expand All @@ -173,6 +166,47 @@ public void onCompleted() {
}
}

private static void uploadVectorsList(String indexName, Iterator<float[]> vectors,
StreamObserver<IndexManagerOuterClass.UploadDataRequest> requestObserver,
CountDownLatch finishedLatch) {
while (vectors.hasNext()) {
var vector = vectors.next();
var builder = IndexManagerOuterClass.UploadDataRequest.newBuilder();
builder.setIndexName(indexName);

for (var value : vector) {
builder.addVectorComponents(value);
}

var request = builder.build();
requestObserver.onNext(request);

if (finishedLatch.getCount() == 0) {
break;
}
}
}

private static void uploadVectorsArray(String indexName, float[][] vectors,
StreamObserver<IndexManagerOuterClass.UploadDataRequest> requestObserver,
CountDownLatch finishedLatch) {
for (var vector : vectors) {
var builder = IndexManagerOuterClass.UploadDataRequest.newBuilder();
builder.setIndexName(indexName);

for (var value : vector) {
builder.addVectorComponents(value);
}

var request = builder.build();
requestObserver.onNext(request);

if (finishedLatch.getCount() == 0) {
break;
}
}
}

public void switchToSearchMode() {
var builder = Empty.newBuilder();
var request = builder.build();
Expand All @@ -187,11 +221,15 @@ public void switchToBuildMode() {
indexManagerBlockingStub.switchToBuildMode(request);
}

public int[] findNearestNeighbours(final String indexName, int k) {
public int[] findNearestNeighbours(final String indexName, final float[] vector, int k) {
var builder = IndexManagerOuterClass.FindNearestNeighboursRequest.newBuilder();
builder.setIndexName(indexName);
builder.setK(k);

for (var vectorComponent : vector) {
builder.addVectorComponents(vectorComponent);
}

var request = builder.build();
var response = indexManagerBlockingStub.findNearestNeighbours(request);

Expand Down Expand Up @@ -226,4 +264,50 @@ public void buildStatus(IndexBuildStatusListener buildStatusListener) {
}
}
}

public void buildStatusAsync(IndexBuildStatusListener buildStatusListener) {
var builder = Empty.newBuilder();
var request = builder.build();

try (var cancellation = Context.current().withCancellation()) {
indexManagerAsyncStub.buildStatus(request, new StreamObserver<>() {
@Override
public void onNext(IndexManagerOuterClass.BuildStatusResponse value) {
var indexName = value.getIndexName();

var phasesResponse = value.getPhasesList();
var phases = new ArrayList<IndexBuildStatusListener.Phase>(phasesResponse.size());

for (var phase : phasesResponse) {
var phaseName = phase.getName();
var progress = phase.getCompletionPercentage();
var parameters = phase.getParametersList().toArray(new String[0]);

phases.add(new IndexBuildStatusListener.Phase(phaseName, progress, parameters));
}

if (!buildStatusListener.onIndexBuildStatusUpdate(indexName, phases)) {
cancellation.cancel(new InterruptedException("Cancelled by build status listener"));
}
}

@Override
public void onError(Throwable t) {
logger.error("Error while getting build status", t);
cancellation.cancel(t);
}

@Override
public void onCompleted() {
//ignore
}
});
}
}

private interface VectorsUploader<T> {
void uploadVectors(String indexName, T vectors,
StreamObserver<IndexManagerOuterClass.UploadDataRequest> requestObserver,
CountDownLatch finishedLatch);
}
}

0 comments on commit c677585

Please sign in to comment.