Skip to content

Commit

Permalink
TODO: vocoder
Browse files Browse the repository at this point in the history
  • Loading branch information
cantabile-kwok committed Oct 8, 2023
1 parent ffe04e2 commit 0532b53
Show file tree
Hide file tree
Showing 6 changed files with 163 additions and 81 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -152,4 +152,5 @@ cython_debug/
.idea/
synthetic_wav/
exp/
**/*.wav
**/*.wav
.DS_Store
49 changes: 45 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,50 @@ Several notes:
* You can set `use_gt_dur` to `false` to turn on MAS algorithm. In this setting, it is better to set `add_blank` to `true`.

## Generate Data for ReFlow and Perform Reflow
TO BE DONE
After training the model to some degree, it can be ready for flow rectification process.
Flow rectification requires to generate data using the trained model and use the (noise, data) pair to train the model again.
As this process should always involve the whole training dataset, it is recommended to run on multiple GPUs for parallel decoding.
We provide a script to do this:
```shell
# Set CUDA_VISIBLE_DEVICES, or the program will use all available GPUs.
python generate_for_reflow.py -c configs/${your_yaml} -m ${model_name} \
--EMA --max-utt-num 100000000 \
--dataset train \
--solver euler -t 10 \
--gt-dur
# --EMA specifies to load EMA checkpoint (latest)
# --max-utt-num sets the number of utterances to decode (in this case, arbitrarily high)
# --solver euler -t 10 specifies the solver and timesteps. Could be adaptive solvers like dopri5.
# --gt-dur forces the model to use ground truth duration for decoding.
```
This will create `synthetic_wav/${model_name}/generate_for_reflow/train` for storage. `noise.scp` together with `feats.scp` will be stored.
After decoding the training set, you can also decode validation set by `--dataset val`.

Then, specify the paths to these `feats.scp` and `noise.scp` in a new configuration yaml, like in the `lj_16k_gt_dur_reflow.yaml`:
```yaml
perform_reflow: true
...
data:
train:
feats_scp: "synthetic_wav/lj_16k_gt_dur/train/feats.scp"
noise_scp: "synthetic_wav/lj_16k_gt_dur/train/noise.scp"
...
```

Now it is ready for training again in ReFlow, with the same script in training but new yaml config files.
Feel free to copy a trained model to the new log dir for resuming.
Also, it is possible to change the model structure and train from scratch on the reflow data.

## Inference
TO BE DONE
Similar to "generate data for reflow", model inference can be done by
```shell
python inference_dataset.py -c configs/${your_yaml} -m ${model_name} --EMA \
--solver euler -t 10
```
This will synthesize mel-spectrograms for the validation set in your config, storing them at `synthetic_wav/${model_name}/tts_gt_spk/feats.scp`.
Speaker, speed and temperature can be specified; see `tools.get_hparams_decode()` function for complete set of options.

