-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Redesign spark integration to take advantage of resource profiles.
- Loading branch information
Jelmer Kuperus
committed
Dec 30, 2024
1 parent
662691a
commit d2c2764
Showing
23 changed files
with
1,795 additions
and
764 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
syntax = "proto3"; | ||
|
||
import "scalapb/scalapb.proto"; | ||
|
||
package com.github.jelmerk.server; | ||
|
||
service IndexService { | ||
rpc Search (stream SearchRequest) returns (stream SearchResponse) {} | ||
rpc SaveIndex (SaveIndexRequest) returns (SaveIndexResponse) {} | ||
} | ||
|
||
message SearchRequest { | ||
|
||
oneof vector { | ||
FloatArrayVector float_array_vector = 4; | ||
DoubleArrayVector double_array_vector = 5; | ||
SparseVector sparse_vector = 6; | ||
DenseVector dense_vector = 7; | ||
} | ||
|
||
int32 k = 8; | ||
} | ||
|
||
message SearchResponse { | ||
repeated Result results = 1; | ||
} | ||
|
||
message Result { | ||
oneof id { | ||
string string_id = 1; | ||
int64 long_id = 2; | ||
int32 int_id = 3; | ||
} | ||
|
||
oneof distance { | ||
float float_distance = 4; | ||
double double_distance = 5; | ||
} | ||
} | ||
|
||
message FloatArrayVector { | ||
repeated float values = 1 [(scalapb.field).collection_type="Array"]; | ||
} | ||
|
||
message DoubleArrayVector { | ||
repeated double values = 1 [(scalapb.field).collection_type="Array"]; | ||
} | ||
|
||
message SparseVector { | ||
int32 size = 1; | ||
repeated int32 indices = 2 [(scalapb.field).collection_type="Array"]; | ||
repeated double values = 3 [(scalapb.field).collection_type="Array"]; | ||
} | ||
|
||
message DenseVector { | ||
repeated double values = 1 [(scalapb.field).collection_type="Array"]; | ||
} | ||
|
||
message SaveIndexRequest { | ||
string path = 1; | ||
} | ||
|
||
message SaveIndexResponse { | ||
int64 bytes_written = 1; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
syntax = "proto3"; | ||
|
||
package com.github.jelmerk.server; | ||
|
||
service RegistrationService { | ||
rpc Register (RegisterRequest) returns ( RegisterResponse) {} | ||
} | ||
|
||
message RegisterRequest { | ||
int32 partition_num = 1; | ||
int32 replica_num = 2; | ||
string host = 3; | ||
int32 port = 4; | ||
} | ||
|
||
message RegisterResponse { | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
64 changes: 64 additions & 0 deletions
64
hnswlib-spark/src/main/scala/com/github/jelmerk/server/index/DefaultIndexService.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
package com.github.jelmerk.server.index | ||
|
||
import scala.concurrent.{ExecutionContext, Future} | ||
import scala.language.implicitConversions | ||
|
||
import com.github.jelmerk.knn.scalalike.{Index, Item} | ||
import com.github.jelmerk.server.index.IndexServiceGrpc.IndexService | ||
import io.grpc.stub.StreamObserver | ||
import org.apache.commons.io.output.CountingOutputStream | ||
import org.apache.hadoop.conf.Configuration | ||
import org.apache.hadoop.fs.Path | ||
|
||
class DefaultIndexService[TId, TVector, TItem <: Item[TId, TVector], TDistance]( | ||
index: Index[TId, TVector, TItem, TDistance], | ||
hadoopConfiguration: Configuration, | ||
vectorExtractor: SearchRequest => TVector, | ||
resultIdConverter: TId => Result.Id, | ||
resultDistanceConverter: TDistance => Result.Distance | ||
)(implicit | ||
executionContext: ExecutionContext | ||
) extends IndexService { | ||
|
||
override def search(responseObserver: StreamObserver[SearchResponse]): StreamObserver[SearchRequest] = { | ||
new StreamObserver[SearchRequest] { | ||
override def onNext(request: SearchRequest): Unit = { | ||
|
||
val vector = vectorExtractor(request) | ||
val nearest = index.findNearest(vector, request.k) | ||
|
||
val results = nearest.map { searchResult => | ||
val id = resultIdConverter(searchResult.item.id) | ||
val distance = resultDistanceConverter(searchResult.distance) | ||
|
||
Result(id, distance) | ||
} | ||
|
||
val response = SearchResponse(results) | ||
responseObserver.onNext(response) | ||
} | ||
|
||
override def onError(t: Throwable): Unit = { | ||
responseObserver.onError(t) | ||
} | ||
|
||
override def onCompleted(): Unit = { | ||
responseObserver.onCompleted() | ||
} | ||
} | ||
} | ||
|
||
override def saveIndex(request: SaveIndexRequest): Future[SaveIndexResponse] = Future { | ||
val path = new Path(request.path) | ||
val fileSystem = path.getFileSystem(hadoopConfiguration) | ||
|
||
val outputStream = fileSystem.create(path) | ||
|
||
val countingOutputStream = new CountingOutputStream(outputStream) | ||
|
||
index.save(countingOutputStream) | ||
|
||
SaveIndexResponse(bytesWritten = countingOutputStream.getCount) | ||
} | ||
|
||
} |
Oops, something went wrong.