Skip to content

Commit

Permalink
More fixing of tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Jelmer Kuperus committed Jan 1, 2025
1 parent df0406f commit 0fddd99
Show file tree
Hide file tree
Showing 10 changed files with 76 additions and 288 deletions.
4 changes: 3 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ ThisBuild / dynverSonatypeSnapshots := true

ThisBuild / versionScheme := Some("early-semver")

ThisBuild / resolvers += "Local Maven Repository" at "file://" + Path.userHome.absolutePath + "/.m2/repository"

lazy val publishSettings = Seq(
pomIncludeRepository := { _ => false },

Expand Down Expand Up @@ -42,7 +44,7 @@ lazy val publishSettings = Seq(
lazy val noPublishSettings =
publish / skip := true

val hnswLibVersion = "1.1.2"
val hnswLibVersion = "1.1.2+4-3fc68540+20241231-1737-SNAPSHOT"
val sparkVersion = settingKey[String]("Spark version")
val venvFolder = settingKey[String]("Venv folder")
val pythonVersion = settingKey[String]("Python version")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ class IndexServer[TId, TVector, TItem <: Item[TId, TVector] with Product, TDista
server.awaitTermination()
}

def isTerminated(): Boolean = {
server.isTerminated
}

def shutdown(): Unit = {
Try(server.shutdown())
Try(executor.shutdown())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ private[knn] trait IndexCreator extends IndexType {
* @tparam TDistance
* type of distance between items
* @return
* create an index
* the created index
*/
protected def createIndex[
TId,
Expand All @@ -94,6 +94,26 @@ private[knn] trait IndexCreator extends IndexType {
idSerializer: ObjectSerializer[TId],
itemSerializer: ObjectSerializer[TItem]
): TIndex[TId, TVector, TItem, TDistance]

/** Create an immutable empty index.
*
* @tparam TId
* type of the index item identifier
* @tparam TVector
* type of the index item vector
* @tparam TItem
* type of the index item
* @tparam TDistance
* type of distance between items
* @return
* the created index
*/
protected def emptyIndex[
TId,
TVector,
TItem <: Item[TId, TVector] with Product,
TDistance
]: TIndex[TId, TVector, TItem, TDistance]
}

private[knn] trait IndexLoader extends IndexType {
Expand Down Expand Up @@ -490,7 +510,6 @@ private[knn] abstract class KnnModelReader[TModel <: KnnModelBase[TModel]]
val indexRdd = sc
.makeRDD(partitionPaths)
.partitionBy(new PartitionIdPartitioner(metadata.numPartitions))
.withResources(profile)
.mapPartitions { it =>
val (partitionId, indexPath) = it.next()
val fs = indexPath.getFileSystem(serializableConfiguration.value)
Expand All @@ -501,6 +520,7 @@ private[knn] abstract class KnnModelReader[TModel <: KnnModelBase[TModel]]
logInfo(partitionId, s"Finished loading index from $indexPath")
Iterator(index)
}
.withResources(profile)

val servers = serve(metadata.uid, indexRdd, metadata.numPartitions, metadata.numReplicas, metadata.numThreads)

Expand Down Expand Up @@ -804,7 +824,6 @@ private[knn] abstract class KnnAlgorithm[TModel <: KnnModelBase[TModel]](overrid
val profile = new ResourceProfileBuilder().require(taskReqs).build()

val indexRdd = partitionedIndexItems
.withResources(profile) // TODO JELMER is this the correct place ?
.mapPartitions(
(it: Iterator[TItem]) => {

Expand All @@ -831,12 +850,15 @@ private[knn] abstract class KnnAlgorithm[TModel <: KnnModelBase[TModel]](overrid
logInfo(partitionId, s"started indexing ${items.size} items on host ${InetAddress.getLocalHost.getHostName}")

val index = existingIndexOption
.filter(_.nonEmpty)
.getOrElse(
createIndex[TId, TVector, TItem, TDistance](
items.head.dimensions,
items.size,
distanceFunctionFactory(getDistanceFunction)
)
items.headOption.fold(emptyIndex[TId, TVector, TItem, TDistance]) { item =>
createIndex[TId, TVector, TItem, TDistance](
item.dimensions,
items.size,
distanceFunctionFactory(getDistanceFunction)
)
}
)

index.addAll(
Expand All @@ -852,12 +874,15 @@ private[knn] abstract class KnnAlgorithm[TModel <: KnnModelBase[TModel]](overrid
},
preservesPartitioning = true
)
.withResources(profile)

val modelUid = uid + "_" + System.currentTimeMillis().toString

val registrations =
serve[TId, TVector, TItem, TDistance](modelUid, indexRdd, getNumPartitions, getNumReplicas, numThreads)

println(registrations)

createModel[TId, TVector, TItem, TDistance](
modelUid,
getNumPartitions,
Expand Down Expand Up @@ -904,8 +929,7 @@ private[knn] class IndexRunnable(uid: String, sparkContext: SparkContext, indexR
try {
indexRdd.count()
} catch {
case NonFatal(_) =>
// todo should i log this or not ? it produces a lot of noise
case NonFatal(e) =>
()
} finally {
sparkContext.clearJobGroup()
Expand Down Expand Up @@ -972,23 +996,33 @@ private[knn] trait IndexServing extends ModelLogging with IndexType {
replicaNum,
s"started index server on host ${serverAddress.getHostName}:${serverAddress.getPort}"
)

logInfo(
partitionNum,
replicaNum,
s"registering partition at ${registrationAddress.getHostName}:${registrationAddress.getPort}"
)
try {
RegistrationClient.register(registrationAddress, partitionNum, replicaNum, serverAddress)

logInfo(partitionNum, replicaNum, "awaiting requests")

while (!TaskContext.get().isInterrupted() && !server.isTerminated()) {
Thread.sleep(500)
}

logInfo(
partitionNum,
replicaNum,
s"registering partition at ${registrationAddress.getHostName}:${registrationAddress.getPort}"
s"Task canceled"
)

RegistrationClient.register(registrationAddress, partitionNum, replicaNum, serverAddress)

logInfo(partitionNum, replicaNum, "awaiting requests")
server.awaitTermination()
} finally {
server.shutdown()
}

true
}
.withResources(indexRdd.getResourceProfile())

// the count will never complete because the tasks start the index server
new Thread(new IndexRunnable(uid, sparkContext, serverRdd), s"knn-index-thread-$uid").start()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,4 +157,7 @@ class BruteForceSimilarity(override val uid: String)
): BruteForceIndex[TId, TVector, TItem, TDistance] =
BruteForceIndex[TId, TVector, TItem, TDistance](dimensions, distanceFunction)

override protected def emptyIndex[TId, TVector, TItem <: Item[TId, TVector] with Product, TDistance]
: BruteForceIndex[TId, TVector, TItem, TDistance] =
BruteForceIndex.empty[TId, TVector, TItem, TDistance]
}
Original file line number Diff line number Diff line change
Expand Up @@ -232,4 +232,8 @@ class HnswSimilarity(override val uid: String)
idSerializer,
itemSerializer
)

override protected def emptyIndex[TId, TVector, TItem <: Item[TId, TVector] with Product, TDistance]
: HnswIndex[TId, TVector, TItem, TDistance] =
HnswIndex.empty[TId, TVector, TItem, TDistance]
}
118 changes: 0 additions & 118 deletions hnswlib-spark/src/test/scala/ClientTest.scala

This file was deleted.

33 changes: 0 additions & 33 deletions hnswlib-spark/src/test/scala/RegistrationClientTest.scala

This file was deleted.

Loading

0 comments on commit 0fddd99

Please sign in to comment.