From d2c2764a2cd1ea0db5cb62e842c05ba7803afc3c Mon Sep 17 00:00:00 2001 From: Jelmer Kuperus Date: Thu, 28 Nov 2024 07:26:06 +0100 Subject: [PATCH] Redesign spark integration to take advantage of resource profiles. --- .github/workflows/ci.yml | 6 - .github/workflows/publish.yml | 6 - build.sbt | 48 +- hnswlib-spark/src/main/protobuf/index.proto | 65 + .../src/main/protobuf/registration.proto | 18 + .../src/main/python/pyspark_hnsw/knn.py | 24 +- .../server/index/DefaultIndexService.scala | 64 + .../server/index/IndexServerFactory.scala | 76 ++ .../DefaultRegistrationService.scala | 28 + .../registration/RegistrationClient.scala | 41 + .../RegistrationServerFactory.scala | 47 + .../jelmerk/spark/knn/IndexClient.scala | 167 +++ .../spark/knn/IndexClientFactory.scala | 19 + .../jelmerk/spark/knn/KnnAlgorithm.scala | 1087 +++++++---------- .../jelmerk/spark/knn/QueryIterator.scala | 32 + .../knn/bruteforce/BruteForceSimilarity.scala | 124 +- .../spark/knn/hnsw/HnswSimilarity.scala | 126 +- .../com/github/jelmerk/spark/knn/knn.scala | 313 ++++- hnswlib-spark/src/test/scala/ClientTest.scala | 118 ++ .../test/scala/RegistrationClientTest.scala | 33 + .../test/scala/RegistrationServerTest.scala | 59 + hnswlib-spark/src/test/scala/ServerTest.scala | 55 + project/plugins.sbt | 3 + 23 files changed, 1795 insertions(+), 764 deletions(-) create mode 100644 hnswlib-spark/src/main/protobuf/index.proto create mode 100644 hnswlib-spark/src/main/protobuf/registration.proto create mode 100644 hnswlib-spark/src/main/scala/com/github/jelmerk/server/index/DefaultIndexService.scala create mode 100644 hnswlib-spark/src/main/scala/com/github/jelmerk/server/index/IndexServerFactory.scala create mode 100644 hnswlib-spark/src/main/scala/com/github/jelmerk/server/registration/DefaultRegistrationService.scala create mode 100644 hnswlib-spark/src/main/scala/com/github/jelmerk/server/registration/RegistrationClient.scala create mode 100644 hnswlib-spark/src/main/scala/com/github/jelmerk/server/registration/RegistrationServerFactory.scala create mode 100644 hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/IndexClient.scala create mode 100644 hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/IndexClientFactory.scala create mode 100644 hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/QueryIterator.scala create mode 100644 hnswlib-spark/src/test/scala/ClientTest.scala create mode 100644 hnswlib-spark/src/test/scala/RegistrationClientTest.scala create mode 100644 hnswlib-spark/src/test/scala/RegistrationServerTest.scala create mode 100644 hnswlib-spark/src/test/scala/ServerTest.scala diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0196e7e..485b898 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -20,11 +20,6 @@ jobs: fail-fast: false matrix: spark: - - 2.4.8 - - 3.0.2 - - 3.1.3 - - 3.2.4 - - 3.3.2 - 3.4.1 - 3.5.0 env: @@ -39,7 +34,6 @@ jobs: - uses: actions/setup-python@v5 with: python-version: | - 3.7 3.9 - name: Build and test run: | diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 7c8d74e..38ae08f 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -17,11 +17,6 @@ jobs: fail-fast: false matrix: spark: - - 2.4.8 - - 3.0.2 - - 3.1.3 - - 3.2.4 - - 3.3.2 - 3.4.1 - 3.5.0 @@ -39,7 +34,6 @@ jobs: - uses: actions/setup-python@v5 with: python-version: | - 3.7 3.9 - name: Import GPG Key uses: crazy-max/ghaction-import-gpg@v6 diff --git a/build.sbt b/build.sbt index c65413a..1c83563 100644 --- a/build.sbt +++ b/build.sbt @@ -1,5 +1,7 @@ import Path.relativeTo import sys.process.* +import scalapb.compiler.Version.scalapbVersion +import scalapb.compiler.Version.grpcJavaVersion ThisBuild / organization := "com.github.jelmerk" ThisBuild / scalaVersion := "2.12.18" @@ -59,15 +61,7 @@ lazy val hnswlibSpark = (project in file("hnswlib-spark")) .settings( name := s"hnswlib-spark_${sparkVersion.value.split('.').take(2).mkString("_")}", publishSettings, - crossScalaVersions := { - if (sparkVersion.value >= "3.2.0") { - Seq("2.12.18", "2.13.10") - } else if (sparkVersion.value >= "3.0.0") { - Seq("2.12.18") - } else { - Seq("2.12.18", "2.11.12") - } - }, + crossScalaVersions := Seq("2.12.18", "2.13.10"), autoScalaLibrary := false, Compile / unmanagedSourceDirectories += baseDirectory.value / "src" / "main" / "python", Test / unmanagedSourceDirectories += baseDirectory.value / "src" / "test" / "python", @@ -83,9 +77,24 @@ lazy val hnswlibSpark = (project in file("hnswlib-spark")) assembly / assemblyOption ~= { _.withIncludeScala(false) }, - sparkVersion := sys.props.getOrElse("sparkVersion", "3.3.2"), + assembly / assemblyMergeStrategy := { + case PathList("META-INF", "io.netty.versions.properties") => MergeStrategy.first + case x => + val oldStrategy = (ThisBuild / assemblyMergeStrategy).value + oldStrategy(x) + }, + assembly / assemblyShadeRules := Seq( + ShadeRule.rename("com.google.protobuf.**" -> "shaded.com.google.protobuf.@1").inAll, + ShadeRule.rename("com.google.common.**" -> "shaded.com.google.common.@1").inAll, + ShadeRule.rename("io.netty.**" -> "shaded.io.netty.@1").inAll + ), + Compile / PB.targets := Seq( + scalapb.gen() -> (Compile / sourceManaged).value / "scalapb" + ), + sparkVersion := sys.props.getOrElse("sparkVersion", "3.4.1"), +// sparkVersion := sys.props.getOrElse("sparkVersion", "3.5.3"), venvFolder := s"${baseDirectory.value}/.venv", - pythonVersion := (if (scalaVersion.value == "2.11.12") "python3.7" else "python3.9"), + pythonVersion := "python3.9", createVirtualEnv := { val ret = ( s"${pythonVersion.value} -m venv ${venvFolder.value}" #&& @@ -128,12 +137,15 @@ lazy val hnswlibSpark = (project in file("hnswlib-spark")) }, flake8 := flake8.dependsOn(createVirtualEnv).value, libraryDependencies ++= Seq( - "com.github.jelmerk" % "hnswlib-utils" % hnswLibVersion, - "com.github.jelmerk" % "hnswlib-core-jdk17" % hnswLibVersion, - "com.github.jelmerk" %% "hnswlib-scala" % hnswLibVersion, - "org.apache.spark" %% "spark-hive" % sparkVersion.value % Provided, - "org.apache.spark" %% "spark-mllib" % sparkVersion.value % Provided, - "com.holdenkarau" %% "spark-testing-base" % s"${sparkVersion.value}_1.4.7" % Test, - "org.scalatest" %% "scalatest" % "3.2.17" % Test + "com.github.jelmerk" % "hnswlib-utils" % hnswLibVersion, + "com.github.jelmerk" % "hnswlib-core-jdk17" % hnswLibVersion, + "com.github.jelmerk" %% "hnswlib-scala" % hnswLibVersion, + "com.thesamet.scalapb" %% "scalapb-runtime-grpc" % scalapbVersion, + "com.thesamet.scalapb" %% "scalapb-runtime" % scalapbVersion % "protobuf", + "io.grpc" % "grpc-netty" % grpcJavaVersion, + "org.apache.spark" %% "spark-hive" % sparkVersion.value % Provided, + "org.apache.spark" %% "spark-mllib" % sparkVersion.value % Provided, + "com.holdenkarau" %% "spark-testing-base" % s"${sparkVersion.value}_1.4.7" % Test, + "org.scalatest" %% "scalatest" % "3.2.17" % Test ) ) \ No newline at end of file diff --git a/hnswlib-spark/src/main/protobuf/index.proto b/hnswlib-spark/src/main/protobuf/index.proto new file mode 100644 index 0000000..c81fb2f --- /dev/null +++ b/hnswlib-spark/src/main/protobuf/index.proto @@ -0,0 +1,65 @@ +syntax = "proto3"; + +import "scalapb/scalapb.proto"; + +package com.github.jelmerk.server; + +service IndexService { + rpc Search (stream SearchRequest) returns (stream SearchResponse) {} + rpc SaveIndex (SaveIndexRequest) returns (SaveIndexResponse) {} +} + +message SearchRequest { + + oneof vector { + FloatArrayVector float_array_vector = 4; + DoubleArrayVector double_array_vector = 5; + SparseVector sparse_vector = 6; + DenseVector dense_vector = 7; + } + + int32 k = 8; +} + +message SearchResponse { + repeated Result results = 1; +} + +message Result { + oneof id { + string string_id = 1; + int64 long_id = 2; + int32 int_id = 3; + } + + oneof distance { + float float_distance = 4; + double double_distance = 5; + } +} + +message FloatArrayVector { + repeated float values = 1 [(scalapb.field).collection_type="Array"]; +} + +message DoubleArrayVector { + repeated double values = 1 [(scalapb.field).collection_type="Array"]; +} + +message SparseVector { + int32 size = 1; + repeated int32 indices = 2 [(scalapb.field).collection_type="Array"]; + repeated double values = 3 [(scalapb.field).collection_type="Array"]; +} + +message DenseVector { + repeated double values = 1 [(scalapb.field).collection_type="Array"]; +} + +message SaveIndexRequest { + string path = 1; +} + +message SaveIndexResponse { + int64 bytes_written = 1; +} \ No newline at end of file diff --git a/hnswlib-spark/src/main/protobuf/registration.proto b/hnswlib-spark/src/main/protobuf/registration.proto new file mode 100644 index 0000000..ced9439 --- /dev/null +++ b/hnswlib-spark/src/main/protobuf/registration.proto @@ -0,0 +1,18 @@ +syntax = "proto3"; + +package com.github.jelmerk.server; + +service RegistrationService { + rpc Register (RegisterRequest) returns ( RegisterResponse) {} +} + +message RegisterRequest { + int32 partition_num = 1; + int32 replica_num = 2; + string host = 3; + int32 port = 4; +} + +message RegisterResponse { +} + diff --git a/hnswlib-spark/src/main/python/pyspark_hnsw/knn.py b/hnswlib-spark/src/main/python/pyspark_hnsw/knn.py index d1ac2f1..a65d394 100644 --- a/hnswlib-spark/src/main/python/pyspark_hnsw/knn.py +++ b/hnswlib-spark/src/main/python/pyspark_hnsw/knn.py @@ -1,3 +1,5 @@ +from typing import Any, Dict, Optional, TYPE_CHECKING + from pyspark.ml.wrapper import JavaEstimator, JavaModel from pyspark.ml.param.shared import ( Params, @@ -10,8 +12,12 @@ # noinspection PyProtectedMember from pyspark import keyword_only +# noinspection PyProtectedMember from pyspark.ml.util import JavaMLReadable, JavaMLWritable, MLReader, _jvm +if TYPE_CHECKING: + from py4j.java_gateway import JavaObject + __all__ = [ "HnswSimilarity", "HnswSimilarityModel", @@ -31,6 +37,7 @@ def __init__(self, clazz, java_class): self._clazz = clazz self._jread = self._load_java_obj(java_class).read() + # noinspection PyProtectedMember def load(self, path): """Load the ML instance from the input path.""" java_obj = self._jread.load(path) @@ -132,25 +139,25 @@ def getK(self): """ return self.getOrDefault(self.k) - def getExcludeSelf(self): + def getExcludeSelf(self) -> bool: """ Gets the value of excludeSelf or its default value. """ return self.getOrDefault(self.excludeSelf) - def getSimilarityThreshold(self): + def getSimilarityThreshold(self) -> float: """ Gets the value of similarityThreshold or its default value. """ return self.getOrDefault(self.similarityThreshold) - def getOutputFormat(self): + def getOutputFormat(self) -> str: """ Gets the value of outputFormat or its default value. """ return self.getOrDefault(self.outputFormat) - def getNumReplicas(self): + def getNumReplicas(self) -> int: """ Gets the value of numReplicas or its default value. """ @@ -294,6 +301,9 @@ class BruteForceSimilarity(JavaEstimator, _KnnParams, JavaMLReadable, JavaMLWrit Exact nearest neighbour search. """ + _input_kwargs: Dict[str, Any] + + # noinspection PyUnusedLocal @keyword_only def __init__( self, @@ -410,6 +420,7 @@ def setInitialModelPath(self, value): """ return self._set(initialModelPath=value) + # noinspection PyUnusedLocal @keyword_only def setParams( self, @@ -507,6 +518,7 @@ class HnswSimilarity(JavaEstimator, _HnswParams, JavaMLReadable, JavaMLWritable) Approximate nearest neighbour search. """ + # noinspection PyUnusedLocal @keyword_only def __init__( self, @@ -647,7 +659,9 @@ def setInitialModelPath(self, value): """ return self._set(initialModelPath=value) - @keyword_only + # noinspection PyUnusedLocal + @keywor + d_only def setParams( self, identifierCol="id", diff --git a/hnswlib-spark/src/main/scala/com/github/jelmerk/server/index/DefaultIndexService.scala b/hnswlib-spark/src/main/scala/com/github/jelmerk/server/index/DefaultIndexService.scala new file mode 100644 index 0000000..bb481cb --- /dev/null +++ b/hnswlib-spark/src/main/scala/com/github/jelmerk/server/index/DefaultIndexService.scala @@ -0,0 +1,64 @@ +package com.github.jelmerk.server.index + +import scala.concurrent.{ExecutionContext, Future} +import scala.language.implicitConversions + +import com.github.jelmerk.knn.scalalike.{Index, Item} +import com.github.jelmerk.server.index.IndexServiceGrpc.IndexService +import io.grpc.stub.StreamObserver +import org.apache.commons.io.output.CountingOutputStream +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path + +class DefaultIndexService[TId, TVector, TItem <: Item[TId, TVector], TDistance]( + index: Index[TId, TVector, TItem, TDistance], + hadoopConfiguration: Configuration, + vectorExtractor: SearchRequest => TVector, + resultIdConverter: TId => Result.Id, + resultDistanceConverter: TDistance => Result.Distance +)(implicit + executionContext: ExecutionContext +) extends IndexService { + + override def search(responseObserver: StreamObserver[SearchResponse]): StreamObserver[SearchRequest] = { + new StreamObserver[SearchRequest] { + override def onNext(request: SearchRequest): Unit = { + + val vector = vectorExtractor(request) + val nearest = index.findNearest(vector, request.k) + + val results = nearest.map { searchResult => + val id = resultIdConverter(searchResult.item.id) + val distance = resultDistanceConverter(searchResult.distance) + + Result(id, distance) + } + + val response = SearchResponse(results) + responseObserver.onNext(response) + } + + override def onError(t: Throwable): Unit = { + responseObserver.onError(t) + } + + override def onCompleted(): Unit = { + responseObserver.onCompleted() + } + } + } + + override def saveIndex(request: SaveIndexRequest): Future[SaveIndexResponse] = Future { + val path = new Path(request.path) + val fileSystem = path.getFileSystem(hadoopConfiguration) + + val outputStream = fileSystem.create(path) + + val countingOutputStream = new CountingOutputStream(outputStream) + + index.save(countingOutputStream) + + SaveIndexResponse(bytesWritten = countingOutputStream.getCount) + } + +} diff --git a/hnswlib-spark/src/main/scala/com/github/jelmerk/server/index/IndexServerFactory.scala b/hnswlib-spark/src/main/scala/com/github/jelmerk/server/index/IndexServerFactory.scala new file mode 100644 index 0000000..6a3e25d --- /dev/null +++ b/hnswlib-spark/src/main/scala/com/github/jelmerk/server/index/IndexServerFactory.scala @@ -0,0 +1,76 @@ +package com.github.jelmerk.server.index + +import java.net.{InetAddress, InetSocketAddress} +import java.util.concurrent.{LinkedBlockingQueue, ThreadPoolExecutor, TimeUnit} + +import scala.concurrent.ExecutionContext +import scala.util.Try + +import com.github.jelmerk.knn.scalalike.{Index, Item} +import io.grpc.netty.NettyServerBuilder +import org.apache.hadoop.conf.Configuration + +class IndexServer[TId, TVector, TItem <: Item[TId, TVector] with Product, TDistance]( + vectorExtractor: SearchRequest => TVector, + resultIdConverter: TId => Result.Id, + resultDistanceConverter: TDistance => Result.Distance, + index: Index[TId, TVector, TItem, TDistance], + hadoopConfig: Configuration, + threads: Int +) { + private val executor = new ThreadPoolExecutor( + threads, + threads, + 0L, + TimeUnit.MILLISECONDS, + new LinkedBlockingQueue[Runnable]() + ) + + private val executionContext: ExecutionContext = ExecutionContext.fromExecutor(executor) + + private implicit val ec: ExecutionContext = ExecutionContext.global + private val service = + new DefaultIndexService(index, hadoopConfig, vectorExtractor, resultIdConverter, resultDistanceConverter) + + // Build the gRPC server + private val server = NettyServerBuilder + .forAddress(new InetSocketAddress(InetAddress.getLocalHost, 0)) + .addService(IndexServiceGrpc.bindService(service, executionContext)) + .build() + + def start(): Unit = server.start() + + def address: InetSocketAddress = server.getListenSockets.get(0).asInstanceOf[InetSocketAddress] // TODO CLEANUP + + def awaitTermination(): Unit = { + server.awaitTermination() + } + + def shutdownNow(): Unit = { + Try(server.shutdownNow()) + Try(executor.shutdownNow()) + } +} + +class IndexServerFactory[TId, TVector, TItem <: Item[TId, TVector] with Product, TDistance]( + vectorExtractor: SearchRequest => TVector, + resultIdConverter: TId => Result.Id, + resultDistanceConverter: TDistance => Result.Distance +) extends Serializable { + + def create( + index: Index[TId, TVector, TItem, TDistance], + hadoopConfig: Configuration, + threads: Int + ): IndexServer[TId, TVector, TItem, TDistance] = { + new IndexServer[TId, TVector, TItem, TDistance]( + vectorExtractor, + resultIdConverter, + resultDistanceConverter, + index, + hadoopConfig, + threads + ) + + } +} diff --git a/hnswlib-spark/src/main/scala/com/github/jelmerk/server/registration/DefaultRegistrationService.scala b/hnswlib-spark/src/main/scala/com/github/jelmerk/server/registration/DefaultRegistrationService.scala new file mode 100644 index 0000000..03c2c72 --- /dev/null +++ b/hnswlib-spark/src/main/scala/com/github/jelmerk/server/registration/DefaultRegistrationService.scala @@ -0,0 +1,28 @@ +package com.github.jelmerk.server.registration + +import java.net.InetSocketAddress +import java.util.concurrent.{ConcurrentHashMap, CountDownLatch} + +import scala.concurrent.Future + +import com.github.jelmerk.server.registration.RegistrationServiceGrpc.RegistrationService + +class DefaultRegistrationService(val registrationLatch: CountDownLatch) extends RegistrationService { + + val registrations = new ConcurrentHashMap[PartitionAndReplica, InetSocketAddress]() + + override def register(request: RegisterRequest): Future[RegisterResponse] = { + + val key = PartitionAndReplica(request.partitionNum, request.replicaNum) + val previousValue = registrations.put(key, new InetSocketAddress(request.host, request.port)) + + if (previousValue == null) { + registrationLatch.countDown() + } + + Future.successful(RegisterResponse()) + } + +} + +case class PartitionAndReplica(partitionNum: Int, replicaNum: Int) diff --git a/hnswlib-spark/src/main/scala/com/github/jelmerk/server/registration/RegistrationClient.scala b/hnswlib-spark/src/main/scala/com/github/jelmerk/server/registration/RegistrationClient.scala new file mode 100644 index 0000000..4a0d1ae --- /dev/null +++ b/hnswlib-spark/src/main/scala/com/github/jelmerk/server/registration/RegistrationClient.scala @@ -0,0 +1,41 @@ +package com.github.jelmerk.server.registration + +import java.net.{InetSocketAddress, SocketAddress} + +import scala.concurrent.Await +import scala.concurrent.duration.Duration + +import io.grpc.netty.NettyChannelBuilder + +object RegistrationClient { + + def register( + server: SocketAddress, + partitionNo: Int, + replicaNo: Int, + indexServerAddress: InetSocketAddress + ): RegisterResponse = { + val channel = NettyChannelBuilder + .forAddress(server) + .usePlaintext + .build() + + try { + val client = RegistrationServiceGrpc.stub(channel) + + val request = RegisterRequest( + partitionNum = partitionNo, + replicaNum = replicaNo, + indexServerAddress.getHostName, + indexServerAddress.getPort + ) + + val response = client.register(request) + + Await.result(response, Duration.Inf) + } finally { + channel.shutdownNow() + } + + } +} diff --git a/hnswlib-spark/src/main/scala/com/github/jelmerk/server/registration/RegistrationServerFactory.scala b/hnswlib-spark/src/main/scala/com/github/jelmerk/server/registration/RegistrationServerFactory.scala new file mode 100644 index 0000000..6a0884b --- /dev/null +++ b/hnswlib-spark/src/main/scala/com/github/jelmerk/server/registration/RegistrationServerFactory.scala @@ -0,0 +1,47 @@ +package com.github.jelmerk.server.registration + +import java.net.{InetAddress, InetSocketAddress} +import java.util.concurrent.{CountDownLatch, Executors} + +import scala.concurrent.ExecutionContext +import scala.jdk.CollectionConverters.mapAsScalaConcurrentMapConverter +import scala.util.Try + +import io.grpc.netty.NettyServerBuilder + +// TODO this is all a bit messy +class RegistrationServer(numPartitions: Int, numReplicas: Int) { + + private val executor = Executors.newSingleThreadExecutor() + + private val executionContext: ExecutionContext = ExecutionContext.fromExecutor(executor) + + private val registrationLatch = new CountDownLatch(numPartitions + (numReplicas * numPartitions)) + private val service = new DefaultRegistrationService(registrationLatch) + // Build the gRPC server + private val server = NettyServerBuilder + .forAddress(new InetSocketAddress(InetAddress.getLocalHost, 0)) + .addService(RegistrationServiceGrpc.bindService(service, executionContext)) + .build() + + def start(): Unit = server.start() + + def address: InetSocketAddress = server.getListenSockets.get(0).asInstanceOf[InetSocketAddress] // TODO CLEANUP + + def awaitRegistrations(): Map[PartitionAndReplica, InetSocketAddress] = { + service.registrationLatch.await() + service.registrations.asScala.toMap + } + + def shutdownNow(): Unit = { + Try(server.shutdownNow()) + Try(executor.shutdownNow()) + } + +} + +object RegistrationServerFactory { + + def create(numPartitions: Int, numReplicas: Int): RegistrationServer = + new RegistrationServer(numPartitions, numReplicas) +} diff --git a/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/IndexClient.scala b/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/IndexClient.scala new file mode 100644 index 0000000..a4c3d08 --- /dev/null +++ b/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/IndexClient.scala @@ -0,0 +1,167 @@ +package com.github.jelmerk.spark.knn + +import java.net.InetSocketAddress +import java.util.concurrent.{Executors, LinkedBlockingQueue} +import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger} + +import scala.concurrent.{Await, Future} +import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent.duration.Duration +import scala.language.implicitConversions +import scala.util.Random + +import com.github.jelmerk.server.index._ +import com.github.jelmerk.server.registration.PartitionAndReplica +import io.grpc.netty.NettyChannelBuilder +import io.grpc.stub.StreamObserver + +class IndexClient[TId, TVector, TDistance]( + indexAddresses: Map[PartitionAndReplica, InetSocketAddress], + vectorConverter: TVector => SearchRequest.Vector, + idExtractor: Result => TId, + distanceExtractor: Result => TDistance, + distanceOrdering: Ordering[TDistance] +) { + + private val random = new Random() + + private val (channels, grpcClients) = indexAddresses.map { case (key, address) => + val channel = NettyChannelBuilder + .forAddress(address) + .usePlaintext + .build() + + (channel, (key, IndexServiceGrpc.stub(channel))) + }.unzip + + private val partitionClients = grpcClients.toList + .sortBy { case (partitionAndReplica, _) => (partitionAndReplica.partitionNum, partitionAndReplica.replicaNum) } + .foldLeft(Map.empty[Int, Seq[IndexServiceGrpc.IndexServiceStub]]) { + case (acc, (PartitionAndReplica(partitionNum, replicaNum), client)) => + val old = acc.getOrElse(partitionNum, Seq.empty[IndexServiceGrpc.IndexServiceStub]) + acc.updated(partitionNum, old :+ client) + } + + private val threadPool = Executors.newFixedThreadPool(1) + + def search[TQueryId]( + batch: Seq[(Seq[Int], TQueryId, TVector)], + k: Int + ): Iterator[(TQueryId, Seq[Neighbor[TId, TDistance]])] = { + + val randomClient = partitionClients.map { case (_, clients) => clients(random.nextInt(clients.size)) } + + val (requestObservers, responseIterators) = randomClient.zipWithIndex.toArray.map { case (client, partition) => + // TODO this is kind of inefficient + val partitionCount = batch.count { case (partitions, _, _) => partitions.contains(partition) } + + val responseStreamObserver = new StreamObserverAdapter[SearchResponse](partitionCount) + val requestStreamObserver = client.search(responseStreamObserver) + + (requestStreamObserver, responseStreamObserver: Iterator[SearchResponse]) + }.unzip + + threadPool.submit(new Runnable { + override def run(): Unit = { + val batchIter = batch.iterator + for { + (queryPartitions, _, vector) <- batchIter + last = !batchIter.hasNext + (observer, observerPartition) <- requestObservers.zipWithIndex + } { + if (queryPartitions.contains(observerPartition)) { + val request = SearchRequest( + vector = vectorConverter(vector), + k = k + ) + observer.onNext(request) + } + if (last) { + observer.onCompleted() + } + } + } + }) + + val expectations: Iterator[(Seq[Int], TQueryId)] = batch.map { case (partitions, id, _) => + partitions -> id + }.iterator + + new ResultsIterator(expectations, responseIterators: Array[Iterator[SearchResponse]], k) + } + + def saveIndex(path: String): Unit = { + val futures = partitionClients.flatMap { case (partition, clients) => + // only the primary replica saves the index + clients.headOption.map { client => + val request = SaveIndexRequest(s"$path/$partition") + client.saveIndex(request) + } + } + + val responses = Await.result(Future.sequence(futures), Duration.Inf) // TODO not sure if inf is smart + responses.foreach(println) // TODO remove + } + + def shutdown(): Unit = { + channels.foreach(_.shutdownNow()) + threadPool.shutdownNow() + } + + private class StreamObserverAdapter[T](expected: Int) extends StreamObserver[T] with Iterator[T] { + + private val queue = new LinkedBlockingQueue[Either[Throwable, T]] + private val counter = new AtomicInteger() + private val done = new AtomicBoolean(false) + + // ======================================== StreamObserver ======================================== + + override def onNext(value: T): Unit = { + queue.add(Right(value)) + counter.incrementAndGet() + } + + override def onError(t: Throwable): Unit = { + queue.add(Left(t)) + done.set(true) + } + + override def onCompleted(): Unit = { + done.set(true) + } + + // ========================================== Iterator ========================================== + + override def hasNext: Boolean = { + !queue.isEmpty || (counter.get() < expected && !done.get()) + } + + override def next(): T = queue.take() match { + case Right(value) => value + case Left(t) => throw t + } + } + + private class ResultsIterator[TQueryId]( + iterator: Iterator[(Seq[Int], TQueryId)], + partitionIterators: Array[Iterator[SearchResponse]], + k: Int + ) extends Iterator[(TQueryId, Seq[Neighbor[TId, TDistance]])] { + + override def hasNext: Boolean = iterator.hasNext + + override def next(): (TQueryId, Seq[Neighbor[TId, TDistance]]) = { + val (partitions, queryId) = iterator.next() + + val responses = partitions.map(partitionIterators.apply).map(_.next()) + + val allResults = for { + response <- responses + result <- response.results + } yield Neighbor(idExtractor(result), distanceExtractor(result)) + + queryId -> allResults.sortBy(_.distance)(distanceOrdering).take(k) + } + } + +} diff --git a/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/IndexClientFactory.scala b/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/IndexClientFactory.scala new file mode 100644 index 0000000..3016eb2 --- /dev/null +++ b/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/IndexClientFactory.scala @@ -0,0 +1,19 @@ +package com.github.jelmerk.spark.knn + +import java.net.InetSocketAddress + +import com.github.jelmerk.server.index.{Result, SearchRequest} +import com.github.jelmerk.server.registration.PartitionAndReplica + +class IndexClientFactory[TId, TVector, TDistance]( + vectorConverter: TVector => SearchRequest.Vector, + idExtractor: Result => TId, + distanceExtractor: Result => TDistance, + distanceOrdering: Ordering[TDistance] +) extends Serializable { + + def create(servers: Map[PartitionAndReplica, InetSocketAddress]): IndexClient[TId, TVector, TDistance] = { + new IndexClient(servers, vectorConverter, idExtractor, distanceExtractor, distanceOrdering) + } + +} diff --git a/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/KnnAlgorithm.scala b/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/KnnAlgorithm.scala index c51e9dc..397d9c3 100644 --- a/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/KnnAlgorithm.scala +++ b/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/KnnAlgorithm.scala @@ -1,33 +1,20 @@ package com.github.jelmerk.spark.knn import java.io.InputStream -import java.net.InetAddress -import java.util.concurrent.{ - CountDownLatch, - ExecutionException, - FutureTask, - LinkedBlockingQueue, - ThreadLocalRandom, - ThreadPoolExecutor, - TimeUnit -} +import java.net.{InetAddress, InetSocketAddress} -import scala.Seq -import scala.annotation.tailrec import scala.language.{higherKinds, implicitConversions} import scala.reflect.ClassTag import scala.reflect.runtime.universe._ -import scala.util.Try import scala.util.control.NonFatal -import com.github.jelmerk.knn.{Jdk17DistanceFunctions, ObjectSerializer} +import com.github.jelmerk.knn.ObjectSerializer import com.github.jelmerk.knn.scalalike._ -import com.github.jelmerk.knn.scalalike.jdk17DistanceFunctions._ -import com.github.jelmerk.knn.util.NamedThreadFactory -import com.github.jelmerk.spark.linalg.functions.VectorDistanceFunctions +import com.github.jelmerk.server.index.IndexServerFactory +import com.github.jelmerk.server.registration.{PartitionAndReplica, RegistrationClient, RegistrationServerFactory} import com.github.jelmerk.spark.util.SerializableConfiguration -import org.apache.hadoop.fs.{FileUtil, Path} -import org.apache.spark.{Partitioner, TaskContext} +import org.apache.hadoop.fs.Path +import org.apache.spark.{Partitioner, SparkContext, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.linalg.SQLDataTypes.VectorType @@ -35,50 +22,26 @@ import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasPredictionCol} import org.apache.spark.ml.util.{MLReader, MLWriter} -import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} +import org.apache.spark.rdd.RDD +import org.apache.spark.resource.{ResourceProfileBuilder, TaskResourceRequests} import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.json4s._ -import org.json4s.jackson.JsonMethods._ - -private[knn] case class IntDoubleArrayIndexItem(id: Int, vector: Array[Double]) extends Item[Int, Array[Double]] { - override def dimensions: Int = vector.length -} - -private[knn] case class LongDoubleArrayIndexItem(id: Long, vector: Array[Double]) extends Item[Long, Array[Double]] { - override def dimensions: Int = vector.length -} - -private[knn] case class StringDoubleArrayIndexItem(id: String, vector: Array[Double]) - extends Item[String, Array[Double]] { - override def dimensions: Int = vector.length -} - -private[knn] case class IntFloatArrayIndexItem(id: Int, vector: Array[Float]) extends Item[Int, Array[Float]] { - override def dimensions: Int = vector.length -} - -private[knn] case class LongFloatArrayIndexItem(id: Long, vector: Array[Float]) extends Item[Long, Array[Float]] { - override def dimensions: Int = vector.length -} - -private[knn] case class StringFloatArrayIndexItem(id: String, vector: Array[Float]) extends Item[String, Array[Float]] { - override def dimensions: Int = vector.length -} - -private[knn] case class IntVectorIndexItem(id: Int, vector: Vector) extends Item[Int, Vector] { - override def dimensions: Int = vector.size -} - -private[knn] case class LongVectorIndexItem(id: Long, vector: Vector) extends Item[Long, Vector] { - override def dimensions: Int = vector.size -} - -private[knn] case class StringVectorIndexItem(id: String, vector: Vector) extends Item[String, Vector] { - override def dimensions: Int = vector.size -} +import org.json4s.jackson.Serialization.{read, write} + +private[knn] case class ModelMetaData( + `class`: String, + timestamp: Long, + sparkVersion: String, + uid: String, + identifierType: String, + vectorType: String, + numPartitions: Int, + numReplicas: Int, + numThreads: Int, + paramMap: Map[String, Any] +) /** Neighbor of an item. * @@ -94,8 +57,119 @@ private[knn] case class StringVectorIndexItem(id: String, vector: Vector) extend */ private[knn] case class Neighbor[TId, TDistance](neighbor: TId, distance: TDistance) -/** Common params for KnnAlgorithm and KnnModel. - */ +private[knn] trait IndexType { + + /** Type of index. */ + protected type TIndex[TId, TVector, TItem <: Item[TId, TVector], TDistance] <: Index[TId, TVector, TItem, TDistance] + + protected implicit def indexClassTag[TId: ClassTag, TVector: ClassTag, TItem <: Item[ + TId, + TVector + ]: ClassTag, TDistance: ClassTag]: ClassTag[TIndex[TId, TVector, TItem, TDistance]] +} + +private[knn] trait IndexCreator extends IndexType { + + /** Create the index used to do the nearest neighbor search. + * + * @param dimensions + * dimensionality of the items stored in the index + * @param maxItemCount + * maximum number of items the index can hold + * @param distanceFunction + * the distance function + * @param distanceOrdering + * the distance ordering + * @param idSerializer + * invoked for serializing ids when saving the index + * @param itemSerializer + * invoked for serializing items when saving items + * + * @tparam TId + * type of the index item identifier + * @tparam TVector + * type of the index item vector + * @tparam TItem + * type of the index item + * @tparam TDistance + * type of distance between items + * @return + * create an index + */ + protected def createIndex[ + TId, + TVector, + TItem <: Item[TId, TVector] with Product, + TDistance + ](dimensions: Int, maxItemCount: Int, distanceFunction: DistanceFunction[TVector, TDistance])(implicit + distanceOrdering: Ordering[TDistance], + idSerializer: ObjectSerializer[TId], + itemSerializer: ObjectSerializer[TItem] + ): TIndex[TId, TVector, TItem, TDistance] +} + +private[knn] trait IndexLoader extends IndexType { + + /** Load an index + * + * @param inputStream + * InputStream to restore the index from + * @param minCapacity + * loaded index needs to have space for at least this man additional items + * + * @tparam TId + * type of the index item identifier + * @tparam TVector + * type of the index item vector + * @tparam TItem + * type of the index item + * @tparam TDistance + * type of distance between items + * @return + * create an index + */ + protected def loadIndex[TId, TVector, TItem <: Item[TId, TVector] with Product, TDistance]( + inputStream: InputStream, + minCapacity: Int + ): TIndex[TId, TVector, TItem, TDistance] +} + +private[knn] trait ModelCreator[TModel <: KnnModelBase[TModel]] { + + /** Creates the model to be returned from fitting the data. + * + * @param uid + * identifier + * @param indices + * map of index servers + * @tparam TId + * type of the index item identifier + * @tparam TVector + * type of the index item vector + * @tparam TItem + * type of the index item + * @tparam TDistance + * type of distance between items + * @return + * model + */ + protected def createModel[ + TId: TypeTag, + TVector: TypeTag, + TItem <: Item[TId, TVector] with Product: TypeTag, + TDistance: TypeTag + ]( + uid: String, + numPartitions: Int, + numReplicas: Int, + numThreads: Int, + sparkContext: SparkContext, + indices: Map[PartitionAndReplica, InetSocketAddress], + clientFactory: IndexClientFactory[TId, TVector, TDistance] + ): TModel +} + +/** Common params for KnnAlgorithm and KnnModel. */ private[knn] trait KnnModelParams extends Params with HasFeaturesCol with HasPredictionCol { /** Param for the column name for the query identifier. @@ -125,8 +199,8 @@ private[knn] trait KnnModelParams extends Params with HasFeaturesCol with HasPre /** @group getParam */ final def getK: Int = $(k) - /** Param that indicates whether to not return the a candidate when it's identifier equals the query identifier - * Default: false + /** Param that indicates whether to not return the candidate when it's identifier equals the query identifier Default: + * false * * @group param */ @@ -146,26 +220,6 @@ private[knn] trait KnnModelParams extends Params with HasFeaturesCol with HasPre /** @group getParam */ final def getSimilarityThreshold: Double = $(similarityThreshold) - /** Param that specifies the number of index replicas to create when querying the index. More replicas means you can - * execute more queries in parallel at the expense of increased resource usage. Default: 0 - * - * @group param - */ - final val numReplicas = new IntParam(this, "numReplicas", "number of index replicas to create when querying") - - /** @group getParam */ - final def getNumReplicas: Int = $(numReplicas) - - /** Param that specifies the number of threads to use. Default: number of processors available to the Java virtual - * machine - * - * @group param - */ - final val parallelism = new IntParam(this, "parallelism", "number of threads to use") - - /** @group getParam */ - final def getParallelism: Int = $(parallelism) - /** Param for the output format to produce. One of "full", "minimal" Setting this to minimal is more efficient when * all you need is the identifier with its neighbors * @@ -173,7 +227,12 @@ private[knn] trait KnnModelParams extends Params with HasFeaturesCol with HasPre * * @group param */ - final val outputFormat = new Param[String](this, "outputFormat", "output format to produce") + final val outputFormat = new Param[String]( + this, + "outputFormat", + "output format to produce", + ParamValidators.inArray(Array("full", "minimal")) + ) /** @group getParam */ final def getOutputFormat: String = $(outputFormat) @@ -230,13 +289,33 @@ private[knn] trait KnnAlgorithmParams extends KnnModelParams { /** @group getParam */ final def getIdentifierCol: String = $(identifierCol) - /** Number of partitions (default: 1) + /** Number of partitions */ final val numPartitions = new IntParam(this, "numPartitions", "number of partitions", ParamValidators.gt(0)) /** @group getParam */ final def getNumPartitions: Int = $(numPartitions) + /** Param that specifies the number of index replicas to create when querying the index. More replicas means you can + * execute more queries in parallel at the expense of increased resource usage. Default: 0 + * + * @group param + */ + final val numReplicas = + new IntParam(this, "numReplicas", "number of index replicas to create when querying", ParamValidators.gtEq(0)) + + /** @group getParam */ + final def getNumReplicas: Int = $(numReplicas) + + /** Param that specifies the number of threads to use. + * + * @group param + */ + final val numThreads = new IntParam(this, "numThreads", "number of threads to use per index", ParamValidators.gt(0)) + + /** @group getParam */ + final def getNumThreads: Int = $(numThreads) + /** Param for the distance function to use. One of "bray-curtis", "canberra", "cosine", "correlation", "euclidean", * "inner-product", "manhattan" or the fully qualified classname of a distance function Default: "cosine" * @@ -254,14 +333,18 @@ private[knn] trait KnnAlgorithmParams extends KnnModelParams { /** @group getParam */ final def getPartitionCol: String = $(partitionCol) - /** Param to the initial model. All the vectors from the initial model will included in the final output model. + /** Param to the initial model. All the vectors from the initial model will be included in the final output model. */ final val initialModelPath = new Param[String](this, "initialModelPath", "path to the initial model") /** @group getParam */ final def getInitialModelPath: String = $(initialModelPath) - setDefault(identifierCol -> "id", distanceFunction -> "cosine", numPartitions -> 1, numReplicas -> 0) + setDefault(identifierCol -> "id", distanceFunction -> "cosine", numReplicas -> 0) +} + +object KnnModelWriter { + private implicit val format: Formats = DefaultFormats.withLong } /** Persists a knn model. @@ -292,61 +375,41 @@ private[knn] class KnnModelWriter[ ](instance: TModel with KnnModelOps[TModel, TId, TVector, TItem, TDistance, TIndex]) extends MLWriter { + import KnnModelWriter._ + override protected def saveImpl(path: String): Unit = { - val params = JObject( - instance - .extractParamMap() - .toSeq - .toList - // cannot use parse because of incompatibilities between json4s 3.2.11 used by spark 2.3 and 3.6.6 used by spark 2.4 - .map { case ParamPair(param, value) => - val fieldName = param.name - val fieldValue = mapper.readValue(param.jsonEncode(value), classOf[JValue]) - JField(fieldName, fieldValue) - } - ) - val metaData = JObject( - List( - JField("class", JString(instance.getClass.getName)), - JField("timestamp", JLong(System.currentTimeMillis())), - JField("sparkVersion", JString(sc.version)), - JField("uid", JString(instance.uid)), - JField("identifierType", JString(typeDescription[TId])), - JField("vectorType", JString(typeDescription[TVector])), - JField("partitions", JInt(instance.getNumPartitions)), - JField("paramMap", params) - ) + val metadata = ModelMetaData( + `class` = instance.getClass.getName, + timestamp = System.currentTimeMillis(), + sparkVersion = sc.version, + uid = instance.uid, + identifierType = typeDescription[TId], + vectorType = typeDescription[TVector], + numPartitions = instance.numPartitions, + numReplicas = instance.numReplicas, + numThreads = instance.numThreads, + paramMap = toMap(instance.extractParamMap()) ) val metadataPath = new Path(path, "metadata").toString - sc.parallelize(Seq(compact(metaData)), numSlices = 1).saveAsTextFile(metadataPath) + sc.parallelize(Seq(write(metadata)), numSlices = 1).saveAsTextFile(metadataPath) val indicesPath = new Path(path, "indices").toString - val modelOutputDir = instance.outputDir - - val serializableConfiguration = new SerializableConfiguration(sc.hadoopConfiguration) - - sc.range(start = 0, end = instance.getNumPartitions).foreach { partitionId => - val originPath = new Path(modelOutputDir, partitionId.toString) - val originFileSystem = originPath.getFileSystem(serializableConfiguration.value) - - if (originFileSystem.exists(originPath)) { - val destinationPath = new Path(indicesPath, partitionId.toString) - val destinationFileSystem = destinationPath.getFileSystem(serializableConfiguration.value) - FileUtil.copy( - originFileSystem, - originPath, - destinationFileSystem, - destinationPath, - false, - serializableConfiguration.value - ) - } + val client = instance.clientFactory.create(instance.indexAddresses) + try { + client.saveIndex(indicesPath) + } finally { + client.shutdown() } } + private def toMap(paramMap: ParamMap): Map[String, Any] = + paramMap.toSeq.map { case ParamPair(param, value) => param.name -> value }.toMap + + // TODO should i make this an implicit like elsewhere + private def typeDescription[T: TypeTag] = typeOf[T] match { case t if t =:= typeOf[Int] => "int" case t if t =:= typeOf[Long] => "long" @@ -358,17 +421,23 @@ private[knn] class KnnModelWriter[ } } +object KnnModelReader { + private implicit val format: Formats = DefaultFormats.withLong +} + /** Reads a knn model from persistent storage. * - * @param ev - * classtag * @tparam TModel * type of model */ -private[knn] abstract class KnnModelReader[TModel <: KnnModelBase[TModel]](implicit ev: ClassTag[TModel]) - extends MLReader[TModel] { +private[knn] abstract class KnnModelReader[TModel <: KnnModelBase[TModel]: ClassTag] + extends MLReader[TModel] + with IndexLoader + with IndexServing + with ModelCreator[TModel] + with Serializable { - private implicit val format: Formats = DefaultFormats + import KnnModelReader._ override def load(path: String): TModel = { @@ -376,83 +445,95 @@ private[knn] abstract class KnnModelReader[TModel <: KnnModelBase[TModel]](impli val metadataStr = sc.textFile(metadataPath, 1).first() - // cannot use parse because of incompatibilities between json4s 3.2.11 used by spark 2.3 and 3.6.6 used by spark 2.4 - val metadata = mapper.readValue(metadataStr, classOf[JValue]) + val metadata = read[ModelMetaData](metadataStr) - val uid = (metadata \ "uid").extract[String] - - val identifierType = (metadata \ "identifierType").extract[String] - val vectorType = (metadata \ "vectorType").extract[String] - val partitions = (metadata \ "partitions").extract[Int] - - val paramMap = (metadata \ "paramMap").extract[JObject] - - val indicesPath = new Path(path, "indices").toString - - val model = (identifierType, vectorType) match { + val model = (metadata.identifierType, metadata.vectorType) match { case ("int", "float_array") => - createModel[Int, Array[Float], IntFloatArrayIndexItem, Float](uid, indicesPath, partitions) + typedLoad[Int, Array[Float], IntFloatArrayIndexItem, Float](path, metadata) case ("int", "double_array") => - createModel[Int, Array[Double], IntDoubleArrayIndexItem, Double](uid, indicesPath, partitions) - case ("int", "vector") => createModel[Int, Vector, IntVectorIndexItem, Double](uid, indicesPath, partitions) + typedLoad[Int, Array[Double], IntDoubleArrayIndexItem, Double](path, metadata) + case ("int", "vector") => + typedLoad[Int, Vector, IntVectorIndexItem, Double](path, metadata) case ("long", "float_array") => - createModel[Long, Array[Float], LongFloatArrayIndexItem, Float](uid, indicesPath, partitions) + typedLoad[Long, Array[Float], LongFloatArrayIndexItem, Float](path, metadata) case ("long", "double_array") => - createModel[Long, Array[Double], LongDoubleArrayIndexItem, Double](uid, indicesPath, partitions) - case ("long", "vector") => createModel[Long, Vector, LongVectorIndexItem, Double](uid, indicesPath, partitions) + typedLoad[Long, Array[Double], LongDoubleArrayIndexItem, Double](path, metadata) + case ("long", "vector") => + typedLoad[Long, Vector, LongVectorIndexItem, Double](path, metadata) case ("string", "float_array") => - createModel[String, Array[Float], StringFloatArrayIndexItem, Float](uid, indicesPath, partitions) + typedLoad[String, Array[Float], StringFloatArrayIndexItem, Float](path, metadata) case ("string", "double_array") => - createModel[String, Array[Double], StringDoubleArrayIndexItem, Double](uid, indicesPath, partitions) + typedLoad[String, Array[Double], StringDoubleArrayIndexItem, Double](path, metadata) case ("string", "vector") => - createModel[String, Vector, StringVectorIndexItem, Double](uid, indicesPath, partitions) - case _ => + typedLoad[String, Vector, StringVectorIndexItem, Double](path, metadata) + case (identifierType, vectorType) => throw new IllegalStateException( s"Cannot create model for identifier type $identifierType and vector type $vectorType." ) } - paramMap.obj.foreach { case (paramName, jsonValue) => + model + + } + + private def typedLoad[ + TId: TypeTag: ClassTag, + TVector: TypeTag: ClassTag, + TItem <: Item[TId, TVector] with Product: TypeTag: ClassTag, + TDistance: TypeTag: ClassTag + ](path: String, metadata: ModelMetaData)(implicit + indexServerFactory: IndexServerFactory[TId, TVector, TItem, TDistance], + clientFactory: IndexClientFactory[TId, TVector, TDistance] + ): TModel = { + + val indicesPath = new Path(path, "indices") + + val taskReqs = new TaskResourceRequests().cpus(metadata.numThreads) + val profile = new ResourceProfileBuilder().require(taskReqs).build() + + val serializableConfiguration = new SerializableConfiguration(sc.hadoopConfiguration) + + val partitionPaths = (0 until metadata.numPartitions).map { partitionId => + partitionId -> new Path(indicesPath, partitionId.toString) + } + + val indexRdd = sc + .makeRDD(partitionPaths) + .partitionBy(new PartitionIdPassthrough(metadata.numPartitions)) + .withResources(profile) + .mapPartitions { it => + val (partitionId, indexPath) = it.next() + val fs = indexPath.getFileSystem(serializableConfiguration.value) + + logInfo(partitionId, s"Loading index from $indexPath") + val inputStream = fs.open(indexPath) + val index = loadIndex[TId, TVector, TItem, TDistance](inputStream, 0) + logInfo(partitionId, s"Finished loading index from $indexPath") + Iterator(index) + } + + val servers = serve(metadata.uid, indexRdd, metadata.numPartitions, metadata.numReplicas, metadata.numThreads) + + val model = createModel( + metadata.uid, + metadata.numPartitions, + metadata.numReplicas, + metadata.numThreads, + sc, + servers, + clientFactory + ) + + metadata.paramMap.foreach { case (paramName, value) => val param = model.getParam(paramName) - model.set(param, param.jsonDecode(compact(render(jsonValue)))) + model.set(param, value) } model } - /** Creates the model to be returned from fitting the data. - * - * @param uid - * identifier - * @param outputDir - * directory containing the persisted indices - * @param numPartitions - * number of index partitions - * - * @tparam TId - * type of the index item identifier - * @tparam TVector - * type of the index item vector - * @tparam TItem - * type of the index item - * @tparam TDistance - * type of distance between items - * @return - * model - */ - protected def createModel[ - TId: TypeTag, - TVector: TypeTag, - TItem <: Item[TId, TVector] with Product: TypeTag, - TDistance: TypeTag - ](uid: String, outputDir: String, numPartitions: Int)(implicit - ev: ClassTag[TId], - evVector: ClassTag[TVector], - distanceNumeric: Numeric[TDistance] - ): TModel - } /** Base class for nearest neighbor search models. @@ -462,9 +543,7 @@ private[knn] abstract class KnnModelReader[TModel <: KnnModelBase[TModel]](impli */ private[knn] abstract class KnnModelBase[TModel <: KnnModelBase[TModel]] extends Model[TModel] with KnnModelParams { - private[knn] def outputDir: String - - def getNumPartitions: Int + private[knn] def sparkContext: SparkContext /** @group setParam */ def setQueryIdentifierCol(value: String): this.type = set(queryIdentifierCol, value) @@ -487,15 +566,16 @@ private[knn] abstract class KnnModelBase[TModel <: KnnModelBase[TModel]] extends /** @group setParam */ def setSimilarityThreshold(value: Double): this.type = set(similarityThreshold, value) - /** @group setParam */ - def setNumReplicas(value: Int): this.type = set(numReplicas, value) - - /** @group setParam */ - def setParallelism(value: Int): this.type = set(parallelism, value) - /** @group setParam */ def setOutputFormat(value: String): this.type = set(outputFormat, value) + override def finalize(): Unit = { + destroy() + } + + def destroy(): Unit = { + sparkContext.cancelJobGroup(uid) + } } /** Contains the core knn search logic @@ -523,15 +603,22 @@ private[knn] trait KnnModelOps[ ] { this: TModel with KnnModelParams => + private[knn] def numPartitions: Int + + private[knn] def numReplicas: Int + + private[knn] def numThreads: Int + + private[knn] def indexAddresses: Map[PartitionAndReplica, InetSocketAddress] + + private[knn] def clientFactory: IndexClientFactory[TId, TVector, TDistance] + protected def loadIndex(in: InputStream): TIndex protected def typedTransform(dataset: Dataset[_])(implicit tId: TypeTag[TId], tVector: TypeTag[TVector], - tDistance: TypeTag[TDistance], - evId: ClassTag[TId], - evVector: ClassTag[TVector], - distanceNumeric: Numeric[TDistance] + tDistance: TypeTag[TDistance] ): DataFrame = { if (!isSet(queryIdentifierCol) && getExcludeSelf) { @@ -544,229 +631,29 @@ private[knn] trait KnnModelOps[ .drop("_query_id") } - protected def typedTransformWithQueryCol[TQueryId](dataset: Dataset[_], queryIdCol: String)(implicit + private def typedTransformWithQueryCol[TQueryId](dataset: Dataset[_], queryIdCol: String)(implicit tId: TypeTag[TId], tVector: TypeTag[TVector], tDistance: TypeTag[TDistance], - tQueryId: TypeTag[TQueryId], - evId: ClassTag[TId], - evVector: ClassTag[TVector], - evQueryId: ClassTag[TQueryId], - distanceNumeric: Numeric[TDistance] + tQueryId: TypeTag[TQueryId] ): DataFrame = { import dataset.sparkSession.implicits._ - import distanceNumeric._ - - implicit val encoder: Encoder[TQueryId] = ExpressionEncoder() - implicit val neighborOrdering: Ordering[Neighbor[TId, TDistance]] = Ordering.by(_.distance) - - val serializableHadoopConfiguration = new SerializableConfiguration( - dataset.sparkSession.sparkContext.hadoopConfiguration - ) // construct the queries to the distributed indices. when query partitions are specified we only query those partitions // otherwise we query all partitions val logicalPartitionAndQueries = if (isDefined(queryPartitionsCol)) - dataset - .select( - col(getQueryPartitionsCol), - col(queryIdCol), - col(getFeaturesCol) - ) - .as[(Seq[Int], TQueryId, TVector)] - .rdd - .flatMap { case (queryPartitions, queryId, vector) => - queryPartitions.map { partition => (partition, (queryId, vector)) } - } - else - dataset - .select( - col(queryIdCol), - col(getFeaturesCol) - ) - .as[(TQueryId, TVector)] - .rdd - .flatMap { case (queryId, vector) => - Range(0, getNumPartitions).map { partition => - (partition, (queryId, vector)) - } - } + dataset.select(col(getQueryPartitionsCol), col(queryIdCol), col(getFeaturesCol)) + else dataset.select(sequence(lit(0), lit(numPartitions - 1)), col(queryIdCol), col(getFeaturesCol)) - val numPartitionCopies = getNumReplicas + 1 + val localIndexAddr = indexAddresses + val localClientFactory = clientFactory + val k = getK - val physicalPartitionAndQueries = logicalPartitionAndQueries - .map { case (partition, (queryId, vector)) => - val randomCopy = ThreadLocalRandom.current().nextInt(numPartitionCopies) - val physicalPartition = (partition * numPartitionCopies) + randomCopy - (physicalPartition, (queryId, vector)) - } - .partitionBy(new PartitionIdPassthrough(getNumPartitions * numPartitionCopies)) - - val numThreads = - if (isSet(parallelism) && getParallelism <= 0) sys.runtime.availableProcessors - else if (isSet(parallelism)) getParallelism - else dataset.sparkSession.sparkContext.getConf.getInt("spark.task.cpus", defaultValue = 1) - - val neighborsOnAllQueryPartitions = physicalPartitionAndQueries - .mapPartitions { queriesWithPartition => - val queries = queriesWithPartition.map(_._2) - - // load the partitioned index and execute all queries. - - val physicalPartitionId = TaskContext.getPartitionId() - - val logicalPartitionId = physicalPartitionId / numPartitionCopies - val replica = physicalPartitionId % numPartitionCopies - - val indexPath = new Path(outputDir, logicalPartitionId.toString) - - val fileSystem = indexPath.getFileSystem(serializableHadoopConfiguration.value) - - if (!fileSystem.exists(indexPath)) Iterator.empty - else { - - logInfo( - logicalPartitionId, - replica, - s"started loading index from $indexPath on host ${InetAddress.getLocalHost.getHostName}" - ) - val index = loadIndex(fileSystem.open(indexPath)) - logInfo( - logicalPartitionId, - replica, - s"finished loading index from $indexPath on host ${InetAddress.getLocalHost.getHostName}" - ) - - // execute queries in parallel on multiple threads - new Iterator[(TQueryId, Seq[Neighbor[TId, TDistance]])] { - - private[this] var first = true - private[this] var count = 0 - - private[this] val batchSize = 1000 - private[this] val queue = - new LinkedBlockingQueue[(TQueryId, Seq[Neighbor[TId, TDistance]])](batchSize * numThreads) - private[this] val executorService = new ThreadPoolExecutor( - numThreads, - numThreads, - 60L, - TimeUnit.SECONDS, - new LinkedBlockingQueue[Runnable], - new NamedThreadFactory("searcher-%d") - ) { - override def afterExecute(r: Runnable, t: Throwable): Unit = { - super.afterExecute(r, t) - - Option(t) - .orElse { - r match { - case t: FutureTask[_] => - Try(t.get()).failed.toOption.map { - case e: ExecutionException => e.getCause - case e: InterruptedException => - Thread.currentThread().interrupt() - e - case NonFatal(e) => e - } - case _ => None - } - } - .foreach { e => - logError("Error in worker.", e) - } - } - } - executorService.allowCoreThreadTimeOut(true) - - private[this] val activeWorkers = new CountDownLatch(numThreads) - Range(0, numThreads).map(_ => new Worker(queries, activeWorkers, batchSize)).foreach(executorService.submit) - - override def hasNext: Boolean = { - if (!queue.isEmpty) true - else if (queries.synchronized { queries.hasNext }) true - else { - // in theory all workers could have just picked up the last new work but not started processing any of it. - if (!activeWorkers.await(2, TimeUnit.MINUTES)) { - throw new IllegalStateException("Workers failed to complete.") - } - !queue.isEmpty - } - } - - override def next(): (TQueryId, Seq[Neighbor[TId, TDistance]]) = { - if (first) { - logInfo( - logicalPartitionId, - replica, - s"started querying on host ${InetAddress.getLocalHost.getHostName} with ${sys.runtime.availableProcessors} available processors." - ) - first = false - } - - val value = queue.poll(1, TimeUnit.MINUTES) - - count += 1 - - if (!hasNext) { - logInfo( - logicalPartitionId, - replica, - s"finished querying $count items on host ${InetAddress.getLocalHost.getHostName}" - ) - - executorService.shutdown() - } - - value - } - - class Worker(queries: Iterator[(TQueryId, TVector)], activeWorkers: CountDownLatch, batchSize: Int) - extends Runnable { - - private[this] var work = List.empty[(TQueryId, TVector)] - - private[this] val fetchSize = - if (getExcludeSelf) getK + 1 - else getK - - @tailrec final override def run(): Unit = { - - work.foreach { case (id, vector) => - val neighbors = index - .findNearest(vector, fetchSize) - .collect { - case SearchResult(item, distance) - if (!getExcludeSelf || item.id != id) && (getSimilarityThreshold < 0 || distance.toDouble < getSimilarityThreshold) => - Neighbor[TId, TDistance](item.id, distance) - } - - queue.put(id -> neighbors) - } - - work = queries.synchronized { - queries.take(batchSize).toList - } - - if (work.nonEmpty) { - run() - } else { - activeWorkers.countDown() - } - } - } - } - } - } - .toDS() - - // take the best k results from all partitions - - val topNeighbors = neighborsOnAllQueryPartitions - .groupByKey { case (queryId, _) => queryId } - .flatMapGroups { (queryId, groups) => - val allNeighbors = groups.flatMap { case (_, neighbors) => neighbors }.toList - Iterator.single(queryId -> allNeighbors.sortBy(_.distance).take(getK)) + val topNeighbors = logicalPartitionAndQueries + .as[(Seq[Int], TQueryId, TVector)] + .mapPartitions { it => + new QueryIterator(localIndexAddr, localClientFactory, it, batchSize = 100, k) } .toDF(queryIdCol, getPredictionCol) @@ -783,27 +670,16 @@ private[knn] trait KnnModelOps[ validateAndTransformSchema(schema, idDataType) } - private def logInfo(partition: Int, replica: Int, message: String): Unit = - logInfo(f"partition $partition%04d replica $replica%04d: $message") - } private[knn] abstract class KnnAlgorithm[TModel <: KnnModelBase[TModel]](override val uid: String) extends Estimator[TModel] - with KnnAlgorithmParams { - - /** Type of index. - * - * @tparam TId - * Type of the external identifier of an item - * @tparam TVector - * Type of the vector to perform distance calculation on - * @tparam TItem - * Type of items stored in the index - * @tparam TDistance - * Type of distance between items (expect any numeric type: float, double, int, ..) - */ - protected type TIndex[TId, TVector, TItem <: Item[TId, TVector], TDistance] <: Index[TId, TVector, TItem, TDistance] + with ModelLogging + with KnnAlgorithmParams + with IndexCreator + with IndexLoader + with ModelCreator[TModel] + with IndexServing { /** @group setParam */ def setIdentifierCol(value: String): this.type = set(identifierCol, value) @@ -842,7 +718,7 @@ private[knn] abstract class KnnAlgorithm[TModel <: KnnModelBase[TModel]](overrid def setNumReplicas(value: Int): this.type = set(numReplicas, value) /** @group setParam */ - def setParallelism(value: Int): this.type = set(parallelism, value) + def setNumThreads(value: Int): this.type = set(numThreads, value) /** @group setParam */ def setOutputFormat(value: String): this.type = set(outputFormat, value) @@ -884,126 +760,26 @@ private[knn] abstract class KnnAlgorithm[TModel <: KnnModelBase[TModel]](overrid override def copy(extra: ParamMap): Estimator[TModel] = defaultCopy(extra) - /** Create the index used to do the nearest neighbor search. - * - * @param dimensions - * dimensionality of the items stored in the index - * @param maxItemCount - * maximum number of items the index can hold - * @param distanceFunction - * the distance function - * @param distanceOrdering - * the distance ordering - * @param idSerializer - * invoked for serializing ids when saving the index - * @param itemSerializer - * invoked for serializing items when saving items - * - * @tparam TId - * type of the index item identifier - * @tparam TVector - * type of the index item vector - * @tparam TItem - * type of the index item - * @tparam TDistance - * type of distance between items - * @return - * create an index - */ - protected def createIndex[ - TId, - TVector, - TItem <: Item[TId, TVector] with Product, - TDistance - ](dimensions: Int, maxItemCount: Int, distanceFunction: DistanceFunction[TVector, TDistance])(implicit - distanceOrdering: Ordering[TDistance], - idSerializer: ObjectSerializer[TId], - itemSerializer: ObjectSerializer[TItem] - ): TIndex[TId, TVector, TItem, TDistance] - - /** Load an index - * - * @param inputStream - * InputStream to restore the index from - * @param minCapacity - * loaded index needs to have space for at least this man additional items - * - * @tparam TId - * type of the index item identifier - * @tparam TVector - * type of the index item vector - * @tparam TItem - * type of the index item - * @tparam TDistance - * type of distance between items - * @return - * create an index - */ - protected def loadIndex[TId, TVector, TItem <: Item[TId, TVector] with Product, TDistance]( - inputStream: InputStream, - minCapacity: Int - ): TIndex[TId, TVector, TItem, TDistance] - - /** Creates the model to be returned from fitting the data. - * - * @param uid - * identifier - * @param outputDir - * directory containing the persisted indices - * @param numPartitions - * number of index partitions - * - * @tparam TId - * type of the index item identifier - * @tparam TVector - * type of the index item vector - * @tparam TItem - * type of the index item - * @tparam TDistance - * type of distance between items - * @return - * model - */ - protected def createModel[ - TId: TypeTag, - TVector: TypeTag, - TItem <: Item[TId, TVector] with Product: TypeTag, - TDistance: TypeTag - ](uid: String, outputDir: String, numPartitions: Int)(implicit - ev: ClassTag[TId], - evVector: ClassTag[TVector], - distanceNumeric: Numeric[TDistance] - ): TModel - private def typedFit[ - TId: TypeTag, - TVector: TypeTag, - TItem <: Item[TId, TVector] with Product: TypeTag, - TDistance: TypeTag + TId: TypeTag: ClassTag, + TVector: TypeTag: ClassTag, + TItem <: Item[TId, TVector] with Product: TypeTag: ClassTag, + TDistance: TypeTag: ClassTag ](dataset: Dataset[_])(implicit - ev: ClassTag[TId], - evVector: ClassTag[TVector], - evItem: ClassTag[TItem], distanceNumeric: Numeric[TDistance], distanceFunctionFactory: String => DistanceFunction[TVector, TDistance], idSerializer: ObjectSerializer[TId], - itemSerializer: ObjectSerializer[TItem] + itemSerializer: ObjectSerializer[TItem], + indexServerFactory: IndexServerFactory[TId, TVector, TItem, TDistance], + indexClientFactory: IndexClientFactory[TId, TVector, TDistance] ): TModel = { val sc = dataset.sparkSession val sparkContext = sc.sparkContext - val serializableHadoopConfiguration = new SerializableConfiguration(sparkContext.hadoopConfiguration) - import sc.implicits._ - val cacheFolder = sparkContext.getConf.get(key = "spark.hnswlib.settings.index.cache_folder", defaultValue = "/tmp") - - val outputDir = new Path(cacheFolder, s"${uid}_${System.currentTimeMillis()}").toString - - sparkContext.addSparkListener(new CleanupListener(outputDir, serializableHadoopConfiguration)) - - // read the id and vector from the input dataset and and repartition them over numPartitions amount of partitions. + // read the id and vector from the input dataset, repartition them over numPartitions amount of partitions. // if the data is pre-partitioned by the user repartition the input data by the user defined partition key, use a // hash of the item id otherwise. val partitionedIndexItems = { @@ -1017,21 +793,18 @@ private[knn] abstract class KnnAlgorithm[TModel <: KnnModelBase[TModel]](overrid .rdd .partitionBy(new PartitionIdPassthrough(getNumPartitions)) .values - .toDS else dataset .select(col(getIdentifierCol).as("id"), col(getFeaturesCol).as("vector")) .as[TItem] - .repartition(getNumPartitions, $"id") + .rdd + .repartition(getNumPartitions) } // On each partition collect all the items into memory and construct the HNSW indices. // Save these indices to the hadoop filesystem - val numThreads = - if (isSet(parallelism) && getParallelism <= 0) sys.runtime.availableProcessors - else if (isSet(parallelism)) getParallelism - else dataset.sparkSession.sparkContext.getConf.getInt("spark.task.cpus", defaultValue = 1) + val numThreads = getNumThreads val initialModelOutputDir = if (isSet(initialModelPath)) Some(new Path(getInitialModelPath, "indices").toString) @@ -1039,14 +812,19 @@ private[knn] abstract class KnnAlgorithm[TModel <: KnnModelBase[TModel]](overrid val serializableConfiguration = new SerializableConfiguration(sparkContext.hadoopConfiguration) - partitionedIndexItems - .foreachPartition { it: Iterator[TItem] => - if (it.hasNext) { - val partitionId = TaskContext.getPartitionId() + val taskReqs = new TaskResourceRequests().cpus(numThreads) + val profile = new ResourceProfileBuilder().require(taskReqs).build() + + val numReplicas = getNumReplicas + + val indexRdd = partitionedIndexItems + .withResources(profile) // TODO JELMER is this the correct place ? + .mapPartitions( + (it: Iterator[TItem]) => { val items = it.toSeq - logInfo(partitionId, s"started indexing ${items.size} items on host ${InetAddress.getLocalHost.getHostName}") + val partitionId = TaskContext.getPartitionId() val existingIndexOption = initialModelOutputDir .flatMap { dir => @@ -1054,6 +832,7 @@ private[knn] abstract class KnnAlgorithm[TModel <: KnnModelBase[TModel]](overrid val fs = indexPath.getFileSystem(serializableConfiguration.value) if (fs.exists(indexPath)) Some { + logInfo(partitionId, s"Loading existing index from $indexPath") val inputStream = fs.open(indexPath) loadIndex[TId, TVector, TItem, TDistance](inputStream, items.size) } @@ -1063,6 +842,8 @@ private[knn] abstract class KnnAlgorithm[TModel <: KnnModelBase[TModel]](overrid } } + logInfo(partitionId, s"started indexing ${items.size} items on host ${InetAddress.getLocalHost.getHostName}") + val index = existingIndexOption .getOrElse( createIndex[TId, TVector, TItem, TDistance]( @@ -1079,97 +860,147 @@ private[knn] abstract class KnnAlgorithm[TModel <: KnnModelBase[TModel]](overrid numThreads = numThreads ) - logInfo(partitionId, s"finished indexing ${items.size} items on host ${InetAddress.getLocalHost.getHostName}") + logInfo(partitionId, s"finished indexing ${items.size} items") - val path = new Path(outputDir, partitionId.toString) - val fileSystem = path.getFileSystem(serializableHadoopConfiguration.value) + Iterator(index) + }, + preservesPartitioning = true + ) - val outputStream = fileSystem.create(path) + val modelUid = uid + "_" + System.currentTimeMillis().toString - logInfo(partitionId, s"started saving index to $path on host ${InetAddress.getLocalHost.getHostName}") + val registrations = + serve[TId, TVector, TItem, TDistance](modelUid, indexRdd, getNumPartitions, getNumReplicas, numThreads) + println(registrations) // TODO remove this - index.save(outputStream) + createModel[TId, TVector, TItem, TDistance]( + modelUid, + getNumPartitions, + getNumReplicas, + getNumThreads, + sparkContext, + registrations, + indexClientFactory + ) + } - logInfo(partitionId, s"finished saving index to $path on host ${InetAddress.getLocalHost.getHostName}") - } - } +} - createModel[TId, TVector, TItem, TDistance](uid, outputDir, getNumPartitions) - } +/** Partitioner that uses precomputed partitions + * + * @param numPartitions + * number of partitions + */ +private[knn] class PartitionIdPassthrough(override val numPartitions: Int) extends Partitioner { + override def getPartition(key: Any): Int = key.asInstanceOf[Int] +} - private def logInfo(partition: Int, message: String): Unit = logInfo(f"partition $partition%04d: $message") - - implicit private def floatArrayDistanceFunction(name: String): DistanceFunction[Array[Float], Float] = - (name, vectorApiAvailable) match { - case ("bray-curtis", true) => vectorFloat128BrayCurtisDistance - case ("bray-curtis", _) => floatBrayCurtisDistance - case ("canberra", true) => vectorFloat128CanberraDistance - case ("canberra", _) => floatCanberraDistance - case ("correlation", _) => floatCorrelationDistance - case ("cosine", true) => vectorFloat128CosineDistance - case ("cosine", _) => floatCosineDistance - case ("euclidean", true) => vectorFloat128EuclideanDistance - case ("euclidean", _) => floatEuclideanDistance - case ("inner-product", true) => vectorFloat128InnerProduct - case ("inner-product", _) => floatInnerProduct - case ("manhattan", true) => vectorFloat128ManhattanDistance - case ("manhattan", _) => floatManhattanDistance - case (value, _) => userDistanceFunction(value) - } +// TODO rename this +private[knn] class PartitionReplicaPartitioner(partitions: Int, replicas: Int) extends Partitioner { - implicit private def doubleArrayDistanceFunction(name: String): DistanceFunction[Array[Double], Double] = name match { - case "bray-curtis" => doubleBrayCurtisDistance - case "canberra" => doubleCanberraDistance - case "correlation" => doubleCorrelationDistance - case "cosine" => doubleCosineDistance - case "euclidean" => doubleEuclideanDistance - case "inner-product" => doubleInnerProduct - case "manhattan" => doubleManhattanDistance - case value => userDistanceFunction(value) - } + override def numPartitions: Int = partitions + (replicas * partitions) - implicit private def vectorDistanceFunction(name: String): DistanceFunction[Vector, Double] = name match { - case "bray-curtis" => VectorDistanceFunctions.brayCurtisDistance - case "canberra" => VectorDistanceFunctions.canberraDistance - case "correlation" => VectorDistanceFunctions.correlationDistance - case "cosine" => VectorDistanceFunctions.cosineDistance - case "euclidean" => VectorDistanceFunctions.euclideanDistance - case "inner-product" => VectorDistanceFunctions.innerProduct - case "manhattan" => VectorDistanceFunctions.manhattanDistance - case value => userDistanceFunction(value) + override def getPartition(key: Any): Int = { + val (partition, replica) = key.asInstanceOf[(Int, Int)] + partition + (replica * partitions) } +} - private def vectorApiAvailable: Boolean = try { - val _ = Jdk17DistanceFunctions.VECTOR_FLOAT_128_COSINE_DISTANCE - true - } catch { - case _: Throwable => false +private[knn] class IndexRunnable(uid: String, sparkContext: SparkContext, indexRdd: RDD[_]) extends Runnable { + + override def run(): Unit = { + sparkContext.setJobGroup(uid, "job group that holds the indices") + try { + indexRdd.count() + } catch { + case NonFatal(_) => + // todo should i log this or not ? it produces a lot of noise + () + } finally { + sparkContext.clearJobGroup() + } } - private def userDistanceFunction[TVector, TDistance](name: String): DistanceFunction[TVector, TDistance] = - Try(Class.forName(name).getDeclaredConstructor().newInstance()).toOption - .collect { case f: DistanceFunction[TVector @unchecked, TDistance @unchecked] => f } - .getOrElse(throw new IllegalArgumentException(s"$name is not a valid distance functions.")) } -private[knn] class CleanupListener(dir: String, serializableConfiguration: SerializableConfiguration) - extends SparkListener - with Logging { - override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = { +private[knn] trait ModelLogging extends Logging { +// protected def logInfo(partition: Int, message: String): Unit = logInfo(f"partition $partition%04d: $message") + protected def logInfo(partition: Int, message: String): Unit = println(f"partition $partition%04d: $message") - val path = new Path(dir) - val fileSystem = path.getFileSystem(serializableConfiguration.value) - - logInfo(s"Deleting files below $dir") - fileSystem.delete(path, true) - } +// protected def logInfo(partition: Int, replica: Int, message: String): Unit = logInfo(f"partition $partition%04d replica $partition%04d: $message") + protected def logInfo(partition: Int, replica: Int, message: String): Unit = println(f"partition $partition%04d replica $partition%04d: $message") } -/** Partitioner that uses precomputed partitions - * - * @param numPartitions - * number of partitions - */ -private[knn] class PartitionIdPassthrough(override val numPartitions: Int) extends Partitioner { - override def getPartition(key: Any): Int = key.asInstanceOf[Int] +private[knn] trait IndexServing extends ModelLogging with IndexType { + + protected def serve[ + TId: ClassTag, + TVector: ClassTag, + TItem <: Item[TId, TVector] with Product: ClassTag, + TDistance: ClassTag + ]( + uid: String, + indexRdd: RDD[TIndex[TId, TVector, TItem, TDistance]], + numPartitions: Int, + numReplicas: Int, + numThreads: Int + )(implicit + indexServerFactory: IndexServerFactory[TId, TVector, TItem, TDistance] + ): Map[PartitionAndReplica, InetSocketAddress] = { + + val sparkContext = indexRdd.sparkContext + val serializableConfiguration = new SerializableConfiguration(sparkContext.hadoopConfiguration) + + val keyedIndexRdd = indexRdd.flatMap { index => + Range.inclusive(0, numReplicas).map { replica => (TaskContext.getPartitionId(), replica) -> index } + } + + val replicaRdd = + if (numReplicas > 0) keyedIndexRdd.partitionBy(new PartitionReplicaPartitioner(numPartitions, numReplicas)) + else keyedIndexRdd + + val server = RegistrationServerFactory.create(numPartitions, numReplicas) + server.start() + try { + val registrationAddress = server.address + + logInfo(s"Started registration server on ${registrationAddress.getHostName}:${registrationAddress.getPort}") + + val serverRdd = replicaRdd + .map { case ((partitionNum, replicaNum), index) => + val server = indexServerFactory.create(index, serializableConfiguration.value, numThreads) + + server.start() + + val serverAddress = server.address + + logInfo(partitionNum, replicaNum, s"started index server on host ${serverAddress.getHostName}:${serverAddress.getPort}") + + logInfo( + partitionNum, + replicaNum, + s"registering partition at ${registrationAddress.getHostName}:${registrationAddress.getPort}" + ) + + RegistrationClient.register(registrationAddress, partitionNum, replicaNum, serverAddress) + + logInfo(partitionNum, replicaNum, "awaiting requests") + try { + server.awaitTermination() + } finally { + server.shutdownNow() + } + + true + } + + // the count will never complete because the tasks start the index server + new Thread(new IndexRunnable(uid, sparkContext, serverRdd)).start() + + server.awaitRegistrations() + } finally { + server.shutdownNow() + } + + } } diff --git a/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/QueryIterator.scala b/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/QueryIterator.scala new file mode 100644 index 0000000..0279cca --- /dev/null +++ b/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/QueryIterator.scala @@ -0,0 +1,32 @@ +package com.github.jelmerk.spark.knn + +import java.net.InetSocketAddress + +import com.github.jelmerk.server.registration.PartitionAndReplica + +class QueryIterator[TId, TVector, TDistance, TQueryId]( + indices: Map[PartitionAndReplica, InetSocketAddress], + indexClientFactory: IndexClientFactory[TId, TVector, TDistance], + records: Iterator[(Seq[Int], TQueryId, TVector)], + batchSize: Int, + k: Int +) extends Iterator[(TQueryId, Seq[Neighbor[TId, TDistance]])] { + + private val client = indexClientFactory.create(indices) + + private val iterator = records + .grouped(batchSize) + .map(batch => client.search[TQueryId](batch, k)) + .reduce((a, b) => a ++ b) + + override def hasNext: Boolean = iterator.hasNext + + override def next(): (TQueryId, Seq[Neighbor[TId, TDistance]]) = { + val result = iterator.next() + + if (!hasNext) { + client.shutdown() + } + result + } +} diff --git a/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/bruteforce/BruteForceSimilarity.scala b/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/bruteforce/BruteForceSimilarity.scala index 9d758a8..94703f7 100644 --- a/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/bruteforce/BruteForceSimilarity.scala +++ b/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/bruteforce/BruteForceSimilarity.scala @@ -1,6 +1,7 @@ package com.github.jelmerk.spark.knn.bruteforce import java.io.InputStream +import java.net.InetSocketAddress import scala.reflect.ClassTag import scala.reflect.runtime.universe._ @@ -8,37 +9,70 @@ import scala.reflect.runtime.universe._ import com.github.jelmerk.knn.ObjectSerializer import com.github.jelmerk.knn.scalalike.{DistanceFunction, Item} import com.github.jelmerk.knn.scalalike.bruteforce.BruteForceIndex +import com.github.jelmerk.server.registration.PartitionAndReplica import com.github.jelmerk.spark.knn._ +import org.apache.spark.SparkContext import org.apache.spark.ml.param._ import org.apache.spark.ml.util.{Identifiable, MLReadable, MLReader, MLWritable, MLWriter} import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.types.StructType -/** Companion class for BruteForceSimilarityModel. - */ -object BruteForceSimilarityModel extends MLReadable[BruteForceSimilarityModel] { +private[bruteforce] trait BruteForceIndexType extends IndexType { + protected override type TIndex[TId, TVector, TItem <: Item[TId, TVector], TDistance] = + BruteForceIndex[TId, TVector, TItem, TDistance] - private[knn] class BruteForceModelReader extends KnnModelReader[BruteForceSimilarityModel] { + protected override implicit def indexClassTag[TId: ClassTag, TVector: ClassTag, TItem <: Item[ + TId, + TVector + ]: ClassTag, TDistance: ClassTag]: ClassTag[TIndex[TId, TVector, TItem, TDistance]] = + ClassTag(classOf[BruteForceIndex[TId, TVector, TItem, TDistance]]) +} + +private[bruteforce] trait BruteForceIndexLoader extends IndexLoader with BruteForceIndexType { + protected def loadIndex[TId, TVector, TItem <: Item[TId, TVector] with Product, TDistance]( + inputStream: InputStream, + minCapacity: Int + ): BruteForceIndex[TId, TVector, TItem, TDistance] = BruteForceIndex.loadFromInputStream(inputStream) +} - override protected def createModel[ - TId: TypeTag, - TVector: TypeTag, - TItem <: Item[TId, TVector] with Product: TypeTag, - TDistance: TypeTag - ](uid: String, outputDir: String, numPartitions: Int)(implicit - evId: ClassTag[TId], - evVector: ClassTag[TVector], - distanceNumeric: Numeric[TDistance] - ): BruteForceSimilarityModel = - new BruteForceSimilarityModelImpl[TId, TVector, TItem, TDistance](uid, outputDir, numPartitions) +private[bruteforce] trait BruteForceModelCreator extends ModelCreator[BruteForceSimilarityModel] { + protected def createModel[ + TId: TypeTag, + TVector: TypeTag, + TItem <: Item[TId, TVector] with Product: TypeTag, + TDistance: TypeTag + ]( + uid: String, + numPartitions: Int, + numReplicas: Int, + numThreads: Int, + sparkContext: SparkContext, + indices: Map[PartitionAndReplica, InetSocketAddress], + clientFactory: IndexClientFactory[TId, TVector, TDistance] + ): BruteForceSimilarityModel = + new BruteForceSimilarityModelImpl[TId, TVector, TItem, TDistance]( + uid, + numPartitions, + numReplicas, + numThreads, + sparkContext, + indices, + clientFactory + ) +} - } +/** Companion class for BruteForceSimilarityModel. */ +object BruteForceSimilarityModel extends MLReadable[BruteForceSimilarityModel] { + + private[knn] class BruteForceModelReader + extends KnnModelReader[BruteForceSimilarityModel] + with BruteForceModelCreator + with BruteForceIndexLoader override def read: MLReader[BruteForceSimilarityModel] = new BruteForceModelReader } -/** Model produced by `BruteForceSimilarity`. - */ +/** Model produced by `BruteForceSimilarity`. */ abstract class BruteForceSimilarityModel extends KnnModelBase[BruteForceSimilarityModel] with KnnModelParams @@ -49,10 +83,14 @@ private[knn] class BruteForceSimilarityModelImpl[ TVector: TypeTag, TItem <: Item[TId, TVector] with Product: TypeTag, TDistance: TypeTag -](override val uid: String, val outputDir: String, numPartitions: Int)(implicit - evId: ClassTag[TId], - evVector: ClassTag[TVector], - distanceNumeric: Numeric[TDistance] +]( + override val uid: String, + val numPartitions: Int, + val numReplicas: Int, + val numThreads: Int, + val sparkContext: SparkContext, + val indexAddresses: Map[PartitionAndReplica, InetSocketAddress], + val clientFactory: IndexClientFactory[TId, TVector, TDistance] ) extends BruteForceSimilarityModel with KnnModelOps[ BruteForceSimilarityModel, @@ -63,12 +101,18 @@ private[knn] class BruteForceSimilarityModelImpl[ BruteForceIndex[TId, TVector, TItem, TDistance] ] { - override def getNumPartitions: Int = numPartitions - override def transform(dataset: Dataset[_]): DataFrame = typedTransform(dataset) override def copy(extra: ParamMap): BruteForceSimilarityModel = { - val copied = new BruteForceSimilarityModelImpl[TId, TVector, TItem, TDistance](uid, outputDir, numPartitions) + val copied = new BruteForceSimilarityModelImpl[TId, TVector, TItem, TDistance]( + uid, + numPartitions, + numReplicas, + numThreads, + sparkContext, + indexAddresses, + clientFactory + ) copyValues(copied, extra).setParent(parent) } @@ -94,10 +138,10 @@ private[knn] class BruteForceSimilarityModelImpl[ * @param uid * identifier */ -class BruteForceSimilarity(override val uid: String) extends KnnAlgorithm[BruteForceSimilarityModel](uid) { - - override protected type TIndex[TId, TVector, TItem <: Item[TId, TVector], TDistance] = - BruteForceIndex[TId, TVector, TItem, TDistance] +class BruteForceSimilarity(override val uid: String) + extends KnnAlgorithm[BruteForceSimilarityModel](uid) + with BruteForceModelCreator + with BruteForceIndexLoader { def this() = this(Identifiable.randomUID("brute_force")) @@ -110,26 +154,6 @@ class BruteForceSimilarity(override val uid: String) extends KnnAlgorithm[BruteF idSerializer: ObjectSerializer[TId], itemSerializer: ObjectSerializer[TItem] ): BruteForceIndex[TId, TVector, TItem, TDistance] = - BruteForceIndex[TId, TVector, TItem, TDistance]( - dimensions, - distanceFunction - ) - - override protected def loadIndex[TId, TVector, TItem <: Item[TId, TVector] with Product, TDistance]( - inputStream: InputStream, - minCapacity: Int - ): BruteForceIndex[TId, TVector, TItem, TDistance] = BruteForceIndex.loadFromInputStream(inputStream) - - override protected def createModel[ - TId: TypeTag, - TVector: TypeTag, - TItem <: Item[TId, TVector] with Product: TypeTag, - TDistance: TypeTag - ](uid: String, outputDir: String, numPartitions: Int)(implicit - evId: ClassTag[TId], - evVector: ClassTag[TVector], - distanceNumeric: Numeric[TDistance] - ): BruteForceSimilarityModel = - new BruteForceSimilarityModelImpl[TId, TVector, TItem, TDistance](uid, outputDir, numPartitions) + BruteForceIndex[TId, TVector, TItem, TDistance](dimensions, distanceFunction) } diff --git a/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/hnsw/HnswSimilarity.scala b/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/hnsw/HnswSimilarity.scala index 310d5ba..0851fdf 100644 --- a/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/hnsw/HnswSimilarity.scala +++ b/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/hnsw/HnswSimilarity.scala @@ -1,6 +1,7 @@ package com.github.jelmerk.spark.knn.hnsw import java.io.InputStream +import java.net.InetSocketAddress import scala.reflect.ClassTag import scala.reflect.runtime.universe._ @@ -8,12 +9,63 @@ import scala.reflect.runtime.universe._ import com.github.jelmerk.knn import com.github.jelmerk.knn.scalalike.{DistanceFunction, Item} import com.github.jelmerk.knn.scalalike.hnsw._ +import com.github.jelmerk.server.registration.PartitionAndReplica import com.github.jelmerk.spark.knn._ +import org.apache.spark.SparkContext import org.apache.spark.ml.param._ -import org.apache.spark.ml.util.{Identifiable, MLReadable, MLReader, MLWritable, MLWriter} +import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.types.StructType +private[hnsw] trait HnswIndexType extends IndexType { + protected override type TIndex[TId, TVector, TItem <: Item[TId, TVector], TDistance] = + HnswIndex[TId, TVector, TItem, TDistance] + + protected override implicit def indexClassTag[TId: ClassTag, TVector: ClassTag, TItem <: Item[ + TId, + TVector + ]: ClassTag, TDistance: ClassTag]: ClassTag[TIndex[TId, TVector, TItem, TDistance]] = + ClassTag(classOf[HnswIndex[TId, TVector, TItem, TDistance]]) + +} + +private[hnsw] trait HnswIndexLoader extends IndexLoader with HnswIndexType { + protected override def loadIndex[TId, TVector, TItem <: Item[TId, TVector] with Product, TDistance]( + inputStream: InputStream, + minCapacity: Int + ): HnswIndex[TId, TVector, TItem, TDistance] = { + val index = HnswIndex.loadFromInputStream[TId, TVector, TItem, TDistance](inputStream) + index.resize(index.maxItemCount + minCapacity) + index + } +} + +private[hnsw] trait HnswModelCreator extends ModelCreator[HnswSimilarityModel] { + protected def createModel[ + TId: TypeTag, + TVector: TypeTag, + TItem <: Item[TId, TVector] with Product: TypeTag, + TDistance: TypeTag + ]( + uid: String, + numPartitions: Int, + numReplicas: Int, + numThreads: Int, + sparkContext: SparkContext, + indices: Map[PartitionAndReplica, InetSocketAddress], + clientFactory: IndexClientFactory[TId, TVector, TDistance] + ): HnswSimilarityModel = + new HnswSimilarityModelImpl[TId, TVector, TItem, TDistance]( + uid, + numPartitions, + numReplicas, + numThreads, + sparkContext, + indices, + clientFactory + ) +} + private[hnsw] trait HnswParams extends KnnAlgorithmParams with HnswModelParams { /** The number of bi-directional links created for every new element during construction. @@ -74,20 +126,10 @@ private[hnsw] trait HnswModelParams extends KnnModelParams { */ object HnswSimilarityModel extends MLReadable[HnswSimilarityModel] { - private[hnsw] class HnswModelReader extends KnnModelReader[HnswSimilarityModel] { - - override protected def createModel[ - TId: TypeTag, - TVector: TypeTag, - TItem <: Item[TId, TVector] with Product: TypeTag, - TDistance: TypeTag - ](uid: String, outputDir: String, numPartitions: Int)(implicit - evId: ClassTag[TId], - evVector: ClassTag[TVector], - distanceNumeric: Numeric[TDistance] - ): HnswSimilarityModel = - new HnswSimilarityModelImpl[TId, TVector, TItem, TDistance](uid, outputDir, numPartitions) - } + private[hnsw] class HnswModelReader + extends KnnModelReader[HnswSimilarityModel] + with HnswIndexLoader + with HnswModelCreator override def read: MLReader[HnswSimilarityModel] = new HnswModelReader @@ -107,19 +149,29 @@ private[knn] class HnswSimilarityModelImpl[ TVector: TypeTag, TItem <: Item[TId, TVector] with Product: TypeTag, TDistance: TypeTag -](override val uid: String, val outputDir: String, numPartitions: Int)(implicit - evId: ClassTag[TId], - evVector: ClassTag[TVector], - distanceNumeric: Numeric[TDistance] +]( + override val uid: String, + val numPartitions: Int, + val numReplicas: Int, + val numThreads: Int, + val sparkContext: SparkContext, + val indexAddresses: Map[PartitionAndReplica, InetSocketAddress], + val clientFactory: IndexClientFactory[TId, TVector, TDistance] ) extends HnswSimilarityModel with KnnModelOps[HnswSimilarityModel, TId, TVector, TItem, TDistance, HnswIndex[TId, TVector, TItem, TDistance]] { - override def getNumPartitions: Int = numPartitions - override def transform(dataset: Dataset[_]): DataFrame = typedTransform(dataset) override def copy(extra: ParamMap): HnswSimilarityModel = { - val copied = new HnswSimilarityModelImpl[TId, TVector, TItem, TDistance](uid, outputDir, numPartitions) + val copied = new HnswSimilarityModelImpl[TId, TVector, TItem, TDistance]( + uid, + numPartitions, + numReplicas, + numThreads, + sparkContext, + indexAddresses, + clientFactory + ) copyValues(copied, extra).setParent(parent) } @@ -142,10 +194,11 @@ private[knn] class HnswSimilarityModelImpl[ * @param uid * identifier */ -class HnswSimilarity(override val uid: String) extends KnnAlgorithm[HnswSimilarityModel](uid) with HnswParams { - - override protected type TIndex[TId, TVector, TItem <: Item[TId, TVector], TDistance] = - HnswIndex[TId, TVector, TItem, TDistance] +class HnswSimilarity(override val uid: String) + extends KnnAlgorithm[HnswSimilarityModel](uid) + with HnswParams + with HnswIndexLoader + with HnswModelCreator { def this() = this(Identifiable.randomUID("hnsw")) @@ -178,25 +231,4 @@ class HnswSimilarity(override val uid: String) extends KnnAlgorithm[HnswSimilari idSerializer, itemSerializer ) - - override protected def loadIndex[TId, TVector, TItem <: Item[TId, TVector] with Product, TDistance]( - inputStream: InputStream, - minCapacity: Int - ): HnswIndex[TId, TVector, TItem, TDistance] = { - val index = HnswIndex.loadFromInputStream[TId, TVector, TItem, TDistance](inputStream) - index.resize(index.maxItemCount + minCapacity) - index - } - - override protected def createModel[ - TId: TypeTag, - TVector: TypeTag, - TItem <: Item[TId, TVector] with Product: TypeTag, - TDistance: TypeTag - ](uid: String, outputDir: String, numPartitions: Int)(implicit - evId: ClassTag[TId], - evVector: ClassTag[TVector], - distanceNumeric: Numeric[TDistance] - ): HnswSimilarityModel = - new HnswSimilarityModelImpl[TId, TVector, TItem, TDistance](uid, outputDir, numPartitions) } diff --git a/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/knn.scala b/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/knn.scala index 39bdd78..9bcc884 100644 --- a/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/knn.scala +++ b/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/knn.scala @@ -2,11 +2,89 @@ package com.github.jelmerk.spark import java.io.{ObjectInput, ObjectOutput} -import com.github.jelmerk.knn.scalalike.ObjectSerializer -import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} +import scala.language.implicitConversions +import scala.util.Try + +import com.github.jelmerk.knn.Jdk17DistanceFunctions +import com.github.jelmerk.knn.scalalike.{ + doubleBrayCurtisDistance, + doubleCanberraDistance, + doubleCorrelationDistance, + doubleCosineDistance, + doubleEuclideanDistance, + doubleInnerProduct, + doubleManhattanDistance, + floatBrayCurtisDistance, + floatCanberraDistance, + floatCorrelationDistance, + floatCosineDistance, + floatEuclideanDistance, + floatInnerProduct, + floatManhattanDistance, + DistanceFunction, + Item, + ObjectSerializer +} +import com.github.jelmerk.knn.scalalike.jdk17DistanceFunctions.{ + vectorFloat128BrayCurtisDistance, + vectorFloat128CanberraDistance, + vectorFloat128CosineDistance, + vectorFloat128EuclideanDistance, + vectorFloat128InnerProduct, + vectorFloat128ManhattanDistance +} +import com.github.jelmerk.server.index.{ + DenseVector, + DoubleArrayVector, + FloatArrayVector, + IndexServerFactory, + Result, + SearchRequest, + SparseVector +} +import com.github.jelmerk.spark.linalg.functions.VectorDistanceFunctions +import org.apache.spark.ml.linalg.{DenseVector => SparkDenseVector, SparseVector => SparkSparseVector, Vector, Vectors} package object knn { + private[knn] case class IntDoubleArrayIndexItem(id: Int, vector: Array[Double]) extends Item[Int, Array[Double]] { + override def dimensions: Int = vector.length + } + + private[knn] case class LongDoubleArrayIndexItem(id: Long, vector: Array[Double]) extends Item[Long, Array[Double]] { + override def dimensions: Int = vector.length + } + + private[knn] case class StringDoubleArrayIndexItem(id: String, vector: Array[Double]) + extends Item[String, Array[Double]] { + override def dimensions: Int = vector.length + } + + private[knn] case class IntFloatArrayIndexItem(id: Int, vector: Array[Float]) extends Item[Int, Array[Float]] { + override def dimensions: Int = vector.length + } + + private[knn] case class LongFloatArrayIndexItem(id: Long, vector: Array[Float]) extends Item[Long, Array[Float]] { + override def dimensions: Int = vector.length + } + + private[knn] case class StringFloatArrayIndexItem(id: String, vector: Array[Float]) + extends Item[String, Array[Float]] { + override def dimensions: Int = vector.length + } + + private[knn] case class IntVectorIndexItem(id: Int, vector: Vector) extends Item[Int, Vector] { + override def dimensions: Int = vector.size + } + + private[knn] case class LongVectorIndexItem(id: Long, vector: Vector) extends Item[Long, Vector] { + override def dimensions: Int = vector.size + } + + private[knn] case class StringVectorIndexItem(id: String, vector: Vector) extends Item[String, Vector] { + override def dimensions: Int = vector.size + } + private[knn] implicit object StringSerializer extends ObjectSerializer[String] { override def write(item: String, out: ObjectOutput): Unit = out.writeUTF(item) override def read(in: ObjectInput): String = in.readUTF() @@ -58,12 +136,12 @@ package object knn { private[knn] implicit object VectorSerializer extends ObjectSerializer[Vector] { override def write(item: Vector, out: ObjectOutput): Unit = item match { - case v: DenseVector => + case v: SparkDenseVector => out.writeBoolean(true) out.writeInt(v.size) v.values.foreach(out.writeDouble) - case v: SparseVector => + case v: SparkSparseVector => out.writeBoolean(false) out.writeInt(v.size) out.writeInt(v.indices.length) @@ -220,4 +298,231 @@ package object knn { } } + private[knn] implicit object IntVectorIndexServerFactory + extends IndexServerFactory[Int, Vector, IntVectorIndexItem, Double]( + extractVector, + convertIntId, + convertDoubleDistance + ) + + private[knn] implicit object LongVectorIndexServerFactory + extends IndexServerFactory[Long, Vector, LongVectorIndexItem, Double]( + extractVector, + convertLongId, + convertDoubleDistance + ) + + private[knn] implicit object StringVectorIndexServerFactory + extends IndexServerFactory[String, Vector, StringVectorIndexItem, Double]( + extractVector, + convertStringId, + convertDoubleDistance + ) + + private[knn] implicit object IntFloatArrayIndexServerFactory + extends IndexServerFactory[Int, Array[Float], IntFloatArrayIndexItem, Float]( + extractFloatArray, + convertIntId, + convertFloatDistance + ) + + private[knn] implicit object LongFloatArrayIndexServerFactory + extends IndexServerFactory[Long, Array[Float], LongFloatArrayIndexItem, Float]( + extractFloatArray, + convertLongId, + convertFloatDistance + ) + + private[knn] implicit object StringFloatArrayIndexServerFactory + extends IndexServerFactory[String, Array[Float], StringFloatArrayIndexItem, Float]( + extractFloatArray, + convertStringId, + convertFloatDistance + ) + + private[knn] implicit object IntDoubleArrayIndexServerFactory + extends IndexServerFactory[Int, Array[Double], IntDoubleArrayIndexItem, Double]( + extractDoubleArray, + convertIntId, + convertDoubleDistance + ) + + private[knn] implicit object LongDoubleArrayIndexServerFactory + extends IndexServerFactory[Long, Array[Double], LongDoubleArrayIndexItem, Double]( + extractDoubleArray, + convertLongId, + convertDoubleDistance + ) + + private[knn] implicit object StringDoubleArrayIndexServerFactory + extends IndexServerFactory[String, Array[Double], StringDoubleArrayIndexItem, Double]( + extractDoubleArray, + convertStringId, + convertDoubleDistance + ) + + private[knn] implicit object IntVectorIndexClientFactory + extends IndexClientFactory[Int, Vector, Double]( + convertVector, + extractIntId, + extractDoubleDistance, + Ordering.Double + ) + + private[knn] implicit object LongVectorIndexClientFactory + extends IndexClientFactory[Long, Vector, Double]( + convertVector, + extractLongId, + extractDoubleDistance, + Ordering.Double + ) + + private[knn] implicit object StringVectorIndexClientFactory + extends IndexClientFactory[String, Vector, Double]( + convertVector, + extractStringId, + extractDoubleDistance, + Ordering.Double + ) + + private[knn] implicit object IntFloatArrayIndexClientFactory + extends IndexClientFactory[Int, Array[Float], Float]( + convertFloatArray, + extractIntId, + extractFloatDistance, + Ordering.Float + ) + + private[knn] implicit object LongFloatArrayIndexClientFactory + extends IndexClientFactory[Long, Array[Float], Float]( + convertFloatArray, + extractLongId, + extractFloatDistance, + Ordering.Float + ) + + private[knn] implicit object StringFloatArrayIndexClientFactory + extends IndexClientFactory[String, Array[Float], Float]( + convertFloatArray, + extractStringId, + extractFloatDistance, + Ordering.Float + ) + + private[knn] implicit object IntDoubleArrayIndexClientFactory + extends IndexClientFactory[Int, Array[Double], Double]( + convertDoubleArray, + extractIntId, + extractFloatDistance, + Ordering.Double + ) + + private[knn] implicit object LongDoubleArrayIndexClientFactory + extends IndexClientFactory[Long, Array[Double], Double]( + convertDoubleArray, + extractLongId, + extractFloatDistance, + Ordering.Double + ) + + private[knn] implicit object StringDoubleArrayIndexClientFactory + extends IndexClientFactory[String, Array[Double], Double]( + convertDoubleArray, + extractStringId, + extractFloatDistance, + Ordering.Double + ) + + private[knn] def convertFloatArray(vector: Array[Float]): SearchRequest.Vector = + SearchRequest.Vector.FloatArrayVector(FloatArrayVector(vector)) + + private[knn] def convertDoubleArray(vector: Array[Double]): SearchRequest.Vector = + SearchRequest.Vector.DoubleArrayVector(DoubleArrayVector(vector)) + + private[knn] def convertVector(vector: Vector): SearchRequest.Vector = vector match { + case v: SparkDenseVector => SearchRequest.Vector.DenseVector(DenseVector(v.values)) + case v: SparkSparseVector => SearchRequest.Vector.SparseVector(SparseVector(vector.size, v.indices, v.values)) + } + + private[knn] def extractDoubleDistance(result: Result): Double = result.getDoubleDistance + + private[knn] def extractFloatDistance(result: Result): Float = result.getFloatDistance + + private[knn] def extractStringId(result: Result): String = result.getStringId + + private[knn] def extractLongId(result: Result): Long = result.getLongId + + private[knn] def extractIntId(result: Result): Int = result.getIntId + + private[knn] def extractFloatArray(request: SearchRequest): Array[Float] = request.vector.floatArrayVector + .map(_.values) + .orNull + + private[knn] def extractDoubleArray(request: SearchRequest): Array[Double] = request.vector.doubleArrayVector + .map(_.values) + .orNull + + private[knn] def extractVector(request: SearchRequest): Vector = + if (request.vector.isDenseVector) request.vector.denseVector.map { v => new SparkDenseVector(v.values) }.orNull + else request.vector.sparseVector.map { v => new SparkSparseVector(v.size, v.indices, v.values) }.orNull + + private[knn] def convertStringId(value: String): Result.Id = Result.Id.StringId(value) + private[knn] def convertLongId(value: Long): Result.Id = Result.Id.LongId(value) + private[knn] def convertIntId(value: Int): Result.Id = Result.Id.IntId(value) + + private[knn] def convertFloatDistance(value: Float): Result.Distance = Result.Distance.FloatDistance(value) + private[knn] def convertDoubleDistance(value: Double): Result.Distance = Result.Distance.DoubleDistance(value) + + implicit private[knn] def floatArrayDistanceFunction(name: String): DistanceFunction[Array[Float], Float] = + (name, vectorApiAvailable) match { + case ("bray-curtis", true) => vectorFloat128BrayCurtisDistance + case ("bray-curtis", _) => floatBrayCurtisDistance + case ("canberra", true) => vectorFloat128CanberraDistance + case ("canberra", _) => floatCanberraDistance + case ("correlation", _) => floatCorrelationDistance + case ("cosine", true) => vectorFloat128CosineDistance + case ("cosine", _) => floatCosineDistance + case ("euclidean", true) => vectorFloat128EuclideanDistance + case ("euclidean", _) => floatEuclideanDistance + case ("inner-product", true) => vectorFloat128InnerProduct + case ("inner-product", _) => floatInnerProduct + case ("manhattan", true) => vectorFloat128ManhattanDistance + case ("manhattan", _) => floatManhattanDistance + case (value, _) => userDistanceFunction(value) + } + + implicit private[knn] def doubleArrayDistanceFunction(name: String): DistanceFunction[Array[Double], Double] = + name match { + case "bray-curtis" => doubleBrayCurtisDistance + case "canberra" => doubleCanberraDistance + case "correlation" => doubleCorrelationDistance + case "cosine" => doubleCosineDistance + case "euclidean" => doubleEuclideanDistance + case "inner-product" => doubleInnerProduct + case "manhattan" => doubleManhattanDistance + case value => userDistanceFunction(value) + } + + implicit private[knn] def vectorDistanceFunction(name: String): DistanceFunction[Vector, Double] = name match { + case "bray-curtis" => VectorDistanceFunctions.brayCurtisDistance + case "canberra" => VectorDistanceFunctions.canberraDistance + case "correlation" => VectorDistanceFunctions.correlationDistance + case "cosine" => VectorDistanceFunctions.cosineDistance + case "euclidean" => VectorDistanceFunctions.euclideanDistance + case "inner-product" => VectorDistanceFunctions.innerProduct + case "manhattan" => VectorDistanceFunctions.manhattanDistance + case value => userDistanceFunction(value) + } + + private def vectorApiAvailable: Boolean = try { + val _ = Jdk17DistanceFunctions.VECTOR_FLOAT_128_COSINE_DISTANCE + true + } catch { + case _: Throwable => false + } + + private def userDistanceFunction[TVector, TDistance](name: String): DistanceFunction[TVector, TDistance] = + Try(Class.forName(name).getDeclaredConstructor().newInstance()).toOption + .collect { case f: DistanceFunction[TVector @unchecked, TDistance @unchecked] => f } + .getOrElse(throw new IllegalArgumentException(s"$name is not a valid distance functions.")) } diff --git a/hnswlib-spark/src/test/scala/ClientTest.scala b/hnswlib-spark/src/test/scala/ClientTest.scala new file mode 100644 index 0000000..9299af7 --- /dev/null +++ b/hnswlib-spark/src/test/scala/ClientTest.scala @@ -0,0 +1,118 @@ +import java.util.concurrent.{Executors, LinkedBlockingQueue} +import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger} + +import com.github.jelmerk.server.index.{FloatArrayVector, IndexServiceGrpc, SearchRequest, SearchResponse} +import io.grpc.netty.NettyChannelBuilder +import io.grpc.stub.StreamObserver + +object ClientTest { + + def main(args: Array[String]): Unit = { + + println("start") + val clients = Range.inclusive(1, 4).map { shard => + val channel = NettyChannelBuilder + .forAddress("127.0.0.1", 8080) + .usePlaintext + .build() + + IndexServiceGrpc.stub(channel) + + } + + println("created clients") + + val threadPool = Executors.newFixedThreadPool(1) + + val requests = Iterator.fill(100)( + SearchRequest( + vector = SearchRequest.Vector.FloatArrayVector(FloatArrayVector(Array(0.1f, 0.2f))), + k = 10 + ) + ) + + val batches = requests.grouped(10) + + val abc = batches + .map { batch => + val (requestObservers, responseIterators) = clients.map { client => + val responseStreamObserver = new MyStreamObserver[SearchResponse](batch.size) + val requestStreamObserver = client.search(responseStreamObserver) + + (requestStreamObserver, responseStreamObserver) + }.unzip + + threadPool.submit(new Runnable { + override def run(): Unit = { + val batchIter = batch.iterator + for { + request <- batchIter + last = !batchIter.hasNext + observer <- requestObservers + } { + observer.onNext(request) + if (last) { + observer.onCompleted() + } + } + } + }) + + val compoundIterator = new CompoundIterator(responseIterators) + + compoundIterator + .map { responses => + // TODO harcoded to float distance now + val results = responses.flatMap(_.results).sortBy(_.distance.floatDistance).take(10) + SearchResponse(results) + } + + } + .reduce((a, b) => a ++ b) + + abc.foreach(println) + + threadPool.shutdownNow() + +// abc.foreach(println) + + println("done.") + System.in.read() + } +} + +class MyStreamObserver[T](expected: Int) extends StreamObserver[T] with Iterator[T] { + + private val queue = new LinkedBlockingQueue[Either[Throwable, T]] + private val counter = new AtomicInteger() + private val done = new AtomicBoolean(false) + + override def onNext(value: T): Unit = { + queue.add(Right(value)) + counter.incrementAndGet() + } + + override def onError(t: Throwable): Unit = { + queue.add(Left(t)) + done.set(true) + } + + override def onCompleted(): Unit = { + done.set(true) + } + + override def hasNext: Boolean = { + !queue.isEmpty || (counter.get() < expected && !done.get()) + } + + override def next(): T = queue.take() match { + case Right(value) => value + case Left(t) => throw t + } +} + +class CompoundIterator[T](iterators: Seq[Iterator[T]]) extends Iterator[Seq[T]] { + override def hasNext: Boolean = iterators.forall(_.hasNext) + + override def next(): Seq[T] = iterators.map(_.next()) +} diff --git a/hnswlib-spark/src/test/scala/RegistrationClientTest.scala b/hnswlib-spark/src/test/scala/RegistrationClientTest.scala new file mode 100644 index 0000000..433d85c --- /dev/null +++ b/hnswlib-spark/src/test/scala/RegistrationClientTest.scala @@ -0,0 +1,33 @@ +import scala.concurrent.Await +import scala.concurrent.duration.Duration + +import com.github.jelmerk.server.registration.{RegisterRequest, RegistrationServiceGrpc} +import io.grpc.netty.NettyChannelBuilder + +object RegistrationClientTest { + + def main(args: Array[String]): Unit = { + + val channel = NettyChannelBuilder + .forAddress("127.0.0.1", 8000) + .usePlaintext + .build() + + val client = RegistrationServiceGrpc.stub(channel) + + (1 to 4).foreach { partitionNo => + val response = client.register( + RegisterRequest( + partitionNum = partitionNo, + replicaNum = 1, + "localhost", + 123 + ) + ) + + Await.result(response, Duration.Inf) + + } + + } +} diff --git a/hnswlib-spark/src/test/scala/RegistrationServerTest.scala b/hnswlib-spark/src/test/scala/RegistrationServerTest.scala new file mode 100644 index 0000000..a96355f --- /dev/null +++ b/hnswlib-spark/src/test/scala/RegistrationServerTest.scala @@ -0,0 +1,59 @@ +//import com.github.jelmerk.server.registration.{RegisterRequest, RegisterResponse, RegistrationServiceGrpc} +//import com.github.jelmerk.server.registration.RegistrationServiceGrpc.RegistrationService +//import io.grpc.netty.NettyServerBuilder +// +//import java.net.InetSocketAddress +//import java.util.concurrent.{ConcurrentHashMap, CountDownLatch, Executors} +//import scala.concurrent.{ExecutionContext, Future} +//import scala.collection.concurrent +//import scala.collection.JavaConverters._ +// +//object RegistrationServerTest { +// +// +// private class DefaultRegistrationService(registrationLatch: CountDownLatch) extends RegistrationService { +// +// val registrations = new ConcurrentHashMap[Int, InetSocketAddress]() +// +// override def register(request: RegisterRequest): Future[RegisterResponse] = { +// +// val previousValue = registrations.put(request.partitionNum, new InetSocketAddress(request.host, request.port)) +// +// if (previousValue == null) { +// registrationLatch.countDown() +// } +// +// Future.successful(RegisterResponse()) +// } +// +// } +// +// def startRegistrationServerAndAwaitRegistrations(numPartitions: Int, host: String, port: Int): Map[PartitionAndReplica, InetSocketAddress] = { +// val registrationLatch = new CountDownLatch(numPartitions) +// +// val executor = Executors.newSingleThreadExecutor() +// +// val executionContext: ExecutionContext = ExecutionContext.fromExecutor(executor) +// +// val service = new DefaultRegistrationService(registrationLatch) +// // Build the gRPC server +// val server = NettyServerBuilder +// .forAddress(new InetSocketAddress(host, port)) +// .addService(RegistrationServiceGrpc.bindService(service, executionContext)) +// .build() +// +// server.start() +// +// registrationLatch.await() +// +// server.shutdownNow() +// executor.shutdownNow() +// +// service.registrations.asScala +// } +// +// def main(args: Array[String]): Unit = { +// val registrations = startRegistrationServerAndAwaitRegistrations(4, "localhost", 8000) +// println(registrations) +// } +//} diff --git a/hnswlib-spark/src/test/scala/ServerTest.scala b/hnswlib-spark/src/test/scala/ServerTest.scala new file mode 100644 index 0000000..24708ef --- /dev/null +++ b/hnswlib-spark/src/test/scala/ServerTest.scala @@ -0,0 +1,55 @@ +import java.util.concurrent.{LinkedBlockingQueue, ThreadPoolExecutor, TimeUnit} + +import scala.concurrent.ExecutionContext + +import com.github.jelmerk.knn.scalalike.{floatCosineDistance, Item} +import com.github.jelmerk.knn.scalalike.hnsw.HnswIndex +import com.github.jelmerk.server.index.{DefaultIndexService, IndexServiceGrpc} +import io.grpc.netty.NettyServerBuilder +import org.apache.hadoop.conf.Configuration + +case class StringFloatArrayIndexItem(id: String, vector: Array[Float]) extends Item[String, Array[Float]] { + override def dimensions: Int = vector.length +} + +object ServerTest { + + def main(args: Array[String]): Unit = { + +// val index = HnswIndex[String, Array[Float], StringFloatArrayIndexItem, Float]( +// dimensions = 2, +// distanceFunction = floatCosineDistance, +// maxItemCount = 1000 +// ) +// +// index.add(StringFloatArrayIndexItem("1", Array(0.1f, 0.8f))) +// index.add(StringFloatArrayIndexItem("2", Array(0.3f, 0.2f))) +// index.add(StringFloatArrayIndexItem("3", Array(0.1f, 0.5f))) +// +// import DefaultIndexService._ +// +// val hadoopConfig = new Configuration() +// +// val executor = new ThreadPoolExecutor( +// 4, 4, 0L, TimeUnit.MILLISECONDS, +// new LinkedBlockingQueue[Runnable]() +// ) +// +// val executionContext: ExecutionContext = ExecutionContext.fromExecutor(executor) +// +// implicit val ec: ExecutionContext = ExecutionContext.global +// val service = new DefaultIndexService(index, hadoopConfig) +// +// // Build the gRPC server +// val server = NettyServerBuilder +// .forPort(8080) +// .addService(IndexServiceGrpc.bindService(service, executionContext)) +// .build() +// +// server.start() +// +// server.awaitTermination() + + } + +} diff --git a/project/plugins.sbt b/project/plugins.sbt index 7f9f2dc..d5bf29c 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -3,3 +3,6 @@ addSbtPlugin("com.github.sbt" % "sbt-dynver" % "5.0.1") addSbtPlugin("com.github.sbt" % "sbt-pgp" % "2.2.1") addSbtPlugin("org.xerial.sbt" % "sbt-sonatype" % "3.10.0") addSbtPlugin("org.scalameta" % "sbt-scalafmt" % "2.4.6") +addSbtPlugin("com.thesamet" % "sbt-protoc" % "1.0.6") + +libraryDependencies += "com.thesamet.scalapb" %% "compilerplugin" % "0.11.11" \ No newline at end of file