Skip to content

Commit

Permalink
Refactor the whisper microphone example.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Sep 30, 2024
1 parent aa35bf2 commit 399ec6d
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 13 deletions.
2 changes: 1 addition & 1 deletion candle-examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/
nccl = ["cuda", "cudarc/nccl", "dep:half"]
onnx = ["candle-onnx"]
metal = ["candle/metal", "candle-nn/metal"]
microphone = ["cpal"]
microphone = ["cpal", "rubato"]
encodec = ["cpal", "symphonia", "rubato"]
mimi = ["cpal", "symphonia", "rubato"]
depth_anything_v2 = ["palette", "enterpolation"]
Expand Down
27 changes: 15 additions & 12 deletions candle-examples/examples/whisper-microphone/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,10 @@ struct Args {
/// Print the full DecodingResult structure rather than just the text.
#[arg(long)]
verbose: bool,

/// The input device to use.
#[arg(long)]
device: Option<String>,
}

pub fn main() -> Result<()> {
Expand Down Expand Up @@ -543,13 +547,12 @@ pub fn main() -> Result<()> {
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], m::DTYPE, &device)? };
Model::Normal(m::model::Whisper::load(&vb, config.clone())?)
};
let language_token = None;
let mut dc = Decoder::new(
model,
tokenizer.clone(),
args.seed,
&device,
language_token,
/* language_token */ None,
args.task,
args.timestamps,
args.verbose,
Expand All @@ -565,26 +568,26 @@ pub fn main() -> Result<()> {

// Set up the input device and stream with the default input config.
let host = cpal::default_host();
let _device = "default";
let _device = if _device == "default" {
host.default_input_device()
} else {
host.input_devices()?
.find(|x| x.name().map(|y| y == _device).unwrap_or(false))
let audio_device = match args.device.as_ref() {
None => host.default_input_device(),
Some(device) => host
.input_devices()?
.find(|x| x.name().map_or(false, |y| &y == device)),
}
.expect("failed to find input device");
.expect("failed to find the audio input device");

let _config = _device
let audio_config = audio_device
.default_input_config()
.expect("Failed to get default input config");
println!("audio config {audio_config:?}");

let channel_count = _config.channels() as usize;
let channel_count = audio_config.channels() as usize;

let audio_ring_buffer = Arc::new(Mutex::new(Vec::new()));
let audio_ring_buffer_2 = audio_ring_buffer.clone();

std::thread::spawn(move || loop {
let data = record_audio(&_device, &_config, 300).unwrap();
let data = record_audio(&audio_device, &audio_config, 300).unwrap();
audio_ring_buffer.lock().unwrap().extend_from_slice(&data);
let max_len = data.len() * 16;
let data_len = data.len();
Expand Down

0 comments on commit 399ec6d

Please sign in to comment.