diff --git a/README.md b/README.md index abccbe5..d17b506 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ Current Status - [X] Insert - [X] StreamInsert - [X] Search - - [ ] StreamSearch + - [X] StreamSearch - [ ] SearchByID - [ ] StreamSearchByID - [X] CreateIndex diff --git a/src/vald.rs b/src/vald.rs index fb3cdd6..45f5fdb 100644 --- a/src/vald.rs +++ b/src/vald.rs @@ -91,6 +91,31 @@ impl ValdImpl { Err(err) => Err(Status::internal(err.to_string())), } } + + fn search_impl( + &self, + request: &payload::v1::search::Request, + ) -> Result { + let vector: Vec = request.vector.iter().map(|f| { + *f as f64 + }).collect(); + let config = match &request.config { + Some(c) => c, + None => return Err(Status::invalid_argument("config is required.")), + }; + + let request_id = config.request_id.clone(); + let num: u64 = From::from(config.num); + + let results = self.ngt.lock().unwrap().search(vector, num, config.epsilon).unwrap(); + + let reply = payload::v1::search::Response{ + request_id, + results, + }; + + Ok(reply) + } } impl Clone for ValdImpl { @@ -163,26 +188,10 @@ impl Search for ValdImpl { &self, request: Request, ) -> Result, Status> { - let msg = request.get_ref(); - let vector: Vec = msg.vector.iter().map(|f| { - *f as f64 - }).collect(); - let config = match &msg.config { - Some(c) => c, - None => return Err(Status::invalid_argument("config is required.")), - }; - - let request_id = config.request_id.clone(); - let num: u64 = From::from(config.num); - - let results = self.ngt.lock().unwrap().search(vector, num, config.epsilon).unwrap(); - - let reply = payload::v1::search::Response{ - request_id, - results, - }; - - Ok(Response::new(reply)) + match self.search_impl(request.get_ref()) { + Ok(res) => Ok(Response::new(res)), + Err(err) => Err(err), + } } async fn search_by_id( @@ -198,7 +207,34 @@ impl Search for ValdImpl { &self, request: Request>, ) -> Result, Status> { - unimplemented!() + let mut stream = request.into_inner(); + let (mut tx, rx) = mpsc::channel(4); + let vald = self.clone(); + + tokio::spawn(async move { + while let Some(req) = stream.message().await.unwrap() { + let reply = match vald.search_impl(&req) { + Ok(res) => payload::v1::search::StreamResponse{ + payload: Some(payload::v1::search::stream_response::Payload::Response(res)), + }, + Err(st) => payload::v1::search::StreamResponse{ + payload: Some(payload::v1::search::stream_response::Payload::Error(errors::v1::errors::Rpc{ + r#type: "".to_string(), + msg: "".to_string(), + details: Vec::new(), + error: st.to_string(), + instance: "".to_string(), + status: 0, + roots: Vec::new(), + })), + }, + }; + + tx.send(Ok(reply)).await.unwrap(); + } + }); + + Ok(Response::new(rx)) } type StreamSearchByIDStream = mpsc::Receiver>;