-
Notifications
You must be signed in to change notification settings - Fork 630
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Voice Conversion Model: Noro
- Loading branch information
Showing
19 changed files
with
3,396 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
# Copyright (c) 2023 Amphion. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import argparse | ||
|
||
import torch | ||
from models.vc.Noro.noro_trainer import NoroTrainer | ||
from utils.util import load_config | ||
|
||
|
||
def build_trainer(args, cfg): | ||
supported_trainer = { | ||
"VC": NoroTrainer, | ||
} | ||
trainer_class = supported_trainer[cfg.model_type] | ||
trainer = trainer_class(args, cfg) | ||
return trainer | ||
|
||
|
||
def cuda_relevant(deterministic=False): | ||
torch.cuda.empty_cache() | ||
# TF32 on Ampere and above | ||
torch.backends.cuda.matmul.allow_tf32 = True | ||
torch.backends.cudnn.enabled = True | ||
torch.backends.cudnn.allow_tf32 = True | ||
# Deterministic | ||
torch.backends.cudnn.deterministic = deterministic | ||
torch.backends.cudnn.benchmark = not deterministic | ||
torch.use_deterministic_algorithms(deterministic) | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--config", | ||
default="config.json", | ||
help="json files for configurations.", | ||
required=True, | ||
) | ||
parser.add_argument( | ||
"--exp_name", | ||
type=str, | ||
default="exp_name", | ||
help="A specific name to note the experiment", | ||
required=True, | ||
) | ||
parser.add_argument( | ||
"--resume", action="store_true", help="The model name to restore" | ||
) | ||
parser.add_argument( | ||
"--log_level", default="warning", help="logging level (debug, info, warning)" | ||
) | ||
parser.add_argument( | ||
"--resume_type", | ||
type=str, | ||
default="resume", | ||
help="Resume training or finetuning.", | ||
) | ||
parser.add_argument( | ||
"--checkpoint_path", | ||
type=str, | ||
default=None, | ||
help="Checkpoint for resume training or finetuning.", | ||
) | ||
args = parser.parse_args() | ||
cfg = load_config(args.config) | ||
print("experiment name: ", args.exp_name) | ||
# # CUDA settings | ||
cuda_relevant() | ||
# Build trainer | ||
print(f"Building {cfg.model_type} trainer") | ||
trainer = build_trainer(args, cfg) | ||
torch.set_num_threads(1) | ||
torch.set_num_interop_threads(1) | ||
print(f"Start training {cfg.model_type} model") | ||
trainer.train_loop() | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
{ | ||
"base_config": "config/base.json", | ||
"model_type": "VC", | ||
"dataset": ["mls"], | ||
"model": { | ||
"reference_encoder": { | ||
"encoder_layer": 6, | ||
"encoder_hidden": 512, | ||
"encoder_head": 8, | ||
"conv_filter_size": 2048, | ||
"conv_kernel_size": 9, | ||
"encoder_dropout": 0.2, | ||
"use_skip_connection": false, | ||
"use_new_ffn": true, | ||
"ref_in_dim": 80, | ||
"ref_out_dim": 512, | ||
"use_query_emb": true, | ||
"num_query_emb": 32 | ||
}, | ||
"diffusion": { | ||
"beta_min": 0.05, | ||
"beta_max": 20, | ||
"sigma": 1.0, | ||
"noise_factor": 1.0, | ||
"ode_solve_method": "euler", | ||
"diff_model_type": "WaveNet", | ||
"diff_wavenet":{ | ||
"input_size": 80, | ||
"hidden_size": 512, | ||
"out_size": 80, | ||
"num_layers": 47, | ||
"cross_attn_per_layer": 3, | ||
"dilation_cycle": 2, | ||
"attn_head": 8, | ||
"drop_out": 0.2 | ||
} | ||
}, | ||
"prior_encoder": { | ||
"encoder_layer": 6, | ||
"encoder_hidden": 512, | ||
"encoder_head": 8, | ||
"conv_filter_size": 2048, | ||
"conv_kernel_size": 9, | ||
"encoder_dropout": 0.2, | ||
"use_skip_connection": false, | ||
"use_new_ffn": true, | ||
"vocab_size": 256, | ||
"cond_dim": 512, | ||
"duration_predictor": { | ||
"input_size": 512, | ||
"filter_size": 512, | ||
"kernel_size": 3, | ||
"conv_layers": 30, | ||
"cross_attn_per_layer": 3, | ||
"attn_head": 8, | ||
"drop_out": 0.2 | ||
}, | ||
"pitch_predictor": { | ||
"input_size": 512, | ||
"filter_size": 512, | ||
"kernel_size": 5, | ||
"conv_layers": 30, | ||
"cross_attn_per_layer": 3, | ||
"attn_head": 8, | ||
"drop_out": 0.5 | ||
}, | ||
"pitch_min": 50, | ||
"pitch_max": 1100, | ||
"pitch_bins_num": 512 | ||
}, | ||
"vc_feature": { | ||
"content_feature_dim": 768, | ||
"hidden_dim": 512 | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
# Noro: A Noise-Robust One-shot Voice Conversion System | ||
|
||
<br> | ||
<div align="center"> | ||
<img src="../../../imgs/vc/NoroVC.png" width="85%"> | ||
</div> | ||
<br> | ||
|
||
This is the official implementation of the paper: NORO: A Noise-Robust One-Shot Voice Conversion System with Hidden Speaker Representation Capabilities. | ||
|
||
- The semantic extractor is from [Hubert](https://github.com/facebookresearch/fairseq/tree/main/examples/hubert). | ||
- The vocoder is [BigVGAN](https://github.com/NVIDIA/BigVGAN) architecture. | ||
|
||
## Project Overview | ||
Noro is a noise-robust one-shot voice conversion (VC) system designed to convert the timbre of speech from a source speaker to a target speaker using only a single reference speech sample, while preserving the semantic content of the original speech. Noro introduces innovative components tailored for VC using noisy reference speeches, including a dual-branch reference encoding module and a noise-agnostic contrastive speaker loss. | ||
|
||
## Features | ||
- **Noise-Robust Voice Conversion**: Utilizes a dual-branch reference encoding module and noise-agnostic contrastive speaker loss to maintain high-quality voice conversion in noisy environments. | ||
- **One-shot Voice Conversion**: Achieves timbre conversion using only one reference speech sample. | ||
- **Speaker Representation Learning**: Explores the potential of the reference encoder as a self-supervised speaker encoder. | ||
|
||
## Installation Requirement | ||
|
||
Set up your environment as in Amphion README (you'll need a conda environment, and we recommend using Linux). | ||
|
||
### Prepare Hubert Model | ||
|
||
Humbert checkpoint and kmeans can be downloaded [here](https://github.com/facebookresearch/fairseq/tree/main/examples/hubert). | ||
Set the downloded model path at `egs/vc/Noro/exp_config_base.json`. | ||
|
||
|
||
## Usage | ||
|
||
### Download pretrained weights | ||
You need to download our pretrained weights from [Google Drive](https://drive.google.com/drive/folders/1NPzSIuSKO8o87g5ImNzpw_BgbhsZaxNg?usp=drive_link). | ||
|
||
### Inference | ||
1. Configure inference parameters: | ||
Modify the pretrained checkpoint path, source voice path and reference voice path at `egs/vc/Noro/noro_inference.sh` file. | ||
Currently it's at line 35. | ||
``` | ||
checkpoint_path="path/to/checkpoint/model.safetensors" | ||
output_dir="path/to/output/directory" | ||
source_path="path/to/source/audio.wav" | ||
reference_path="path/to/reference/audio.wav" | ||
``` | ||
2. Start inference: | ||
```bash | ||
bash path/to/Amphion/egs/vc/noro_inference.sh | ||
``` | ||
|
||
3. You got the reconstructed mel spectrum saved to the output direction. | ||
Then use the [BigVGAN](https://github.com/NVIDIA/BigVGAN) to construct the wav file. | ||
|
||
## Training from Scratch | ||
|
||
### Data Preparation | ||
|
||
We use the LibriLight dataset for training and evaluation. You can download it using the following commands: | ||
```bash | ||
wget https://dl.fbaipublicfiles.com/librilight/data/large.tar | ||
wget https://dl.fbaipublicfiles.com/librilight/data/medium.tar | ||
wget https://dl.fbaipublicfiles.com/librilight/data/small.tar | ||
``` | ||
|
||
### Training the Model with Clean Reference Voice | ||
|
||
Configure training parameters: | ||
Our configuration file for training clean Noro model is at "egs/vc/Noro/exp_config_clean.json", and Nosiy Noro model at "egs/vc/Noro/exp_config_noisy.json". | ||
|
||
To train your model, you need to modify the `dataset` variable in the json configurations. | ||
Currently it's at line 40, you should modify the "data_dir" to your dataset's root directory. | ||
``` | ||
"directory_list": [ | ||
"path/to/your/training_data_directory1", | ||
"path/to/your/training_data_directory2", | ||
"path/to/your/training_data_directory3" | ||
], | ||
``` | ||
If you want to train for the noisy noro model, you also need to set the direction path for the noisy data at "egs/vc/Noro/exp_config_noisy.json". | ||
``` | ||
"noise_dir": "path/to/your/noise/train/directory", | ||
"test_noise_dir": "path/to/your/noise/test/directory" | ||
``` | ||
You can change other experiment settings in the config flies such as the learning rate, optimizer and the dataset. | ||
**Set smaller batch_size if you are out of memory😢😢** | ||
I used max_tokens = 3200000 to successfully run on a single card, if you'r out of memory, try smaller. | ||
```json | ||
"max_tokens": 3200000 | ||
``` | ||
### Resume from existing checkpoint | ||
Our framework supports resuming from existing checkpoint. | ||
If this is a new experiment, use the following command: | ||
``` | ||
CUDA_VISIBLE_DEVICES=$gpu accelerate launch --main_process_port 26667 --mixed_precision fp16 \ | ||
"${work_dir}/bins/vc/train.py" \ | ||
--config $exp_config \ | ||
--exp_name $exp_name \ | ||
--log_level debug | ||
``` | ||
To resume training or fine-tune from a checkpoint, use the following command: | ||
Ensure the options `--resume`, `--resume_type resume`, and `--checkpoint_path` are set. | ||
|
||
### Run the command to Train model | ||
Start clean training: | ||
```bash | ||
bash path/to/Amphion/egs/vc/noro_train_clean.sh | ||
``` | ||
|
||
|
||
Start noisy training: | ||
```bash | ||
bash path/to/Amphion/egs/vc/noro_train_noisy.sh | ||
``` | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
{ | ||
"base_config": "config/noro.json", | ||
"model_type": "VC", | ||
"dataset": [ | ||
"mls" | ||
], | ||
"sample_rate": 16000, | ||
"n_fft": 1024, | ||
"n_mel": 80, | ||
"hop_size": 200, | ||
"win_size": 800, | ||
"fmin": 0, | ||
"fmax": 8000, | ||
"preprocess": { | ||
"kmeans_model_path": "path/to/kmeans_model", | ||
"hubert_model_path": "path/to/hubert_model", | ||
"sample_rate": 16000, | ||
"hop_size": 200, | ||
"f0_min": 50, | ||
"f0_max": 500, | ||
"frame_period": 12.5 | ||
}, | ||
"model": { | ||
"reference_encoder": { | ||
"encoder_layer": 6, | ||
"encoder_hidden": 512, | ||
"encoder_head": 8, | ||
"conv_filter_size": 2048, | ||
"conv_kernel_size": 9, | ||
"encoder_dropout": 0.2, | ||
"use_skip_connection": false, | ||
"use_new_ffn": true, | ||
"ref_in_dim": 80, | ||
"ref_out_dim": 512, | ||
"use_query_emb": true, | ||
"num_query_emb": 32 | ||
}, | ||
"diffusion": { | ||
"beta_min": 0.05, | ||
"beta_max": 20, | ||
"sigma": 1.0, | ||
"noise_factor": 1.0, | ||
"ode_solve_method": "euler", | ||
"diff_model_type": "WaveNet", | ||
"diff_wavenet":{ | ||
"input_size": 80, | ||
"hidden_size": 512, | ||
"out_size": 80, | ||
"num_layers": 47, | ||
"cross_attn_per_layer": 3, | ||
"dilation_cycle": 2, | ||
"attn_head": 8, | ||
"drop_out": 0.2 | ||
} | ||
}, | ||
"vc_feature": { | ||
"content_feature_dim": 768, | ||
"hidden_dim": 512 | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
{ | ||
"base_config": "egs/vc/Noro/exp_config_base.json", | ||
"dataset": [ | ||
"mls" | ||
], | ||
// Specify the output root path to save model checkpoints and logs | ||
"log_dir": "path/to/your/checkpoint/directory", | ||
"train": { | ||
// New trainer and Accelerator | ||
"gradient_accumulation_step": 1, | ||
"tracker": ["tensorboard"], | ||
"max_epoch": 10, | ||
"save_checkpoint_stride": [1000], | ||
"keep_last": [20], | ||
"run_eval": [true], | ||
"dataloader": { | ||
"num_worker": 64, | ||
"pin_memory": true | ||
}, | ||
"adam": { | ||
"lr": 5e-5 | ||
}, | ||
"use_dynamic_batchsize": true, | ||
"max_tokens": 3200000, | ||
"max_sentences": 64, | ||
"lr_warmup_steps": 5000, | ||
"lr_scheduler": "cosine", | ||
"num_train_steps": 800000 | ||
}, | ||
"trans_exp": { | ||
"directory_list": [ | ||
"path/to/your/training_data_directory1", | ||
"path/to/your/training_data_directory2", | ||
"path/to/your/training_data_directory3" | ||
], | ||
"use_ref_noise": false | ||
} | ||
} |
Oops, something went wrong.