Skip to content

Commit

Permalink
Redesign spark integration to take advantage of resource profiles.
Browse files Browse the repository at this point in the history
  • Loading branch information
Jelmer Kuperus committed Dec 30, 2024
1 parent 662691a commit d2c2764
Show file tree
Hide file tree
Showing 23 changed files with 1,795 additions and 764 deletions.
6 changes: 0 additions & 6 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,6 @@ jobs:
fail-fast: false
matrix:
spark:
- 2.4.8
- 3.0.2
- 3.1.3
- 3.2.4
- 3.3.2
- 3.4.1
- 3.5.0
env:
Expand All @@ -39,7 +34,6 @@ jobs:
- uses: actions/setup-python@v5
with:
python-version: |
3.7
3.9
- name: Build and test
run: |
Expand Down
6 changes: 0 additions & 6 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,6 @@ jobs:
fail-fast: false
matrix:
spark:
- 2.4.8
- 3.0.2
- 3.1.3
- 3.2.4
- 3.3.2
- 3.4.1
- 3.5.0

Expand All @@ -39,7 +34,6 @@ jobs:
- uses: actions/setup-python@v5
with:
python-version: |
3.7
3.9
- name: Import GPG Key
uses: crazy-max/ghaction-import-gpg@v6
Expand Down
48 changes: 30 additions & 18 deletions build.sbt
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import Path.relativeTo
import sys.process.*
import scalapb.compiler.Version.scalapbVersion
import scalapb.compiler.Version.grpcJavaVersion

