Skip to content

Commit

Permalink
repackage
Browse files Browse the repository at this point in the history
  • Loading branch information
Jelmer Kuperus committed Dec 30, 2024
1 parent d2c2764 commit 2637200
Show file tree
Hide file tree
Showing 15 changed files with 94 additions and 76 deletions.
5 changes: 3 additions & 2 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,9 @@ lazy val hnswlibSpark = (project in file("hnswlib-spark"))
oldStrategy(x)
},
assembly / assemblyShadeRules := Seq(
ShadeRule.rename("com.google.protobuf.**" -> "shaded.com.google.protobuf.@1").inAll,
ShadeRule.rename("com.google.common.**" -> "shaded.com.google.common.@1").inAll,
ShadeRule.rename("google.**" -> "shaded.google.@1").inAll,
ShadeRule.rename("com.google.**" -> "shaded.com.google.@1").inAll,
ShadeRule.rename("io.grpc.**" -> "shaded.io.grpc.@1").inAll,
ShadeRule.rename("io.netty.**" -> "shaded.io.netty.@1").inAll
),
Compile / PB.targets := Seq(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
package com.github.jelmerk.server.registration
package com.github.jelmerk.registration.client

import java.net.{InetSocketAddress, SocketAddress}

import scala.concurrent.Await
import scala.concurrent.duration.Duration

import com.github.jelmerk.server.registration.{RegisterRequest, RegisterResponse, RegistrationServiceGrpc}
import io.grpc.netty.NettyChannelBuilder

object RegistrationClient {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
package com.github.jelmerk.server.registration
package com.github.jelmerk.registration.server

import java.net.InetSocketAddress
import java.util.concurrent.{ConcurrentHashMap, CountDownLatch}

import scala.concurrent.Future

import com.github.jelmerk.server.registration.{RegisterRequest, RegisterResponse}
import com.github.jelmerk.server.registration.RegistrationServiceGrpc.RegistrationService

class DefaultRegistrationService(val registrationLatch: CountDownLatch) extends RegistrationService {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.github.jelmerk.server.registration
package com.github.jelmerk.registration.server

import java.net.{InetAddress, InetSocketAddress}
import java.util.concurrent.{CountDownLatch, Executors}
Expand All @@ -7,9 +7,9 @@ import scala.concurrent.ExecutionContext
import scala.jdk.CollectionConverters.mapAsScalaConcurrentMapConverter
import scala.util.Try

import com.github.jelmerk.server.registration.RegistrationServiceGrpc
import io.grpc.netty.NettyServerBuilder

// TODO this is all a bit messy
class RegistrationServer(numPartitions: Int, numReplicas: Int) {

private val executor = Executors.newSingleThreadExecutor()
Expand All @@ -18,7 +18,7 @@ class RegistrationServer(numPartitions: Int, numReplicas: Int) {

private val registrationLatch = new CountDownLatch(numPartitions + (numReplicas * numPartitions))
private val service = new DefaultRegistrationService(registrationLatch)
// Build the gRPC server

private val server = NettyServerBuilder
.forAddress(new InetSocketAddress(InetAddress.getLocalHost, 0))
.addService(RegistrationServiceGrpc.bindService(service, executionContext))
Expand All @@ -33,9 +33,9 @@ class RegistrationServer(numPartitions: Int, numReplicas: Int) {
service.registrations.asScala.toMap
}

def shutdownNow(): Unit = {
Try(server.shutdownNow())
Try(executor.shutdownNow())
def shutdown(): Unit = {
Try(server.shutdown())
Try(executor.shutdown())
}

}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.github.jelmerk.spark.knn
package com.github.jelmerk.serving.client

import java.net.InetSocketAddress
import java.util.concurrent.{Executors, LinkedBlockingQueue}
Expand All @@ -10,8 +10,8 @@ import scala.concurrent.duration.Duration
import scala.language.implicitConversions
import scala.util.Random

import com.github.jelmerk.registration.server.PartitionAndReplica
import com.github.jelmerk.server.index._
import com.github.jelmerk.server.registration.PartitionAndReplica
import io.grpc.netty.NettyChannelBuilder
import io.grpc.stub.StreamObserver

Expand All @@ -37,7 +37,7 @@ class IndexClient[TId, TVector, TDistance](
private val partitionClients = grpcClients.toList
.sortBy { case (partitionAndReplica, _) => (partitionAndReplica.partitionNum, partitionAndReplica.replicaNum) }
.foldLeft(Map.empty[Int, Seq[IndexServiceGrpc.IndexServiceStub]]) {
case (acc, (PartitionAndReplica(partitionNum, replicaNum), client)) =>
case (acc, (PartitionAndReplica(partitionNum, _), client)) =>
val old = acc.getOrElse(partitionNum, Seq.empty[IndexServiceGrpc.IndexServiceStub])
acc.updated(partitionNum, old :+ client)
}
Expand Down Expand Up @@ -165,3 +165,16 @@ class IndexClient[TId, TVector, TDistance](
}

}

class IndexClientFactory[TId, TVector, TDistance](
vectorConverter: TVector => SearchRequest.Vector,
idExtractor: Result => TId,
distanceExtractor: Result => TDistance,
distanceOrdering: Ordering[TDistance]
) extends Serializable {

def create(servers: Map[PartitionAndReplica, InetSocketAddress]): IndexClient[TId, TVector, TDistance] = {
new IndexClient(servers, vectorConverter, idExtractor, distanceExtractor, distanceOrdering)
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package com.github.jelmerk.serving.client

/** Neighbor of an item.
*
* @param neighbor
* identifies the neighbor
* @param distance
* distance to the item
*
* @tparam TId
* type of the index item identifier
* @tparam TDistance
* type of distance
*/
case class Neighbor[TId, TDistance](neighbor: TId, distance: TDistance) // TODO move this
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
package com.github.jelmerk.server.index
package com.github.jelmerk.serving.server

import scala.concurrent.{ExecutionContext, Future}
import scala.language.implicitConversions

import com.github.jelmerk.knn.scalalike.{Index, Item}
import com.github.jelmerk.server.index._
import com.github.jelmerk.server.index.IndexServiceGrpc.IndexService
import io.grpc.stub.StreamObserver
import org.apache.commons.io.output.CountingOutputStream
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.github.jelmerk.server.index
package com.github.jelmerk.serving.server

import java.net.{InetAddress, InetSocketAddress}
import java.util.concurrent.{LinkedBlockingQueue, ThreadPoolExecutor, TimeUnit}
Expand All @@ -7,6 +7,7 @@ import scala.concurrent.ExecutionContext
import scala.util.Try

import com.github.jelmerk.knn.scalalike.{Index, Item}
import com.github.jelmerk.server.index.{IndexServiceGrpc, Result, SearchRequest}
import io.grpc.netty.NettyServerBuilder
import org.apache.hadoop.conf.Configuration

Expand Down Expand Up @@ -46,9 +47,9 @@ class IndexServer[TId, TVector, TItem <: Item[TId, TVector] with Product, TDista
server.awaitTermination()
}

def shutdownNow(): Unit = {
Try(server.shutdownNow())
Try(executor.shutdownNow())
def shutdown(): Unit = {
Try(server.shutdown())
Try(executor.shutdown())
}
}

Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@ import scala.util.control.NonFatal

import com.github.jelmerk.knn.ObjectSerializer
import com.github.jelmerk.knn.scalalike._
import com.github.jelmerk.server.index.IndexServerFactory
import com.github.jelmerk.server.registration.{PartitionAndReplica, RegistrationClient, RegistrationServerFactory}
import com.github.jelmerk.registration.client.RegistrationClient
import com.github.jelmerk.registration.server.{PartitionAndReplica, RegistrationServerFactory}
import com.github.jelmerk.serving.client.IndexClientFactory
import com.github.jelmerk.serving.server.IndexServerFactory
import com.github.jelmerk.spark.util.SerializableConfiguration
import org.apache.hadoop.fs.Path
import org.apache.spark.{Partitioner, SparkContext, TaskContext}
Expand Down Expand Up @@ -43,20 +45,6 @@ private[knn] case class ModelMetaData(
paramMap: Map[String, Any]
)

/** Neighbor of an item.
*
* @param neighbor
* identifies the neighbor
* @param distance
* distance to the item
*
* @tparam TId
* type of the index item identifier
* @tparam TDistance
* type of distance
*/
private[knn] case class Neighbor[TId, TDistance](neighbor: TId, distance: TDistance)

private[knn] trait IndexType {

/** Type of index. */
Expand Down Expand Up @@ -430,7 +418,7 @@ object KnnModelReader {
* @tparam TModel
* type of model
*/
private[knn] abstract class KnnModelReader[TModel <: KnnModelBase[TModel]: ClassTag]
private[knn] abstract class KnnModelReader[TModel <: KnnModelBase[TModel]]
extends MLReader[TModel]
with IndexLoader
with IndexServing
Expand Down Expand Up @@ -501,7 +489,7 @@ private[knn] abstract class KnnModelReader[TModel <: KnnModelBase[TModel]: Class

val indexRdd = sc
.makeRDD(partitionPaths)
.partitionBy(new PartitionIdPassthrough(metadata.numPartitions))
.partitionBy(new PartitionIdPartitioner(metadata.numPartitions))
.withResources(profile)
.mapPartitions { it =>
val (partitionId, indexPath) = it.next()
Expand Down Expand Up @@ -791,7 +779,7 @@ private[knn] abstract class KnnAlgorithm[TModel <: KnnModelBase[TModel]](overrid
)
.as[(Int, TItem)]
.rdd
.partitionBy(new PartitionIdPassthrough(getNumPartitions))
.partitionBy(new PartitionIdPartitioner(getNumPartitions))
.values
else
dataset
Expand Down Expand Up @@ -886,17 +874,23 @@ private[knn] abstract class KnnAlgorithm[TModel <: KnnModelBase[TModel]](overrid

}

/** Partitioner that uses precomputed partitions
/** Partitioner that uses precomputed partitions. Each partition id is its own partition
*
* @param numPartitions
* number of partitions
*/
private[knn] class PartitionIdPassthrough(override val numPartitions: Int) extends Partitioner {
private[knn] class PartitionIdPartitioner(override val numPartitions: Int) extends Partitioner {
override def getPartition(key: Any): Int = key.asInstanceOf[Int]
}

// TODO rename this
private[knn] class PartitionReplicaPartitioner(partitions: Int, replicas: Int) extends Partitioner {
/** Partitioner that uses precomputed partitions. Each unique partition and replica combination is own partition
*
* @param partitions
* the total number of partitions
* @param replicas
* the total number of replicas
*/
private[knn] class PartitionReplicaIdPartitioner(partitions: Int, replicas: Int) extends Partitioner {

override def numPartitions: Int = partitions + (replicas * partitions)

Expand Down Expand Up @@ -928,7 +922,9 @@ private[knn] trait ModelLogging extends Logging {
protected def logInfo(partition: Int, message: String): Unit = println(f"partition $partition%04d: $message")

// protected def logInfo(partition: Int, replica: Int, message: String): Unit = logInfo(f"partition $partition%04d replica $partition%04d: $message")
protected def logInfo(partition: Int, replica: Int, message: String): Unit = println(f"partition $partition%04d replica $partition%04d: $message")
protected def logInfo(partition: Int, replica: Int, message: String): Unit = println(
f"partition $partition%04d replica $partition%04d: $message"
)
}

private[knn] trait IndexServing extends ModelLogging with IndexType {
Expand Down Expand Up @@ -956,7 +952,7 @@ private[knn] trait IndexServing extends ModelLogging with IndexType {
}

val replicaRdd =
if (numReplicas > 0) keyedIndexRdd.partitionBy(new PartitionReplicaPartitioner(numPartitions, numReplicas))
if (numReplicas > 0) keyedIndexRdd.partitionBy(new PartitionReplicaIdPartitioner(numPartitions, numReplicas))
else keyedIndexRdd

val server = RegistrationServerFactory.create(numPartitions, numReplicas)
Expand All @@ -974,32 +970,35 @@ private[knn] trait IndexServing extends ModelLogging with IndexType {

val serverAddress = server.address

logInfo(partitionNum, replicaNum, s"started index server on host ${serverAddress.getHostName}:${serverAddress.getPort}")

logInfo(
partitionNum,
replicaNum,
s"registering partition at ${registrationAddress.getHostName}:${registrationAddress.getPort}"
s"started index server on host ${serverAddress.getHostName}:${serverAddress.getPort}"
)
try {
logInfo(
partitionNum,
replicaNum,
s"registering partition at ${registrationAddress.getHostName}:${registrationAddress.getPort}"
)

RegistrationClient.register(registrationAddress, partitionNum, replicaNum, serverAddress)
RegistrationClient.register(registrationAddress, partitionNum, replicaNum, serverAddress)

logInfo(partitionNum, replicaNum, "awaiting requests")
try {
logInfo(partitionNum, replicaNum, "awaiting requests")
server.awaitTermination()
} finally {
server.shutdownNow()
server.shutdown()
}

true
}

// the count will never complete because the tasks start the index server
new Thread(new IndexRunnable(uid, sparkContext, serverRdd)).start()
new Thread(new IndexRunnable(uid, sparkContext, serverRdd), s"knn-index-thread-$uid").start()

server.awaitRegistrations()
} finally {
server.shutdownNow()
server.shutdown()
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ package com.github.jelmerk.spark.knn

import java.net.InetSocketAddress

import com.github.jelmerk.server.registration.PartitionAndReplica
import com.github.jelmerk.registration.server.PartitionAndReplica
import com.github.jelmerk.serving.client.{IndexClientFactory, Neighbor}

class QueryIterator[TId, TVector, TDistance, TQueryId](
indices: Map[PartitionAndReplica, InetSocketAddress],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ import scala.reflect.runtime.universe._
import com.github.jelmerk.knn.ObjectSerializer
import com.github.jelmerk.knn.scalalike.{DistanceFunction, Item}
import com.github.jelmerk.knn.scalalike.bruteforce.BruteForceIndex
import com.github.jelmerk.server.registration.PartitionAndReplica
import com.github.jelmerk.registration.server.PartitionAndReplica
import com.github.jelmerk.serving.client.IndexClientFactory
import com.github.jelmerk.spark.knn._
import org.apache.spark.SparkContext
import org.apache.spark.ml.param._
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ import scala.reflect.runtime.universe._
import com.github.jelmerk.knn
import com.github.jelmerk.knn.scalalike.{DistanceFunction, Item}
import com.github.jelmerk.knn.scalalike.hnsw._
import com.github.jelmerk.server.registration.PartitionAndReplica
import com.github.jelmerk.registration.server.PartitionAndReplica
import com.github.jelmerk.serving.client.IndexClientFactory
import com.github.jelmerk.spark.knn._
import org.apache.spark.SparkContext
import org.apache.spark.ml.param._
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,12 @@ import com.github.jelmerk.server.index.{
DenseVector,
DoubleArrayVector,
FloatArrayVector,
IndexServerFactory,
Result,
SearchRequest,
SparseVector
}
import com.github.jelmerk.serving.client.IndexClientFactory
import com.github.jelmerk.serving.server.IndexServerFactory
import com.github.jelmerk.spark.linalg.functions.VectorDistanceFunctions
import org.apache.spark.ml.linalg.{DenseVector => SparkDenseVector, SparseVector => SparkSparseVector, Vector, Vectors}

Expand Down
Loading

0 comments on commit 2637200

Please sign in to comment.