> TODO: VOCODER
## Acknowledgement
During the development, the following repositories were referred to:
Expand All @@ -83,7 +124,7 @@ During the development, the following repositories were referred to:
* [VITS](https://github.com/jaywalnut310/vits), whose distributed bucket sampler is used.
* [CFM](https://github.com/atong01/conditional-flow-matching), for the ODE samplers.

## Easter Eggs & Citation
## 💡Easter Eggs & Citation
This repository also contains some experimental functionalities. ⚠️Warning: not guaranteed to be correct!
* **Voice conversion**. As GlowTTS can perform voice conversion via the disentangling property of normalizing flows, it is reasonable that flow matching can also perform it. Method `model.tts.GradTTS.voice_conversion` gives a preliminary try.

Expand All @@ -94,7 +135,7 @@ This repository also contains some experimental functionalities. ⚠️Warning:
In practice, integral is replaced by summation, and divergence is replaced by the Skilling-Hutchinson trace estimator. See the Appendix D.2 in [Song, et. al](https://arxiv.org/abs/2011.13456) for theoretical details. I implemented this in `model.tts.GradTTS.compute_likelihood`.
* **Optimal transport**. The conditional flow matching used in this paper is not a **marginally** optimal transport path but only a **conditionally** optimal path. For the marginal optimal transport, [Tong et. al](https://arxiv.org/abs/2302.00482) introduces to sample $x_0,x_1$ together from the joint optimal transport distribution $\pi(x_0,x_1)$. I tried this in `model.cfm.OTCFM`, though it doe not work very well for now.
* **Different estimator architectures**. You can specify an estimator besides the `GradLogPEstimator2d` by the `model.fm_net_type` configuration. Currently the [DiffSinger](https://ojs.aaai.org/index.php/AAAI/article/view/21350)'s estimator architecture is also supported. You can add more, e.g. that introduced in [Matcha-TTS](https://github.com/shivammehta25/Matcha-TTS).
* 💡**Better alignment learning**. This repo supports supervised duration modeling together with monotonic alignment search as that in GradTTS. However, there might be a better way for MAS in flow-matching TTS. `model.tts.GradTTS.forward` now supports beta binomial prior for alignment maps; and if you want, you can change the variable `MAS_target` to something else, e.g. flow-transformed noise!
* **Better alignment learning**. This repo supports supervised duration modeling together with monotonic alignment search as that in GradTTS. However, there might be a better way for MAS in flow-matching TTS. `model.tts.GradTTS.forward` now supports beta binomial prior for alignment maps; and if you want, you can change the variable `MAS_target` to something else, e.g. flow-transformed noise!

Feel free to cite this work if it helps 😄

Expand Down
57 changes: 57 additions & 0 deletions configs/lj_16k_gt_dur_reflow.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
xvector: false # whether to use xvector for speaker modeling.

perform_reflow: true # if true, will need noise_scp be specified

train:
test_size: 4
n_epochs: 10000
batch_size: 24
learning_rate: !!float 5e-5
seed: 37
save_every: 10
use_gt_dur: true # whether to supervise duration modeling

data:
sampling_rate: 16000
n_mel_channels: 80
add_blank: false # whether to add blank tokens between each input phones
hop_length: 200 # in sampling points

phn2id: "data/ljspeech/phones.txt"

train:
utts: "data/ljspeech/train/utts.list"
utt2phns: "data/ljspeech/train/text"
utt2phn_duration: "data/ljspeech/train/phn_duration"
feats_scp: "synthetic_wav/lj_16k_gt_dur/train/feats.scp"
noise_scp: "synthetic_wav/lj_16k_gt_dur/train/noise.scp"
utt2num_frames: "feats/normed_fbank/ljspeech/train/utt2num_frames"
utt2spk: "data/ljspeech/train/utt2spk_id.json"

val:
utts: "data/ljspeech/val/utts.list"
utt2phns: "data/ljspeech/val/text"
utt2phn_duration: "data/ljspeech/val/phn_duration"
feats_scp: "synthetic_wav/lj_16k_gt_dur/val/feats.scp"
noise_scp: "synthetic_wav/lj_16k_gt_dur/val/noise.scp"
utt2num_frames: "feats/normed_fbank/ljspeech/val/utt2num_frames"
utt2spk: "data/ljspeech/val/utt2spk_id.json"

model:
n_vocab: 148
n_spks: 1
spk_emb_dim: 64
n_enc_channels: 192
filter_channels: 768
filter_channels_dp: 256
n_enc_layers: 6
enc_kernel: 3
enc_dropout: 0.1
n_heads: 2
window_size: 4
dec_dim: 128
pe_scale: 1000
fm_type: "CFM" # FM, CFM
fm_net_type: "unet" # unet or diffsinger
shift_by_mu: false # whether to shift the prior distribution by mu. True means GradTTS-style.
condition_by_mu: true # whether to condition the flow matching decoder by mu. False supports text-agnostic voice conversion like GlowTTS.
43 changes: 15 additions & 28 deletions generate_for_reflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def run(rank, n_gpus, hps, args, ckpt, feats_dir, temp_dir):
model = model(**hps.model).to(device)
tools.load_checkpoint(ckpt, model, None)
print(f"Loaded checkpoint from {ckpt}")
model.to(device).eval()
model.eval()
print(f"Number of parameters: {model.nparams}")
print(f"Number of encoder parameters: {model.encoder.nparams}")
print(f"Number of decoder parameters: {model.decoder.nparams}")
Expand Down Expand Up @@ -88,38 +88,25 @@ def run(rank, n_gpus, hps, args, ckpt, feats_dir, temp_dir):
if hps.xvector:
if args.use_control_spk:
xvector = which_set.spk2xvector[args.control_spk_name]
xvector = (
torch.FloatTensor(xvector).squeeze().unsqueeze(0).to(device)
)
spk = torch.FloatTensor(xvector).squeeze().unsqueeze(0).to(device)
else:
xvector = batch["xvector"].to(device)

y_enc, y_dec, attn, z, pred_dur = model.inference(
x,
x_lengths,
n_timesteps=args.timesteps,
temperature=1.5,
spk=xvector,
length_scale=1.0,
solver=args.solver,
gt_dur=dur,
)
spk = batch["xvector"].to(device)
else:
if args.use_control_spk:
sid = torch.LongTensor([args.control_spk_id]).to(device)
spk = torch.LongTensor([args.control_spk_id]).to(device)
else:
sid = batch["spk_ids"].to(device)
spk = batch["spk_ids"].to(device)

y_enc, y_dec, attn, z, pred_dur = model.inference(
x,
x_lengths,
n_timesteps=args.timesteps,
temperature=1.5,
spk=sid,
length_scale=1.0,
solver=args.solver,
gt_dur=dur,
)
y_enc, y_dec, attn, z, pred_dur = model.inference(
x,
x_lengths,
n_timesteps=args.timesteps,
temperature=1.5,
spk=spk,
length_scale=1.0,
solver=args.solver,
gt_dur=dur,
)
# =================================================

if args.use_control_spk:
Expand Down
28 changes: 12 additions & 16 deletions inference_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import tools


# @profile
def evaluate(hps, args, ckpt, feats_dir):
logger = tools.get_logger(hps.model_dir, "inference.log")
device = torch.device('cpu' if not torch.cuda.is_available() else "cuda")
Expand All @@ -25,7 +24,7 @@ def evaluate(hps, args, ckpt, feats_dir):
model = model(**hps.model).to(device)
tools.load_checkpoint(ckpt, model, None)
print(f"Loaded checkpoint from {ckpt}")
_ = model.cuda().eval()
model.eval()
print(f'Number of parameters: {model.nparams}')
print(f"Number of encoder parameters: {model.encoder.nparams}")
print(f"Number of decoder parameters: {model.decoder.nparams}")
Expand Down Expand Up @@ -64,23 +63,20 @@ def evaluate(hps, args, ckpt, feats_dir):
if hps.xvector:
if args.use_control_spk:
xvector = which_set.spk2xvector[args.control_spk_name]
xvector = torch.FloatTensor(xvector).squeeze().unsqueeze(0).to(device)
spk = torch.FloatTensor(xvector).squeeze().unsqueeze(0).to(device)
else:
xvector = batch['xvector'].to(device)
s = time.time()
y_enc, y_dec, attn, z, pred_dur = model.inference(x, x_lengths, n_timesteps=args.timesteps, temperature=args.temperature,
spk=xvector, length_scale=args.duration_scale, solver=args.solver, gt_dur=dur)
t = time.time()
spk = batch['xvector'].to(device)

else:
if args.use_control_spk:
sid = torch.LongTensor([args.control_spk_id]).to(device)
spk = torch.LongTensor([args.control_spk_id]).to(device)
else:
sid = batch['spk_ids'].to(device)
s = time.time()
y_enc, y_dec, attn, z, pred_dur = model.inference(x, x_lengths, n_timesteps=args.timesteps, temperature=args.temperature,
spk=sid, length_scale=args.duration_scale, solver=args.solver, gt_dur=dur)
t = time.time()
total_inference_time += t-s
spk = batch['spk_ids'].to(device)
s = time.time()
y_enc, y_dec, attn, z, pred_dur = model.inference(x, x_lengths, n_timesteps=args.timesteps, temperature=args.temperature,
spk=spk, length_scale=args.duration_scale, solver=args.solver, gt_dur=dur)
t = time.time()
total_inference_time += t - s
total_inference_frames += y_dec.squeeze().shape[1]
# =================================================

Expand All @@ -91,7 +87,7 @@ def evaluate(hps, args, ckpt, feats_dir):

feats(save_utt_name, y_dec.squeeze().cpu().numpy().T) # save to ark and scp, mel: (L, 80)
print(f"Inference finished. Total time: {total_inference_time}, total frames: {total_inference_frames} "
f"==> {total_inference_frames/total_inference_time} frame/s")
f"==> {total_inference_frames / total_inference_time} frame/s")


if __name__ == '__main__':
Expand Down
64 changes: 32 additions & 32 deletions tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,38 +162,6 @@ def get_correct_class(hps, train=True):
return dataset, collate(), model


def get_hparams(init=True):
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', type=str, default="./configs/base.yaml",
help='YAML file for configuration')
parser.add_argument('-m', '--model', type=str, required=True,
help='Model name')
parser.add_argument('-s', '--seed', type=int, default=1234)

args = parser.parse_args()
model_dir = os.path.join("./logs", args.model)

if not os.path.exists(model_dir):
os.makedirs(model_dir)

config_path = args.config
config_save_path = os.path.join(model_dir, "config.yaml")
if init:
with open(config_path, "r") as f:
data = f.read()
with open(config_save_path, "w") as f:
f.write(data)
else:
with open(config_save_path, "r") as f:
data = f.read()
config = yaml.load(data, Loader=yaml.FullLoader)

hparams = HParams(**config)
hparams.model_dir = model_dir
hparams.train.seed = args.seed
return hparams


class HParams():
def __init__(self, **kwargs):
for k, v in kwargs.items():
Expand Down Expand Up @@ -254,6 +222,38 @@ def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path)
'learning_rate': learning_rate}, checkpoint_path)


def get_hparams(init=True):
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', type=str, default="./configs/base.yaml",
help='YAML file for configuration')
parser.add_argument('-m', '--model', type=str, required=True,
help='Model name')
parser.add_argument('-s', '--seed', type=int, default=1234)

args = parser.parse_args()
model_dir = os.path.join("./logs", args.model)

if not os.path.exists(model_dir):
os.makedirs(model_dir)

config_path = args.config
config_save_path = os.path.join(model_dir, "config.yaml")
if init:
with open(config_path, "r") as f:
data = f.read()
with open(config_save_path, "w") as f:
f.write(data)
else:
with open(config_save_path, "r") as f:
data = f.read()
config = yaml.load(data, Loader=yaml.FullLoader)

hparams = HParams(**config)
hparams.model_dir = model_dir
hparams.train.seed = args.seed
return hparams


def get_hparams_decode():
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', type=str, default="./configs/base.yaml",
Expand Down

0 comments on commit 0532b53

Please sign in to comment.