diff --git a/egs/wham/WaveSplit/README.md b/egs/wham/WaveSplit/README.md index 6995c7a9c..3e51a096a 100644 --- a/egs/wham/WaveSplit/README.md +++ b/egs/wham/WaveSplit/README.md @@ -1,17 +1,11 @@ ### WaveSplit -things currently not clear: ---- -- not clear if different encoders are used for separation and speaker stack. (from image in the paper it seems so) -- what is embedding dimension ? It seems 512 but it is not explicit in the paper -- mask used (sigmoid ?) -- when speakers in an example < sep stack outputs loss is simply masked or an embedding for silence is used ? (Probably masked) -- is VAD used in WSJ02MiX/ WHAM for determining speech activity at frame level ? Some files can have pauses of even one second -- loss right now is prone to go NaN especially if we don't take the mean after l2-distances computation. - ---- -structure: -- train.py contains training loop (nets instantiation lines 48-60, training loop lines 100- 116) -- losses.py wavesplit losses -- wavesplit.py sep and speaker stacks nets -- wavesplitwham.py dataset parsing \ No newline at end of file + we train on 1 sec now. + + tried with 256 embedding dimension. + + still does not work with oracle embeddings. + + not clear how in sep stack loss at every layer is computed ( is the same output layer used in all ?). + Also no mention in the paper about output layer and that first conv has no skip connection. + \ No newline at end of file diff --git a/egs/wham/WaveSplit/local/conf.yml b/egs/wham/WaveSplit/local/conf.yml index 090f0aa39..912366f48 100644 --- a/egs/wham/WaveSplit/local/conf.yml +++ b/egs/wham/WaveSplit/local/conf.yml @@ -23,7 +23,7 @@ training: num_workers: 4 half_lr: yes early_stop: yes - gradient_clipping: 5 + gradient_clipping: 5000 # Optim config optim: optimizer: adam @@ -38,4 +38,4 @@ data: nondefault_nsrc: sample_rate: 8000 mode: min - segment: 0.750 + segment: 1 \ No newline at end of file diff --git a/egs/wham/WaveSplit/local/preprocess_wham.py b/egs/wham/WaveSplit/local/preprocess_wham.py index 36a4e2bc9..69b8fd419 100644 --- a/egs/wham/WaveSplit/local/preprocess_wham.py +++ b/egs/wham/WaveSplit/local/preprocess_wham.py @@ -14,8 +14,8 @@ def preprocess_task(task, in_dir, out_dir): examples = [] for mix in mix_both: filename = mix.split("/")[-1] - spk1_id = filename.split("_")[0][:3] - spk2_id = filename.split("_")[2][:3] + spk1_id = filename.split("_")[0] + spk2_id = filename.split("_")[2] length = len(sf.SoundFile(mix)) noise = os.path.join(in_dir, "noise", filename) @@ -33,8 +33,8 @@ def preprocess_task(task, in_dir, out_dir): examples = [] for mix in mix_clean: filename = mix.split("/")[-1] - spk1_id = filename.split("_")[0][:3] - spk2_id = filename.split("_")[2][:3] + spk1_id = filename.split("_")[0] + spk2_id = filename.split("_")[2] length = len(sf.SoundFile(mix)) s1 = os.path.join(in_dir, "s1", filename) @@ -51,7 +51,7 @@ def preprocess_task(task, in_dir, out_dir): examples = [] for mix in mix_single: filename = mix.split("/")[-1] - spk1_id = filename.split("_")[0][:3] + spk1_id = filename.split("_")[0] length = len(sf.SoundFile(mix)) s1 = os.path.join(in_dir, "s1", filename) diff --git a/egs/wham/WaveSplit/losses.py b/egs/wham/WaveSplit/losses.py index 0f147d05e..9374f28cb 100644 --- a/egs/wham/WaveSplit/losses.py +++ b/egs/wham/WaveSplit/losses.py @@ -3,7 +3,8 @@ import numpy as np from torch.nn import functional as F from itertools import permutations -from asteroid.losses.sdr import MultiSrcNegSDR +from asteroid.losses.sdr import MultiSrcNegSDR, SingleSrcNegSDR +from asteroid.losses import PITLossWrapper, PairwiseNegSDR,pairwise_neg_sisdr import math @@ -12,7 +13,7 @@ class ClippedSDR(nn.Module): def __init__(self, clip_value=-30): super(ClippedSDR, self).__init__() - self.snr = MultiSrcNegSDR("snr") + self.snr = PITLossWrapper(pairwise_neg_sisdr) self.clip_value = float(clip_value) def forward(self, est_targets, targets): @@ -23,12 +24,9 @@ def forward(self, est_targets, targets): class SpeakerVectorLoss(nn.Module): def __init__(self, n_speakers, embed_dim=32, learnable_emb=True, loss_type="global", - weight=10, distance_reg=0.3, gaussian_reg=0.2, return_oracle=True): + weight=2, distance_reg=0.3, gaussian_reg=0.2, return_oracle=False): super(SpeakerVectorLoss, self).__init__() - - # not clear how embeddings are initialized. - self.learnable_emb = learnable_emb self.loss_type = loss_type self.weight = float(weight) @@ -38,36 +36,30 @@ def __init__(self, n_speakers, embed_dim=32, learnable_emb=True, loss_type="glob assert loss_type in ["distance", "global", "local"] - # I initialize embeddings to be on unit sphere as speaker stack uses euclidean normalization - - spk_emb = torch.rand((n_speakers, embed_dim)) - norms = torch.sum(spk_emb ** 2, -1, keepdim=True).sqrt() - spk_emb = spk_emb / norms # generate points on n-dimensional unit sphere + spk_emb = torch.eye(max(n_speakers, embed_dim)) # one-hot init works better according to Neil + spk_emb = spk_emb[:n_speakers, :embed_dim] if learnable_emb == True: self.spk_embeddings = nn.Parameter(spk_emb) else: self.register_buffer("spk_embeddings", spk_emb) - if loss_type != "dist": - self.alpha = nn.Parameter(torch.Tensor([1.])) # not clear how these are initialized... + if loss_type != "distance": + self.alpha = nn.Parameter(torch.Tensor([1.])) self.beta = nn.Parameter(torch.Tensor([0.])) - - ### losses go to NaN if I follow strictly the formulas maybe I am missing something... - @staticmethod def _l_dist_speaker(c_spk_vec_perm, spk_embeddings, spk_labels, spk_mask): utt_embeddings = spk_embeddings[spk_labels].unsqueeze(-1) * spk_mask.unsqueeze(2) c_spk = c_spk_vec_perm[:, 0] pair_dist = ((c_spk.unsqueeze(1) - c_spk_vec_perm)**2).sum(2) - pair_dist = pair_dist[:, 1:].sqrt() - distance = ((c_spk_vec_perm - utt_embeddings)**2).sum(2).sqrt() - return (distance + F.relu(1. - pair_dist).sum(1).unsqueeze(1)).sum(1) + pair_dist = pair_dist[:, 1:] + distance = ((c_spk_vec_perm - utt_embeddings)**2).sum(dim=(1,2)) + return distance + F.relu(1. - pair_dist).sum(dim=(1)) def _l_local_speaker(self, c_spk_vec_perm, spk_embeddings, spk_labels, spk_mask): - + raise NotImplemented utt_embeddings = spk_embeddings[spk_labels].unsqueeze(-1) * spk_mask.unsqueeze(2) alpha = torch.clamp(self.alpha, 1e-8) @@ -79,25 +71,23 @@ def _l_local_speaker(self, c_spk_vec_perm, spk_embeddings, spk_labels, spk_mask) return out.sum(1) def _l_global_speaker(self, c_spk_vec_perm, spk_embeddings, spk_labels, spk_mask): - + raise NotImplemented utt_embeddings = spk_embeddings[spk_labels].unsqueeze(-1) * spk_mask.unsqueeze(2) alpha = torch.clamp(self.alpha, 1e-8) - distance_utt = alpha*((c_spk_vec_perm - utt_embeddings)**2).sum(2).sqrt() + self.beta + distance_utt = alpha*((c_spk_vec_perm - utt_embeddings)**2).sum(2) + self.beta B, src, embed_dim, frames = c_spk_vec_perm.size() spk_embeddings = spk_embeddings.reshape(1, spk_embeddings.shape[0], embed_dim, 1).expand(B, -1, -1, frames) distances = alpha * ((c_spk_vec_perm.unsqueeze(1) - spk_embeddings.unsqueeze(2)) ** 2).sum(3).sqrt() + self.beta # exp normalize trick - with torch.no_grad(): - b = torch.max(distances, dim=1, keepdim=True)[0] - out = -distance_utt + b.squeeze(1) - torch.log(torch.exp(-distances + b).sum(1)) - return out.sum(1) + #with torch.no_grad(): + # b = torch.max(distances, dim=1, keepdim=True)[0] + #out = -distance_utt + b.squeeze(1) - torch.log(torch.exp(-distances + b).sum(1)) + #return out.sum(1) - def forward(self, speaker_vectors, spk_mask, spk_labels): - # spk_mask ideally would be the speaker activty at frame level. Because WHAM speakers can be considered always two and active we fix this for now. - # mask with ones and zeros B, SRC, FRAMES + def forward(self, speaker_vectors, spk_mask, spk_labels): if self.gaussian_reg: noise = torch.randn(self.spk_embeddings.size(), device=speaker_vectors.device)*math.sqrt(self.gaussian_reg) @@ -105,16 +95,13 @@ def forward(self, speaker_vectors, spk_mask, spk_labels): else: spk_embeddings = self.spk_embeddings - if self.learnable_emb or self.gaussian_reg: # re project on unit sphere after noise has been applied and before computing the distance reg + if self.learnable_emb or self.gaussian_reg: # re project on unit sphere spk_embeddings = spk_embeddings / torch.sum(spk_embeddings ** 2, -1, keepdim=True).sqrt() if self.distance_reg: - pairwise_dist = ((spk_embeddings.unsqueeze(0) - spk_embeddings.unsqueeze(1))**2).sum(-1) - idx = torch.arange(0, pairwise_dist.shape[0]) - pairwise_dist[idx, idx] = np.inf # masking with itself - pairwise_dist = pairwise_dist.sqrt() + pairwise_dist = (torch.abs(spk_embeddings.unsqueeze(0) - spk_embeddings.unsqueeze(1))).mean(-1).fill_diagonal_(np.inf) distance_reg = -torch.sum(torch.min(torch.log(pairwise_dist), dim=-1)[0]) # speaker vectors B, n_src, dim, frames @@ -145,10 +132,8 @@ def forward(self, speaker_vectors, spk_mask, spk_labels): min_loss_perm = min_loss_perm.transpose(0, 1).reshape(B, n_src, 1, frames).expand(-1, -1, embed_dim, -1) # tot_loss - spk_loss = self.weight*min_loss.mean() if self.distance_reg: - spk_loss += self.distance_reg*distance_reg reordered_sources = torch.gather(speaker_vectors, dim=1, index=min_loss_perm) @@ -160,23 +145,24 @@ def forward(self, speaker_vectors, spk_mask, spk_labels): if __name__ == "__main__": + n_speakers = 101 + emb_speaker = 256 # testing exp normalize average - distances = torch.ones((1, 101, 4000))*99 - with torch.no_grad(): - b = torch.max(distances, dim=1, keepdim=True)[0] - out = b.squeeze(1) - torch.log(torch.exp(-distances + b).sum(1)) - out2 = - torch.log(torch.exp(-distances).sum(1)) + #distances = torch.ones((1, 101, 4000)) + #with torch.no_grad(): + # b = torch.max(distances, dim=1, keepdim=True)[0] + #out = b.squeeze(1) - torch.log(torch.exp(-distances + b).sum(1)) + #out2 = - torch.log(torch.exp(-distances).sum(1)) - loss_spk = SpeakerVectorLoss(1000, 32, loss_type="distance") # 1000 speakers in training set + loss_spk = SpeakerVectorLoss(n_speakers, emb_speaker, loss_type="global") - speaker_vectors = torch.rand(2, 3, 32, 200) + speaker_vectors = torch.rand(2, 3, emb_speaker, 200) speaker_labels = torch.from_numpy(np.array([[1, 2, 0], [5, 2, 10]])) speaker_mask = torch.randint(0, 2, (2, 3, 200)) # silence where there are no speakers actually thi is test speaker_mask[:, -1, :] = speaker_mask[:, -1, :]*0 loss_spk(speaker_vectors, speaker_mask, speaker_labels) - c = ClippedSDR(-30) a = torch.rand((2, 3, 200)) print(c(a, a)) diff --git a/egs/wham/WaveSplit/run.sh b/egs/wham/WaveSplit/run.sh index b068011ed..74483eb14 100755 --- a/egs/wham/WaveSplit/run.sh +++ b/egs/wham/WaveSplit/run.sh @@ -42,10 +42,8 @@ mode=min nondefault_src= # If you want to train a network with 3 output streams for example. # Training -batch_size=1 -num_workers=8 -kernel_size=16 -stride=8 +batch_size=4 +num_workers=4 #optimizer=adam lr=0.001 epochs=400 @@ -134,8 +132,6 @@ if [[ $stage -le 3 ]]; then --epochs $epochs \ --batch_size $batch_size \ --num_workers $num_workers \ - --kernel_size $kernel_size \ - --stride $stride \ --exp_dir ${expdir}/ | tee logs/train_${tag}.log fi diff --git a/egs/wham/WaveSplit/train.py b/egs/wham/WaveSplit/train.py index 30cc2cf46..41cb645b7 100644 --- a/egs/wham/WaveSplit/train.py +++ b/egs/wham/WaveSplit/train.py @@ -14,6 +14,7 @@ from asteroid.engine.system import System from asteroid.losses import PITLossWrapper, pairwise_neg_sisdr + from losses import SpeakerVectorLoss, ClippedSDR from wavesplit import SpeakerStack, SeparationStack from asteroid.filterbanks import make_enc_dec @@ -39,24 +40,19 @@ class Wavesplit(pl.LightningModule): # redefinition - def __init__(self, train_loader, + def __init__(self, spk_stack, sep_stack, optimizer, spk_loss, sep_loss, train_loader, val_loader=None, scheduler=None, config=None): super().__init__() - # instantiation of stacks optimizers etc - # NOTE: I use separated encoders for speaker and sep stack as it is not specified in the paper... - - self.enc_spk, self.dec = make_enc_dec("free", 512, 16, 8) - self.enc_sep = deepcopy(self.enc_spk) - self.spk_stack = SpeakerStack(512, 2, 14, 1) - self.sep_stack = SeparationStack(512, 8, 3, mask_act="sigmoid") # original is 10 4 but I have not enought GPU mem + #self.spk_stack = SpeakerStack(256, 2, 1, 1) - self.spk_loss = SpeakerVectorLoss(101, 512, True, "global") # 512 spk embedding - self.sep_loss = ClippedSDR(-30) - params = list(self.enc_spk.parameters()) + list(self.enc_sep.parameters()) + list(self.dec.parameters()) + \ - list(self.spk_stack.parameters()) + list(self.sep_stack.parameters()) + list(self.spk_loss.parameters()) - self.optimizer = torch.optim.Adam(params, lr=0.002) # optimizer i think also is not specified i use adam + #self.spk_loss = SpeakerVectorLoss(101, 256, False, "distance", 10) + self.spk_stack = spk_stack + self.sep_stack = sep_stack + self.optimizer = optimizer + self.sep_loss = sep_loss + self.spk_loss = spk_loss self.train_loader = train_loader self.val_loader = val_loader @@ -69,6 +65,13 @@ def __init__(self, train_loader, # See https://github.com/pytorch/pytorch/issues/33140 self.hparams = Namespace(**self.none_to_string(flatten_dict(config))) + #n_speakers = 100 + #embed_dim = 128 + #spk_emb = torch.eye(max(n_speakers, embed_dim)) # one-hot init works better according to Neil + #spk_emb = spk_emb[:n_speakers, :embed_dim] + + #self.oracle = spk_emb.cuda() + def forward(self, *args, **kwargs): """ Applies forward pass of the model. @@ -98,27 +101,26 @@ def common_step(self, batch, batch_nb): `training_step` and `validation_step` instead. """ inputs, targets, spk_ids = batch - tf_rep = self.enc_spk(inputs) - spk_vectors = self.spk_stack(tf_rep) - B, src, embed, frames = spk_vectors.size() - - # torch.ones ideally would be the speaker activty at frame level. Because WHAM speakers can be considered always two and active we fix this for now. - spk_loss, spk_vectors, oracle = self.spk_loss(spk_vectors, torch.ones((B, src, frames)).to(spk_vectors.device), spk_ids) - tf_rep = self.enc_sep(inputs) - B, n_filters, frames = tf_rep.size() - tf_rep = tf_rep[:, None, ...].expand(-1, src, -1, -1).reshape(B*src, n_filters, frames) - masks = self.sep_stack(tf_rep, spk_vectors.reshape(B*src, embed, frames)) + spk_embed = self.spk_stack(inputs) - masked = tf_rep*masks - masked = masked.reshape(B, src, n_filters, frames) + spk_loss, reordered_embed = self.spk_loss(spk_embed, torch.ones((spk_embed.shape[0], + spk_embed.shape[1],spk_embed.shape[-1])).to(spk_embed.device), spk_ids) + reordered_embed = reordered_embed.mean(-1) - masked = self.pad_output_to_inp(self.dec(masked), inputs) - - sep_loss = self.sep_loss(masked, targets).mean() + #reordered_embed = self.oracle[spk_ids] + b, n_spk, spk_vec_size = reordered_embed.size() + separated = self.sep_stack(inputs, torch.cat((reordered_embed[:, 0], reordered_embed[:, 1]), 1)) + sep_loss = 0 + for o in separated: + o = self.pad_output_to_inp(o, inputs) + last = self.sep_loss(o, targets).mean() + sep_loss += last + spk_loss = sep_loss loss = sep_loss + spk_loss - return loss, spk_loss, sep_loss + + return loss, spk_loss, last @staticmethod def pad_output_to_inp(output, inp): @@ -259,7 +261,6 @@ def none_to_string(dic): dic[k] = str(v) return dic - def main(conf): train_set = WaveSplitWhamDataset(conf['data']['train_dir'], conf['data']['task'], @@ -273,18 +274,27 @@ def main(conf): batch_size=conf['training']['batch_size'], num_workers=conf['training']['num_workers'], drop_last=True) - val_loader = DataLoader(val_set, shuffle=True, + val_loader = DataLoader(val_set, shuffle=False, batch_size=conf['training']['batch_size'], num_workers=conf['training']['num_workers'], drop_last=True) # Update number of source values (It depends on the task) conf['masknet'].update({'n_src': train_set.n_src}) - + spk_stack = SpeakerStack(2, 256) # inner dim is 256 instead of 512 from paper to spare mem 13 layers as in the paper. + sep_stack = SeparationStack(2, 256, 512, 10, 1) # 40 layers. # Define model and optimizer in a local function (defined in the recipe). # Two advantages to this : re-instantiating the model and optimizer # for retraining and evaluating is straight-forward. # Define scheduler - + spk_loss = SpeakerVectorLoss(100, 256, loss_type="distance") # 100 speakers in WHAM dev and train, 256 embed dim + sep_loss = ClippedSDR(-30) + + params = list(spk_stack.parameters()) + list(sep_stack.parameters()) + list(spk_loss.parameters()) + optimizer = torch.optim.Adam(params, lr=0.003) + scheduler = None + if conf['training']['half_lr']: + scheduler = ReduceLROnPlateau(optimizer=optimizer, factor=0.5, + patience=5) # Just after instantiating, save the args. Easy loading in the future. exp_dir = conf['main_args']['exp_dir'] os.makedirs(exp_dir, exist_ok=True) @@ -292,8 +302,7 @@ def main(conf): with open(conf_path, 'w') as outfile: yaml.safe_dump(conf, outfile) - - system = Wavesplit(train_loader, val_loader, conf) + system = Wavesplit(spk_stack, sep_stack, optimizer, spk_loss, sep_loss, train_loader, val_loader, scheduler, conf) # Define callbacks checkpoint_dir = os.path.join(exp_dir, 'checkpoints/') checkpoint = ModelCheckpoint(checkpoint_dir, monitor='val_loss', @@ -312,8 +321,7 @@ def main(conf): early_stop_callback=early_stopping, default_save_path=exp_dir, gpus=conf['main_args']['gpus'], - distributed_backend='dp', - gradient_clip_val=conf['training']["gradient_clipping"]) + ) trainer.fit(system) with open(os.path.join(exp_dir, "best_k_models.json"), "w") as f: diff --git a/egs/wham/WaveSplit/wavesplit.py b/egs/wham/WaveSplit/wavesplit.py index 4a31e07e2..9ca172dc7 100644 --- a/egs/wham/WaveSplit/wavesplit.py +++ b/egs/wham/WaveSplit/wavesplit.py @@ -7,49 +7,47 @@ class Conv1DBlock(nn.Module): - def __init__(self, hid_chan, kernel_size, padding, + def __init__(self, in_chan, hid_chan, kernel_size, padding, dilation, norm_type="gLN"): super(Conv1DBlock, self).__init__() conv_norm = norms.get(norm_type) - depth_conv1d = nn.Conv1d(hid_chan, hid_chan, kernel_size, - padding=padding, dilation=dilation, - groups=hid_chan) + depth_conv1d = nn.Conv1d(in_chan, hid_chan, kernel_size, + padding=padding, dilation=dilation) self.out = nn.Sequential(depth_conv1d, nn.PReLU(), conv_norm(hid_chan)) - def forward(self, x): """ Input shape [batch, feats, seq]""" return self.out(x) + class SepConv1DBlock(nn.Module): - def __init__(self, in_chan_spk_vec, hid_chan, kernel_size, padding, + def __init__(self, in_chan, hid_chan, spk_vec_chan, kernel_size, padding, dilation, norm_type="gLN", use_FiLM=True): super(SepConv1DBlock, self).__init__() self.use_FiLM = use_FiLM conv_norm = norms.get(norm_type) - self.depth_conv1d = nn.Conv1d(hid_chan, hid_chan, kernel_size, - padding=padding, dilation=dilation, - groups=hid_chan) + self.depth_conv1d = nn.Conv1d(in_chan, hid_chan, kernel_size, + padding=padding, dilation=dilation) self.out = nn.Sequential(nn.PReLU(), conv_norm(hid_chan)) # FiLM conditioning if self.use_FiLM: - self.mul_lin = nn.Linear(in_chan_spk_vec, hid_chan) - self.add_lin = nn.Linear(in_chan_spk_vec, hid_chan) + self.mul_lin = nn.Linear(spk_vec_chan, hid_chan) + self.add_lin = nn.Linear(spk_vec_chan, hid_chan) def apply_conditioning(self, spk_vec, squeezed): - bias = self.add_lin(spk_vec.transpose(1, -1)).transpose(1, -1) + bias = self.add_lin(spk_vec) if self.use_FiLM: - mul = self.mul_lin(spk_vec.transpose(1, -1)).transpose(1, -1) - return mul*squeezed + bias + mul = self.mul_lin(spk_vec) + return mul.unsqueeze(-1)*squeezed + bias.unsqueeze(-1) else: - return squeezed + bias + return squeezed + bias.unsqueeze(-1) def forward(self, x, spk_vec): """ Input shape [batch, feats, seq]""" @@ -62,31 +60,32 @@ def forward(self, x, spk_vec): class SpeakerStack(nn.Module): # basically this is plain conv-tasnet remove this in future releases - def __init__(self, in_chan, n_src, n_blocks=14, n_repeats=1, + def __init__(self, n_src, embed_dim, n_blocks=14, n_repeats=1, kernel_size=3, norm_type="gLN"): super(SpeakerStack, self).__init__() - self.in_chan = in_chan + self.embed_dim = embed_dim self.n_src = n_src self.n_blocks = n_blocks self.n_repeats = n_repeats self.kernel_size = kernel_size self.norm_type = norm_type - #layer_norm = norms.get(norm_type)(in_chan) - #bottleneck_conv = nn.Conv1d(in_chan, bn_chan, 1) - #self.bottleneck = nn.Sequential(layer_norm, bottleneck_conv) # Succession of Conv1DBlock with exponentially increasing dilation. self.TCN = nn.ModuleList() for r in range(n_repeats): for x in range(n_blocks): padding = (kernel_size - 1) * 2 ** x // 2 - self.TCN.append(Conv1DBlock(in_chan, #TODO ask if also skip connections are used (probably not) + if r == 0 and x == 0: + in_chan = 1 + else: + in_chan = embed_dim + self.TCN.append(Conv1DBlock(in_chan, embed_dim, kernel_size, padding=padding, dilation=2 ** x, norm_type=norm_type)) - mask_conv = nn.Conv1d(in_chan, n_src * in_chan, 1) - self.mask_net = nn.Sequential(nn.PReLU(), mask_conv) + mask_conv = nn.Conv1d(embed_dim, n_src * embed_dim, 1) + self.mask_net = nn.Sequential(mask_conv) def forward(self, mixture_w): """ @@ -98,42 +97,35 @@ def forward(self, mixture_w): :class:`torch.Tensor`: estimated mask of shape [batch, n_src, n_filters, n_frames] """ - batch, n_filters, n_frames = mixture_w.size() - output = mixture_w - #output = self.bottleneck(mixture_w) + batch, n_frames = mixture_w.size() + output = mixture_w.unsqueeze(1) for i in range(len(self.TCN)): - residual = self.TCN[i](output) - output = output + residual + if i == 0: + output = self.TCN[i](output) + else: + residual = self.TCN[i](output) + output = output + residual emb = self.mask_net(output) - emb = emb.view(batch, self.n_src, self.in_chan, n_frames) + + emb = emb.view(batch, self.n_src, self.embed_dim, n_frames) emb = emb / torch.sqrt(torch.sum(emb**2, 2, keepdim=True)) return emb - def get_config(self): - config = { - 'in_chan': self.in_chan, - 'kernel_size': self.kernel_size, - 'n_blocks': self.n_blocks, - 'n_repeats': self.n_repeats, - 'n_src': self.n_src, - 'norm_type': self.norm_type, - } - return config - class SeparationStack(nn.Module): # basically this is plain conv-tasnet remove this in future releases - def __init__(self, in_chan, n_blocks=10, n_repeats=4, + def __init__(self, src, embed_dim=256, spk_vec_dim=512, n_blocks=10, n_repeats=4, kernel_size=3, - norm_type="gLN", mask_act="sigmoid"): + norm_type="gLN"): super(SeparationStack, self).__init__() - self.in_chan = in_chan self.n_blocks = n_blocks self.n_repeats = n_repeats self.kernel_size = kernel_size self.norm_type = norm_type + self.src = src + self.embed_dim = embed_dim # layer_norm = norms.get(norm_type)(in_chan) # bottleneck_conv = nn.Conv1d(in_chan, bn_chan, 1) @@ -142,19 +134,15 @@ def __init__(self, in_chan, n_blocks=10, n_repeats=4, self.TCN = nn.ModuleList() for r in range(n_repeats): for x in range(n_blocks): + if r == 0 and x == 0: + in_chan = 1 + else: + in_chan = embed_dim padding = (kernel_size - 1) * 2 ** x // 2 - self.TCN.append(SepConv1DBlock(in_chan, in_chan, # TODO ask if also skip connections are used (probably not) + self.TCN.append(SepConv1DBlock(in_chan, embed_dim, spk_vec_dim, kernel_size, padding=padding, dilation=2 ** x, norm_type=norm_type)) - mask_conv = nn.Conv1d(in_chan, in_chan, 1) - self.mask_net = nn.Sequential(nn.PReLU(), mask_conv) - - # Get activation function. - mask_nl_class = activations.get(mask_act) - if has_arg(mask_nl_class, 'dim'): - self.output_act = mask_nl_class(dim=1) - else: - self.output_act = mask_nl_class() + self.out = nn.Conv1d(embed_dim, 2, 1) def forward(self, mixture_w, spk_vectors): """ @@ -166,25 +154,25 @@ def forward(self, mixture_w, spk_vectors): :class:`torch.Tensor`: estimated mask of shape [batch, n_src, n_filters, n_frames] """ - output = mixture_w + output = mixture_w.unsqueeze(1) + outputs = [] # output = self.bottleneck(mixture_w) for i in range(len(self.TCN)): - residual = self.TCN[i](output, spk_vectors) - output = output + residual - mask = self.mask_net(output) - return self.output_act(mask) - - def get_config(self): - config = { - 'in_chan': self.in_chan, - 'kernel_size': self.kernel_size, - 'n_blocks': self.n_blocks, - 'n_repeats': self.n_repeats, - 'norm_type': self.norm_type, - } - return config - - - + if i == 0: + output = self.TCN[i](output, spk_vectors) + outputs.append(output) + else: + residual = self.TCN[i](output, spk_vectors) + output = output + residual + outputs.append(output) + + return [self.out(o) for o in outputs] + + +if __name__ == "__main__": + sep = SeparationStack(2, 256, 512, 10, 3, kernel_size=3) + wave = torch.rand((2, 16000)) + spk_vectors = torch.rand((2, 2, 256)) + out = sep(wave, spk_vectors.reshape(2, 2*256)) diff --git a/egs/wham/WaveSplit/wavesplitwham.py b/egs/wham/WaveSplit/wavesplitwham.py index e5985654d..d4db3a9ce 100644 --- a/egs/wham/WaveSplit/wavesplitwham.py +++ b/egs/wham/WaveSplit/wavesplitwham.py @@ -97,8 +97,7 @@ def __init__(self, json_dir, task, sample_rate=8000, segment=4.0, speakers = set() for ex in self.examples: for spk in ex["spk_id"]: - speakers.add(spk) - + speakers.add(spk[:3]) print("Total number of speakers {}".format(len(list(speakers)))) @@ -112,10 +111,9 @@ def __init__(self, json_dir, task, sample_rate=8000, segment=4.0, for ex in self.examples: new = [] for spk in ex["spk_id"]: - new.append(spk2indx[spk]) + new.append(spk2indx[spk[:3]]) ex["spk_id"] = new - def __len__(self): return len(self.examples)