Skip to content

Commit

Permalink
msd + mpd black format
Browse files Browse the repository at this point in the history
  • Loading branch information
hansheng-zhang committed Jul 4, 2024
1 parent 0399414 commit 1f8c125
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
9 changes: 6 additions & 3 deletions models/vocoders/gan/discriminator/mpd.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ def forward(self, y, y_hat):

return outputs


class DiscriminatorP_JETS(torch.nn.Module):
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
super(DiscriminatorP_JETS, self).__init__()
Expand Down Expand Up @@ -347,6 +348,7 @@ def forward(self, x):

return x, fmap


class MultiPeriodDiscriminator_JETS(torch.nn.Module):
def __init__(self, use_spectral_norm=False):
super(MultiPeriodDiscriminator_JETS, self).__init__()
Expand All @@ -367,10 +369,11 @@ def forward(self, y):

return y_d_rs, fmap_rs

# JETS Multi-scale Multi-period discriminator module.

# JETS Multi-scale Multi-period discriminator module.
class MultiScaleMultiPeriodDiscriminator(torch.nn.Module):
"""HiFi-GAN multi-scale + multi-period discriminator module."""

def __init__(self, use_spectral_norm=False):
super(MultiScaleMultiPeriodDiscriminator, self).__init__()

Expand All @@ -385,4 +388,4 @@ def forward(self, y):
# mpd_outs = self.mpd(y, y_hat)
return msd_outs_d_rs + mpd_outs_d_rs
# ground_truth, generated
# return msd_outs + mpd_outs
# return msd_outs + mpd_outs
6 changes: 3 additions & 3 deletions models/vocoders/gan/discriminator/msd.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def forward(self, y, y_hat):

return y_d_rs, y_d_gs, fmap_rs, fmap_gs


class MultiScaleDiscriminator_JETS(nn.Module):
def __init__(self):
super(MultiScaleDiscriminator_JETS, self).__init__()
Expand All @@ -104,7 +105,7 @@ def __init__(self):
)

def forward(self, y):
y_d_rs = [] # p, y, groud-truth
y_d_rs = [] # p, y, groud-truth
fmap_rs = []

for i, d in enumerate(self.discriminators):
Expand All @@ -114,6 +115,5 @@ def forward(self, y):
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)


return y_d_rs, fmap_rs
# fmap_rs is real, fmap_gs is generated.
# fmap_rs is real, fmap_gs is generated.

0 comments on commit 1f8c125

Please sign in to comment.