Skip to content

Commit

Permalink
Use candle_nn::LSTM in encodec. (#1051)
Browse files Browse the repository at this point in the history
* Use candle_nn::LSTM in encodec.

* More Encodec implementation.

* Decoder implementation.
  • Loading branch information
LaurentMazare authored Oct 7, 2023
1 parent a496760 commit d833527
Showing 1 changed file with 46 additions and 15 deletions.
61 changes: 46 additions & 15 deletions candle-examples/examples/musicgen/encodec_model.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::nn::conv1d_weight_norm;
use candle::{DType, IndexOp, Result, Tensor};
use candle_nn::{conv1d, Conv1d, Conv1dConfig, Module, VarBuilder};
use candle::{DType, IndexOp, Module, Result, Tensor};
use candle_nn::{conv1d, Conv1d, Conv1dConfig, VarBuilder};

// Encodec Model
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/encodec/modeling_encodec.py
Expand Down Expand Up @@ -199,25 +199,34 @@ impl EncodecResidualVectorQuantizer {
// https://github.com/huggingface/transformers/blob/abaca9f9432a84cfaa95531de4c72334f38a42f2/src/transformers/models/encodec/modeling_encodec.py#L226
#[derive(Debug)]
struct EncodecLSTM {
layers: Vec<(Tensor, Tensor, Tensor, Tensor)>,
layers: Vec<candle_nn::LSTM>,
}

impl EncodecLSTM {
fn load(dim: usize, vb: VarBuilder, cfg: &Config) -> Result<Self> {
let vb = &vb.pp("lstm");
let mut layers = vec![];
for i in 0..cfg.num_lstm_layers {
let w_hh = vb.get((4 * dim, dim), &format!("weight_hh_l{i}"))?;
let w_ih = vb.get((4 * dim, dim), &format!("weight_ih_l{i}"))?;
let b_hh = vb.get(4 * dim, &format!("bias_hh_l{i}"))?;
let b_ih = vb.get(4 * dim, &format!("bias_ih_l{i}"))?;
layers.push((w_hh, w_ih, b_hh, b_ih))
for layer_idx in 0..cfg.num_lstm_layers {
let config = candle_nn::LSTMConfig {
layer_idx,
..Default::default()
};
let lstm = candle_nn::lstm(dim, dim, config, vb.clone())?;
layers.push(lstm)
}
Ok(Self { layers })
}
}

fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
todo!()
impl Module for EncodecLSTM {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
use candle_nn::RNN;
let mut xs = xs.clone();
for layer in self.layers.iter() {
let states = layer.seq(&xs)?;
xs = layer.states_to_tensor(&states)?;
}
Ok(xs)
}
}

Expand Down Expand Up @@ -247,7 +256,9 @@ impl EncodecConvTranspose1d {
bias,
})
}
}

impl Module for EncodecConvTranspose1d {
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
todo!()
}
Expand Down Expand Up @@ -299,7 +310,9 @@ impl EncodecConv1d {
conv,
})
}
}

impl Module for EncodecConv1d {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
// TODO: padding, depending on causal.
let xs = self.conv.forward(xs)?;
Expand Down Expand Up @@ -340,7 +353,9 @@ impl EncodecResnetBlock {
shortcut,
})
}
}

impl Module for EncodecResnetBlock {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let residual = xs.clone();
let xs = xs.elu(1.)?;
Expand Down Expand Up @@ -439,8 +454,17 @@ impl EncodecEncoder {
})
}

fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
todo!()
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut xs = xs.apply(&self.init_conv)?;
for (resnets, conv) in self.sampling_layers.iter() {
for resnet in resnets.iter() {
xs = xs.apply(resnet)?;
}
xs = xs.elu(1.0)?.apply(conv)?;
}
xs.apply(&self.final_lstm)?
.elu(1.0)?
.apply(&self.final_conv)
}
}

Expand Down Expand Up @@ -507,8 +531,15 @@ impl EncodecDecoder {
})
}

fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
todo!()
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut xs = xs.apply(&self.init_conv)?.apply(&self.init_lstm)?;
for (conv, resnets) in self.sampling_layers.iter() {
xs = xs.elu(1.)?.apply(conv)?;
for resnet in resnets.iter() {
xs = xs.apply(resnet)?
}
}
xs.elu(1.)?.apply(&self.final_conv)
}
}

Expand Down

0 comments on commit d833527

Please sign in to comment.