diff --git a/build.sbt b/build.sbt index f42ee30..ce4cd2c 100644 --- a/build.sbt +++ b/build.sbt @@ -12,6 +12,8 @@ ThisBuild / dynverSonatypeSnapshots := true ThisBuild / versionScheme := Some("early-semver") +ThisBuild / resolvers += "Local Maven Repository" at "file://" + Path.userHome.absolutePath + "/.m2/repository" + lazy val publishSettings = Seq( pomIncludeRepository := { _ => false }, @@ -42,7 +44,7 @@ lazy val publishSettings = Seq( lazy val noPublishSettings = publish / skip := true -val hnswLibVersion = "1.1.2" +val hnswLibVersion = "1.1.2+4-3fc68540+20241231-1737-SNAPSHOT" val sparkVersion = settingKey[String]("Spark version") val venvFolder = settingKey[String]("Venv folder") val pythonVersion = settingKey[String]("Python version") diff --git a/hnswlib-spark/src/main/scala/com/github/jelmerk/serving/server/IndexServerFactory.scala b/hnswlib-spark/src/main/scala/com/github/jelmerk/serving/server/IndexServerFactory.scala index 52b074b..cc1f474 100644 --- a/hnswlib-spark/src/main/scala/com/github/jelmerk/serving/server/IndexServerFactory.scala +++ b/hnswlib-spark/src/main/scala/com/github/jelmerk/serving/server/IndexServerFactory.scala @@ -47,6 +47,10 @@ class IndexServer[TId, TVector, TItem <: Item[TId, TVector] with Product, TDista server.awaitTermination() } + def isTerminated(): Boolean = { + server.isTerminated + } + def shutdown(): Unit = { Try(server.shutdown()) Try(executor.shutdown()) diff --git a/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/KnnAlgorithm.scala b/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/KnnAlgorithm.scala index 4384c14..5617a76 100644 --- a/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/KnnAlgorithm.scala +++ b/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/KnnAlgorithm.scala @@ -82,7 +82,7 @@ private[knn] trait IndexCreator extends IndexType { * @tparam TDistance * type of distance between items * @return - * create an index + * the created index */ protected def createIndex[ TId, @@ -94,6 +94,26 @@ private[knn] trait IndexCreator extends IndexType { idSerializer: ObjectSerializer[TId], itemSerializer: ObjectSerializer[TItem] ): TIndex[TId, TVector, TItem, TDistance] + + /** Create an immutable empty index. + * + * @tparam TId + * type of the index item identifier + * @tparam TVector + * type of the index item vector + * @tparam TItem + * type of the index item + * @tparam TDistance + * type of distance between items + * @return + * the created index + */ + protected def emptyIndex[ + TId, + TVector, + TItem <: Item[TId, TVector] with Product, + TDistance + ]: TIndex[TId, TVector, TItem, TDistance] } private[knn] trait IndexLoader extends IndexType { @@ -490,7 +510,6 @@ private[knn] abstract class KnnModelReader[TModel <: KnnModelBase[TModel]] val indexRdd = sc .makeRDD(partitionPaths) .partitionBy(new PartitionIdPartitioner(metadata.numPartitions)) - .withResources(profile) .mapPartitions { it => val (partitionId, indexPath) = it.next() val fs = indexPath.getFileSystem(serializableConfiguration.value) @@ -501,6 +520,7 @@ private[knn] abstract class KnnModelReader[TModel <: KnnModelBase[TModel]] logInfo(partitionId, s"Finished loading index from $indexPath") Iterator(index) } + .withResources(profile) val servers = serve(metadata.uid, indexRdd, metadata.numPartitions, metadata.numReplicas, metadata.numThreads) @@ -804,7 +824,6 @@ private[knn] abstract class KnnAlgorithm[TModel <: KnnModelBase[TModel]](overrid val profile = new ResourceProfileBuilder().require(taskReqs).build() val indexRdd = partitionedIndexItems - .withResources(profile) // TODO JELMER is this the correct place ? .mapPartitions( (it: Iterator[TItem]) => { @@ -831,12 +850,15 @@ private[knn] abstract class KnnAlgorithm[TModel <: KnnModelBase[TModel]](overrid logInfo(partitionId, s"started indexing ${items.size} items on host ${InetAddress.getLocalHost.getHostName}") val index = existingIndexOption + .filter(_.nonEmpty) .getOrElse( - createIndex[TId, TVector, TItem, TDistance]( - items.head.dimensions, - items.size, - distanceFunctionFactory(getDistanceFunction) - ) + items.headOption.fold(emptyIndex[TId, TVector, TItem, TDistance]) { item => + createIndex[TId, TVector, TItem, TDistance]( + item.dimensions, + items.size, + distanceFunctionFactory(getDistanceFunction) + ) + } ) index.addAll( @@ -852,12 +874,15 @@ private[knn] abstract class KnnAlgorithm[TModel <: KnnModelBase[TModel]](overrid }, preservesPartitioning = true ) + .withResources(profile) val modelUid = uid + "_" + System.currentTimeMillis().toString val registrations = serve[TId, TVector, TItem, TDistance](modelUid, indexRdd, getNumPartitions, getNumReplicas, numThreads) + println(registrations) + createModel[TId, TVector, TItem, TDistance]( modelUid, getNumPartitions, @@ -904,8 +929,7 @@ private[knn] class IndexRunnable(uid: String, sparkContext: SparkContext, indexR try { indexRdd.count() } catch { - case NonFatal(_) => - // todo should i log this or not ? it produces a lot of noise + case NonFatal(e) => () } finally { sparkContext.clearJobGroup() @@ -972,23 +996,33 @@ private[knn] trait IndexServing extends ModelLogging with IndexType { replicaNum, s"started index server on host ${serverAddress.getHostName}:${serverAddress.getPort}" ) + + logInfo( + partitionNum, + replicaNum, + s"registering partition at ${registrationAddress.getHostName}:${registrationAddress.getPort}" + ) try { + RegistrationClient.register(registrationAddress, partitionNum, replicaNum, serverAddress) + + logInfo(partitionNum, replicaNum, "awaiting requests") + + while (!TaskContext.get().isInterrupted() && !server.isTerminated()) { + Thread.sleep(500) + } + logInfo( partitionNum, replicaNum, - s"registering partition at ${registrationAddress.getHostName}:${registrationAddress.getPort}" + s"Task canceled" ) - - RegistrationClient.register(registrationAddress, partitionNum, replicaNum, serverAddress) - - logInfo(partitionNum, replicaNum, "awaiting requests") - server.awaitTermination() } finally { server.shutdown() } true } + .withResources(indexRdd.getResourceProfile()) // the count will never complete because the tasks start the index server new Thread(new IndexRunnable(uid, sparkContext, serverRdd), s"knn-index-thread-$uid").start() diff --git a/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/bruteforce/BruteForceSimilarity.scala b/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/bruteforce/BruteForceSimilarity.scala index 95f0f7f..6a255b9 100644 --- a/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/bruteforce/BruteForceSimilarity.scala +++ b/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/bruteforce/BruteForceSimilarity.scala @@ -157,4 +157,7 @@ class BruteForceSimilarity(override val uid: String) ): BruteForceIndex[TId, TVector, TItem, TDistance] = BruteForceIndex[TId, TVector, TItem, TDistance](dimensions, distanceFunction) + override protected def emptyIndex[TId, TVector, TItem <: Item[TId, TVector] with Product, TDistance] + : BruteForceIndex[TId, TVector, TItem, TDistance] = + BruteForceIndex.empty[TId, TVector, TItem, TDistance] } diff --git a/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/hnsw/HnswSimilarity.scala b/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/hnsw/HnswSimilarity.scala index aa3e8e3..1d8e2ae 100644 --- a/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/hnsw/HnswSimilarity.scala +++ b/hnswlib-spark/src/main/scala/com/github/jelmerk/spark/knn/hnsw/HnswSimilarity.scala @@ -232,4 +232,8 @@ class HnswSimilarity(override val uid: String) idSerializer, itemSerializer ) + + override protected def emptyIndex[TId, TVector, TItem <: Item[TId, TVector] with Product, TDistance] + : HnswIndex[TId, TVector, TItem, TDistance] = + HnswIndex.empty[TId, TVector, TItem, TDistance] } diff --git a/hnswlib-spark/src/test/scala/ClientTest.scala b/hnswlib-spark/src/test/scala/ClientTest.scala deleted file mode 100644 index 9299af7..0000000 --- a/hnswlib-spark/src/test/scala/ClientTest.scala +++ /dev/null @@ -1,118 +0,0 @@ -import java.util.concurrent.{Executors, LinkedBlockingQueue} -import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger} - -import com.github.jelmerk.server.index.{FloatArrayVector, IndexServiceGrpc, SearchRequest, SearchResponse} -import io.grpc.netty.NettyChannelBuilder -import io.grpc.stub.StreamObserver - -object ClientTest { - - def main(args: Array[String]): Unit = { - - println("start") - val clients = Range.inclusive(1, 4).map { shard => - val channel = NettyChannelBuilder - .forAddress("127.0.0.1", 8080) - .usePlaintext - .build() - - IndexServiceGrpc.stub(channel) - - } - - println("created clients") - - val threadPool = Executors.newFixedThreadPool(1) - - val requests = Iterator.fill(100)( - SearchRequest( - vector = SearchRequest.Vector.FloatArrayVector(FloatArrayVector(Array(0.1f, 0.2f))), - k = 10 - ) - ) - - val batches = requests.grouped(10) - - val abc = batches - .map { batch => - val (requestObservers, responseIterators) = clients.map { client => - val responseStreamObserver = new MyStreamObserver[SearchResponse](batch.size) - val requestStreamObserver = client.search(responseStreamObserver) - - (requestStreamObserver, responseStreamObserver) - }.unzip - - threadPool.submit(new Runnable { - override def run(): Unit = { - val batchIter = batch.iterator - for { - request <- batchIter - last = !batchIter.hasNext - observer <- requestObservers - } { - observer.onNext(request) - if (last) { - observer.onCompleted() - } - } - } - }) - - val compoundIterator = new CompoundIterator(responseIterators) - - compoundIterator - .map { responses => - // TODO harcoded to float distance now - val results = responses.flatMap(_.results).sortBy(_.distance.floatDistance).take(10) - SearchResponse(results) - } - - } - .reduce((a, b) => a ++ b) - - abc.foreach(println) - - threadPool.shutdownNow() - -// abc.foreach(println) - - println("done.") - System.in.read() - } -} - -class MyStreamObserver[T](expected: Int) extends StreamObserver[T] with Iterator[T] { - - private val queue = new LinkedBlockingQueue[Either[Throwable, T]] - private val counter = new AtomicInteger() - private val done = new AtomicBoolean(false) - - override def onNext(value: T): Unit = { - queue.add(Right(value)) - counter.incrementAndGet() - } - - override def onError(t: Throwable): Unit = { - queue.add(Left(t)) - done.set(true) - } - - override def onCompleted(): Unit = { - done.set(true) - } - - override def hasNext: Boolean = { - !queue.isEmpty || (counter.get() < expected && !done.get()) - } - - override def next(): T = queue.take() match { - case Right(value) => value - case Left(t) => throw t - } -} - -class CompoundIterator[T](iterators: Seq[Iterator[T]]) extends Iterator[Seq[T]] { - override def hasNext: Boolean = iterators.forall(_.hasNext) - - override def next(): Seq[T] = iterators.map(_.next()) -} diff --git a/hnswlib-spark/src/test/scala/RegistrationClientTest.scala b/hnswlib-spark/src/test/scala/RegistrationClientTest.scala deleted file mode 100644 index 433d85c..0000000 --- a/hnswlib-spark/src/test/scala/RegistrationClientTest.scala +++ /dev/null @@ -1,33 +0,0 @@ -import scala.concurrent.Await -import scala.concurrent.duration.Duration - -import com.github.jelmerk.server.registration.{RegisterRequest, RegistrationServiceGrpc} -import io.grpc.netty.NettyChannelBuilder - -object RegistrationClientTest { - - def main(args: Array[String]): Unit = { - - val channel = NettyChannelBuilder - .forAddress("127.0.0.1", 8000) - .usePlaintext - .build() - - val client = RegistrationServiceGrpc.stub(channel) - - (1 to 4).foreach { partitionNo => - val response = client.register( - RegisterRequest( - partitionNum = partitionNo, - replicaNum = 1, - "localhost", - 123 - ) - ) - - Await.result(response, Duration.Inf) - - } - - } -} diff --git a/hnswlib-spark/src/test/scala/RegistrationServerTest.scala b/hnswlib-spark/src/test/scala/RegistrationServerTest.scala deleted file mode 100644 index a96355f..0000000 --- a/hnswlib-spark/src/test/scala/RegistrationServerTest.scala +++ /dev/null @@ -1,59 +0,0 @@ -//import com.github.jelmerk.server.registration.{RegisterRequest, RegisterResponse, RegistrationServiceGrpc} -//import com.github.jelmerk.server.registration.RegistrationServiceGrpc.RegistrationService -//import io.grpc.netty.NettyServerBuilder -// -//import java.net.InetSocketAddress -//import java.util.concurrent.{ConcurrentHashMap, CountDownLatch, Executors} -//import scala.concurrent.{ExecutionContext, Future} -//import scala.collection.concurrent -//import scala.collection.JavaConverters._ -// -//object RegistrationServerTest { -// -// -// private class DefaultRegistrationService(registrationLatch: CountDownLatch) extends RegistrationService { -// -// val registrations = new ConcurrentHashMap[Int, InetSocketAddress]() -// -// override def register(request: RegisterRequest): Future[RegisterResponse] = { -// -// val previousValue = registrations.put(request.partitionNum, new InetSocketAddress(request.host, request.port)) -// -// if (previousValue == null) { -// registrationLatch.countDown() -// } -// -// Future.successful(RegisterResponse()) -// } -// -// } -// -// def startRegistrationServerAndAwaitRegistrations(numPartitions: Int, host: String, port: Int): Map[PartitionAndReplica, InetSocketAddress] = { -// val registrationLatch = new CountDownLatch(numPartitions) -// -// val executor = Executors.newSingleThreadExecutor() -// -// val executionContext: ExecutionContext = ExecutionContext.fromExecutor(executor) -// -// val service = new DefaultRegistrationService(registrationLatch) -// // Build the gRPC server -// val server = NettyServerBuilder -// .forAddress(new InetSocketAddress(host, port)) -// .addService(RegistrationServiceGrpc.bindService(service, executionContext)) -// .build() -// -// server.start() -// -// registrationLatch.await() -// -// server.shutdownNow() -// executor.shutdownNow() -// -// service.registrations.asScala -// } -// -// def main(args: Array[String]): Unit = { -// val registrations = startRegistrationServerAndAwaitRegistrations(4, "localhost", 8000) -// println(registrations) -// } -//} diff --git a/hnswlib-spark/src/test/scala/ServerTest.scala b/hnswlib-spark/src/test/scala/ServerTest.scala deleted file mode 100644 index 340844a..0000000 --- a/hnswlib-spark/src/test/scala/ServerTest.scala +++ /dev/null @@ -1,56 +0,0 @@ -import java.util.concurrent.{LinkedBlockingQueue, ThreadPoolExecutor, TimeUnit} - -import scala.concurrent.ExecutionContext - -import com.github.jelmerk.knn.scalalike.{floatCosineDistance, Item} -import com.github.jelmerk.knn.scalalike.hnsw.HnswIndex -import com.github.jelmerk.server.index.IndexServiceGrpc -import com.github.jelmerk.serving.server.DefaultIndexService -import io.grpc.netty.NettyServerBuilder -import org.apache.hadoop.conf.Configuration - -case class StringFloatArrayIndexItem(id: String, vector: Array[Float]) extends Item[String, Array[Float]] { - override def dimensions: Int = vector.length -} - -object ServerTest { - - def main(args: Array[String]): Unit = { - -// val index = HnswIndex[String, Array[Float], StringFloatArrayIndexItem, Float]( -// dimensions = 2, -// distanceFunction = floatCosineDistance, -// maxItemCount = 1000 -// ) -// -// index.add(StringFloatArrayIndexItem("1", Array(0.1f, 0.8f))) -// index.add(StringFloatArrayIndexItem("2", Array(0.3f, 0.2f))) -// index.add(StringFloatArrayIndexItem("3", Array(0.1f, 0.5f))) -// -// import DefaultIndexService._ -// -// val hadoopConfig = new Configuration() -// -// val executor = new ThreadPoolExecutor( -// 4, 4, 0L, TimeUnit.MILLISECONDS, -// new LinkedBlockingQueue[Runnable]() -// ) -// -// val executionContext: ExecutionContext = ExecutionContext.fromExecutor(executor) -// -// implicit val ec: ExecutionContext = ExecutionContext.global -// val service = new DefaultIndexService(index, hadoopConfig) -// -// // Build the gRPC server -// val server = NettyServerBuilder -// .forPort(8080) -// .addService(IndexServiceGrpc.bindService(service, executionContext)) -// .build() -// -// server.start() -// -// server.awaitTermination() - - } - -} diff --git a/hnswlib-spark/src/test/scala/com/github/jelmerk/spark/knn/hnsw/HnswSimilaritySpec.scala b/hnswlib-spark/src/test/scala/com/github/jelmerk/spark/knn/hnsw/HnswSimilaritySpec.scala index 972e364..c064c56 100644 --- a/hnswlib-spark/src/test/scala/com/github/jelmerk/spark/knn/hnsw/HnswSimilaritySpec.scala +++ b/hnswlib-spark/src/test/scala/com/github/jelmerk/spark/knn/hnsw/HnswSimilaritySpec.scala @@ -38,11 +38,11 @@ case class MinimalOutputRow[TId, TDistance](id: TId, neighbors: Seq[Neighbor[TId class HnswSimilaritySpec extends AnyFunSuite with DataFrameSuiteBase { - // for some reason kryo cannot serialize the hnswindex so configure it to make sure it never gets serialized override def conf: SparkConf = super.conf .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") .set("spark.kryo.registrator", "com.github.jelmerk.spark.HnswLibKryoRegistrator") .set("spark.speculation", "false") + .set("spark.ui.enabled", "true") test("prepartitioned data") { @@ -79,6 +79,8 @@ class HnswSimilaritySpec extends AnyFunSuite with DataFrameSuiteBase { .head result.neighbors.size should be(2) // it couldn't see 3000000 because we only query partition 0 + + model.destroy() } test("find neighbors") { @@ -206,8 +208,8 @@ class HnswSimilaritySpec extends AnyFunSuite with DataFrameSuiteBase { .setIdentifierCol("id") .setQueryIdentifierCol("id") .setFeaturesCol("vector") - .setNumPartitions(5) - .setNumReplicas(3) + .setNumPartitions(2) + .setNumReplicas(1) .setNumThreads(1) .setK(10) .setExcludeSelf(excludeSelf) @@ -216,9 +218,14 @@ class HnswSimilaritySpec extends AnyFunSuite with DataFrameSuiteBase { val model = hnsw.fit(input).setPredictionCol("neighbors").setEf(10) - val result = model.transform(input) + try { + val result = model.transform(input) - validator(result) + validator(result) + } finally { + model.destroy() + Thread.sleep(5000) + } } }