Skip to content

Commit

Permalink
Update DiffComoSVC (#135)
Browse files Browse the repository at this point in the history
Update comosvc hyperparam
  • Loading branch information
Lokshaw-Chau authored Feb 9, 2024
1 parent 1c66555 commit c3d2f77
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 91 deletions.
8 changes: 4 additions & 4 deletions config/comosvc.json
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@
"sigma_min": 0.002,
"sigma_max": 80,
"rho": 7,
"n_timesteps": 40,
"n_timesteps": 18,
},
"diffusion": {
// Diffusion steps encoder
Expand All @@ -154,7 +154,7 @@
"train": {
// Basic settings
"fast_steps": 0,
"batch_size": 32,
"batch_size": 64,
"gradient_accumulation_step": 1,
"max_epoch": -1,
// -1 means no limit
Expand Down Expand Up @@ -195,7 +195,7 @@
// Optimizer
"optimizer": "AdamW",
"adamw": {
"lr": 4.0e-4
"lr": 5.0e-5
// nn model lr
},
// LR Scheduler
Expand All @@ -204,7 +204,7 @@
"factor": 0.8,
"patience": 10,
// unit is epoch
"min_lr": 1.0e-4
"min_lr": 5.0e-6
}
},
"inference": {
Expand Down
84 changes: 49 additions & 35 deletions models/svc/comosvc/comosvc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""Adapted from https://github.com/zhenye234/CoMoSpeech"""

import torch
import torch.nn as nn
import copy
Expand All @@ -16,14 +14,12 @@

from models.svc.transformer.conformer import Conformer, BaseModule
from models.svc.diffusion.diffusion_wrapper import DiffusionWrapper
from models.svc.comosvc.utils import slice_segments, rand_ids_segments


class Consistency(nn.Module):
def __init__(self, cfg, distill=False):
super().__init__()
self.cfg = cfg
# self.denoise_fn = GradLogPEstimator2d(96)
self.denoise_fn = DiffusionWrapper(self.cfg)
self.cfg = cfg.model.comosvc
self.teacher = not distill
Expand Down Expand Up @@ -53,7 +49,7 @@ def init_consistency_training(self):
self.denoise_fn_ema = copy.deepcopy(self.denoise_fn)
self.denoise_fn_pretrained = copy.deepcopy(self.denoise_fn)

def EDMPrecond(self, x, sigma, cond, denoise_fn, mask, spk=None):
def EDMPrecond(self, x, sigma, cond, denoise_fn):
"""
karras diffusion reverse process
Expand All @@ -62,24 +58,29 @@ def EDMPrecond(self, x, sigma, cond, denoise_fn, mask, spk=None):
sigma: noise level [B x 1 x 1]
cond: output of conformer encoder [B x n_mel x L]
denoise_fn: denoiser neural network e.g. DilatedCNN
mask: mask of padded frames [B x n_mel x L]
Returns:
denoised mel-spectrogram [B x n_mel x L]
"""
sigma = sigma.reshape(-1, 1, 1)

c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt()
c_out = (
(sigma - self.sigma_min)
* self.sigma_data
/ (sigma**2 + self.sigma_data**2).sqrt()
)
c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt()
c_noise = sigma.log() / 4

x_in = c_in * x
x_in = x_in.transpose(1, 2)
x = x.transpose(1, 2)
cond = cond.transpose(1, 2)
F_x = denoise_fn(x_in, c_noise.squeeze(), cond)
# F_x = denoise_fn((c_in * x), mask, cond, c_noise.flatten())
c_noise = c_noise.squeeze()
if c_noise.dim() == 0:
c_noise = c_noise.unsqueeze(0)
F_x = denoise_fn(x_in, c_noise, cond)
D_x = c_skip * x + c_out * (F_x)
D_x = D_x.transpose(1, 2)
return D_x
Expand All @@ -99,7 +100,7 @@ def EDMLoss(self, x_start, cond, mask):

# follow Grad-TTS, start from Gaussian noise with mean cond and std I
noise = (torch.randn_like(x_start) + cond) * sigma
D_yn = self.EDMPrecond(x_start + noise, sigma, cond, self.denoise_fn, mask)
D_yn = self.EDMPrecond(x_start + noise, sigma, cond, self.denoise_fn)
loss = weight * ((D_yn - x_start) ** 2)
loss = torch.sum(loss * mask) / torch.sum(mask)
return loss
Expand All @@ -120,10 +121,6 @@ def edm_sampler(
S_min=0,
S_max=float("inf"),
S_noise=1,
# S_churn=40 ,S_min=0.05,S_max=50,S_noise=1.003,# S_churn=0, S_min=0, S_max=float('inf'), S_noise=1,
# S_churn=30 ,S_min=0.01,S_max=30,S_noise=1.007,
# S_churn=30 ,S_min=0.01,S_max=1,S_noise=1.007,
# S_churn=80 ,S_min=0.05,S_max=50,S_noise=1.003,
):
"""
karras diffusion sampler
Expand All @@ -138,9 +135,9 @@ def edm_sampler(
denoised mel-spectrogram [B x n_mel x L]
"""
# Time step discretization.
step_indices = torch.arange(num_steps, device=latents.device)

num_steps = num_steps + 1
step_indices = torch.arange(num_steps, device=latents.device)
t_steps = (
sigma_max ** (1 / rho)
+ step_indices
Expand Down Expand Up @@ -169,10 +166,19 @@ def edm_sampler(
t_hat**2 - t_cur**2
).sqrt() * S_noise * torch.randn_like(x_cur)
# Euler step.
denoised = self.EDMPrecond(x_hat, t_hat, cond, self.denoise_fn, nonpadding)
denoised = self.EDMPrecond(x_hat, t_hat, cond, self.denoise_fn)
d_cur = (x_hat - denoised) / t_hat
x_next = x_hat + (t_next - t_hat) * d_cur

# add Heun’s 2nd order method
# if i < num_steps - 1:
# t = torch.zeros((x_cur.shape[0], 1, 1), device=x_cur.device)
# t[:, 0, 0] = t_next
# #t_next = t
# denoised = self.EDMPrecond(x_next, t, cond, self.denoise_fn, nonpadding)
# d_prime = (x_next - denoised) / t_next
# x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)

return x_next

def CTLoss_D(self, y, cond, mask):
Expand All @@ -195,31 +201,41 @@ def CTLoss_D(self, y, cond, mask):
z = torch.randn_like(y) + cond

tn_1 = self.t_steps[n + 1].reshape(-1, 1, 1).to(y.device)
f_theta = self.EDMPrecond(y + tn_1 * z, tn_1, cond, self.denoise_fn, mask)
f_theta = self.EDMPrecond(y + tn_1 * z, tn_1, cond, self.denoise_fn)

with torch.no_grad():
tn = self.t_steps[n].reshape(-1, 1, 1).to(y.device)

# euler step
x_hat = y + tn_1 * z
denoised = self.EDMPrecond(
x_hat, tn_1, cond, self.denoise_fn_pretrained, mask
)
denoised = self.EDMPrecond(x_hat, tn_1, cond, self.denoise_fn_pretrained)
d_cur = (x_hat - denoised) / tn_1
y_tn = x_hat + (tn - tn_1) * d_cur

f_theta_ema = self.EDMPrecond(y_tn, tn, cond, self.denoise_fn_ema, mask)
# Heun’s 2nd order method

denoised2 = self.EDMPrecond(y_tn, tn, cond, self.denoise_fn_pretrained)
d_prime = (y_tn - denoised2) / tn
y_tn = x_hat + (tn - tn_1) * (0.5 * d_cur + 0.5 * d_prime)

f_theta_ema = self.EDMPrecond(y_tn, tn, cond, self.denoise_fn_ema)

# loss = (f_theta - f_theta_ema.detach()) ** 2
# loss = torch.sum(loss * mask) / torch.sum(mask)
loss = self.ssim_loss(f_theta, f_theta_ema.detach())
loss = (f_theta - f_theta_ema.detach()) ** 2
loss = torch.sum(loss * mask) / torch.sum(mask)

# check nan
if torch.any(torch.isnan(loss)):
print("nan loss")
if torch.any(torch.isnan(f_theta)):
print("nan f_theta")
if torch.any(torch.isnan(f_theta_ema)):
print("nan f_theta_ema")

return loss

def get_t_steps(self, N):
N = N + 1
step_indices = torch.arange(N) # , device=latents.device)
step_indices = torch.arange(N)
t_steps = (
self.sigma_min ** (1 / self.rho)
+ step_indices
Expand Down Expand Up @@ -252,17 +268,16 @@ def CT_sampler(self, latents, cond, nonpadding, t_steps=1):
t_steps = torch.as_tensor(t_steps).to(latents.device)
latents = latents * t_steps[0]
_t = torch.zeros((latents.shape[0], 1, 1), device=latents.device)
_t[:, 0, 0] = t_steps
x = self.EDMPrecond(latents, _t, cond, self.denoise_fn_ema, nonpadding)
_t[:, 0, 0] = t_steps[0]
x = self.EDMPrecond(latents, _t, cond, self.denoise_fn_ema)

for t in t_steps[1:-1]:
z = torch.randn_like(x) + cond
x_tn = x + (t**2 - self.sigma_min**2).sqrt() * z
_t = torch.zeros((x.shape[0], 1, 1), device=x.device)
_t[:, 0, 0] = t
t = _t
print(t)
x = self.EDMPrecond(x_tn, t, cond, self.denoise_fn_ema, nonpadding)
x = self.EDMPrecond(x_tn, t, cond, self.denoise_fn_ema)
return x

def forward(self, x, nonpadding, cond, t_steps=1, infer=False):
Expand Down Expand Up @@ -335,10 +350,10 @@ def forward(self, x_mask, x, n_timesteps, temperature=1.0):
decoder_outputs = decoder_outputs.transpose(1, 2)
return encoder_outputs, decoder_outputs

def compute_loss(self, x_mask, x, mel, out_size=None, skip_diff=False):
def compute_loss(self, x_mask, x, mel, skip_diff=False):
"""
Computes 2 losses:
1. prior loss: loss between mel-spectrogram and encoder outputs.
1. prior loss: loss between mel-spectrogram and encoder outputs. (l2 and ssim loss)
2. diffusion loss: loss between gaussian noise and its reconstruction by diffusion-based decoder.
Args:
Expand All @@ -349,9 +364,11 @@ def compute_loss(self, x_mask, x, mel, out_size=None, skip_diff=False):

mu_x = self.encoder(x, x_mask)
# prior loss
x_mask = x_mask.repeat(1, 1, mel.shape[-1])
prior_loss = torch.sum(
0.5 * ((mel - mu_x) ** 2 + math.log(2 * math.pi)) * x_mask
)

prior_loss = prior_loss / (torch.sum(x_mask) * self.cfg.model.comosvc.n_mel)
# ssim loss
ssim_loss = self.ssim_loss(mu_x, mel)
Expand All @@ -366,10 +383,7 @@ def compute_loss(self, x_mask, x, mel, out_size=None, skip_diff=False):

# Cut a small segment of mel-spectrogram in order to increase batch size
else:
if self.distill:
mu_y = mu_x.detach()
else:
mu_y = mu_x
mu_y = mu_x
mask_y = x_mask

diff_loss = self.decoder(mel, mask_y, mu_y, infer=False)
Expand Down
66 changes: 45 additions & 21 deletions models/svc/comosvc/comosvc_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import torch
import os
import json5
from collections import OrderedDict
from tqdm import tqdm
import json
import shutil
Expand All @@ -25,28 +24,17 @@ def __init__(self, args=None, cfg=None):
SVCTrainer.__init__(self, args, cfg)
self.distill = cfg.model.comosvc.distill
self.skip_diff = True
if self.distill: # and args.resume is None:
self.teacher_model_path = cfg.model.teacher_model_path
self.teacher_state_dict = self._load_teacher_state_dict()
self._load_teacher_model(self.teacher_state_dict)
self.acoustic_mapper.decoder.init_consistency_training()

### Following are methods only for comoSVC models ###
def _load_teacher_state_dict(self):

def _load_teacher_model(self, model):
r"""Load teacher model from checkpoint file."""
self.checkpoint_file = self.teacher_model_path
print("Load teacher acoustic model from {}".format(self.checkpoint_file))
raw_state_dict = torch.load(self.checkpoint_file) # , map_location=self.device)
return raw_state_dict

def _load_teacher_model(self, state_dict):
raw_dict = state_dict
clean_dict = OrderedDict()
for k, v in raw_dict.items():
if k.startswith("module."):
clean_dict[k[7:]] = v
else:
clean_dict[k] = v
self.model.load_state_dict(clean_dict)
self.logger.info(
"Load teacher acoustic model from {}".format(self.checkpoint_file)
)
raw_dict = torch.load(self.checkpoint_file)
model.load_state_dict(raw_dict)

def _build_model(self):
r"""Build the model for training. This function is called in ``__init__`` function."""
Expand All @@ -57,8 +45,40 @@ def _build_model(self):
self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder)
self.acoustic_mapper = ComoSVC(self.cfg)
model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper])
if self.cfg.model.comosvc.distill:
if not self.args.resume:
# do not load teacher model when resume
self.teacher_model_path = self.cfg.model.teacher_model_path
self._load_teacher_model(model)
# build teacher & target decoder and freeze teacher
self.acoustic_mapper.decoder.init_consistency_training()
self.freeze_net(self.condition_encoder)
self.freeze_net(self.acoustic_mapper.encoder)
self.freeze_net(self.acoustic_mapper.decoder.denoise_fn_pretrained)
self.freeze_net(self.acoustic_mapper.decoder.denoise_fn_ema)
return model

def freeze_net(self, model):
r"""Freeze the model for training."""
for name, param in model.named_parameters():
param.requires_grad = False

def __build_optimizer(self):
r"""Build optimizer for training. This function is called in ``__init__`` function."""

if self.cfg.train.optimizer.lower() == "adamw":
optimizer = torch.optim.AdamW(
params=filter(lambda p: p.requires_grad, self.model.parameters()),
**self.cfg.train.adamw,
)

else:
raise NotImplementedError(
"Not support optimizer: {}".format(self.cfg.train.optimizer)
)

return optimizer

def _forward_step(self, batch):
r"""Forward step for training and inference. This function is called
in ``_train_step`` & ``_test_step`` function.
Expand Down Expand Up @@ -124,7 +144,7 @@ def _train_epoch(self):
for k, v in loss.items():
key = "Step/Train Loss/{}".format(k)
log_info[key] = v
log_info["Step/Learning Rate"]: self.optimizer.param_groups[0]["lr"]
log_info["Step/Learning Rate"] = self.optimizer.param_groups[0]["lr"]
self.accelerator.log(
log_info,
step=self.step,
Expand Down Expand Up @@ -197,13 +217,16 @@ def train_loop(self):
self.epoch, self.step, train_loss
),
)
self.tmp_checkpoint_save_path = path
self.accelerator.save_state(path)
print(f"save checkpoint in {path}")
json.dump(
self.checkpoints_path,
open(os.path.join(path, "ckpts.json"), "w"),
ensure_ascii=False,
indent=4,
)
self._save_auxiliary_states()

# Remove old checkpoints
to_remove = []
Expand Down Expand Up @@ -247,6 +270,7 @@ def train_loop(self):
),
)
)
self._save_auxiliary_states()
self.accelerator.end_training()

@torch.inference_mode()
Expand Down
Loading

0 comments on commit c3d2f77

Please sign in to comment.