diff --git a/.gitignore b/.gitignore index 125a113..70addff 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ # configuration config/* !config/default.yaml +temp-restore.yaml # logs, checkpoints chkpt/ diff --git a/hubconf.py b/hubconf.py index f8a76a6..1031f6b 100644 --- a/hubconf.py +++ b/hubconf.py @@ -3,14 +3,14 @@ from model.generator import Generator model_params = { - 'nvidia_tacotron2_LJ11_epoch3200': { + 'nvidia_tacotron2_LJ11_epoch6400': { 'mel_channel': 80, - 'model_url': 'https://github.com/seungwonpark/melgan/releases/download/v0.2-alpha/nvidia_tacotron2_LJ11_epoch3200_v02.pt', + 'model_url': 'https://github.com/seungwonpark/melgan/releases/download/v0.3-alpha/nvidia_tacotron2_LJ11_epoch6400.pt', }, } -def melgan(model_name='nvidia_tacotron2_LJ11_epoch3200', pretrained=True, progress=True): +def melgan(model_name='nvidia_tacotron2_LJ11_epoch6400', pretrained=True, progress=True): params = model_params[model_name] model = Generator(params['mel_channel']) diff --git a/model/discriminator.py b/model/discriminator.py index 46240c7..63d2736 100644 --- a/model/discriminator.py +++ b/model/discriminator.py @@ -9,28 +9,29 @@ def __init__(self): self.discriminator = nn.ModuleList([ nn.Sequential( - nn.utils.weight_norm(nn.Conv1d(1, 16, kernel_size=15, stride=1, padding=7)), - nn.LeakyReLU(), + nn.ReflectionPad1d(7), + nn.utils.weight_norm(nn.Conv1d(1, 16, kernel_size=15, stride=1)), + nn.LeakyReLU(0.2, inplace=True), ), nn.Sequential( nn.utils.weight_norm(nn.Conv1d(16, 64, kernel_size=41, stride=4, padding=20, groups=4)), - nn.LeakyReLU(), + nn.LeakyReLU(0.2, inplace=True), ), nn.Sequential( nn.utils.weight_norm(nn.Conv1d(64, 256, kernel_size=41, stride=4, padding=20, groups=16)), - nn.LeakyReLU(), + nn.LeakyReLU(0.2, inplace=True), ), nn.Sequential( nn.utils.weight_norm(nn.Conv1d(256, 1024, kernel_size=41, stride=4, padding=20, groups=64)), - nn.LeakyReLU(), + nn.LeakyReLU(0.2, inplace=True), ), nn.Sequential( nn.utils.weight_norm(nn.Conv1d(1024, 1024, kernel_size=41, stride=4, padding=20, groups=256)), - nn.LeakyReLU(), + nn.LeakyReLU(0.2, inplace=True), ), nn.Sequential( nn.utils.weight_norm(nn.Conv1d(1024, 1024, kernel_size=5, stride=1, padding=2)), - nn.LeakyReLU(), + nn.LeakyReLU(0.2, inplace=True), ), nn.utils.weight_norm(nn.Conv1d(1024, 1, kernel_size=3, stride=1, padding=1)), ]) @@ -58,3 +59,6 @@ def forward(self, x): for feat in features: print(feat.shape) print(score.shape) + + pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + print(pytorch_total_params) \ No newline at end of file diff --git a/model/generator.py b/model/generator.py index 8f1c34d..74f7155 100644 --- a/model/generator.py +++ b/model/generator.py @@ -3,7 +3,7 @@ import torch.nn.functional as F from .res_stack import ResStack -#from res_stack import ResStack +# from res_stack import ResStack MAX_WAV_VALUE = 32768.0 @@ -14,30 +14,32 @@ def __init__(self, mel_channel): self.mel_channel = mel_channel self.generator = nn.Sequential( - nn.utils.weight_norm(nn.Conv1d(mel_channel, 512, kernel_size=7, stride=1, padding=3)), + nn.ReflectionPad1d(3), + nn.utils.weight_norm(nn.Conv1d(mel_channel, 512, kernel_size=7, stride=1)), - nn.LeakyReLU(), + nn.LeakyReLU(0.2), nn.utils.weight_norm(nn.ConvTranspose1d(512, 256, kernel_size=16, stride=8, padding=4)), ResStack(256), - nn.LeakyReLU(), + nn.LeakyReLU(0.2), nn.utils.weight_norm(nn.ConvTranspose1d(256, 128, kernel_size=16, stride=8, padding=4)), ResStack(128), - nn.LeakyReLU(), + nn.LeakyReLU(0.2), nn.utils.weight_norm(nn.ConvTranspose1d(128, 64, kernel_size=4, stride=2, padding=1)), ResStack(64), - nn.LeakyReLU(), + nn.LeakyReLU(0.2), nn.utils.weight_norm(nn.ConvTranspose1d(64, 32, kernel_size=4, stride=2, padding=1)), ResStack(32), - nn.LeakyReLU(), - nn.utils.weight_norm(nn.Conv1d(32, 1, kernel_size=7, stride=1, padding=3)), + nn.LeakyReLU(0.2), + nn.ReflectionPad1d(3), + nn.utils.weight_norm(nn.Conv1d(32, 1, kernel_size=7, stride=1)), nn.Tanh(), ) @@ -84,11 +86,14 @@ def inference(self, mel): from res_stack import ResStack ''' if __name__ == '__main__': - model = Generator(7) + model = Generator(80) - x = torch.randn(3, 7, 10) + x = torch.randn(3, 80, 10) print(x.shape) y = model(x) print(y.shape) assert y.shape == torch.Size([3, 1, 2560]) + + pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + print(pytorch_total_params) \ No newline at end of file diff --git a/model/multiscale.py b/model/multiscale.py index 4393be2..640b59e 100644 --- a/model/multiscale.py +++ b/model/multiscale.py @@ -16,7 +16,7 @@ def __init__(self): self.pooling = nn.ModuleList( [Identity()] + - [nn.AvgPool1d(kernel_size=4, stride=2, padding=2) for _ in range(1, 3)] + [nn.AvgPool1d(kernel_size=4, stride=2, padding=1, count_include_pad=False) for _ in range(1, 3)] ) def forward(self, x): diff --git a/model/res_stack.py b/model/res_stack.py index 37d9fc3..0ccf175 100644 --- a/model/res_stack.py +++ b/model/res_stack.py @@ -8,22 +8,29 @@ class ResStack(nn.Module): def __init__(self, channel): super(ResStack, self).__init__() - self.layers = nn.ModuleList([ + self.blocks = nn.ModuleList([ nn.Sequential( - nn.LeakyReLU(), - nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=3, dilation=3**i, padding=3**i)), - nn.LeakyReLU(), - nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=3, dilation=1, padding=1)), + nn.LeakyReLU(0.2), + nn.ReflectionPad1d(3**i), + nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=3, dilation=3**i)), + nn.LeakyReLU(0.2), + nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=1)), ) for i in range(3) ]) + self.shortcuts = nn.ModuleList([ + nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=1)) + for i in range(3) + ]) + def forward(self, x): - for layer in self.layers: - x = x + layer(x) + for block, shortcut in zip(self.blocks, self.shortcuts): + x = shortcut(x) + block(x) return x def remove_weight_norm(self): - for layer in self.layers: - nn.utils.remove_weight_norm(layer[1]) - nn.utils.remove_weight_norm(layer[3]) + for block, shortcut in zip(self.blocks, self.shortcuts): + nn.utils.remove_weight_norm(block[1]) + nn.utils.remove_weight_norm(block[3]) + nn.utils.remove_weight_norm(shortcut)