Skip to content

Commit

Permalink
Merge pull request #539 from coasys/change-whisper-model
Browse files Browse the repository at this point in the history
Tune model config and selection & other misc. clean-ups and fixes
  • Loading branch information
lucksus authored Dec 17, 2024
2 parents 1bd7b14 + 8c7fc27 commit 3d23650
Show file tree
Hide file tree
Showing 8 changed files with 71 additions and 12 deletions.
53 changes: 48 additions & 5 deletions rust-executor/src/ai_service/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ use log::error;

pub type Result<T> = std::result::Result<T, AnyError>;

static WHISPER_MODEL: WhisperSource = WhisperSource::Small;

lazy_static! {
static ref AI_SERVICE: Arc<Mutex<Option<AIService>>> = Arc::new(Mutex::new(None));
}
Expand Down Expand Up @@ -130,7 +132,7 @@ async fn handle_progress(model_id: String, loading: ModelLoadingProgress) {
} else {
"Loaded".to_string()
};
println!("Progress update: {}% for model {}", progress, model_id); // Add logging
//println!("Progress update: {}% for model {}", progress, model_id); // Add logging
publish_model_status(model_id.clone(), progress, &status, false, false).await;
}

Expand Down Expand Up @@ -289,9 +291,28 @@ impl AIService {
// Local TinyLlama models
"llama_tiny" => Llama::builder().with_source(LlamaSource::tiny_llama_1_1b()),
"llama_7b" => Llama::builder().with_source(LlamaSource::llama_7b()),
"llama_7b_chat" => Llama::builder().with_source(LlamaSource::llama_7b_chat()),
"llama_7b_code" => Llama::builder().with_source(LlamaSource::llama_7b_code()),
"llama_8b" => Llama::builder().with_source(LlamaSource::llama_8b()),
"llama_8b_chat" => Llama::builder().with_source(LlamaSource::llama_8b_chat()),
"llama_3_1_8b_chat" => Llama::builder().with_source(LlamaSource::llama_3_1_8b_chat()),
"llama_13b" => Llama::builder().with_source(LlamaSource::llama_13b()),
"llama_13b_chat" => Llama::builder().with_source(LlamaSource::llama_13b_chat()),
"llama_13b_code" => Llama::builder().with_source(LlamaSource::llama_13b_code()),
"llama_34b_code" => Llama::builder().with_source(LlamaSource::llama_34b_code()),
"llama_70b" => Llama::builder().with_source(LlamaSource::llama_70b()),
"mistral_7b" => Llama::builder().with_source(LlamaSource::mistral_7b()),
"mistral_7b_instruct" => {
Llama::builder().with_source(LlamaSource::mistral_7b_instruct())
}
"mistral_7b_instruct_2" => {
Llama::builder().with_source(LlamaSource::mistral_7b_instruct_2())
}
"solar_10_7b" => Llama::builder().with_source(LlamaSource::solar_10_7b()),
"solar_10_7b_instruct" => {
Llama::builder().with_source(LlamaSource::solar_10_7b_instruct())
}

// Handle unknown models
_ => {
log::error!("Unknown model string: {}", model_size_string);
Expand All @@ -308,7 +329,7 @@ impl AIService {
.build_with_loading_handler({
let model_id = model_id.clone();
move |progress| {
futures::executor::block_on(handle_progress(model_id.clone(), progress));
tokio::spawn(handle_progress(model_id.clone(), progress));
}
})
.await?;
Expand Down Expand Up @@ -520,6 +541,13 @@ impl AIService {
}
LlmModel::Local(ref mut llama) => {
if let Some(task) = tasks.get(&prompt_request.task_id) {
rt.block_on(publish_model_status(
model_config.id.clone(),
100.0,
"Running inference...",
true,
true,
));
let mut maybe_result: Option<String> = None;
let mut tries = 0;
while maybe_result.is_none() && tries < 20 {
Expand All @@ -536,12 +564,27 @@ impl AIService {
log::error!(
"Llama panicked with: {:?}. Trying again..",
e
)
);
rt.block_on(publish_model_status(
model_config.id.clone(),
100.0,
"Panicked while running inference - trying again...",
true,
true,
));
}
Ok(result) => maybe_result = Some(result),
}
}

rt.block_on(publish_model_status(
model_config.id.clone(),
100.0,
"Ready",
true,
true,
));

if let Some(result) = maybe_result {
let _ = prompt_request.result_sender.send(Ok(result));
} else {
Expand Down Expand Up @@ -808,7 +851,7 @@ impl AIService {

rt.block_on(async {
let maybe_model = WhisperBuilder::default()
.with_source(WhisperSource::Base)
.with_source(WHISPER_MODEL)
.with_device(Device::Cpu)
.build()
.await;
Expand Down Expand Up @@ -906,7 +949,7 @@ impl AIService {
publish_model_status(id.clone(), 0.0, "Loading", false, false).await;

let _ = WhisperBuilder::default()
.with_source(WhisperSource::Base)
.with_source(WHISPER_MODEL)
.with_device(Device::Cpu)
.build_with_loading_handler({
let name = id.clone();
Expand Down
2 changes: 1 addition & 1 deletion rust-executor/src/js_core/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ impl std::fmt::Display for ExternWrapper {

impl JsCore {
pub fn new() -> Self {
deno_core::v8::V8::set_flags_from_string("--no-opt --turbo-disable-all");
deno_core::v8::V8::set_flags_from_string("--no-opt");
JsCore {
#[allow(clippy::arc_with_non_send_sync)]
worker: Arc::new(TokioMutex::new(MainWorker::from_options(
Expand Down
2 changes: 1 addition & 1 deletion rust-executor/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ use libc::{sigaction, sigemptyset, sighandler_t, SA_ONSTACK, SIGURG};
use std::ptr;

extern "C" fn handle_sigurg(_: libc::c_int) {
println!("Received SIGURG signal, but ignoring it.");
//println!("Received SIGURG signal, but ignoring it.");
}

/// Runs the GraphQL server and the deno core runtime
Expand Down
2 changes: 1 addition & 1 deletion ui/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
"change-ui-version": "powershell -ExecutionPolicy Bypass -File ./scripts/patch-prerelease-tags.ps1",
"package-ad4m": "run-script-os",
"package-ad4m:windows": "cargo clean && pnpm run build && pnpm tauri build --verbose",
"package-ad4m:macos": "pnpm run build && pnpm tauri build --verbose",
"package-ad4m:macos": "pnpm run build && pnpm tauri build --verbose --features metal",
"package-ad4m:linux": "pnpm run build && pnpm tauri build --verbose"
},
"eslintConfig": {
Expand Down
4 changes: 4 additions & 0 deletions ui/src-tauri/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ dirs = "5"
# by default Tauri runs in production mode
# when `tauri dev` runs it is executed with `cargo run --no-default-features` if `devPath` is an URL
default = [ "custom-protocol"]
# Pass metal and cuda features through to ad4m-executor
metal = ["ad4m-executor/metal"]
cuda = ["ad4m-executor/cuda"]


dev = []
# this feature is used used for production builds where `devPath` points to the filesystem
Expand Down
2 changes: 1 addition & 1 deletion ui/src-tauri/src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ pub fn create_main_window(app: &AppHandle<Wry>) {
let new_ad4m_window = WebviewWindowBuilder::new(app, "AD4M", WebviewUrl::App(url.into()))
.center()
.focused(true)
.inner_size(1000.0, 700.0)
.inner_size(1000.0, 1200.0)
.title("ADAM Launcher");

let _ = new_ad4m_window.build();
Expand Down
2 changes: 1 addition & 1 deletion ui/src/components/Login.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ const Login = () => {
const llm = { name: "LLM Model 1", modelType: "LLM" } as ModelInput;
if (aiMode === "Local") {
llm.local = {
fileName: "llama_7b",
fileName: "solar_10_7b_instruct",
tokenizerSource: "",
modelParameters: "",
};
Expand Down
16 changes: 14 additions & 2 deletions ui/src/components/ModelModal.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,21 @@ const AITypes = ["LLM", "EMBEDDING", "TRANSCRIPTION"];
const llmModels = [
"External API",
// "tiny_llama_1_1b",
"mistral_7b",
"mistral_7b_instruct",
"mistral_7b_instruct_2",
"solar_10_7b",
"solar_10_7b_instruct",
"llama_7b",
"llama_7b_chat",
"llama_7b_code",
"llama_8b",
"llama_8b_chat",
"llama_3_1_8b_chat",
"llama_13b",
"llama_13b_chat",
"llama_13b_code",
"llama_34b_code",
"llama_70b",
];
const transcriptionModels = ["whisper"];
Expand All @@ -28,7 +40,7 @@ export default function ModelModal(props: {
const [newModelNameError, setNewModelNameError] = useState(false);
const [newModelType, setNewModelType] = useState("LLM");
const [newModels, setNewModels] = useState<any[]>(llmModels);
const [newModel, setNewModel] = useState("llama_7b");
const [newModel, setNewModel] = useState("llama_8b");
const [apiUrl, setApiUrl] = useState("https://api.openai.com/v1");
const [apiKey, setApiKey] = useState("");
const [apiUrlError, setApiUrlError] = useState(false);
Expand Down Expand Up @@ -165,7 +177,7 @@ export default function ModelModal(props: {
setNewModelType(type);
if (type === "LLM") {
setNewModels(llmModels);
setNewModel("llama_7b");
setNewModel("llama_8b");
} else if (type === "EMBEDDING") {
setNewModels(embeddingModels);
setNewModel("bert");
Expand Down

0 comments on commit 3d23650

Please sign in to comment.