diff --git a/build.sbt b/build.sbt index 1c83563..ce5ece2 100644 --- a/build.sbt +++ b/build.sbt @@ -84,9 +84,17 @@ lazy val hnswlibSpark = (project in file("hnswlib-spark")) oldStrategy(x) }, assembly / assemblyShadeRules := Seq( - ShadeRule.rename("com.google.protobuf.**" -> "shaded.com.google.protobuf.@1").inAll, - ShadeRule.rename("com.google.common.**" -> "shaded.com.google.common.@1").inAll, +// ShadeRule.rename("com.google.protobuf.**" -> "shaded.com.google.protobuf.@1").inAll, +// ShadeRule.rename("com.google.common.**" -> "shaded.com.google.common.@1").inAll, + + ShadeRule.rename("google.**" -> "shaded.google.@1").inAll, + ShadeRule.rename("com.google.**" -> "shaded.com.google.@1").inAll, + ShadeRule.rename("io.grpc.**" -> "shaded.io.grpc.@1").inAll, ShadeRule.rename("io.netty.**" -> "shaded.io.netty.@1").inAll + +// ShadeRule.rename("com.github.jelmerk.**" -> "com.github.jelmerk.@1").inAll, +// ShadeRule.rename("*" -> "shaded.@1").inAll + ), Compile / PB.targets := Seq( scalapb.gen() -> (Compile / sourceManaged).value / "scalapb" diff --git a/hnswlib-spark/src/main/scala/com/github/jelmerk/server/registration/RegistrationClient.scala b/hnswlib-spark/src/main/scala/com/github/jelmerk/registration/client/RegistrationClient.scala similarity index 84% rename from hnswlib-spark/src/main/scala/com/github/jelmerk/server/registration/RegistrationClient.scala rename to hnswlib-spark/src/main/scala/com/github/jelmerk/registration/client/RegistrationClient.scala index 4a0d1ae..b08219e 100644 --- a/hnswlib-spark/src/main/scala/com/github/jelmerk/server/registration/RegistrationClient.scala +++ b/hnswlib-spark/src/main/scala/com/github/jelmerk/registration/client/RegistrationClient.scala @@ -1,10 +1,11 @@ -package com.github.jelmerk.server.registration +package com.github.jelmerk.registration.client import java.net.{InetSocketAddress, SocketAddress} import scala.concurrent.Await import scala.concurrent.duration.Duration +import com.github.jelmerk.server.registration.{RegisterRequest, RegisterResponse, RegistrationServiceGrpc} import io.grpc.netty.NettyChannelBuilder object RegistrationClient { diff --git a/hnswlib-spark/src/main/scala/com/github/jelmerk/server/registration/DefaultRegistrationService.scala b/hnswlib-spark/src/main/scala/com/github/jelmerk/registration/server/DefaultRegistrationService.scala similarity index 87% rename from hnswlib-spark/src/main/scala/com/github/jelmerk/server/registration/DefaultRegistrationService.scala rename to hnswlib-spark/src/main/scala/com/github/jelmerk/registration/server/DefaultRegistrationService.scala index 03c2c72..c766b15 100644 --- a/hnswlib-spark/src/main/scala/com/github/jelmerk/server/registration/DefaultRegistrationService.scala +++ b/hnswlib-spark/src/main/scala/com/github/jelmerk/registration/server/DefaultRegistrationService.scala @@ -1,10 +1,11 @@ -package com.github.jelmerk.server.registration +package com.github.jelmerk.registration.server import java.net.InetSocketAddress import java.util.concurrent.{ConcurrentHashMap, CountDownLatch} import scala.concurrent.Future +import com.github.jelmerk.server.registration.{RegisterRequest, RegisterResponse} import com.github.jelmerk.server.registration.RegistrationServiceGrpc.RegistrationService class DefaultRegistrationService(val registrationLatch: CountDownLatch) extends RegistrationService { diff --git a/hnswlib-spark/src/main/scala/com/github/jelmerk/server/registration/RegistrationServerFactory.scala b/hnswlib-spark/src/main/scala/com/github/jelmerk/registration/server/RegistrationServerFactory.scala similarity index 87% rename from hnswlib-spark/src/main/scala/com/github/jelmerk/server/registration/RegistrationServerFactory.scala rename to hnswlib-spark/src/main/scala/com/github/jelmerk/registration/server/RegistrationServerFactory.scala index 6a0884b..8ebba7b 100644 --- a/hnswlib-spark/src/main/scala/com/github/jelmerk/server/registration/RegistrationServerFactory.scala +++ b/hnswlib-spark/src/main/scala/com/github/jelmerk/registration/server/RegistrationServerFactory.scala @@ -1,4 +1,4 @@ -package com.github.jelmerk.server.registration +package com.github.jelmerk.registration.server import java.net.{InetAddress, InetSocketAddress} import java.util.concurrent.{CountDownLatch, Executors} @@ -7,6 +7,7 @@ import scala.concurrent.ExecutionContext import scala.jdk.CollectionConverters.mapAsScalaConcurrentMapConverter import scala.util.Try +import com.github.jelmerk.server.registration.RegistrationServiceGrpc import io.grpc.netty.NettyServerBuilder // TODO this is all a bit messy @@ -33,9 +34,9 @@ class RegistrationServer(numPartitions: Int, numReplicas: Int) { service.registrations.asScala.toMap } - def shutdownNow(): Unit = { - Try(server.shutdownNow()) - Try(executor.shutdownNow()) + def shutdown(): Unit = { + Try(server.shutdown()) + Try(executor.shutdown()) } } diff --git a/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/IndexClient.scala b/hnswlib-spark/src/main/scala/com/github/jelmerk/serving/client/IndexClient.scala similarity index 98% rename from hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/IndexClient.scala rename to hnswlib-spark/src/main/scala/com/github/jelmerk/serving/client/IndexClient.scala index a4c3d08..42c32ca 100644 --- a/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/IndexClient.scala +++ b/hnswlib-spark/src/main/scala/com/github/jelmerk/serving/client/IndexClient.scala @@ -1,4 +1,4 @@ -package com.github.jelmerk.spark.knn +package com.github.jelmerk.serving.client import java.net.InetSocketAddress import java.util.concurrent.{Executors, LinkedBlockingQueue} @@ -10,8 +10,8 @@ import scala.concurrent.duration.Duration import scala.language.implicitConversions import scala.util.Random +import com.github.jelmerk.registration.server.PartitionAndReplica import com.github.jelmerk.server.index._ -import com.github.jelmerk.server.registration.PartitionAndReplica import io.grpc.netty.NettyChannelBuilder import io.grpc.stub.StreamObserver diff --git a/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/IndexClientFactory.scala b/hnswlib-spark/src/main/scala/com/github/jelmerk/serving/client/IndexClientFactory.scala similarity index 84% rename from hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/IndexClientFactory.scala rename to hnswlib-spark/src/main/scala/com/github/jelmerk/serving/client/IndexClientFactory.scala index 3016eb2..ab2ed2e 100644 --- a/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/IndexClientFactory.scala +++ b/hnswlib-spark/src/main/scala/com/github/jelmerk/serving/client/IndexClientFactory.scala @@ -1,9 +1,9 @@ -package com.github.jelmerk.spark.knn +package com.github.jelmerk.serving.client import java.net.InetSocketAddress +import com.github.jelmerk.registration.server.PartitionAndReplica import com.github.jelmerk.server.index.{Result, SearchRequest} -import com.github.jelmerk.server.registration.PartitionAndReplica class IndexClientFactory[TId, TVector, TDistance]( vectorConverter: TVector => SearchRequest.Vector, diff --git a/hnswlib-spark/src/main/scala/com/github/jelmerk/serving/client/Neighbor.scala b/hnswlib-spark/src/main/scala/com/github/jelmerk/serving/client/Neighbor.scala new file mode 100644 index 0000000..c3f3a8a --- /dev/null +++ b/hnswlib-spark/src/main/scala/com/github/jelmerk/serving/client/Neighbor.scala @@ -0,0 +1,15 @@ +package com.github.jelmerk.serving.client + +/** Neighbor of an item. + * + * @param neighbor + * identifies the neighbor + * @param distance + * distance to the item + * + * @tparam TId + * type of the index item identifier + * @tparam TDistance + * type of distance + */ +case class Neighbor[TId, TDistance](neighbor: TId, distance: TDistance) // TODO move this diff --git a/hnswlib-spark/src/main/scala/com/github/jelmerk/server/index/DefaultIndexService.scala b/hnswlib-spark/src/main/scala/com/github/jelmerk/serving/server/DefaultIndexService.scala similarity index 96% rename from hnswlib-spark/src/main/scala/com/github/jelmerk/server/index/DefaultIndexService.scala rename to hnswlib-spark/src/main/scala/com/github/jelmerk/serving/server/DefaultIndexService.scala index bb481cb..dcf0e37 100644 --- a/hnswlib-spark/src/main/scala/com/github/jelmerk/server/index/DefaultIndexService.scala +++ b/hnswlib-spark/src/main/scala/com/github/jelmerk/serving/server/DefaultIndexService.scala @@ -1,9 +1,10 @@ -package com.github.jelmerk.server.index +package com.github.jelmerk.serving.server import scala.concurrent.{ExecutionContext, Future} import scala.language.implicitConversions import com.github.jelmerk.knn.scalalike.{Index, Item} +import com.github.jelmerk.server.index._ import com.github.jelmerk.server.index.IndexServiceGrpc.IndexService import io.grpc.stub.StreamObserver import org.apache.commons.io.output.CountingOutputStream diff --git a/hnswlib-spark/src/main/scala/com/github/jelmerk/server/index/IndexServerFactory.scala b/hnswlib-spark/src/main/scala/com/github/jelmerk/serving/server/IndexServerFactory.scala similarity index 91% rename from hnswlib-spark/src/main/scala/com/github/jelmerk/server/index/IndexServerFactory.scala rename to hnswlib-spark/src/main/scala/com/github/jelmerk/serving/server/IndexServerFactory.scala index 6a3e25d..52b074b 100644 --- a/hnswlib-spark/src/main/scala/com/github/jelmerk/server/index/IndexServerFactory.scala +++ b/hnswlib-spark/src/main/scala/com/github/jelmerk/serving/server/IndexServerFactory.scala @@ -1,4 +1,4 @@ -package com.github.jelmerk.server.index +package com.github.jelmerk.serving.server import java.net.{InetAddress, InetSocketAddress} import java.util.concurrent.{LinkedBlockingQueue, ThreadPoolExecutor, TimeUnit} @@ -7,6 +7,7 @@ import scala.concurrent.ExecutionContext import scala.util.Try import com.github.jelmerk.knn.scalalike.{Index, Item} +import com.github.jelmerk.server.index.{IndexServiceGrpc, Result, SearchRequest} import io.grpc.netty.NettyServerBuilder import org.apache.hadoop.conf.Configuration @@ -46,9 +47,9 @@ class IndexServer[TId, TVector, TItem <: Item[TId, TVector] with Product, TDista server.awaitTermination() } - def shutdownNow(): Unit = { - Try(server.shutdownNow()) - Try(executor.shutdownNow()) + def shutdown(): Unit = { + Try(server.shutdown()) + Try(executor.shutdown()) } } diff --git a/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/KnnAlgorithm.scala b/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/KnnAlgorithm.scala index 397d9c3..e37d044 100644 --- a/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/KnnAlgorithm.scala +++ b/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/KnnAlgorithm.scala @@ -10,8 +10,10 @@ import scala.util.control.NonFatal import com.github.jelmerk.knn.ObjectSerializer import com.github.jelmerk.knn.scalalike._ -import com.github.jelmerk.server.index.IndexServerFactory -import com.github.jelmerk.server.registration.{PartitionAndReplica, RegistrationClient, RegistrationServerFactory} +import com.github.jelmerk.registration.client.RegistrationClient +import com.github.jelmerk.registration.server.{PartitionAndReplica, RegistrationServerFactory} +import com.github.jelmerk.serving.client.IndexClientFactory +import com.github.jelmerk.serving.server.IndexServerFactory import com.github.jelmerk.spark.util.SerializableConfiguration import org.apache.hadoop.fs.Path import org.apache.spark.{Partitioner, SparkContext, TaskContext} @@ -43,20 +45,6 @@ private[knn] case class ModelMetaData( paramMap: Map[String, Any] ) -/** Neighbor of an item. - * - * @param neighbor - * identifies the neighbor - * @param distance - * distance to the item - * - * @tparam TId - * type of the index item identifier - * @tparam TDistance - * type of distance - */ -private[knn] case class Neighbor[TId, TDistance](neighbor: TId, distance: TDistance) - private[knn] trait IndexType { /** Type of index. */ @@ -430,7 +418,7 @@ object KnnModelReader { * @tparam TModel * type of model */ -private[knn] abstract class KnnModelReader[TModel <: KnnModelBase[TModel]: ClassTag] +private[knn] abstract class KnnModelReader[TModel <: KnnModelBase[TModel]] extends MLReader[TModel] with IndexLoader with IndexServing @@ -501,7 +489,7 @@ private[knn] abstract class KnnModelReader[TModel <: KnnModelBase[TModel]: Class val indexRdd = sc .makeRDD(partitionPaths) - .partitionBy(new PartitionIdPassthrough(metadata.numPartitions)) + .partitionBy(new PartitionIdPartitioner(metadata.numPartitions)) .withResources(profile) .mapPartitions { it => val (partitionId, indexPath) = it.next() @@ -791,7 +779,7 @@ private[knn] abstract class KnnAlgorithm[TModel <: KnnModelBase[TModel]](overrid ) .as[(Int, TItem)] .rdd - .partitionBy(new PartitionIdPassthrough(getNumPartitions)) + .partitionBy(new PartitionIdPartitioner(getNumPartitions)) .values else dataset @@ -886,17 +874,23 @@ private[knn] abstract class KnnAlgorithm[TModel <: KnnModelBase[TModel]](overrid } -/** Partitioner that uses precomputed partitions +/** Partitioner that uses precomputed partitions. Each partition id is its own partition * * @param numPartitions * number of partitions */ -private[knn] class PartitionIdPassthrough(override val numPartitions: Int) extends Partitioner { +private[knn] class PartitionIdPartitioner(override val numPartitions: Int) extends Partitioner { override def getPartition(key: Any): Int = key.asInstanceOf[Int] } -// TODO rename this -private[knn] class PartitionReplicaPartitioner(partitions: Int, replicas: Int) extends Partitioner { +/** Partitioner that uses precomputed partitions. Each unique partition and replica combination is own partition + * + * @param partitions + * the total number of partitions + * @param replicas + * the total number of replicas + */ +private[knn] class PartitionReplicaIdPartitioner(partitions: Int, replicas: Int) extends Partitioner { override def numPartitions: Int = partitions + (replicas * partitions) @@ -928,7 +922,9 @@ private[knn] trait ModelLogging extends Logging { protected def logInfo(partition: Int, message: String): Unit = println(f"partition $partition%04d: $message") // protected def logInfo(partition: Int, replica: Int, message: String): Unit = logInfo(f"partition $partition%04d replica $partition%04d: $message") - protected def logInfo(partition: Int, replica: Int, message: String): Unit = println(f"partition $partition%04d replica $partition%04d: $message") + protected def logInfo(partition: Int, replica: Int, message: String): Unit = println( + f"partition $partition%04d replica $partition%04d: $message" + ) } private[knn] trait IndexServing extends ModelLogging with IndexType { @@ -956,7 +952,7 @@ private[knn] trait IndexServing extends ModelLogging with IndexType { } val replicaRdd = - if (numReplicas > 0) keyedIndexRdd.partitionBy(new PartitionReplicaPartitioner(numPartitions, numReplicas)) + if (numReplicas > 0) keyedIndexRdd.partitionBy(new PartitionReplicaIdPartitioner(numPartitions, numReplicas)) else keyedIndexRdd val server = RegistrationServerFactory.create(numPartitions, numReplicas) @@ -974,32 +970,35 @@ private[knn] trait IndexServing extends ModelLogging with IndexType { val serverAddress = server.address - logInfo(partitionNum, replicaNum, s"started index server on host ${serverAddress.getHostName}:${serverAddress.getPort}") - logInfo( partitionNum, replicaNum, - s"registering partition at ${registrationAddress.getHostName}:${registrationAddress.getPort}" + s"started index server on host ${serverAddress.getHostName}:${serverAddress.getPort}" ) + try { + logInfo( + partitionNum, + replicaNum, + s"registering partition at ${registrationAddress.getHostName}:${registrationAddress.getPort}" + ) - RegistrationClient.register(registrationAddress, partitionNum, replicaNum, serverAddress) + RegistrationClient.register(registrationAddress, partitionNum, replicaNum, serverAddress) - logInfo(partitionNum, replicaNum, "awaiting requests") - try { + logInfo(partitionNum, replicaNum, "awaiting requests") server.awaitTermination() } finally { - server.shutdownNow() + server.shutdown() } true } // the count will never complete because the tasks start the index server - new Thread(new IndexRunnable(uid, sparkContext, serverRdd)).start() + new Thread(new IndexRunnable(uid, sparkContext, serverRdd), s"knn-index-thread-$uid").start() server.awaitRegistrations() } finally { - server.shutdownNow() + server.shutdown() } } diff --git a/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/QueryIterator.scala b/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/QueryIterator.scala index 0279cca..b8ede92 100644 --- a/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/QueryIterator.scala +++ b/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/QueryIterator.scala @@ -2,7 +2,8 @@ package com.github.jelmerk.spark.knn import java.net.InetSocketAddress -import com.github.jelmerk.server.registration.PartitionAndReplica +import com.github.jelmerk.registration.server.PartitionAndReplica +import com.github.jelmerk.serving.client.{IndexClientFactory, Neighbor} class QueryIterator[TId, TVector, TDistance, TQueryId]( indices: Map[PartitionAndReplica, InetSocketAddress], diff --git a/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/bruteforce/BruteForceSimilarity.scala b/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/bruteforce/BruteForceSimilarity.scala index 94703f7..95f0f7f 100644 --- a/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/bruteforce/BruteForceSimilarity.scala +++ b/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/bruteforce/BruteForceSimilarity.scala @@ -9,7 +9,8 @@ import scala.reflect.runtime.universe._ import com.github.jelmerk.knn.ObjectSerializer import com.github.jelmerk.knn.scalalike.{DistanceFunction, Item} import com.github.jelmerk.knn.scalalike.bruteforce.BruteForceIndex -import com.github.jelmerk.server.registration.PartitionAndReplica +import com.github.jelmerk.registration.server.PartitionAndReplica +import com.github.jelmerk.serving.client.IndexClientFactory import com.github.jelmerk.spark.knn._ import org.apache.spark.SparkContext import org.apache.spark.ml.param._ diff --git a/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/hnsw/HnswSimilarity.scala b/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/hnsw/HnswSimilarity.scala index 0851fdf..aa3e8e3 100644 --- a/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/hnsw/HnswSimilarity.scala +++ b/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/hnsw/HnswSimilarity.scala @@ -9,7 +9,8 @@ import scala.reflect.runtime.universe._ import com.github.jelmerk.knn import com.github.jelmerk.knn.scalalike.{DistanceFunction, Item} import com.github.jelmerk.knn.scalalike.hnsw._ -import com.github.jelmerk.server.registration.PartitionAndReplica +import com.github.jelmerk.registration.server.PartitionAndReplica +import com.github.jelmerk.serving.client.IndexClientFactory import com.github.jelmerk.spark.knn._ import org.apache.spark.SparkContext import org.apache.spark.ml.param._ diff --git a/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/knn.scala b/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/knn.scala index 9bcc884..73348b5 100644 --- a/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/knn.scala +++ b/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/knn.scala @@ -37,11 +37,12 @@ import com.github.jelmerk.server.index.{ DenseVector, DoubleArrayVector, FloatArrayVector, - IndexServerFactory, Result, SearchRequest, SparseVector } +import com.github.jelmerk.serving.client.IndexClientFactory +import com.github.jelmerk.serving.server.IndexServerFactory import com.github.jelmerk.spark.linalg.functions.VectorDistanceFunctions import org.apache.spark.ml.linalg.{DenseVector => SparkDenseVector, SparseVector => SparkSparseVector, Vector, Vectors} diff --git a/hnswlib-spark/src/test/scala/ServerTest.scala b/hnswlib-spark/src/test/scala/ServerTest.scala index 24708ef..340844a 100644 --- a/hnswlib-spark/src/test/scala/ServerTest.scala +++ b/hnswlib-spark/src/test/scala/ServerTest.scala @@ -4,7 +4,8 @@ import scala.concurrent.ExecutionContext import com.github.jelmerk.knn.scalalike.{floatCosineDistance, Item} import com.github.jelmerk.knn.scalalike.hnsw.HnswIndex -import com.github.jelmerk.server.index.{DefaultIndexService, IndexServiceGrpc} +import com.github.jelmerk.server.index.IndexServiceGrpc +import com.github.jelmerk.serving.server.DefaultIndexService import io.grpc.netty.NettyServerBuilder import org.apache.hadoop.conf.Configuration