diff --git a/candle-examples/examples/musicgen/encodec_model.rs b/candle-examples/examples/musicgen/encodec_model.rs index bf33d49def..095c90a9c8 100644 --- a/candle-examples/examples/musicgen/encodec_model.rs +++ b/candle-examples/examples/musicgen/encodec_model.rs @@ -1,6 +1,6 @@ use crate::nn::conv1d_weight_norm; -use candle::{DType, IndexOp, Result, Tensor}; -use candle_nn::{conv1d, Conv1d, Conv1dConfig, Module, VarBuilder}; +use candle::{DType, IndexOp, Module, Result, Tensor}; +use candle_nn::{conv1d, Conv1d, Conv1dConfig, VarBuilder}; // Encodec Model // https://github.com/huggingface/transformers/blob/main/src/transformers/models/encodec/modeling_encodec.py @@ -199,25 +199,34 @@ impl EncodecResidualVectorQuantizer { // https://github.com/huggingface/transformers/blob/abaca9f9432a84cfaa95531de4c72334f38a42f2/src/transformers/models/encodec/modeling_encodec.py#L226 #[derive(Debug)] struct EncodecLSTM { - layers: Vec<(Tensor, Tensor, Tensor, Tensor)>, + layers: Vec, } impl EncodecLSTM { fn load(dim: usize, vb: VarBuilder, cfg: &Config) -> Result { let vb = &vb.pp("lstm"); let mut layers = vec![]; - for i in 0..cfg.num_lstm_layers { - let w_hh = vb.get((4 * dim, dim), &format!("weight_hh_l{i}"))?; - let w_ih = vb.get((4 * dim, dim), &format!("weight_ih_l{i}"))?; - let b_hh = vb.get(4 * dim, &format!("bias_hh_l{i}"))?; - let b_ih = vb.get(4 * dim, &format!("bias_ih_l{i}"))?; - layers.push((w_hh, w_ih, b_hh, b_ih)) + for layer_idx in 0..cfg.num_lstm_layers { + let config = candle_nn::LSTMConfig { + layer_idx, + ..Default::default() + }; + let lstm = candle_nn::lstm(dim, dim, config, vb.clone())?; + layers.push(lstm) } Ok(Self { layers }) } +} - fn forward(&self, _xs: &Tensor) -> Result { - todo!() +impl Module for EncodecLSTM { + fn forward(&self, xs: &Tensor) -> Result { + use candle_nn::RNN; + let mut xs = xs.clone(); + for layer in self.layers.iter() { + let states = layer.seq(&xs)?; + xs = layer.states_to_tensor(&states)?; + } + Ok(xs) } } @@ -247,7 +256,9 @@ impl EncodecConvTranspose1d { bias, }) } +} +impl Module for EncodecConvTranspose1d { fn forward(&self, _xs: &Tensor) -> Result { todo!() } @@ -299,7 +310,9 @@ impl EncodecConv1d { conv, }) } +} +impl Module for EncodecConv1d { fn forward(&self, xs: &Tensor) -> Result { // TODO: padding, depending on causal. let xs = self.conv.forward(xs)?; @@ -340,7 +353,9 @@ impl EncodecResnetBlock { shortcut, }) } +} +impl Module for EncodecResnetBlock { fn forward(&self, xs: &Tensor) -> Result { let residual = xs.clone(); let xs = xs.elu(1.)?; @@ -439,8 +454,17 @@ impl EncodecEncoder { }) } - fn forward(&self, _xs: &Tensor) -> Result { - todo!() + fn forward(&self, xs: &Tensor) -> Result { + let mut xs = xs.apply(&self.init_conv)?; + for (resnets, conv) in self.sampling_layers.iter() { + for resnet in resnets.iter() { + xs = xs.apply(resnet)?; + } + xs = xs.elu(1.0)?.apply(conv)?; + } + xs.apply(&self.final_lstm)? + .elu(1.0)? + .apply(&self.final_conv) } } @@ -507,8 +531,15 @@ impl EncodecDecoder { }) } - fn forward(&self, _xs: &Tensor) -> Result { - todo!() + fn forward(&self, xs: &Tensor) -> Result { + let mut xs = xs.apply(&self.init_conv)?.apply(&self.init_lstm)?; + for (conv, resnets) in self.sampling_layers.iter() { + xs = xs.elu(1.)?.apply(conv)?; + for resnet in resnets.iter() { + xs = xs.apply(resnet)? + } + } + xs.elu(1.)?.apply(&self.final_conv) } }