Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Nonstreaming API #85

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion apps/desktop/src-tauri/src/inference/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ pub struct CompletionRequest {

sampler: Option<String>,

stream: Option<bool>,
pub stream: Option<bool>,

max_tokens: Option<usize>,

Expand Down
53 changes: 38 additions & 15 deletions apps/desktop/src-tauri/src/inference/process.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@ pub type ModelGuard = Arc<Mutex<Option<Box<dyn Model>>>>;
pub struct InferenceThreadRequest {
pub token_sender: Sender<Bytes>,
pub abort_flag: Arc<RwLock<bool>>,

pub model_guard: ModelGuard,
pub completion_request: CompletionRequest,
pub nonstream_completion_tokens: Arc<Mutex<String>>,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can make this private if we use it as a trait state for the non-stream feature. Making it pub would allow others to inspect it while it's writing/locked, which could potentially deadlock the Mutex writer if we're not careful... :d

pub stream: bool,
pub tx: Option<Sender<()>>,
}

impl InferenceThreadRequest {
Expand Down Expand Up @@ -77,7 +79,7 @@ fn get_inference_params(
}

// Perhaps might be better to clone the model for each thread...
pub fn start(req: InferenceThreadRequest) -> JoinHandle<()> {
pub fn start<'a>(req: InferenceThreadRequest) -> JoinHandle<()> {
println!("Spawning inference thread...");
actix_web::rt::task::spawn_blocking(move || {
let mut rng = req.completion_request.get_rng();
Expand All @@ -86,6 +88,8 @@ pub fn start(req: InferenceThreadRequest) -> JoinHandle<()> {

let mut token_utf8_buf = TokenUtf8Buffer::new();
let guard = req.model_guard.lock();
let stream_enabled = req.stream;
let mut nonstream_res_str_buf = req.nonstream_completion_tokens.lock();

let model = match guard.as_ref() {
Some(m) => m,
Expand All @@ -105,7 +109,10 @@ pub fn start(req: InferenceThreadRequest) -> JoinHandle<()> {
let start_at = std::time::SystemTime::now();

println!("Feeding prompt ...");
req.send_event("FEEDING_PROMPT");

if stream_enabled {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we do this check at the trait level instead. This way we can unify the interface call (in this file), and handle the stream/non-stream logic at the trait implementation level instead, would make it much nicer and more cohesive :)

req.send_event("FEEDING_PROMPT");
}

match session.feed_prompt::<Infallible, Prompt>(
model.as_ref(),
Expand All @@ -118,7 +125,9 @@ pub fn start(req: InferenceThreadRequest) -> JoinHandle<()> {
}

if let Some(token) = token_utf8_buf.push(t) {
req.send_comment(format!("Processing token: {:?}", token).as_str());
if stream_enabled {
req.send_comment(format!("Processing token: {:?}", token).as_str());
}
}

Ok(InferenceFeedback::Continue)
Expand All @@ -138,8 +147,10 @@ pub fn start(req: InferenceThreadRequest) -> JoinHandle<()> {
}
};

req.send_comment("Generating tokens ...");
req.send_event("GENERATING_TOKENS");
if stream_enabled {
req.send_comment("Generating tokens ...");
req.send_event("GENERATING_TOKENS");
}

// Reset the utf8 buf
token_utf8_buf = TokenUtf8Buffer::new();
Expand Down Expand Up @@ -176,14 +187,19 @@ pub fn start(req: InferenceThreadRequest) -> JoinHandle<()> {

// Buffer the token until it's valid UTF-8, then call the callback.
if let Some(tokens) = token_utf8_buf.push(&token) {
match req
.token_sender
.send(CompletionResponse::to_data_bytes(tokens))
{
Ok(_) => {}
Err(_) => {
break;
if req.stream {
match req
.token_sender
.send(CompletionResponse::to_data_bytes(tokens))
{
Ok(_) => {}
Err(_) => {
break;
}
}
} else {
//Collect tokens into str buffer
*nonstream_res_str_buf += &tokens;
}
}

Expand All @@ -195,8 +211,15 @@ pub fn start(req: InferenceThreadRequest) -> JoinHandle<()> {

println!("Inference stats: {:?}", stats);

if !req.token_sender.is_disconnected() {
req.send_done();
if stream_enabled {
if !req.token_sender.is_disconnected() {
req.send_done();
}
} else {
if let Some(tx) = req.tx {
//Tell server thread that inference completed, and let it respond
let _ = tx.send(());
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need that _ or can we just call send here?

}
}

// TODO: Might make this into a callback later, for now we just abuse the singleton
Expand Down
69 changes: 50 additions & 19 deletions apps/desktop/src-tauri/src/inference/server.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use actix_cors::Cors;
use actix_web::dev::ServerHandle;
use actix_web::web::{Bytes, Json};

use actix_web::{get, post, App, HttpResponse, HttpServer, Responder};
use parking_lot::{Mutex, RwLock};
use serde::Serialize;
use serde_json::json;

use std::sync::{
atomic::{AtomicBool, Ordering},
Expand Down Expand Up @@ -58,24 +58,55 @@ async fn post_completions(payload: Json<CompletionRequest>) -> impl Responder {

let (token_sender, receiver) = flume::unbounded::<Bytes>();

HttpResponse::Ok()
.append_header(("Content-Type", "text/event-stream"))
.append_header(("Cache-Control", "no-cache"))
.keep_alive()
.streaming({
let abort_flag = Arc::new(RwLock::new(false));

AbortStream::new(
receiver,
abort_flag.clone(),
start(InferenceThreadRequest {
model_guard: model_guard.clone(),
abort_flag: abort_flag.clone(),
token_sender,
completion_request: payload.0,
}),
)
})
if let Some(true) = payload.stream {
Copy link
Owner

@louisgv louisgv Aug 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be payload.0.stream I think, since it's a JSON.

If we can reconcile our trait above, we can infer the stream boolean via the completion_request as well, skipping a couple of lookup hoop!

HttpResponse::Ok()
.append_header(("Content-Type", "text/event-stream"))
.append_header(("Cache-Control", "no-cache"))
.keep_alive()
.streaming({
let abort_flag = Arc::new(RwLock::new(false));
let str_buffer = Arc::new(Mutex::new(String::new()));

AbortStream::new(
receiver,
abort_flag.clone(),
start(InferenceThreadRequest {
model_guard: model_guard.clone(),
abort_flag: abort_flag.clone(),
token_sender,
completion_request: payload.0,
nonstream_completion_tokens: str_buffer.clone(),
stream: true,
tx: None,
}),
Comment on lines +73 to +81
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have this idea which I think would make this nicer - we can create the InferenceThreadRequest before the isStream check actually, since it's non-blocking state. We can then do

let request = InferenceThreadRequest {
            model_guard: model_guard.clone(),
            abort_flag: abort_flag.clone(),
            token_sender,
            completion_request: payload.0,
            nonstream_completion_tokens: str_buffer.clone(),          
}

if request.isStream() {} else {} 

And the .isStream is a trait public method we expose via InferenceThreadRequest, which basically return completion_request.stream

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really like your attention to detail and design thinking! I will try to implement this one, I agree, it is indeed cleaner.

)
})
} else {
let abort_flag = Arc::new(RwLock::new(false));
let completion_tokens = Arc::new(Mutex::new(String::new()));
let (tx, rx) = flume::unbounded::<()>();
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we can make the tokensender generic, so that we can reuse that argument. The token_sender and the tx serve very similar function here, we just need to reconcile the Byte/String type. That'd make for nicer interface I think

start(InferenceThreadRequest {
model_guard: model_guard.clone(),
abort_flag: abort_flag.clone(),
token_sender,
completion_request: payload.0,
nonstream_completion_tokens: completion_tokens.clone(),
stream: false,
tx: Some(tx),
});

rx.recv().unwrap();
Copy link
Owner

@louisgv louisgv Aug 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should match for error and return HTTP error here IMO, otherwise would be hard to triage :d


let locked_str_buffer = completion_tokens.lock();
let completion_body = json!({
"completion": locked_str_buffer.clone()
});

HttpResponse::Ok()
.append_header(("Content-Type", "text/plain"))
Copy link
Owner

@louisgv louisgv Aug 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should return application/json type here instead I think, it helps the client know to do JSON chunk parsing as needed as well based on that header type

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that makes sense, yes. Will fix those. Thanks for looking at my code.

.append_header(("Cache-Control", "no-cache"))
.json(completion_body)
}
}

#[tauri::command]
Expand Down