diff --git a/egs/wham/WaveSplit/README.md b/egs/wham/WaveSplit/README.md index 6995c7a9c..2e26c2f03 100644 --- a/egs/wham/WaveSplit/README.md +++ b/egs/wham/WaveSplit/README.md @@ -1,7 +1,7 @@ ### WaveSplit -things currently not clear: ---- +#### Currently not clear: + - not clear if different encoders are used for separation and speaker stack. (from image in the paper it seems so) - what is embedding dimension ? It seems 512 but it is not explicit in the paper - mask used (sigmoid ?) @@ -10,8 +10,10 @@ things currently not clear: - loss right now is prone to go NaN especially if we don't take the mean after l2-distances computation. --- -structure: -- train.py contains training loop (nets instantiation lines 48-60, training loop lines 100- 116) +#### Structure: +- train.py contains training loop (nets instantiation +[lines 48-60](https://github.com/mpariente/asteroid/pull/70/files#diff-f69bcb61820a4a7cfc8fda9a554c251cR49), training loop lines +[100- 116](https://github.com/mpariente/asteroid/pull/70/files#diff-f69bcb61820a4a7cfc8fda9a554c251cR100)) - losses.py wavesplit losses - wavesplit.py sep and speaker stacks nets - wavesplitwham.py dataset parsing \ No newline at end of file diff --git a/egs/wham/WaveSplit/losses.py b/egs/wham/WaveSplit/losses.py index 0f147d05e..2a5a9a4a0 100644 --- a/egs/wham/WaveSplit/losses.py +++ b/egs/wham/WaveSplit/losses.py @@ -8,7 +8,6 @@ class ClippedSDR(nn.Module): - def __init__(self, clip_value=-30): super(ClippedSDR, self).__init__() @@ -16,17 +15,14 @@ def __init__(self, clip_value=-30): self.clip_value = float(clip_value) def forward(self, est_targets, targets): - return torch.clamp(self.snr(est_targets, targets), min=self.clip_value) class SpeakerVectorLoss(nn.Module): - def __init__(self, n_speakers, embed_dim=32, learnable_emb=True, loss_type="global", weight=10, distance_reg=0.3, gaussian_reg=0.2, return_oracle=True): super(SpeakerVectorLoss, self).__init__() - # not clear how embeddings are initialized. self.learnable_emb = learnable_emb @@ -35,7 +31,6 @@ def __init__(self, n_speakers, embed_dim=32, learnable_emb=True, loss_type="glob self.distance_reg = float(distance_reg) self.gaussian_reg = float(gaussian_reg) self.return_oracle = return_oracle - assert loss_type in ["distance", "global", "local"] # I initialize embeddings to be on unit sphere as speaker stack uses euclidean normalization @@ -53,7 +48,6 @@ def __init__(self, n_speakers, embed_dim=32, learnable_emb=True, loss_type="glob self.alpha = nn.Parameter(torch.Tensor([1.])) # not clear how these are initialized... self.beta = nn.Parameter(torch.Tensor([0.])) - ### losses go to NaN if I follow strictly the formulas maybe I am missing something... @staticmethod @@ -96,7 +90,9 @@ def _l_global_speaker(self, c_spk_vec_perm, spk_embeddings, spk_labels, spk_mask def forward(self, speaker_vectors, spk_mask, spk_labels): - # spk_mask ideally would be the speaker activty at frame level. Because WHAM speakers can be considered always two and active we fix this for now. + # spk_mask ideally would be the speaker activty at frame level. + # Because WHAM speakers can be considered always two and active we + # fix this for now. # mask with ones and zeros B, SRC, FRAMES if self.gaussian_reg: @@ -180,19 +176,3 @@ def forward(self, speaker_vectors, spk_mask, spk_labels): c = ClippedSDR(-30) a = torch.rand((2, 3, 200)) print(c(a, a)) - - - - - - - - - - - - - - - - diff --git a/egs/wham/WaveSplit/train.py b/egs/wham/WaveSplit/train.py index 30cc2cf46..163ae6a1a 100644 --- a/egs/wham/WaveSplit/train.py +++ b/egs/wham/WaveSplit/train.py @@ -34,17 +34,14 @@ parser.add_argument('--exp_dir', default='exp/tmp', help='Full path to save best validation model') -warnings.simplefilter("ignore", UserWarning) - class Wavesplit(pl.LightningModule): # redefinition - def __init__(self, train_loader, val_loader=None, scheduler=None, config=None): super().__init__() - # instantiation of stacks optimizers etc - # NOTE: I use separated encoders for speaker and sep stack as it is not specified in the paper... + # NOTE: I use separated encoders for speaker and sep stack + # as it is not specified in the paper... self.enc_spk, self.dec = make_enc_dec("free", 512, 16, 8) self.enc_sep = deepcopy(self.enc_spk) @@ -75,7 +72,6 @@ def forward(self, *args, **kwargs): Returns: :class:`torch.Tensor` """ - return self.model(*args, **kwargs) def common_step(self, batch, batch_nb): @@ -102,8 +98,13 @@ def common_step(self, batch, batch_nb): spk_vectors = self.spk_stack(tf_rep) B, src, embed, frames = spk_vectors.size() - # torch.ones ideally would be the speaker activty at frame level. Because WHAM speakers can be considered always two and active we fix this for now. - spk_loss, spk_vectors, oracle = self.spk_loss(spk_vectors, torch.ones((B, src, frames)).to(spk_vectors.device), spk_ids) + # torch.ones ideally would be the speaker activty at frame level. + # Because WHAM speakers can be considered always two and active we + # fix this for now. + spk_loss, spk_vectors, oracle = self.spk_loss( + spk_vectors, torch.ones((B, src, frames)).to(spk_vectors.device), + spk_ids + ) tf_rep = self.enc_sep(inputs) B, n_filters, frames = tf_rep.size() tf_rep = tf_rep[:, None, ...].expand(-1, src, -1, -1).reshape(B*src, n_filters, frames) @@ -218,32 +219,11 @@ def train_dataloader(self): def val_dataloader(self): return self.val_loader - @pl.data_loader - def tng_dataloader(self): # pragma: no cover - """ Deprecated.""" - pass - def on_save_checkpoint(self, checkpoint): """ Overwrite if you want to save more things in the checkpoint.""" checkpoint['training_config'] = self.config return checkpoint - def on_batch_start(self, batch): - """ Overwrite if needed. Called by pytorch-lightning""" - pass - - def on_batch_end(self): - """ Overwrite if needed. Called by pytorch-lightning""" - pass - - def on_epoch_start(self): - """ Overwrite if needed. Called by pytorch-lightning""" - pass - - def on_epoch_end(self): - """ Overwrite if needed. Called by pytorch-lightning""" - pass - @staticmethod def none_to_string(dic): """ Converts `None` to ``'None'`` to be handled by torch summary writer. diff --git a/egs/wham/WaveSplit/wavesplit.py b/egs/wham/WaveSplit/wavesplit.py index 4a31e07e2..d2708b8e1 100644 --- a/egs/wham/WaveSplit/wavesplit.py +++ b/egs/wham/WaveSplit/wavesplit.py @@ -6,7 +6,6 @@ class Conv1DBlock(nn.Module): - def __init__(self, hid_chan, kernel_size, padding, dilation, norm_type="gLN"): super(Conv1DBlock, self).__init__() @@ -24,12 +23,11 @@ def forward(self, x): return self.out(x) -class SepConv1DBlock(nn.Module): +class SepConv1DBlock(nn.Module): def __init__(self, in_chan_spk_vec, hid_chan, kernel_size, padding, dilation, norm_type="gLN", use_FiLM=True): super(SepConv1DBlock, self).__init__() - self.use_FiLM = use_FiLM conv_norm = norms.get(norm_type) self.depth_conv1d = nn.Conv1d(hid_chan, hid_chan, kernel_size, @@ -61,7 +59,6 @@ def forward(self, x, spk_vec): class SpeakerStack(nn.Module): # basically this is plain conv-tasnet remove this in future releases - def __init__(self, in_chan, n_src, n_blocks=14, n_repeats=1, kernel_size=3, norm_type="gLN"): @@ -183,8 +180,3 @@ def get_config(self): 'norm_type': self.norm_type, } return config - - - - -