Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement VitsSVC resume training / finetune feature #95

Merged
merged 10 commits into from
Jan 22, 2024
1 change: 1 addition & 0 deletions config/vitssvc.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"extract_mel": true,
"extract_linear_spec": true,
"extract_audio": true,
"mel_min_max_norm": true,
"use_linear": true,
"use_mel": true,
"use_audio": true,
Expand Down
55 changes: 18 additions & 37 deletions egs/svc/VitsSVC/exp_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"svcc": "[SVCC dataset path]",
"vctk": "[VCTK dataset path]"
},
"use_custom_dataset": [],
// TODO: Fill in the output log path. The default value is "Amphion/ckpts/svc"
"log_dir": "ckpts/svc",
"preprocess": {
Expand All @@ -34,13 +35,13 @@
"win_size": 2048,
"segment_size": 8192,
"n_mel": 100,
"sample_rate": 44100,
"sample_rate": 24000,

// Config for features extraction
"extract_mel": true,
"extract_pitch": true,
"pitch_extractor": "parselmouth",
"extract_energy": false,
"extract_energy": true,
"extract_uv": true,
"extract_linear_spec": true,
"extract_audio": true,
Expand All @@ -54,6 +55,11 @@
"whisper_sample_rate": 16000,
"whisper_frameshift": 0.01,
"whisper_downsample_rate": 2,
// wenet
"extract_wenet_feature": false,
"wenet_downsample_rate": 4,
"wenet_frameshift": 0.01,
"wenet_sample_rate": 16000,
// Fill in the content-based pretrained model's path
"contentvec_file": "pretrained/contentvec/checkpoint_best_legacy_500.pt",
"wenet_model_path": "pretrained/wenet/20220506_u2pp_conformer_exp/final.pt",
Expand All @@ -67,6 +73,7 @@
"use_spkid": true,
"use_contentvec": true,
"use_whisper": true,
"use_wenet": false,
"use_text": false,
"use_phone": false,

Expand All @@ -84,30 +91,24 @@
"condition_encoder": {
// Config for features usage
"merge_mode": "add",
"input_melody_dim": 1,
"use_log_f0": true,
"n_bins_melody": 256,
//# Quantization (0 for not quantization)
"output_melody_dim": 192,

"use_log_loudness": true,
"use_contentvec": true,
"use_whisper": true,
"use_mert": false,
"use_wenet": false,
"whisper_dim": 1024,
"contentvec_dim": 256,
"content_encoder_dim": 192,
"output_singer_dim": 192,
"singer_table_size": 512,
"output_content_dim": 192,
"use_spkid": true,

"wenet_dim": 512,
"output_melody_dim": 384,
"output_loudness_dim": 384,
"content_encoder_dim": 384,
"output_singer_dim": 384,
"output_content_dim": 384,
"pitch_max": 1100.0,
"pitch_min": 50.0,
},
"vits": {
"inter_channels": 192,
"hidden_channels": 192,
"inter_channels": 384,
"hidden_channels": 384,
"filter_channels": 256,
"n_heads": 2,
"n_layers": 6,
Expand Down Expand Up @@ -135,26 +136,6 @@
3,
2
],
"run_eval": [
true,
true
],
"adamw": {
"lr": 2.0e-4
},
"reducelronplateau": {
"factor": 0.8,
"patience": 30,
"min_lr": 1.0e-4
},
"dataloader": {
"num_worker": 8,
"pin_memory": true
},
"sampler": {
"holistic_shuffle": false,
"drop_last": true
}
},
"inference": {
"batch_size": 1,
Expand Down
10 changes: 4 additions & 6 deletions models/svc/vits/vits.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,6 @@ def forward(self, data):
"""

# TODO: elegantly handle the dimensions
c = data["contentvec_feat"].transpose(1, 2)
spec = data["linear"].transpose(1, 2)

g = data["spk_id"]
Expand All @@ -201,9 +200,9 @@ def forward(self, data):
spec_lengths = data["target_len"]
f0 = data["frame_pitch"]

x_mask = torch.unsqueeze(sequence_mask(c_lengths, c.size(2)), 1).to(c.dtype)
# condition_encoder ver.
x = self.condition_encoder(data).transpose(1, 2)
x_mask = torch.unsqueeze(sequence_mask(c_lengths, f0.size(1)), 1).to(x.dtype)

# prior encoder
z_ptemp, m_p, logs_p, _ = self.enc_p(x, x_mask)
Expand Down Expand Up @@ -240,24 +239,23 @@ def forward(self, data):
@torch.no_grad()
def infer(self, data, noise_scale=0.35, seed=52468):
# c, f0, uv, g
c = data["contentvec_feat"].transpose(1, 2)
f0 = data["frame_pitch"]
g = data["spk_id"]

if c.device == torch.device("cuda"):
if f0.device == torch.device("cuda"):
viewfinder-annn marked this conversation as resolved.
Show resolved Hide resolved
torch.cuda.manual_seed_all(seed)
else:
torch.manual_seed(seed)

c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device)
c_lengths = (torch.ones(f0.size(0)) * f0.size(-1)).to(f0.device)

if g.dim() == 1:
g = g.unsqueeze(0)
g = self.emb_g(g).transpose(1, 2)

x_mask = torch.unsqueeze(sequence_mask(c_lengths, c.size(2)), 1).to(c.dtype)
# condition_encoder ver.
x = self.condition_encoder(data).transpose(1, 2)
x_mask = torch.unsqueeze(sequence_mask(c_lengths, f0.size(1)), 1).to(x.dtype)

z_p, m_p, logs_p, c_mask = self.enc_p(x, x_mask, noice_scale=noise_scale)
z = self.flow(z_p, c_mask, g=g, reverse=True)
Expand Down
86 changes: 82 additions & 4 deletions models/svc/vits/vits_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import torch
from torch.optim.lr_scheduler import ExponentialLR
from tqdm import tqdm
from pathlib import Path
import accelerate

# from models.svc.base import SVCTrainer
from models.svc.base.svc_dataset import SVCCollator, SVCDataset
Expand All @@ -30,6 +32,82 @@ def __init__(self, args, cfg):
self.singers = self._build_singer_lut()
TTSTrainer.__init__(self, args, cfg)

def _check_resume(self):
self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint")

if self.args.resume:
if self.args.resume_from_ckpt_path == "":
## Automatically resume according to the current exprimental name
self.logger.info(
"Automatically resuming from latest checkpoint in {}...".format(
self.checkpoint_dir
)
)
start = time.monotonic_ns()
self.ckpt_path = self._load_model(
self.checkpoint_dir, None, self.args.resume_type
)
end = time.monotonic_ns()
self.logger.info(
f"Resuming from checkpoint done in {(end - start) / 1e6:.2f}ms"
)
self.checkpoints_path = json.load(
open(os.path.join(self.ckpt_path, "ckpts.json"), "r")
)
else:
## Resume from the given checkpoint path
if not os.path.exists(self.args.resume_from_ckpt_path):
raise ValueError(
"[Error] The resumed checkpoint path {} don't exist.".format(
self.args.resume_from_ckpt_path
)
)
self.logger.info(
"Resuming from {}...".format(self.args.resume_from_ckpt_path)
)
start = time.monotonic_ns()
self.ckpt_path = self._load_model(
self.checkpoint_dir,
self.args.resume_from_ckpt_path,
self.args.resume_type,
)
end = time.monotonic_ns()
self.logger.info(
f"Resuming from checkpoint done in {(end - start) / 1e6:.2f}ms"
)

def _load_model(self, checkpoint_dir, checkpoint_path=None, resume_type="resume"):
"""Load model from checkpoint. If a folder is given, it will
load the latest checkpoint in checkpoint_dir. If a path is given
it will load the checkpoint specified by checkpoint_path.
**Only use this method after** ``accelerator.prepare()``.
"""
if checkpoint_path is None:
ls = [str(i) for i in Path(checkpoint_dir).glob("*")]
ls.sort(key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True)
checkpoint_path = ls[0]
self.logger.info("Load model from {}".format(checkpoint_path))

if resume_type == "resume":
self.accelerator.load_state(checkpoint_path)
self.epoch = int(checkpoint_path.split("_")[-3].split("-")[-1]) + 1
self.step = int(checkpoint_path.split("_")[-2].split("-")[-1]) + 1
elif resume_type == "finetune":
# TODO: figure out why only using "pytorch_model.bin" works
accelerate.load_checkpoint_and_dispatch(
self.accelerator.unwrap_model(self.model["generator"]),
os.path.join(checkpoint_path, "pytorch_model.bin"),
)
accelerate.load_checkpoint_and_dispatch(
self.accelerator.unwrap_model(self.model["discriminator"]),
os.path.join(checkpoint_path, "pytorch_model.bin"),
)
self.logger.info("Load model weights for finetune SUCCESS!")
else:
raise ValueError("Unsupported resume type: {}".format(resume_type))

return checkpoint_path

def _build_model(self):
net_g = SynthesizerTrn(
self.cfg.preprocess.n_fft // 2 + 1,
Expand Down Expand Up @@ -447,12 +525,12 @@ def _train_epoch(self):
return epoch_sum_loss, epoch_losses

def _build_singer_lut(self):
# custom for vitssvc, singers.json isn't saved in checkpoint
resumed_singer_path = None
if self.args.resume_from_ckpt_path and self.args.resume_from_ckpt_path != "":
resumed_singer_path = os.path.join(
self.args.resume_from_ckpt_path, self.cfg.preprocess.spk2id
)
if os.path.exists(os.path.join(self.exp_dir, self.cfg.preprocess.spk2id)):
ckpt_exp_dir = os.path.join(self.args.resume_from_ckpt_path, "../../")
resumed_singer_path = os.path.join(ckpt_exp_dir, self.cfg.preprocess.spk2id)
elif os.path.exists(os.path.join(self.exp_dir, self.cfg.preprocess.spk2id)):
resumed_singer_path = os.path.join(self.exp_dir, self.cfg.preprocess.spk2id)

if resumed_singer_path:
Expand Down
10 changes: 5 additions & 5 deletions models/tts/base/tts_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,17 +135,17 @@ def __init__(self, args=None, cfg=None):
end = time.monotonic_ns()
self.logger.info(f"Building criterion done in {(end - start) / 1e6:.2f}ms")

# Resume or Finetune
with self.accelerator.main_process_first():
self._check_resume()

# accelerate prepare
self.logger.info("Initializing accelerate...")
start = time.monotonic_ns()
self._accelerator_prepare()
end = time.monotonic_ns()
self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.2f}ms")

# Resume or Finetune
with self.accelerator.main_process_first():
self._check_resume()

# save config file path
self.config_save_path = os.path.join(self.exp_dir, "args.json")
self.device = self.accelerator.device
Expand All @@ -155,7 +155,7 @@ def __init__(self, args=None, cfg=None):
self.utt2spk_dict = self._build_utt2spk_dict()

# Only for TTS tasks
self.task_type = "TTS"
self.task_type = cfg.task_type.upper()
self.logger.info("Task type: {}".format(self.task_type))

def _check_resume(self):
Expand Down