diff --git a/rust-executor/src/ai_service/mod.rs b/rust-executor/src/ai_service/mod.rs index c04faa111..e4bc86824 100644 --- a/rust-executor/src/ai_service/mod.rs +++ b/rust-executor/src/ai_service/mod.rs @@ -32,6 +32,8 @@ use log::error; pub type Result = std::result::Result; +static WHISPER_MODEL: WhisperSource = WhisperSource::Small; + lazy_static! { static ref AI_SERVICE: Arc>> = Arc::new(Mutex::new(None)); } @@ -130,7 +132,7 @@ async fn handle_progress(model_id: String, loading: ModelLoadingProgress) { } else { "Loaded".to_string() }; - println!("Progress update: {}% for model {}", progress, model_id); // Add logging + //println!("Progress update: {}% for model {}", progress, model_id); // Add logging publish_model_status(model_id.clone(), progress, &status, false, false).await; } @@ -289,9 +291,28 @@ impl AIService { // Local TinyLlama models "llama_tiny" => Llama::builder().with_source(LlamaSource::tiny_llama_1_1b()), "llama_7b" => Llama::builder().with_source(LlamaSource::llama_7b()), + "llama_7b_chat" => Llama::builder().with_source(LlamaSource::llama_7b_chat()), + "llama_7b_code" => Llama::builder().with_source(LlamaSource::llama_7b_code()), "llama_8b" => Llama::builder().with_source(LlamaSource::llama_8b()), + "llama_8b_chat" => Llama::builder().with_source(LlamaSource::llama_8b_chat()), + "llama_3_1_8b_chat" => Llama::builder().with_source(LlamaSource::llama_3_1_8b_chat()), "llama_13b" => Llama::builder().with_source(LlamaSource::llama_13b()), + "llama_13b_chat" => Llama::builder().with_source(LlamaSource::llama_13b_chat()), + "llama_13b_code" => Llama::builder().with_source(LlamaSource::llama_13b_code()), + "llama_34b_code" => Llama::builder().with_source(LlamaSource::llama_34b_code()), "llama_70b" => Llama::builder().with_source(LlamaSource::llama_70b()), + "mistral_7b" => Llama::builder().with_source(LlamaSource::mistral_7b()), + "mistral_7b_instruct" => { + Llama::builder().with_source(LlamaSource::mistral_7b_instruct()) + } + "mistral_7b_instruct_2" => { + Llama::builder().with_source(LlamaSource::mistral_7b_instruct_2()) + } + "solar_10_7b" => Llama::builder().with_source(LlamaSource::solar_10_7b()), + "solar_10_7b_instruct" => { + Llama::builder().with_source(LlamaSource::solar_10_7b_instruct()) + } + // Handle unknown models _ => { log::error!("Unknown model string: {}", model_size_string); @@ -308,7 +329,7 @@ impl AIService { .build_with_loading_handler({ let model_id = model_id.clone(); move |progress| { - futures::executor::block_on(handle_progress(model_id.clone(), progress)); + tokio::spawn(handle_progress(model_id.clone(), progress)); } }) .await?; @@ -520,6 +541,13 @@ impl AIService { } LlmModel::Local(ref mut llama) => { if let Some(task) = tasks.get(&prompt_request.task_id) { + rt.block_on(publish_model_status( + model_config.id.clone(), + 100.0, + "Running inference...", + true, + true, + )); let mut maybe_result: Option = None; let mut tries = 0; while maybe_result.is_none() && tries < 20 { @@ -536,12 +564,27 @@ impl AIService { log::error!( "Llama panicked with: {:?}. Trying again..", e - ) + ); + rt.block_on(publish_model_status( + model_config.id.clone(), + 100.0, + "Panicked while running inference - trying again...", + true, + true, + )); } Ok(result) => maybe_result = Some(result), } } + rt.block_on(publish_model_status( + model_config.id.clone(), + 100.0, + "Ready", + true, + true, + )); + if let Some(result) = maybe_result { let _ = prompt_request.result_sender.send(Ok(result)); } else { @@ -808,7 +851,7 @@ impl AIService { rt.block_on(async { let maybe_model = WhisperBuilder::default() - .with_source(WhisperSource::Base) + .with_source(WHISPER_MODEL) .with_device(Device::Cpu) .build() .await; @@ -906,7 +949,7 @@ impl AIService { publish_model_status(id.clone(), 0.0, "Loading", false, false).await; let _ = WhisperBuilder::default() - .with_source(WhisperSource::Base) + .with_source(WHISPER_MODEL) .with_device(Device::Cpu) .build_with_loading_handler({ let name = id.clone(); diff --git a/rust-executor/src/js_core/mod.rs b/rust-executor/src/js_core/mod.rs index 0ddbc9a02..5ec46e25d 100644 --- a/rust-executor/src/js_core/mod.rs +++ b/rust-executor/src/js_core/mod.rs @@ -137,7 +137,7 @@ impl std::fmt::Display for ExternWrapper { impl JsCore { pub fn new() -> Self { - deno_core::v8::V8::set_flags_from_string("--no-opt --turbo-disable-all"); + deno_core::v8::V8::set_flags_from_string("--no-opt"); JsCore { #[allow(clippy::arc_with_non_send_sync)] worker: Arc::new(TokioMutex::new(MainWorker::from_options( diff --git a/rust-executor/src/lib.rs b/rust-executor/src/lib.rs index c6083e592..35b2a8ab2 100644 --- a/rust-executor/src/lib.rs +++ b/rust-executor/src/lib.rs @@ -42,7 +42,7 @@ use libc::{sigaction, sigemptyset, sighandler_t, SA_ONSTACK, SIGURG}; use std::ptr; extern "C" fn handle_sigurg(_: libc::c_int) { - println!("Received SIGURG signal, but ignoring it."); + //println!("Received SIGURG signal, but ignoring it."); } /// Runs the GraphQL server and the deno core runtime diff --git a/ui/package.json b/ui/package.json index 09ba62854..cbcd97103 100644 --- a/ui/package.json +++ b/ui/package.json @@ -41,7 +41,7 @@ "change-ui-version": "powershell -ExecutionPolicy Bypass -File ./scripts/patch-prerelease-tags.ps1", "package-ad4m": "run-script-os", "package-ad4m:windows": "cargo clean && pnpm run build && pnpm tauri build --verbose", - "package-ad4m:macos": "pnpm run build && pnpm tauri build --verbose", + "package-ad4m:macos": "pnpm run build && pnpm tauri build --verbose --features metal", "package-ad4m:linux": "pnpm run build && pnpm tauri build --verbose" }, "eslintConfig": { diff --git a/ui/src-tauri/Cargo.toml b/ui/src-tauri/Cargo.toml index c9110ecbd..03289092c 100644 --- a/ui/src-tauri/Cargo.toml +++ b/ui/src-tauri/Cargo.toml @@ -58,6 +58,10 @@ dirs = "5" # by default Tauri runs in production mode # when `tauri dev` runs it is executed with `cargo run --no-default-features` if `devPath` is an URL default = [ "custom-protocol"] +# Pass metal and cuda features through to ad4m-executor +metal = ["ad4m-executor/metal"] +cuda = ["ad4m-executor/cuda"] + dev = [] # this feature is used used for production builds where `devPath` points to the filesystem diff --git a/ui/src-tauri/src/util.rs b/ui/src-tauri/src/util.rs index b38ef3321..0180550db 100644 --- a/ui/src-tauri/src/util.rs +++ b/ui/src-tauri/src/util.rs @@ -37,7 +37,7 @@ pub fn create_main_window(app: &AppHandle) { let new_ad4m_window = WebviewWindowBuilder::new(app, "AD4M", WebviewUrl::App(url.into())) .center() .focused(true) - .inner_size(1000.0, 700.0) + .inner_size(1000.0, 1200.0) .title("ADAM Launcher"); let _ = new_ad4m_window.build(); diff --git a/ui/src/components/Login.tsx b/ui/src/components/Login.tsx index 1c7ec5a0c..dc515da25 100644 --- a/ui/src/components/Login.tsx +++ b/ui/src/components/Login.tsx @@ -95,7 +95,7 @@ const Login = () => { const llm = { name: "LLM Model 1", modelType: "LLM" } as ModelInput; if (aiMode === "Local") { llm.local = { - fileName: "llama_7b", + fileName: "solar_10_7b_instruct", tokenizerSource: "", modelParameters: "", }; diff --git a/ui/src/components/ModelModal.tsx b/ui/src/components/ModelModal.tsx index 18ef42d08..edf7cb3b8 100644 --- a/ui/src/components/ModelModal.tsx +++ b/ui/src/components/ModelModal.tsx @@ -7,9 +7,21 @@ const AITypes = ["LLM", "EMBEDDING", "TRANSCRIPTION"]; const llmModels = [ "External API", // "tiny_llama_1_1b", + "mistral_7b", + "mistral_7b_instruct", + "mistral_7b_instruct_2", + "solar_10_7b", + "solar_10_7b_instruct", "llama_7b", + "llama_7b_chat", + "llama_7b_code", "llama_8b", + "llama_8b_chat", + "llama_3_1_8b_chat", "llama_13b", + "llama_13b_chat", + "llama_13b_code", + "llama_34b_code", "llama_70b", ]; const transcriptionModels = ["whisper"]; @@ -28,7 +40,7 @@ export default function ModelModal(props: { const [newModelNameError, setNewModelNameError] = useState(false); const [newModelType, setNewModelType] = useState("LLM"); const [newModels, setNewModels] = useState(llmModels); - const [newModel, setNewModel] = useState("llama_7b"); + const [newModel, setNewModel] = useState("llama_8b"); const [apiUrl, setApiUrl] = useState("https://api.openai.com/v1"); const [apiKey, setApiKey] = useState(""); const [apiUrlError, setApiUrlError] = useState(false); @@ -165,7 +177,7 @@ export default function ModelModal(props: { setNewModelType(type); if (type === "LLM") { setNewModels(llmModels); - setNewModel("llama_7b"); + setNewModel("llama_8b"); } else if (type === "EMBEDDING") { setNewModels(embeddingModels); setNewModel("bert");