ThisBuild / organization := "com.github.jelmerk"
ThisBuild / scalaVersion := "2.12.18"
Expand Down Expand Up @@ -59,15 +61,7 @@ lazy val hnswlibSpark = (project in file("hnswlib-spark"))
.settings(
name := s"hnswlib-spark_${sparkVersion.value.split('.').take(2).mkString("_")}",
publishSettings,
crossScalaVersions := {
if (sparkVersion.value >= "3.2.0") {
Seq("2.12.18", "2.13.10")
} else if (sparkVersion.value >= "3.0.0") {
Seq("2.12.18")
} else {
Seq("2.12.18", "2.11.12")
}
},
crossScalaVersions := Seq("2.12.18", "2.13.10"),
autoScalaLibrary := false,
Compile / unmanagedSourceDirectories += baseDirectory.value / "src" / "main" / "python",
Test / unmanagedSourceDirectories += baseDirectory.value / "src" / "test" / "python",
Expand All @@ -83,9 +77,24 @@ lazy val hnswlibSpark = (project in file("hnswlib-spark"))
assembly / assemblyOption ~= {
_.withIncludeScala(false)
},
sparkVersion := sys.props.getOrElse("sparkVersion", "3.3.2"),
assembly / assemblyMergeStrategy := {
case PathList("META-INF", "io.netty.versions.properties") => MergeStrategy.first
case x =>
val oldStrategy = (ThisBuild / assemblyMergeStrategy).value
oldStrategy(x)
},
assembly / assemblyShadeRules := Seq(
ShadeRule.rename("com.google.protobuf.**" -> "shaded.com.google.protobuf.@1").inAll,
ShadeRule.rename("com.google.common.**" -> "shaded.com.google.common.@1").inAll,
ShadeRule.rename("io.netty.**" -> "shaded.io.netty.@1").inAll
),
Compile / PB.targets := Seq(
scalapb.gen() -> (Compile / sourceManaged).value / "scalapb"
),
sparkVersion := sys.props.getOrElse("sparkVersion", "3.4.1"),
// sparkVersion := sys.props.getOrElse("sparkVersion", "3.5.3"),
venvFolder := s"${baseDirectory.value}/.venv",
pythonVersion := (if (scalaVersion.value == "2.11.12") "python3.7" else "python3.9"),
pythonVersion := "python3.9",
createVirtualEnv := {
val ret = (
s"${pythonVersion.value} -m venv ${venvFolder.value}" #&&
Expand Down Expand Up @@ -128,12 +137,15 @@ lazy val hnswlibSpark = (project in file("hnswlib-spark"))
},
flake8 := flake8.dependsOn(createVirtualEnv).value,
libraryDependencies ++= Seq(
"com.github.jelmerk" % "hnswlib-utils" % hnswLibVersion,
"com.github.jelmerk" % "hnswlib-core-jdk17" % hnswLibVersion,
"com.github.jelmerk" %% "hnswlib-scala" % hnswLibVersion,
"org.apache.spark" %% "spark-hive" % sparkVersion.value % Provided,
"org.apache.spark" %% "spark-mllib" % sparkVersion.value % Provided,
"com.holdenkarau" %% "spark-testing-base" % s"${sparkVersion.value}_1.4.7" % Test,
"org.scalatest" %% "scalatest" % "3.2.17" % Test
"com.github.jelmerk" % "hnswlib-utils" % hnswLibVersion,
"com.github.jelmerk" % "hnswlib-core-jdk17" % hnswLibVersion,
"com.github.jelmerk" %% "hnswlib-scala" % hnswLibVersion,
"com.thesamet.scalapb" %% "scalapb-runtime-grpc" % scalapbVersion,
"com.thesamet.scalapb" %% "scalapb-runtime" % scalapbVersion % "protobuf",
"io.grpc" % "grpc-netty" % grpcJavaVersion,
"org.apache.spark" %% "spark-hive" % sparkVersion.value % Provided,
"org.apache.spark" %% "spark-mllib" % sparkVersion.value % Provided,
"com.holdenkarau" %% "spark-testing-base" % s"${sparkVersion.value}_1.4.7" % Test,
"org.scalatest" %% "scalatest" % "3.2.17" % Test
)
)
65 changes: 65 additions & 0 deletions hnswlib-spark/src/main/protobuf/index.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
syntax = "proto3";

import "scalapb/scalapb.proto";

package com.github.jelmerk.server;

service IndexService {
rpc Search (stream SearchRequest) returns (stream SearchResponse) {}
rpc SaveIndex (SaveIndexRequest) returns (SaveIndexResponse) {}
}

message SearchRequest {

oneof vector {
FloatArrayVector float_array_vector = 4;
DoubleArrayVector double_array_vector = 5;
SparseVector sparse_vector = 6;
DenseVector dense_vector = 7;
}

int32 k = 8;
}

message SearchResponse {
repeated Result results = 1;
}

message Result {
oneof id {
string string_id = 1;
int64 long_id = 2;
int32 int_id = 3;
}

oneof distance {
float float_distance = 4;
double double_distance = 5;
}
}

message FloatArrayVector {
repeated float values = 1 [(scalapb.field).collection_type="Array"];
}

message DoubleArrayVector {
repeated double values = 1 [(scalapb.field).collection_type="Array"];
}

message SparseVector {
int32 size = 1;
repeated int32 indices = 2 [(scalapb.field).collection_type="Array"];
repeated double values = 3 [(scalapb.field).collection_type="Array"];
}

message DenseVector {
repeated double values = 1 [(scalapb.field).collection_type="Array"];
}

message SaveIndexRequest {
string path = 1;
}

message SaveIndexResponse {
int64 bytes_written = 1;
}
18 changes: 18 additions & 0 deletions hnswlib-spark/src/main/protobuf/registration.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
syntax = "proto3";

package com.github.jelmerk.server;

service RegistrationService {
rpc Register (RegisterRequest) returns ( RegisterResponse) {}
}

message RegisterRequest {
int32 partition_num = 1;
int32 replica_num = 2;
string host = 3;
int32 port = 4;
}

message RegisterResponse {
}

24 changes: 19 additions & 5 deletions hnswlib-spark/src/main/python/pyspark_hnsw/knn.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any, Dict, Optional, TYPE_CHECKING

from pyspark.ml.wrapper import JavaEstimator, JavaModel
from pyspark.ml.param.shared import (
Params,
Expand All @@ -10,8 +12,12 @@

# noinspection PyProtectedMember
from pyspark import keyword_only
# noinspection PyProtectedMember
from pyspark.ml.util import JavaMLReadable, JavaMLWritable, MLReader, _jvm

if TYPE_CHECKING:
from py4j.java_gateway import JavaObject

__all__ = [
"HnswSimilarity",
"HnswSimilarityModel",
Expand All @@ -31,6 +37,7 @@ def __init__(self, clazz, java_class):
self._clazz = clazz
self._jread = self._load_java_obj(java_class).read()

# noinspection PyProtectedMember
def load(self, path):
"""Load the ML instance from the input path."""
java_obj = self._jread.load(path)
Expand Down Expand Up @@ -132,25 +139,25 @@ def getK(self):
"""
return self.getOrDefault(self.k)

def getExcludeSelf(self):
def getExcludeSelf(self) -> bool:
"""
Gets the value of excludeSelf or its default value.
"""
return self.getOrDefault(self.excludeSelf)

def getSimilarityThreshold(self):
def getSimilarityThreshold(self) -> float:
"""
Gets the value of similarityThreshold or its default value.
"""
return self.getOrDefault(self.similarityThreshold)

def getOutputFormat(self):
def getOutputFormat(self) -> str:
"""
Gets the value of outputFormat or its default value.
"""
return self.getOrDefault(self.outputFormat)

def getNumReplicas(self):
def getNumReplicas(self) -> int:
"""
Gets the value of numReplicas or its default value.
"""
Expand Down Expand Up @@ -294,6 +301,9 @@ class BruteForceSimilarity(JavaEstimator, _KnnParams, JavaMLReadable, JavaMLWrit
Exact nearest neighbour search.
"""

_input_kwargs: Dict[str, Any]

# noinspection PyUnusedLocal
@keyword_only
def __init__(
self,
Expand Down Expand Up @@ -410,6 +420,7 @@ def setInitialModelPath(self, value):
"""
return self._set(initialModelPath=value)

# noinspection PyUnusedLocal
@keyword_only
def setParams(
self,
Expand Down Expand Up @@ -507,6 +518,7 @@ class HnswSimilarity(JavaEstimator, _HnswParams, JavaMLReadable, JavaMLWritable)
Approximate nearest neighbour search.
"""

# noinspection PyUnusedLocal
@keyword_only
def __init__(
self,
Expand Down Expand Up @@ -647,7 +659,9 @@ def setInitialModelPath(self, value):
"""
return self._set(initialModelPath=value)

@keyword_only
# noinspection PyUnusedLocal
@keywor
d_only
def setParams(
self,
identifierCol="id",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package com.github.jelmerk.server.index

import scala.concurrent.{ExecutionContext, Future}
import scala.language.implicitConversions

import com.github.jelmerk.knn.scalalike.{Index, Item}
import com.github.jelmerk.server.index.IndexServiceGrpc.IndexService
import io.grpc.stub.StreamObserver
import org.apache.commons.io.output.CountingOutputStream
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path

class DefaultIndexService[TId, TVector, TItem <: Item[TId, TVector], TDistance](
index: Index[TId, TVector, TItem, TDistance],
hadoopConfiguration: Configuration,
vectorExtractor: SearchRequest => TVector,
resultIdConverter: TId => Result.Id,
resultDistanceConverter: TDistance => Result.Distance
)(implicit
executionContext: ExecutionContext
) extends IndexService {

override def search(responseObserver: StreamObserver[SearchResponse]): StreamObserver[SearchRequest] = {
new StreamObserver[SearchRequest] {
override def onNext(request: SearchRequest): Unit = {

val vector = vectorExtractor(request)
val nearest = index.findNearest(vector, request.k)

val results = nearest.map { searchResult =>
val id = resultIdConverter(searchResult.item.id)
val distance = resultDistanceConverter(searchResult.distance)

Result(id, distance)
}

val response = SearchResponse(results)
responseObserver.onNext(response)
}

override def onError(t: Throwable): Unit = {
responseObserver.onError(t)
}

override def onCompleted(): Unit = {
responseObserver.onCompleted()
}
}
}

override def saveIndex(request: SaveIndexRequest): Future[SaveIndexResponse] = Future {
val path = new Path(request.path)
val fileSystem = path.getFileSystem(hadoopConfiguration)

val outputStream = fileSystem.create(path)

val countingOutputStream = new CountingOutputStream(outputStream)

index.save(countingOutputStream)

SaveIndexResponse(bytesWritten = countingOutputStream.getCount)
}

}
Loading

0 comments on commit d2c2764

Please sign in to comment.