diff --git a/.gitignore b/.gitignore index 029eac4..4d29eca 100644 --- a/.gitignore +++ b/.gitignore @@ -152,4 +152,5 @@ cython_debug/ .idea/ synthetic_wav/ exp/ -**/*.wav \ No newline at end of file +**/*.wav +.DS_Store \ No newline at end of file diff --git a/README.md b/README.md index e59f2e7..8e86d43 100644 --- a/README.md +++ b/README.md @@ -72,9 +72,50 @@ Several notes: * You can set `use_gt_dur` to `false` to turn on MAS algorithm. In this setting, it is better to set `add_blank` to `true`. ## Generate Data for ReFlow and Perform Reflow -TO BE DONE +After training the model to some degree, it can be ready for flow rectification process. +Flow rectification requires to generate data using the trained model and use the (noise, data) pair to train the model again. +As this process should always involve the whole training dataset, it is recommended to run on multiple GPUs for parallel decoding. +We provide a script to do this: +```shell +# Set CUDA_VISIBLE_DEVICES, or the program will use all available GPUs. +python generate_for_reflow.py -c configs/${your_yaml} -m ${model_name} \ + --EMA --max-utt-num 100000000 \ + --dataset train \ + --solver euler -t 10 \ + --gt-dur +# --EMA specifies to load EMA checkpoint (latest) +# --max-utt-num sets the number of utterances to decode (in this case, arbitrarily high) +# --solver euler -t 10 specifies the solver and timesteps. Could be adaptive solvers like dopri5. +# --gt-dur forces the model to use ground truth duration for decoding. +``` +This will create `synthetic_wav/${model_name}/generate_for_reflow/train` for storage. `noise.scp` together with `feats.scp` will be stored. +After decoding the training set, you can also decode validation set by `--dataset val`. + +Then, specify the paths to these `feats.scp` and `noise.scp` in a new configuration yaml, like in the `lj_16k_gt_dur_reflow.yaml`: +```yaml +perform_reflow: true +... +data: + train: + feats_scp: "synthetic_wav/lj_16k_gt_dur/train/feats.scp" + noise_scp: "synthetic_wav/lj_16k_gt_dur/train/noise.scp" +... +``` + +Now it is ready for training again in ReFlow, with the same script in training but new yaml config files. +Feel free to copy a trained model to the new log dir for resuming. +Also, it is possible to change the model structure and train from scratch on the reflow data. + ## Inference -TO BE DONE +Similar to "generate data for reflow", model inference can be done by +```shell +python inference_dataset.py -c configs/${your_yaml} -m ${model_name} --EMA \ + --solver euler -t 10 +``` +This will synthesize mel-spectrograms for the validation set in your config, storing them at `synthetic_wav/${model_name}/tts_gt_spk/feats.scp`. +Speaker, speed and temperature can be specified; see `tools.get_hparams_decode()` function for complete set of options. + +> TODO: VOCODER ## Acknowledgement During the development, the following repositories were referred to: @@ -83,7 +124,7 @@ During the development, the following repositories were referred to: * [VITS](https://github.com/jaywalnut310/vits), whose distributed bucket sampler is used. * [CFM](https://github.com/atong01/conditional-flow-matching), for the ODE samplers. -## Easter Eggs & Citation +## 💡Easter Eggs & Citation This repository also contains some experimental functionalities. ⚠️Warning: not guaranteed to be correct! * **Voice conversion**. As GlowTTS can perform voice conversion via the disentangling property of normalizing flows, it is reasonable that flow matching can also perform it. Method `model.tts.GradTTS.voice_conversion` gives a preliminary try. @@ -94,7 +135,7 @@ This repository also contains some experimental functionalities. ⚠️Warning: In practice, integral is replaced by summation, and divergence is replaced by the Skilling-Hutchinson trace estimator. See the Appendix D.2 in [Song, et. al](https://arxiv.org/abs/2011.13456) for theoretical details. I implemented this in `model.tts.GradTTS.compute_likelihood`. * **Optimal transport**. The conditional flow matching used in this paper is not a **marginally** optimal transport path but only a **conditionally** optimal path. For the marginal optimal transport, [Tong et. al](https://arxiv.org/abs/2302.00482) introduces to sample $x_0,x_1$ together from the joint optimal transport distribution $\pi(x_0,x_1)$. I tried this in `model.cfm.OTCFM`, though it doe not work very well for now. * **Different estimator architectures**. You can specify an estimator besides the `GradLogPEstimator2d` by the `model.fm_net_type` configuration. Currently the [DiffSinger](https://ojs.aaai.org/index.php/AAAI/article/view/21350)'s estimator architecture is also supported. You can add more, e.g. that introduced in [Matcha-TTS](https://github.com/shivammehta25/Matcha-TTS). -* 💡**Better alignment learning**. This repo supports supervised duration modeling together with monotonic alignment search as that in GradTTS. However, there might be a better way for MAS in flow-matching TTS. `model.tts.GradTTS.forward` now supports beta binomial prior for alignment maps; and if you want, you can change the variable `MAS_target` to something else, e.g. flow-transformed noise! +* **Better alignment learning**. This repo supports supervised duration modeling together with monotonic alignment search as that in GradTTS. However, there might be a better way for MAS in flow-matching TTS. `model.tts.GradTTS.forward` now supports beta binomial prior for alignment maps; and if you want, you can change the variable `MAS_target` to something else, e.g. flow-transformed noise! Feel free to cite this work if it helps 😄 diff --git a/configs/lj_16k_gt_dur_reflow.yaml b/configs/lj_16k_gt_dur_reflow.yaml new file mode 100644 index 0000000..e3c7051 --- /dev/null +++ b/configs/lj_16k_gt_dur_reflow.yaml @@ -0,0 +1,57 @@ +xvector: false # whether to use xvector for speaker modeling. + +perform_reflow: true # if true, will need noise_scp be specified + +train: + test_size: 4 + n_epochs: 10000 + batch_size: 24 + learning_rate: !!float 5e-5 + seed: 37 + save_every: 10 + use_gt_dur: true # whether to supervise duration modeling + +data: + sampling_rate: 16000 + n_mel_channels: 80 + add_blank: false # whether to add blank tokens between each input phones + hop_length: 200 # in sampling points + + phn2id: "data/ljspeech/phones.txt" + + train: + utts: "data/ljspeech/train/utts.list" + utt2phns: "data/ljspeech/train/text" + utt2phn_duration: "data/ljspeech/train/phn_duration" + feats_scp: "synthetic_wav/lj_16k_gt_dur/train/feats.scp" + noise_scp: "synthetic_wav/lj_16k_gt_dur/train/noise.scp" + utt2num_frames: "feats/normed_fbank/ljspeech/train/utt2num_frames" + utt2spk: "data/ljspeech/train/utt2spk_id.json" + + val: + utts: "data/ljspeech/val/utts.list" + utt2phns: "data/ljspeech/val/text" + utt2phn_duration: "data/ljspeech/val/phn_duration" + feats_scp: "synthetic_wav/lj_16k_gt_dur/val/feats.scp" + noise_scp: "synthetic_wav/lj_16k_gt_dur/val/noise.scp" + utt2num_frames: "feats/normed_fbank/ljspeech/val/utt2num_frames" + utt2spk: "data/ljspeech/val/utt2spk_id.json" + +model: + n_vocab: 148 + n_spks: 1 + spk_emb_dim: 64 + n_enc_channels: 192 + filter_channels: 768 + filter_channels_dp: 256 + n_enc_layers: 6 + enc_kernel: 3 + enc_dropout: 0.1 + n_heads: 2 + window_size: 4 + dec_dim: 128 + pe_scale: 1000 + fm_type: "CFM" # FM, CFM + fm_net_type: "unet" # unet or diffsinger + shift_by_mu: false # whether to shift the prior distribution by mu. True means GradTTS-style. + condition_by_mu: true # whether to condition the flow matching decoder by mu. False supports text-agnostic voice conversion like GlowTTS. diff --git a/generate_for_reflow.py b/generate_for_reflow.py index 1f59e35..5152dcd 100644 --- a/generate_for_reflow.py +++ b/generate_for_reflow.py @@ -42,7 +42,7 @@ def run(rank, n_gpus, hps, args, ckpt, feats_dir, temp_dir): model = model(**hps.model).to(device) tools.load_checkpoint(ckpt, model, None) print(f"Loaded checkpoint from {ckpt}") - model.to(device).eval() + model.eval() print(f"Number of parameters: {model.nparams}") print(f"Number of encoder parameters: {model.encoder.nparams}") print(f"Number of decoder parameters: {model.decoder.nparams}") @@ -88,38 +88,25 @@ def run(rank, n_gpus, hps, args, ckpt, feats_dir, temp_dir): if hps.xvector: if args.use_control_spk: xvector = which_set.spk2xvector[args.control_spk_name] - xvector = ( - torch.FloatTensor(xvector).squeeze().unsqueeze(0).to(device) - ) + spk = torch.FloatTensor(xvector).squeeze().unsqueeze(0).to(device) else: - xvector = batch["xvector"].to(device) - - y_enc, y_dec, attn, z, pred_dur = model.inference( - x, - x_lengths, - n_timesteps=args.timesteps, - temperature=1.5, - spk=xvector, - length_scale=1.0, - solver=args.solver, - gt_dur=dur, - ) + spk = batch["xvector"].to(device) else: if args.use_control_spk: - sid = torch.LongTensor([args.control_spk_id]).to(device) + spk = torch.LongTensor([args.control_spk_id]).to(device) else: - sid = batch["spk_ids"].to(device) + spk = batch["spk_ids"].to(device) - y_enc, y_dec, attn, z, pred_dur = model.inference( - x, - x_lengths, - n_timesteps=args.timesteps, - temperature=1.5, - spk=sid, - length_scale=1.0, - solver=args.solver, - gt_dur=dur, - ) + y_enc, y_dec, attn, z, pred_dur = model.inference( + x, + x_lengths, + n_timesteps=args.timesteps, + temperature=1.5, + spk=spk, + length_scale=1.0, + solver=args.solver, + gt_dur=dur, + ) # ================================================= if args.use_control_spk: diff --git a/inference_dataset.py b/inference_dataset.py index 6079eda..8ae43ef 100644 --- a/inference_dataset.py +++ b/inference_dataset.py @@ -8,7 +8,6 @@ import tools -# @profile def evaluate(hps, args, ckpt, feats_dir): logger = tools.get_logger(hps.model_dir, "inference.log") device = torch.device('cpu' if not torch.cuda.is_available() else "cuda") @@ -25,7 +24,7 @@ def evaluate(hps, args, ckpt, feats_dir): model = model(**hps.model).to(device) tools.load_checkpoint(ckpt, model, None) print(f"Loaded checkpoint from {ckpt}") - _ = model.cuda().eval() + model.eval() print(f'Number of parameters: {model.nparams}') print(f"Number of encoder parameters: {model.encoder.nparams}") print(f"Number of decoder parameters: {model.decoder.nparams}") @@ -64,23 +63,20 @@ def evaluate(hps, args, ckpt, feats_dir): if hps.xvector: if args.use_control_spk: xvector = which_set.spk2xvector[args.control_spk_name] - xvector = torch.FloatTensor(xvector).squeeze().unsqueeze(0).to(device) + spk = torch.FloatTensor(xvector).squeeze().unsqueeze(0).to(device) else: - xvector = batch['xvector'].to(device) - s = time.time() - y_enc, y_dec, attn, z, pred_dur = model.inference(x, x_lengths, n_timesteps=args.timesteps, temperature=args.temperature, - spk=xvector, length_scale=args.duration_scale, solver=args.solver, gt_dur=dur) - t = time.time() + spk = batch['xvector'].to(device) + else: if args.use_control_spk: - sid = torch.LongTensor([args.control_spk_id]).to(device) + spk = torch.LongTensor([args.control_spk_id]).to(device) else: - sid = batch['spk_ids'].to(device) - s = time.time() - y_enc, y_dec, attn, z, pred_dur = model.inference(x, x_lengths, n_timesteps=args.timesteps, temperature=args.temperature, - spk=sid, length_scale=args.duration_scale, solver=args.solver, gt_dur=dur) - t = time.time() - total_inference_time += t-s + spk = batch['spk_ids'].to(device) + s = time.time() + y_enc, y_dec, attn, z, pred_dur = model.inference(x, x_lengths, n_timesteps=args.timesteps, temperature=args.temperature, + spk=spk, length_scale=args.duration_scale, solver=args.solver, gt_dur=dur) + t = time.time() + total_inference_time += t - s total_inference_frames += y_dec.squeeze().shape[1] # ================================================= @@ -91,7 +87,7 @@ def evaluate(hps, args, ckpt, feats_dir): feats(save_utt_name, y_dec.squeeze().cpu().numpy().T) # save to ark and scp, mel: (L, 80) print(f"Inference finished. Total time: {total_inference_time}, total frames: {total_inference_frames} " - f"==> {total_inference_frames/total_inference_time} frame/s") + f"==> {total_inference_frames / total_inference_time} frame/s") if __name__ == '__main__': diff --git a/tools.py b/tools.py index 980900f..4d1be65 100644 --- a/tools.py +++ b/tools.py @@ -162,38 +162,6 @@ def get_correct_class(hps, train=True): return dataset, collate(), model -def get_hparams(init=True): - parser = argparse.ArgumentParser() - parser.add_argument('-c', '--config', type=str, default="./configs/base.yaml", - help='YAML file for configuration') - parser.add_argument('-m', '--model', type=str, required=True, - help='Model name') - parser.add_argument('-s', '--seed', type=int, default=1234) - - args = parser.parse_args() - model_dir = os.path.join("./logs", args.model) - - if not os.path.exists(model_dir): - os.makedirs(model_dir) - - config_path = args.config - config_save_path = os.path.join(model_dir, "config.yaml") - if init: - with open(config_path, "r") as f: - data = f.read() - with open(config_save_path, "w") as f: - f.write(data) - else: - with open(config_save_path, "r") as f: - data = f.read() - config = yaml.load(data, Loader=yaml.FullLoader) - - hparams = HParams(**config) - hparams.model_dir = model_dir - hparams.train.seed = args.seed - return hparams - - class HParams(): def __init__(self, **kwargs): for k, v in kwargs.items(): @@ -254,6 +222,38 @@ def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path) 'learning_rate': learning_rate}, checkpoint_path) +def get_hparams(init=True): + parser = argparse.ArgumentParser() + parser.add_argument('-c', '--config', type=str, default="./configs/base.yaml", + help='YAML file for configuration') + parser.add_argument('-m', '--model', type=str, required=True, + help='Model name') + parser.add_argument('-s', '--seed', type=int, default=1234) + + args = parser.parse_args() + model_dir = os.path.join("./logs", args.model) + + if not os.path.exists(model_dir): + os.makedirs(model_dir) + + config_path = args.config + config_save_path = os.path.join(model_dir, "config.yaml") + if init: + with open(config_path, "r") as f: + data = f.read() + with open(config_save_path, "w") as f: + f.write(data) + else: + with open(config_save_path, "r") as f: + data = f.read() + config = yaml.load(data, Loader=yaml.FullLoader) + + hparams = HParams(**config) + hparams.model_dir = model_dir + hparams.train.seed = args.seed + return hparams + + def get_hparams_decode(): parser = argparse.ArgumentParser() parser.add_argument('-c', '--config', type=str, default="./configs/base.yaml",