Skip to content

Commit

Permalink
added more comments
Browse files Browse the repository at this point in the history
  • Loading branch information
popcornell committed Apr 7, 2020
1 parent c9dd0ed commit 8d1ca2e
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 0 deletions.
3 changes: 3 additions & 0 deletions egs/wham/WaveSplit/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ def _l_global_speaker(self, c_spk_vec_perm, spk_embeddings, spk_labels, spk_mask

def forward(self, speaker_vectors, spk_mask, spk_labels):

# spk_mask ideally would be the speaker activty at frame level. Because WHAM speakers can be considered always two and active we fix this for now.
# mask with ones and zeros B, SRC, FRAMES

if self.gaussian_reg:
noise = torch.randn(self.spk_embeddings.size(), device=speaker_vectors.device)*math.sqrt(self.gaussian_reg)
spk_embeddings = self.spk_embeddings + noise
Expand Down
2 changes: 2 additions & 0 deletions egs/wham/WaveSplit/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ def common_step(self, batch, batch_nb):
tf_rep = self.enc_spk(inputs)
spk_vectors = self.spk_stack(tf_rep)
B, src, embed, frames = spk_vectors.size()

# torch.ones ideally would be the speaker activty at frame level. Because WHAM speakers can be considered always two and active we fix this for now.
spk_loss, spk_vectors, oracle = self.spk_loss(spk_vectors, torch.ones((B, src, frames)).to(spk_vectors.device), spk_ids)
tf_rep = self.enc_sep(inputs)
B, n_filters, frames = tf_rep.size()
Expand Down

0 comments on commit 8d1ca2e

Please sign in to comment.