diff --git a/config/vitssvc.json b/config/vitssvc.json index 70687cbb..a2572b66 100644 --- a/config/vitssvc.json +++ b/config/vitssvc.json @@ -3,37 +3,57 @@ "model_type": "VITS", "task_type": "svc", "preprocess": { - "extract_phone": false, + // Config for features extraction "extract_mel": true, + "extract_pitch": true, + "pitch_extractor": "parselmouth", + "extract_energy": true, + "extract_uv": true, "extract_linear_spec": true, "extract_audio": true, + + "mel_min_max_norm": true, + // Config for features usage "use_linear": true, "use_mel": true, + "use_min_max_norm_mel": false, "use_audio": true, + "use_frame_pitch": true, + "use_uv": true, + "use_spkid": true, + "use_contentvec": false, + "use_whisper": false, + "use_wenet": false, "use_text": false, - "use_phone": true, - + "use_phone": false, + "fmin": 0, - "fmax": null, + "fmax": 12000, "f0_min": 50, "f0_max": 1100, // f0_bin in sovits "pitch_bin": 256, // filter_length in sovits - "n_fft": 2048, + "n_fft": 1024, // hop_length in sovits - "hop_size": 512, + "hop_size": 256, // win_length in sovits - "win_size": 2048, + "win_size": 1024, "segment_size": 8192, "n_mel": 100, - "sample_rate": 44100, + "sample_rate": 24000, "mel_min_max_stats_dir": "mel_min_max_stats", "whisper_dir": "whisper", "contentvec_dir": "contentvec", "wenet_dir": "wenet", "mert_dir": "mert", + + // Meta file + "train_file": "train.json", + "valid_file": "test.json", + "spk2id": "singers.json", + "utt2spk": "utt2singer" }, "model": { "condition_encoder": { @@ -41,12 +61,11 @@ "input_melody_dim": 1, "use_log_f0": true, "n_bins_melody": 256, - //# Quantization (0 for not quantization) - "output_melody_dim": 196, + "output_melody_dim": 384, "input_loudness_dim": 1, - "use_log_loudness": false, + "use_log_loudness": true, "n_bins_loudness": 256, - "output_loudness_dim": 196, + "output_loudness_dim": 384, "use_whisper": false, "use_contentvec": false, "use_wenet": false, @@ -55,17 +74,20 @@ "contentvec_dim": 256, "mert_dim": 256, "wenet_dim": 512, - "content_encoder_dim": 196, - "output_singer_dim": 196, + "content_encoder_dim": 384, "singer_table_size": 512, - "output_content_dim": 196, - "use_spkid": true + "output_singer_dim": 384, + "output_content_dim": 384, + "use_spkid": true, + + "pitch_max": 1100.0, + "pitch_min": 50.0, }, "vits": { "filter_channels": 256, "gin_channels": 256, - "hidden_channels": 192, - "inter_channels": 192, + "hidden_channels": 384, + "inter_channels": 384, "kernel_size": 3, "n_flow_layer": 4, "n_heads": 2, @@ -73,7 +95,6 @@ "n_layers_q": 3, "n_speakers": 512, "p_dropout": 0.1, - "ssl_dim": 256, "use_spectral_norm": false, }, "generator": "hifigan", @@ -86,10 +107,10 @@ 11 ], "upsample_rates": [ - 8,8,2,2,2 + 8,8,2,2 ], "upsample_kernel_sizes": [ - 16,16,4,4,4 + 16,16,4,4 ], "upsample_initial_channel": 512, "resblock_dilation_sizes": [ @@ -99,7 +120,7 @@ ] }, "melgan": { - "ratios": [8, 8, 2, 2, 2], + "ratios": [8, 8, 2, 2], "ngf": 32, "n_residual_layers": 3, "num_D": 3, @@ -112,10 +133,10 @@ "activation": "snakebeta", "snake_logscale": true, "upsample_rates": [ - 8,8,2,2,2, + 8,8,2,2 ], "upsample_kernel_sizes": [ - 16,16,4,4,4, + 16,16,4,4 ], "upsample_initial_channel": 512, "resblock_kernel_sizes": [ @@ -133,10 +154,10 @@ "resblock": "1", "harmonic_num": 8, "upsample_rates": [ - 8,8,2,2,2, + 8,8,2,2 ], "upsample_kernel_sizes": [ - 16,16,4,4,4, + 16,16,4,4 ], "upsample_initial_channel": 768, "resblock_kernel_sizes": [ diff --git a/egs/svc/VitsSVC/exp_config.json b/egs/svc/VitsSVC/exp_config.json index bd3b4481..3310f0b3 100644 --- a/egs/svc/VitsSVC/exp_config.json +++ b/egs/svc/VitsSVC/exp_config.json @@ -16,34 +16,16 @@ "svcc": "[SVCC dataset path]", "vctk": "[VCTK dataset path]" }, + "use_custom_dataset": [], // TODO: Fill in the output log path. The default value is "Amphion/ckpts/svc" "log_dir": "ckpts/svc", "preprocess": { // TODO: Fill in the output data path. The default value is "Amphion/data" "processed_dir": "data", - "f0_min": 50, - "f0_max": 1100, - // f0_bin in sovits - "pitch_bin": 256, - // filter_length in sovits - "n_fft": 2048, - // hop_length in sovits - "hop_size": 512, - // win_length in sovits - "win_size": 2048, - "segment_size": 8192, "n_mel": 100, - "sample_rate": 44100, + "sample_rate": 24000, - // Config for features extraction - "extract_mel": true, - "extract_pitch": true, - "pitch_extractor": "parselmouth", - "extract_energy": false, - "extract_uv": true, - "extract_linear_spec": true, - "extract_audio": true, // contentvec "extract_contentvec_feature": true, "contentvec_sample_rate": 16000, @@ -54,66 +36,48 @@ "whisper_sample_rate": 16000, "whisper_frameshift": 0.01, "whisper_downsample_rate": 2, + // wenet + "extract_wenet_feature": true, + "wenet_downsample_rate": 4, + "wenet_frameshift": 0.01, + "wenet_sample_rate": 16000, // Fill in the content-based pretrained model's path "contentvec_file": "pretrained/contentvec/checkpoint_best_legacy_500.pt", "wenet_model_path": "pretrained/wenet/20220506_u2pp_conformer_exp/final.pt", "wenet_config": "pretrained/wenet/20220506_u2pp_conformer_exp/train.yaml", "whisper_model": "medium", "whisper_model_path": "pretrained/whisper/medium.pt", - // Config for features usage - "use_mel": true, - "use_frame_pitch": true, - "use_uv": true, - "use_spkid": true, + "use_contentvec": true, "use_whisper": true, - "use_text": false, - "use_phone": false, - + "use_wenet": false, + // Extract content features using dataloader "pin_memory": true, "num_workers": 8, "content_feature_batch_size": 16, - // Meta file - "train_file": "train.json", - "valid_file": "test.json", - "spk2id": "singers.json", - "utt2spk": "utt2singer" + }, "model": { "condition_encoder": { // Config for features usage "merge_mode": "add", - "input_melody_dim": 1, - "use_log_f0": true, - "n_bins_melody": 256, - //# Quantization (0 for not quantization) - "output_melody_dim": 192, - + "use_log_loudness": true, "use_contentvec": true, "use_whisper": true, - "use_mert": false, "use_wenet": false, "whisper_dim": 1024, "contentvec_dim": 256, - "content_encoder_dim": 192, - "output_singer_dim": 192, - "singer_table_size": 512, - "output_content_dim": 192, - "use_spkid": true, - - "pitch_max": 1100.0, - "pitch_min": 50.0, + "wenet_dim": 512, }, "vits": { - "inter_channels": 192, - "hidden_channels": 192, + "inter_channels": 384, + "hidden_channels": 384, "filter_channels": 256, "n_heads": 2, "n_layers": 6, "kernel_size": 3, "p_dropout": 0.1, - "ssl_dim": 256, "n_flow_layer": 4, "n_layers_q": 3, "gin_channels": 256, @@ -135,26 +99,6 @@ 3, 2 ], - "run_eval": [ - true, - true - ], - "adamw": { - "lr": 2.0e-4 - }, - "reducelronplateau": { - "factor": 0.8, - "patience": 30, - "min_lr": 1.0e-4 - }, - "dataloader": { - "num_worker": 8, - "pin_memory": true - }, - "sampler": { - "holistic_shuffle": false, - "drop_last": true - } }, "inference": { "batch_size": 1, diff --git a/models/base/new_trainer.py b/models/base/new_trainer.py index 3bfd6a73..4200e765 100644 --- a/models/base/new_trainer.py +++ b/models/base/new_trainer.py @@ -115,8 +115,8 @@ def __init__(self, args=None, cfg=None): with self.accelerator.main_process_first(): self.logger.info("Building optimizer and scheduler...") start = time.monotonic_ns() - self.optimizer = self.__build_optimizer() - self.scheduler = self.__build_scheduler() + self.optimizer = self._build_optimizer() + self.scheduler = self._build_scheduler() end = time.monotonic_ns() self.logger.info( f"Building optimizer and scheduler done in {(end - start) / 1e6:.2f}ms" @@ -125,19 +125,7 @@ def __init__(self, args=None, cfg=None): # accelerate prepare self.logger.info("Initializing accelerate...") start = time.monotonic_ns() - ( - self.train_dataloader, - self.valid_dataloader, - self.model, - self.optimizer, - self.scheduler, - ) = self.accelerator.prepare( - self.train_dataloader, - self.valid_dataloader, - self.model, - self.optimizer, - self.scheduler, - ) + self._accelerator_prepare() end = time.monotonic_ns() self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.2f}ms") @@ -160,7 +148,7 @@ def __init__(self, args=None, cfg=None): ) ) start = time.monotonic_ns() - ckpt_path = self.__load_model( + ckpt_path = self._load_model( checkpoint_dir=self.checkpoint_dir, resume_type=args.resume_type ) end = time.monotonic_ns() @@ -182,7 +170,7 @@ def __init__(self, args=None, cfg=None): "Resuming from {}...".format(args.resume_from_ckpt_path) ) start = time.monotonic_ns() - ckpt_path = self.__load_model( + ckpt_path = self._load_model( checkpoint_path=args.resume_from_ckpt_path, resume_type=args.resume_type, ) @@ -194,6 +182,21 @@ def __init__(self, args=None, cfg=None): # save config file path self.config_save_path = os.path.join(self.exp_dir, "args.json") + def _accelerator_prepare(self): + ( + self.train_dataloader, + self.valid_dataloader, + self.model, + self.optimizer, + self.scheduler, + ) = self.accelerator.prepare( + self.train_dataloader, + self.valid_dataloader, + self.model, + self.optimizer, + self.scheduler, + ) + ### Following are abstract methods that should be implemented in child classes ### @abstractmethod def _build_dataset(self): @@ -422,7 +425,7 @@ def _valid_step(self, batch): """ return self._forward_step(batch) - def __load_model( + def _load_model( self, checkpoint_dir: str = None, checkpoint_path: str = None, @@ -546,7 +549,7 @@ def _check_nan(self, loss, y_pred, y_gt): ## Following are private methods ## ## !!! These are inconvenient for GAN-based model training. It'd be better to move these to svc_trainer.py if needed. - def __build_optimizer(self): + def _build_optimizer(self): r"""Build optimizer for model.""" # Make case-insensitive matching if self.cfg.train.optimizer.lower() == "adadelta": @@ -604,7 +607,7 @@ def __build_optimizer(self): ) return optimizer - def __build_scheduler(self): + def _build_scheduler(self): r"""Build scheduler for optimizer.""" # Make case-insensitive matching if self.cfg.train.scheduler.lower() == "lambdalr": diff --git a/models/svc/base/svc_dataset.py b/models/svc/base/svc_dataset.py index 9a66c03c..c4f908f4 100644 --- a/models/svc/base/svc_dataset.py +++ b/models/svc/base/svc_dataset.py @@ -191,7 +191,7 @@ def __init__(self, args, cfg, infer_type): spk2id_path = os.path.join(args.acoustics_dir, cfg.preprocess.spk2id) # utt2sp_path = os.path.join(self.data_root, cfg.preprocess.utt2spk) - with open(spk2id_path, "r") as f: + with open(spk2id_path, "r", encoding="utf-8") as f: self.spk2id = json.load(f) # print("self.spk2id", self.spk2id) @@ -224,9 +224,11 @@ def __init__(self, args, cfg, infer_type): cfg.preprocess.pitch_dir, "statistics.json", ) - self.target_pitch_median = json.load(open(target_f0_statistics_path, "r"))[ - f"{self.target_dataset}_{self.target_singer}" - ]["voiced_positions"]["median"] + self.target_pitch_median = json.load( + open(target_f0_statistics_path, "r", encoding="utf-8") + )[f"{self.target_dataset}_{self.target_singer}"]["voiced_positions"][ + "median" + ] # Source F0 median (if infer from file) if infer_type == "from_file": @@ -238,7 +240,7 @@ def __init__(self, args, cfg, infer_type): "statistics.json", ) self.source_pitch_median = json.load( - open(source_f0_statistics_path, "r") + open(source_f0_statistics_path, "r", encoding="utf-8") )[f"{source_audio_name}_{source_audio_name}"]["voiced_positions"][ "median" ] diff --git a/models/svc/base/svc_trainer.py b/models/svc/base/svc_trainer.py index a2a093a8..1c6588ed 100644 --- a/models/svc/base/svc_trainer.py +++ b/models/svc/base/svc_trainer.py @@ -70,7 +70,9 @@ def _save_auxiliary_states(self): To save the singer's look-up table in the checkpoint saving path """ with open( - os.path.join(self.tmp_checkpoint_save_path, self.cfg.preprocess.spk2id), "w" + os.path.join(self.tmp_checkpoint_save_path, self.cfg.preprocess.spk2id), + "w", + encoding="utf-8", ) as f: json.dump(self.singers, f, indent=4, ensure_ascii=False) diff --git a/models/svc/vits/vits.py b/models/svc/vits/vits.py index c6c40728..baa857ee 100644 --- a/models/svc/vits/vits.py +++ b/models/svc/vits/vits.py @@ -96,7 +96,6 @@ def __init__(self, spec_channels, segment_size, cfg): self.n_layers = cfg.model.vits.n_layers self.kernel_size = cfg.model.vits.kernel_size self.p_dropout = cfg.model.vits.p_dropout - self.ssl_dim = cfg.model.vits.ssl_dim self.n_flow_layer = cfg.model.vits.n_flow_layer self.gin_channels = cfg.model.vits.gin_channels self.n_speakers = cfg.model.vits.n_speakers @@ -191,7 +190,6 @@ def forward(self, data): """ # TODO: elegantly handle the dimensions - c = data["contentvec_feat"].transpose(1, 2) spec = data["linear"].transpose(1, 2) g = data["spk_id"] @@ -201,9 +199,9 @@ def forward(self, data): spec_lengths = data["target_len"] f0 = data["frame_pitch"] - x_mask = torch.unsqueeze(sequence_mask(c_lengths, c.size(2)), 1).to(c.dtype) # condition_encoder ver. x = self.condition_encoder(data).transpose(1, 2) + x_mask = torch.unsqueeze(sequence_mask(c_lengths, f0.size(1)), 1).to(x.dtype) # prior encoder z_ptemp, m_p, logs_p, _ = self.enc_p(x, x_mask) @@ -240,24 +238,23 @@ def forward(self, data): @torch.no_grad() def infer(self, data, noise_scale=0.35, seed=52468): # c, f0, uv, g - c = data["contentvec_feat"].transpose(1, 2) f0 = data["frame_pitch"] g = data["spk_id"] - if c.device == torch.device("cuda"): + if f0.device == torch.device("cuda"): torch.cuda.manual_seed_all(seed) else: torch.manual_seed(seed) - c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device) + c_lengths = (torch.ones(f0.size(0)) * f0.size(-1)).to(f0.device) if g.dim() == 1: g = g.unsqueeze(0) g = self.emb_g(g).transpose(1, 2) - x_mask = torch.unsqueeze(sequence_mask(c_lengths, c.size(2)), 1).to(c.dtype) # condition_encoder ver. x = self.condition_encoder(data).transpose(1, 2) + x_mask = torch.unsqueeze(sequence_mask(c_lengths, f0.size(1)), 1).to(x.dtype) z_p, m_p, logs_p, c_mask = self.enc_p(x, x_mask, noice_scale=noise_scale) z = self.flow(z_p, c_mask, g=g, reverse=True) diff --git a/models/svc/vits/vits_inference.py b/models/svc/vits/vits_inference.py index 3180f8c8..f1e4c1ac 100644 --- a/models/svc/vits/vits_inference.py +++ b/models/svc/vits/vits_inference.py @@ -9,6 +9,7 @@ import numpy as np from tqdm import tqdm import torch +from torch.utils.data import DataLoader from models.svc.base import SVCInference from models.svc.vits.vits import SynthesizerTrn @@ -47,14 +48,31 @@ def build_save_dir(self, dataset, speaker): print("Saving to ", save_dir) return save_dir + def _build_dataloader(self): + datasets, collate = self._build_test_dataset() + self.test_dataset = datasets(self.args, self.cfg, self.infer_type) + self.test_collate = collate(self.cfg) + self.test_batch_size = min( + self.cfg.inference.batch_size, len(self.test_dataset.metadata) + ) + test_dataloader = DataLoader( + self.test_dataset, + collate_fn=self.test_collate, + num_workers=1, + batch_size=self.test_batch_size, + shuffle=False, + ) + return test_dataloader + @torch.inference_mode() def inference(self): res = [] for i, batch in enumerate(self.test_dataloader): pred_audio_list = self._inference_each_batch(batch) - for it, wav in zip(self.test_dataset.metadata, pred_audio_list): - uid = it["Uid"] + for j, wav in enumerate(pred_audio_list): + uid = self.test_dataset.metadata[i * self.test_batch_size + j]["Uid"] file = os.path.join(self.args.output_dir, f"{uid}.wav") + print(f"Saving {file}") wav = wav.numpy(force=True) save_audio( diff --git a/models/svc/vits/vits_trainer.py b/models/svc/vits/vits_trainer.py index 618fd223..766a5faf 100644 --- a/models/svc/vits/vits_trainer.py +++ b/models/svc/vits/vits_trainer.py @@ -6,11 +6,14 @@ import torch from torch.optim.lr_scheduler import ExponentialLR from tqdm import tqdm +from pathlib import Path +import shutil +import accelerate # from models.svc.base import SVCTrainer from models.svc.base.svc_dataset import SVCCollator, SVCDataset from models.svc.vits.vits import * -from models.tts.base import TTSTrainer +from models.svc.base import SVCTrainer from utils.mel import mel_spectrogram_torch import json @@ -20,15 +23,79 @@ ) -class VitsSVCTrainer(TTSTrainer): +class VitsSVCTrainer(SVCTrainer): def __init__(self, args, cfg): self.args = args self.cfg = cfg - self._init_accelerator() - # Only for SVC tasks - with self.accelerator.main_process_first(): - self.singers = self._build_singer_lut() - TTSTrainer.__init__(self, args, cfg) + SVCTrainer.__init__(self, args, cfg) + + def _accelerator_prepare(self): + ( + self.train_dataloader, + self.valid_dataloader, + ) = self.accelerator.prepare( + self.train_dataloader, + self.valid_dataloader, + ) + if isinstance(self.model, dict): + for key in self.model.keys(): + self.model[key] = self.accelerator.prepare(self.model[key]) + else: + self.model = self.accelerator.prepare(self.model) + + if isinstance(self.optimizer, dict): + for key in self.optimizer.keys(): + self.optimizer[key] = self.accelerator.prepare(self.optimizer[key]) + else: + self.optimizer = self.accelerator.prepare(self.optimizer) + + if isinstance(self.scheduler, dict): + for key in self.scheduler.keys(): + self.scheduler[key] = self.accelerator.prepare(self.scheduler[key]) + else: + self.scheduler = self.accelerator.prepare(self.scheduler) + + def _load_model( + self, + checkpoint_dir: str = None, + checkpoint_path: str = None, + resume_type: str = "", + ): + r"""Load model from checkpoint. If checkpoint_path is None, it will + load the latest checkpoint in checkpoint_dir. If checkpoint_path is not + None, it will load the checkpoint specified by checkpoint_path. **Only use this + method after** ``accelerator.prepare()``. + """ + if checkpoint_path is None: + ls = [str(i) for i in Path(checkpoint_dir).glob("*")] + ls.sort(key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True) + checkpoint_path = ls[0] + self.logger.info("Resume from {}...".format(checkpoint_path)) + + if resume_type in ["resume", ""]: + # Load all the things, including model weights, optimizer, scheduler, and random states. + self.accelerator.load_state(input_dir=checkpoint_path) + + # set epoch and step + self.epoch = int(checkpoint_path.split("_")[-3].split("-")[-1]) + 1 + self.step = int(checkpoint_path.split("_")[-2].split("-")[-1]) + 1 + + elif resume_type == "finetune": + # Load only the model weights + accelerate.load_checkpoint_and_dispatch( + self.accelerator.unwrap_model(self.model["generator"]), + os.path.join(checkpoint_path, "pytorch_model.bin"), + ) + accelerate.load_checkpoint_and_dispatch( + self.accelerator.unwrap_model(self.model["discriminator"]), + os.path.join(checkpoint_path, "pytorch_model.bin"), + ) + self.logger.info("Load model weights for finetune...") + + else: + raise ValueError("Resume_type must be `resume` or `finetune`.") + + return checkpoint_path def _build_model(self): net_g = SynthesizerTrn( @@ -319,6 +386,185 @@ def _valid_step(self, batch): valid_stats, ) + @torch.inference_mode() + def _valid_epoch(self): + r"""Testing epoch. Should return average loss of a batch (sample) over + one epoch. See ``train_loop`` for usage. + """ + if isinstance(self.model, dict): + for key in self.model.keys(): + self.model[key].eval() + else: + self.model.eval() + + epoch_sum_loss = 0.0 + epoch_losses = dict() + for batch in tqdm( + self.valid_dataloader, + desc=f"Validating Epoch {self.epoch}", + unit="batch", + colour="GREEN", + leave=False, + dynamic_ncols=True, + smoothing=0.04, + disable=not self.accelerator.is_main_process, + ): + total_loss, valid_losses, valid_stats = self._valid_step(batch) + epoch_sum_loss += total_loss + if isinstance(valid_losses, dict): + for key, value in valid_losses.items(): + if key not in epoch_losses.keys(): + epoch_losses[key] = value + else: + epoch_losses[key] += value + + epoch_sum_loss = epoch_sum_loss / len(self.valid_dataloader) + for key in epoch_losses.keys(): + epoch_losses[key] = epoch_losses[key] / len(self.valid_dataloader) + + self.accelerator.wait_for_everyone() + + return epoch_sum_loss, epoch_losses + + ### THIS IS MAIN ENTRY ### + def train_loop(self): + r"""Training loop. The public entry of training process.""" + # Wait everyone to prepare before we move on + self.accelerator.wait_for_everyone() + # dump config file + if self.accelerator.is_main_process: + self.__dump_cfg(self.config_save_path) + + # self.optimizer.zero_grad() + # Wait to ensure good to go + + self.accelerator.wait_for_everyone() + while self.epoch < self.max_epoch: + self.logger.info("\n") + self.logger.info("-" * 32) + self.logger.info("Epoch {}: ".format(self.epoch)) + + # Do training & validating epoch + train_total_loss, train_losses = self._train_epoch() + if isinstance(train_losses, dict): + for key, loss in train_losses.items(): + self.logger.info(" |- Train/{} Loss: {:.6f}".format(key, loss)) + self.accelerator.log( + {"Epoch/Train {} Loss".format(key): loss}, + step=self.epoch, + ) + + valid_total_loss, valid_losses = self._valid_epoch() + if isinstance(valid_losses, dict): + for key, loss in valid_losses.items(): + self.logger.info(" |- Valid/{} Loss: {:.6f}".format(key, loss)) + self.accelerator.log( + {"Epoch/Train {} Loss".format(key): loss}, + step=self.epoch, + ) + + self.logger.info(" |- Train/Loss: {:.6f}".format(train_total_loss)) + self.logger.info(" |- Valid/Loss: {:.6f}".format(valid_total_loss)) + self.accelerator.log( + { + "Epoch/Train Loss": train_total_loss, + "Epoch/Valid Loss": valid_total_loss, + }, + step=self.epoch, + ) + + self.accelerator.wait_for_everyone() + + # Check if hit save_checkpoint_stride and run_eval + run_eval = False + if self.accelerator.is_main_process: + save_checkpoint = False + hit_dix = [] + for i, num in enumerate(self.save_checkpoint_stride): + if self.epoch % num == 0: + save_checkpoint = True + hit_dix.append(i) + run_eval |= self.run_eval[i] + + self.accelerator.wait_for_everyone() + if self.accelerator.is_main_process and save_checkpoint: + path = os.path.join( + self.checkpoint_dir, + "epoch-{:04d}_step-{:07d}_loss-{:.6f}".format( + self.epoch, self.step, train_total_loss + ), + ) + self.tmp_checkpoint_save_path = path + self.accelerator.save_state(path) + + json.dump( + self.checkpoints_path, + open(os.path.join(path, "ckpts.json"), "w"), + ensure_ascii=False, + indent=4, + ) + self._save_auxiliary_states() + + # Remove old checkpoints + to_remove = [] + for idx in hit_dix: + self.checkpoints_path[idx].append(path) + while len(self.checkpoints_path[idx]) > self.keep_last[idx]: + to_remove.append((idx, self.checkpoints_path[idx].pop(0))) + + # Search conflicts + total = set() + for i in self.checkpoints_path: + total |= set(i) + do_remove = set() + for idx, path in to_remove[::-1]: + if path in total: + self.checkpoints_path[idx].insert(0, path) + else: + do_remove.add(path) + + # Remove old checkpoints + for path in do_remove: + shutil.rmtree(path, ignore_errors=True) + self.logger.debug(f"Remove old checkpoint: {path}") + + self.accelerator.wait_for_everyone() + if run_eval: + # TODO: run evaluation + pass + + # Update info for each epoch + self.epoch += 1 + + # Finish training and save final checkpoint + self.accelerator.wait_for_everyone() + if self.accelerator.is_main_process: + path = os.path.join( + self.checkpoint_dir, + "final_epoch-{:04d}_step-{:07d}_loss-{:.6f}".format( + self.epoch, self.step, valid_total_loss + ), + ) + self.tmp_checkpoint_save_path = path + self.accelerator.save_state( + os.path.join( + self.checkpoint_dir, + "final_epoch-{:04d}_step-{:07d}_loss-{:.6f}".format( + self.epoch, self.step, valid_total_loss + ), + ) + ) + + json.dump( + self.checkpoints_path, + open(os.path.join(path, "ckpts.json"), "w"), + ensure_ascii=False, + indent=4, + ) + self._save_auxiliary_states() + + self.accelerator.end_training() + def _train_step(self, batch): r"""Forward step for training and inference. This function is called in ``_train_step`` & ``_test_step`` function. @@ -446,38 +692,13 @@ def _train_epoch(self): return epoch_sum_loss, epoch_losses - def _build_singer_lut(self): - resumed_singer_path = None - if self.args.resume_from_ckpt_path and self.args.resume_from_ckpt_path != "": - resumed_singer_path = os.path.join( - self.args.resume_from_ckpt_path, self.cfg.preprocess.spk2id - ) - if os.path.exists(os.path.join(self.exp_dir, self.cfg.preprocess.spk2id)): - resumed_singer_path = os.path.join(self.exp_dir, self.cfg.preprocess.spk2id) - - if resumed_singer_path: - with open(resumed_singer_path, "r") as f: - singers = json.load(f) - else: - singers = dict() - - for dataset in self.cfg.dataset: - singer_lut_path = os.path.join( - self.cfg.preprocess.processed_dir, dataset, self.cfg.preprocess.spk2id - ) - with open(singer_lut_path, "r") as singer_lut_path: - singer_lut = json.load(singer_lut_path) - for singer in singer_lut.keys(): - if singer not in singers: - singers[singer] = len(singers) - - with open( - os.path.join(self.exp_dir, self.cfg.preprocess.spk2id), "w" - ) as singer_file: - json.dump(singers, singer_file, indent=4, ensure_ascii=False) - print( - "singers have been dumped to {}".format( - os.path.join(self.exp_dir, self.cfg.preprocess.spk2id) - ) + def __dump_cfg(self, path): + os.makedirs(os.path.dirname(path), exist_ok=True) + json5.dump( + self.cfg, + open(path, "w"), + indent=4, + sort_keys=True, + ensure_ascii=False, + quote_keys=True, ) - return singers