diff --git a/candle-examples/examples/encodec/audio_io.rs b/candle-examples/examples/encodec/audio_io.rs index 2103dd4adf..fa1a26fbf7 100644 --- a/candle-examples/examples/encodec/audio_io.rs +++ b/candle-examples/examples/encodec/audio_io.rs @@ -1,4 +1,3 @@ -#![allow(unused)] use anyhow::{Context, Result}; use std::sync::{Arc, Mutex}; diff --git a/candle-examples/examples/mimi/audio_io.rs b/candle-examples/examples/mimi/audio_io.rs index 2103dd4adf..fa1a26fbf7 100644 --- a/candle-examples/examples/mimi/audio_io.rs +++ b/candle-examples/examples/mimi/audio_io.rs @@ -1,4 +1,3 @@ -#![allow(unused)] use anyhow::{Context, Result}; use std::sync::{Arc, Mutex}; diff --git a/candle-nn/src/rnn.rs b/candle-nn/src/rnn.rs index b4b443c6b8..798db6ac4d 100644 --- a/candle-nn/src/rnn.rs +++ b/candle-nn/src/rnn.rs @@ -116,7 +116,7 @@ impl LSTMConfig { /// A Long Short-Term Memory (LSTM) layer. /// /// -#[allow(clippy::upper_case_acronyms, unused)] +#[allow(clippy::upper_case_acronyms)] #[derive(Clone, Debug)] pub struct LSTM { w_ih: Tensor, @@ -129,6 +129,62 @@ pub struct LSTM { dtype: DType, } +impl LSTM { + /// Creates a LSTM layer. + pub fn new( + in_dim: usize, + hidden_dim: usize, + config: LSTMConfig, + vb: crate::VarBuilder, + ) -> Result { + let layer_idx = config.layer_idx; + let direction_str = match config.direction { + Direction::Forward => "", + Direction::Backward => "_reverse", + }; + let w_ih = vb.get_with_hints( + (4 * hidden_dim, in_dim), + &format!("weight_ih_l{layer_idx}{direction_str}"), // Only a single layer is supported. + config.w_ih_init, + )?; + let w_hh = vb.get_with_hints( + (4 * hidden_dim, hidden_dim), + &format!("weight_hh_l{layer_idx}{direction_str}"), // Only a single layer is supported. + config.w_hh_init, + )?; + let b_ih = match config.b_ih_init { + Some(init) => Some(vb.get_with_hints( + 4 * hidden_dim, + &format!("bias_ih_l{layer_idx}{direction_str}"), + init, + )?), + None => None, + }; + let b_hh = match config.b_hh_init { + Some(init) => Some(vb.get_with_hints( + 4 * hidden_dim, + &format!("bias_hh_l{layer_idx}{direction_str}"), + init, + )?), + None => None, + }; + Ok(Self { + w_ih, + w_hh, + b_ih, + b_hh, + hidden_dim, + config, + device: vb.device().clone(), + dtype: vb.dtype(), + }) + } + + pub fn config(&self) -> &LSTMConfig { + &self.config + } +} + /// Creates a LSTM layer. pub fn lstm( in_dim: usize, @@ -136,47 +192,7 @@ pub fn lstm( config: LSTMConfig, vb: crate::VarBuilder, ) -> Result { - let layer_idx = config.layer_idx; - let direction_str = match config.direction { - Direction::Forward => "", - Direction::Backward => "_reverse", - }; - let w_ih = vb.get_with_hints( - (4 * hidden_dim, in_dim), - &format!("weight_ih_l{layer_idx}{direction_str}"), // Only a single layer is supported. - config.w_ih_init, - )?; - let w_hh = vb.get_with_hints( - (4 * hidden_dim, hidden_dim), - &format!("weight_hh_l{layer_idx}{direction_str}"), // Only a single layer is supported. - config.w_hh_init, - )?; - let b_ih = match config.b_ih_init { - Some(init) => Some(vb.get_with_hints( - 4 * hidden_dim, - &format!("bias_ih_l{layer_idx}{direction_str}"), - init, - )?), - None => None, - }; - let b_hh = match config.b_hh_init { - Some(init) => Some(vb.get_with_hints( - 4 * hidden_dim, - &format!("bias_hh_l{layer_idx}{direction_str}"), - init, - )?), - None => None, - }; - Ok(LSTM { - w_ih, - w_hh, - b_ih, - b_hh, - hidden_dim, - config, - device: vb.device().clone(), - dtype: vb.dtype(), - }) + LSTM::new(in_dim, hidden_dim, config, vb) } impl RNN for LSTM { @@ -270,7 +286,7 @@ impl GRUConfig { /// A Gated Recurrent Unit (GRU) layer. /// /// -#[allow(clippy::upper_case_acronyms, unused)] +#[allow(clippy::upper_case_acronyms)] #[derive(Clone, Debug)] pub struct GRU { w_ih: Tensor, @@ -283,41 +299,56 @@ pub struct GRU { dtype: DType, } -/// Creates a GRU layer. +impl GRU { + /// Creates a GRU layer. + pub fn new( + in_dim: usize, + hidden_dim: usize, + config: GRUConfig, + vb: crate::VarBuilder, + ) -> Result { + let w_ih = vb.get_with_hints( + (3 * hidden_dim, in_dim), + "weight_ih_l0", // Only a single layer is supported. + config.w_ih_init, + )?; + let w_hh = vb.get_with_hints( + (3 * hidden_dim, hidden_dim), + "weight_hh_l0", // Only a single layer is supported. + config.w_hh_init, + )?; + let b_ih = match config.b_ih_init { + Some(init) => Some(vb.get_with_hints(3 * hidden_dim, "bias_ih_l0", init)?), + None => None, + }; + let b_hh = match config.b_hh_init { + Some(init) => Some(vb.get_with_hints(3 * hidden_dim, "bias_hh_l0", init)?), + None => None, + }; + Ok(Self { + w_ih, + w_hh, + b_ih, + b_hh, + hidden_dim, + config, + device: vb.device().clone(), + dtype: vb.dtype(), + }) + } + + pub fn config(&self) -> &GRUConfig { + &self.config + } +} + pub fn gru( in_dim: usize, hidden_dim: usize, config: GRUConfig, vb: crate::VarBuilder, ) -> Result { - let w_ih = vb.get_with_hints( - (3 * hidden_dim, in_dim), - "weight_ih_l0", // Only a single layer is supported. - config.w_ih_init, - )?; - let w_hh = vb.get_with_hints( - (3 * hidden_dim, hidden_dim), - "weight_hh_l0", // Only a single layer is supported. - config.w_hh_init, - )?; - let b_ih = match config.b_ih_init { - Some(init) => Some(vb.get_with_hints(3 * hidden_dim, "bias_ih_l0", init)?), - None => None, - }; - let b_hh = match config.b_hh_init { - Some(init) => Some(vb.get_with_hints(3 * hidden_dim, "bias_hh_l0", init)?), - None => None, - }; - Ok(GRU { - w_ih, - w_hh, - b_ih, - b_hh, - hidden_dim, - config, - device: vb.device().clone(), - dtype: vb.dtype(), - }) + GRU::new(in_dim, hidden_dim, config, vb) } impl RNN for GRU {