Skip to content

Commit

Permalink
✨ Add StreamSearch
Browse files Browse the repository at this point in the history
Signed-off-by: Rintaro Okamura <[email protected]>
  • Loading branch information
rinx committed Jan 16, 2021
1 parent 0c0bf15 commit 67c26f2
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 22 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Current Status
- [X] Insert
- [X] StreamInsert
- [X] Search
- [ ] StreamSearch
- [X] StreamSearch
- [ ] SearchByID
- [ ] StreamSearchByID
- [X] CreateIndex
78 changes: 57 additions & 21 deletions src/vald.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,31 @@ impl ValdImpl {
Err(err) => Err(Status::internal(err.to_string())),
}
}

fn search_impl(
&self,
request: &payload::v1::search::Request,
) -> Result<payload::v1::search::Response, Status> {
let vector: Vec<f64> = request.vector.iter().map(|f| {
*f as f64
}).collect();
let config = match &request.config {
Some(c) => c,
None => return Err(Status::invalid_argument("config is required.")),
};

let request_id = config.request_id.clone();
let num: u64 = From::from(config.num);

let results = self.ngt.lock().unwrap().search(vector, num, config.epsilon).unwrap();

let reply = payload::v1::search::Response{
request_id,
results,
};

Ok(reply)
}
}

impl Clone for ValdImpl {
Expand Down Expand Up @@ -163,26 +188,10 @@ impl Search for ValdImpl {
&self,
request: Request<payload::v1::search::Request>,
) -> Result<Response<payload::v1::search::Response>, Status> {
let msg = request.get_ref();
let vector: Vec<f64> = msg.vector.iter().map(|f| {
*f as f64
}).collect();
let config = match &msg.config {
Some(c) => c,
None => return Err(Status::invalid_argument("config is required.")),
};

let request_id = config.request_id.clone();
let num: u64 = From::from(config.num);

let results = self.ngt.lock().unwrap().search(vector, num, config.epsilon).unwrap();

let reply = payload::v1::search::Response{
request_id,
results,
};

Ok(Response::new(reply))
match self.search_impl(request.get_ref()) {
Ok(res) => Ok(Response::new(res)),
Err(err) => Err(err),
}
}

async fn search_by_id(
Expand All @@ -198,7 +207,34 @@ impl Search for ValdImpl {
&self,
request: Request<Streaming<payload::v1::search::Request>>,
) -> Result<Response<Self::StreamSearchStream>, Status> {
unimplemented!()
let mut stream = request.into_inner();
let (mut tx, rx) = mpsc::channel(4);
let vald = self.clone();

tokio::spawn(async move {
while let Some(req) = stream.message().await.unwrap() {
let reply = match vald.search_impl(&req) {
Ok(res) => payload::v1::search::StreamResponse{
payload: Some(payload::v1::search::stream_response::Payload::Response(res)),
},
Err(st) => payload::v1::search::StreamResponse{
payload: Some(payload::v1::search::stream_response::Payload::Error(errors::v1::errors::Rpc{
r#type: "".to_string(),
msg: "".to_string(),
details: Vec::new(),
error: st.to_string(),
instance: "".to_string(),
status: 0,
roots: Vec::new(),
})),
},
};

tx.send(Ok(reply)).await.unwrap();
}
});

Ok(Response::new(rx))
}

type StreamSearchByIDStream = mpsc::Receiver<Result<payload::v1::search::StreamResponse, Status>>;
Expand Down

0 comments on commit 67c26f2

Please sign in to comment.