Skip to content

Commit

Permalink
fix running heavy computationn threads in asyncronous runtime, better…
Browse files Browse the repository at this point in the history
… threading performace
  • Loading branch information
andychenbruce committed Mar 15, 2024
1 parent 761346e commit 99f73bb
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 16 deletions.
2 changes: 2 additions & 0 deletions backend/crates/server_backend/src/andy_error.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#[derive(thiserror::Error, Debug)]
pub enum AndyError {
#[error("thead joining error")]
ThreadJoin(#[from] tokio::task::JoinError),
#[error("hyper library error")]
Hyper(#[from] hyper::Error),
#[error("serde json (de)serialization error")]
Expand Down
6 changes: 5 additions & 1 deletion backend/crates/server_backend/src/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,11 @@ async fn create_deck_pdf(
let url = data_url::DataUrl::process(&info.file_bytes_base64)?;
let (body, _fragment) = url.decode_to_vec()?;

let text = pdf_parser::extract_text(&body)?;
//spawn pdf parsing in a new thread so it doesn't block the executor
let text = tokio::task::spawn_blocking(move || {
Ok::<String, AndyError>(pdf_parser::extract_text(&body)?)
})
.await??;

enum SliceType<'a> {
Question(&'a str),
Expand Down
52 changes: 37 additions & 15 deletions backend/crates/server_backend/src/server/search_engine/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ const VECTOR_SIZE: u64 = 384;

#[derive(thiserror::Error, Debug)]
pub enum SearchEngineError {
#[error("thread join error")]
ThreadJoin(#[from] tokio::task::JoinError),
#[error("qdrant client error")]
Qdrant(#[from] anyhow::Error),
#[error("embedder error")]
Expand All @@ -29,7 +31,7 @@ pub enum SearchEngineError {

pub struct SearchEngine {
client: Option<QdrantClient>,
embedder: Option<sentence_embedder::SentenceEmbedder>,
embedder: Option<std::sync::Arc<tokio::sync::Mutex<sentence_embedder::SentenceEmbedder>>>,
}

impl SearchEngine {
Expand All @@ -44,7 +46,7 @@ impl SearchEngine {
) -> Result<Self, SearchEngineError> {
let embedder = match embedder_path.map(sentence_embedder::SentenceEmbedder::new) {
None => None,
Some(x) => Some(x?),
Some(x) => Some(std::sync::Arc::new(tokio::sync::Mutex::new(x?))),
};
let client = match qdrant_addr {
Some(addr) => {
Expand Down Expand Up @@ -114,8 +116,14 @@ impl SearchEngine {
.map_err(SearchEngineError::PayloadConversion)?;

let sentences: Vec<_> = cards.into_iter().map(format_card).collect();
//spawn pdf parsing in a new thread so it doesn't block the executor

let embedder_arc = self.get_embedder()?;
let vectors = tokio::task::spawn_blocking(move || {
Ok::<Vec<_>, SearchEngineError>(embedder_arc.blocking_lock().run(sentences)?)
})
.await??;

let vectors = self.get_embedder()?.run(sentences)?;
let points: Vec<_> = vectors
.into_iter()
.enumerate()
Expand All @@ -135,14 +143,19 @@ impl SearchEngine {
prompt: &str,
num_results: u64,
) -> Result<Vec<(super::database::UserId, super::database::DeckId)>, SearchEngineError> {
let vector = self
.get_embedder()?
.run(vec![prompt.to_owned()])?
.pop()
.ok_or(SearchEngineError::EmbedderLength {
expected: 1,
got: 0,
})?;
let prompt_mem = prompt.to_owned();
let embedder_arc = self.get_embedder()?;
let vector = tokio::task::spawn_blocking(move || {
embedder_arc
.blocking_lock()
.run(vec![prompt_mem])?
.pop()
.ok_or(SearchEngineError::EmbedderLength {
expected: 1,
got: 0,
})
})
.await??;

let search_result = self
.get_client()?
Expand All @@ -161,10 +174,13 @@ impl SearchEngine {
}

fn get_embedder(
&mut self,
) -> Result<&mut sentence_embedder::SentenceEmbedder, SearchEngineError> {
&self,
) -> Result<
std::sync::Arc<tokio::sync::Mutex<sentence_embedder::SentenceEmbedder>>,
SearchEngineError,
> {
self.embedder
.as_mut()
.clone()
.ok_or(SearchEngineError::EmbedderNeverLoaded)
}
fn get_client(&mut self) -> Result<&mut QdrantClient, SearchEngineError> {
Expand Down Expand Up @@ -210,7 +226,11 @@ impl SearchEngine {
})
.collect::<Result<Vec<_>, SearchEngineError>>()?;

let vectors = self.get_embedder()?.run(sentences)?;
let embedder_arc = self.get_embedder()?;
let vectors = tokio::task::spawn_blocking(move || {
Ok::<Vec<_>, SearchEngineError>(embedder_arc.blocking_lock().run(sentences)?)
})
.await??;

let points: Vec<_> = vectors
.into_iter()
Expand All @@ -235,6 +255,8 @@ impl SearchEngine {
) -> Result<Vec<String>, SearchEngineError> {
let vector = self
.get_embedder()?
.lock()
.await
.run(vec![prompt.to_owned()])?
.pop()
.ok_or(SearchEngineError::EmbedderLength {
Expand Down

0 comments on commit 99f73bb

Please sign in to comment.