diff --git a/models/vocoders/gan/discriminator/mpd.py b/models/vocoders/gan/discriminator/mpd.py index a2d7aa36..51e4e792 100644 --- a/models/vocoders/gan/discriminator/mpd.py +++ b/models/vocoders/gan/discriminator/mpd.py @@ -269,6 +269,7 @@ def forward(self, y, y_hat): return outputs + class DiscriminatorP_JETS(torch.nn.Module): def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): super(DiscriminatorP_JETS, self).__init__() @@ -347,6 +348,7 @@ def forward(self, x): return x, fmap + class MultiPeriodDiscriminator_JETS(torch.nn.Module): def __init__(self, use_spectral_norm=False): super(MultiPeriodDiscriminator_JETS, self).__init__() @@ -367,10 +369,11 @@ def forward(self, y): return y_d_rs, fmap_rs -# JETS Multi-scale Multi-period discriminator module. + +# JETS Multi-scale Multi-period discriminator module. class MultiScaleMultiPeriodDiscriminator(torch.nn.Module): """HiFi-GAN multi-scale + multi-period discriminator module.""" - + def __init__(self, use_spectral_norm=False): super(MultiScaleMultiPeriodDiscriminator, self).__init__() @@ -385,4 +388,4 @@ def forward(self, y): # mpd_outs = self.mpd(y, y_hat) return msd_outs_d_rs + mpd_outs_d_rs # ground_truth, generated - # return msd_outs + mpd_outs \ No newline at end of file + # return msd_outs + mpd_outs diff --git a/models/vocoders/gan/discriminator/msd.py b/models/vocoders/gan/discriminator/msd.py index fbac562f..0c13c27d 100644 --- a/models/vocoders/gan/discriminator/msd.py +++ b/models/vocoders/gan/discriminator/msd.py @@ -87,6 +87,7 @@ def forward(self, y, y_hat): return y_d_rs, y_d_gs, fmap_rs, fmap_gs + class MultiScaleDiscriminator_JETS(nn.Module): def __init__(self): super(MultiScaleDiscriminator_JETS, self).__init__() @@ -104,7 +105,7 @@ def __init__(self): ) def forward(self, y): - y_d_rs = [] # p, y, groud-truth + y_d_rs = [] # p, y, groud-truth fmap_rs = [] for i, d in enumerate(self.discriminators): @@ -114,6 +115,5 @@ def forward(self, y): y_d_rs.append(y_d_r) fmap_rs.append(fmap_r) - return y_d_rs, fmap_rs - # fmap_rs is real, fmap_gs is generated. + # fmap_rs is real, fmap_gs is generated.