Skip to content

Commit

Permalink
training code done
Browse files Browse the repository at this point in the history
  • Loading branch information
wl-zhao committed Mar 10, 2024
1 parent c9c57a1 commit 7ade7b7
Show file tree
Hide file tree
Showing 16 changed files with 1,533 additions and 47 deletions.
5 changes: 1 addition & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ Some other features include:
## Usage
- [Use without Installation](docs/quick_use.md)
- [Install and Use Locally](docs/install.md)
- [Training on Custom Dataset](docs/training.md)

The Python API and model cards can be found in [this repo](https://github.com/myshell-ai/MeloTTS/blob/main/docs/install.md#python-api) or on [HuggingFace](https://huggingface.co/myshell-ai).

Expand Down Expand Up @@ -57,10 +58,6 @@ If you find this work useful, please consider contributing to this repo.
}
```

## TODO

- Training code release.

## License

This library is under MIT License, which means it is free for both commercial and non-commercial use.
Expand Down
37 changes: 37 additions & 0 deletions docs/training.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
## Training

Before training, please install MeloTTS in dev mode and go to the `melo` folder.
```
pip install -e .
cd melo
```

### Data Preparation
To train a TTS model, we need to prepare the audio files and a metadata file. We recommend using 44100Hz audio files and the metadata file should have the following format:

```
path/to/audio_001.wav |<speaker_name>|<language_code>|<text_001>
path/to/audio_002.wav |<speaker_name>|<language_code>|<text_002>
```
The transcribed text can be obtained by ASR model, (e.g., [whisper](https://github.com/openai/whisper)). An example metadata can be found in `data/example/metadata.list`

We can then run the preprocessing code:
```
python preprocess_text.py --metadata data/example/metadata.list
```
A config file `data/example/config.json` will be generated. Feel free to edit some hyper-parameters in that config file (for example, you may decrease the batch size if you have encountered the CUDA out-of-memory issue).

### Training
The training can be launched by:
```
bash train.sh <path/to/config.json> <num_of_gpus>
```

We have found for some machine the training will sometimes crash due to an [issue](https://github.com/pytorch/pytorch/issues/2530) of gloo. Therefore, we add an auto-resume wrapper in the `train.sh`.

### Inference
Simply run:
```
python infer.py --text "<some text here>" -m /path/to/checkpoint/G_<iter>.pth -o <output_dir>
```

8 changes: 5 additions & 3 deletions melo/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ class TTS(nn.Module):
def __init__(self,
language,
device='auto',
use_hf=True):
use_hf=True,
config_path=None,
ckpt_path=None):
super().__init__()
if device == 'auto':
device = 'cpu'
Expand All @@ -31,7 +33,7 @@ def __init__(self,
assert torch.cuda.is_available()

# config_path =
hps = load_or_download_config(language, use_hf=use_hf)
hps = load_or_download_config(language, use_hf=use_hf, config_path=config_path)

num_languages = hps.num_languages
num_tones = hps.num_tones
Expand All @@ -54,7 +56,7 @@ def __init__(self,
self.device = device

# load state_dict
checkpoint_dict = load_or_download_model(language, device, use_hf=use_hf)
checkpoint_dict = load_or_download_model(language, device, use_hf=use_hf, ckpt_path=ckpt_path)
self.model.load_state_dict(checkpoint_dict['model'], strict=True)

language = language.split('_')[0]
Expand Down
94 changes: 94 additions & 0 deletions melo/configs/config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
{
"train": {
"log_interval": 200,
"eval_interval": 1000,
"seed": 52,
"epochs": 10000,
"learning_rate": 0.0003,
"betas": [
0.8,
0.99
],
"eps": 1e-09,
"batch_size": 6,
"fp16_run": false,
"lr_decay": 0.999875,
"segment_size": 16384,
"init_lr_ratio": 1,
"warmup_epochs": 0,
"c_mel": 45,
"c_kl": 1.0,
"skip_optimizer": true
},
"data": {
"training_files": "",
"validation_files": "",
"max_wav_value": 32768.0,
"sampling_rate": 44100,
"filter_length": 2048,
"hop_length": 512,
"win_length": 2048,
"n_mel_channels": 128,
"mel_fmin": 0.0,
"mel_fmax": null,
"add_blank": true,
"n_speakers": 256,
"cleaned_text": true,
"spk2id": {}
},
"model": {
"use_spk_conditioned_encoder": true,
"use_noise_scaled_mas": true,
"use_mel_posterior_encoder": false,
"use_duration_discriminator": true,
"inter_channels": 192,
"hidden_channels": 192,
"filter_channels": 768,
"n_heads": 2,
"n_layers": 6,
"n_layers_trans_flow": 3,
"kernel_size": 3,
"p_dropout": 0.1,
"resblock": "1",
"resblock_kernel_sizes": [
3,
7,
11
],
"resblock_dilation_sizes": [
[
1,
3,
5
],
[
1,
3,
5
],
[
1,
3,
5
]
],
"upsample_rates": [
8,
8,
2,
2,
2
],
"upsample_initial_channel": 512,
"upsample_kernel_sizes": [
16,
16,
8,
2,
2
],
"n_layers_q": 3,
"use_spectral_norm": false,
"gin_channels": 256
}
}
Loading

0 comments on commit 7ade7b7

Please sign in to comment.