Skip to content

Commit

Permalink
Add VC Noro model (#247)
Browse files Browse the repository at this point in the history
Voice Conversion Model: Noro
  • Loading branch information
kenxxxxx authored Nov 30, 2024
1 parent afc7308 commit f7cb4b4
Show file tree
Hide file tree
Showing 19 changed files with 3,396 additions and 0 deletions.
82 changes: 82 additions & 0 deletions bins/vc/Noro/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import argparse

import torch
from models.vc.Noro.noro_trainer import NoroTrainer
from utils.util import load_config


def build_trainer(args, cfg):
supported_trainer = {
"VC": NoroTrainer,
}
trainer_class = supported_trainer[cfg.model_type]
trainer = trainer_class(args, cfg)
return trainer


def cuda_relevant(deterministic=False):
torch.cuda.empty_cache()
# TF32 on Ampere and above
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.enabled = True
torch.backends.cudnn.allow_tf32 = True
# Deterministic
torch.backends.cudnn.deterministic = deterministic
torch.backends.cudnn.benchmark = not deterministic
torch.use_deterministic_algorithms(deterministic)


def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--config",
default="config.json",
help="json files for configurations.",
required=True,
)
parser.add_argument(
"--exp_name",
type=str,
default="exp_name",
help="A specific name to note the experiment",
required=True,
)
parser.add_argument(
"--resume", action="store_true", help="The model name to restore"
)
parser.add_argument(
"--log_level", default="warning", help="logging level (debug, info, warning)"
)
parser.add_argument(
"--resume_type",
type=str,
default="resume",
help="Resume training or finetuning.",
)
parser.add_argument(
"--checkpoint_path",
type=str,
default=None,
help="Checkpoint for resume training or finetuning.",
)
args = parser.parse_args()
cfg = load_config(args.config)
print("experiment name: ", args.exp_name)
# # CUDA settings
cuda_relevant()
# Build trainer
print(f"Building {cfg.model_type} trainer")
trainer = build_trainer(args, cfg)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
print(f"Start training {cfg.model_type} model")
trainer.train_loop()


if __name__ == "__main__":
main()
76 changes: 76 additions & 0 deletions config/noro.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
{
"base_config": "config/base.json",
"model_type": "VC",
"dataset": ["mls"],
"model": {
"reference_encoder": {
"encoder_layer": 6,
"encoder_hidden": 512,
"encoder_head": 8,
"conv_filter_size": 2048,
"conv_kernel_size": 9,
"encoder_dropout": 0.2,
"use_skip_connection": false,
"use_new_ffn": true,
"ref_in_dim": 80,
"ref_out_dim": 512,
"use_query_emb": true,
"num_query_emb": 32
},
"diffusion": {
"beta_min": 0.05,
"beta_max": 20,
"sigma": 1.0,
"noise_factor": 1.0,
"ode_solve_method": "euler",
"diff_model_type": "WaveNet",
"diff_wavenet":{
"input_size": 80,
"hidden_size": 512,
"out_size": 80,
"num_layers": 47,
"cross_attn_per_layer": 3,
"dilation_cycle": 2,
"attn_head": 8,
"drop_out": 0.2
}
},
"prior_encoder": {
"encoder_layer": 6,
"encoder_hidden": 512,
"encoder_head": 8,
"conv_filter_size": 2048,
"conv_kernel_size": 9,
"encoder_dropout": 0.2,
"use_skip_connection": false,
"use_new_ffn": true,
"vocab_size": 256,
"cond_dim": 512,
"duration_predictor": {
"input_size": 512,
"filter_size": 512,
"kernel_size": 3,
"conv_layers": 30,
"cross_attn_per_layer": 3,
"attn_head": 8,
"drop_out": 0.2
},
"pitch_predictor": {
"input_size": 512,
"filter_size": 512,
"kernel_size": 5,
"conv_layers": 30,
"cross_attn_per_layer": 3,
"attn_head": 8,
"drop_out": 0.5
},
"pitch_min": 50,
"pitch_max": 1100,
"pitch_bins_num": 512
},
"vc_feature": {
"content_feature_dim": 768,
"hidden_dim": 512
}
}
}
122 changes: 122 additions & 0 deletions egs/vc/Noro/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Noro: A Noise-Robust One-shot Voice Conversion System

<br>
<div align="center">
<img src="../../../imgs/vc/NoroVC.png" width="85%">
</div>
<br>

This is the official implementation of the paper: NORO: A Noise-Robust One-Shot Voice Conversion System with Hidden Speaker Representation Capabilities.

- The semantic extractor is from [Hubert](https://github.com/facebookresearch/fairseq/tree/main/examples/hubert).
- The vocoder is [BigVGAN](https://github.com/NVIDIA/BigVGAN) architecture.

## Project Overview
Noro is a noise-robust one-shot voice conversion (VC) system designed to convert the timbre of speech from a source speaker to a target speaker using only a single reference speech sample, while preserving the semantic content of the original speech. Noro introduces innovative components tailored for VC using noisy reference speeches, including a dual-branch reference encoding module and a noise-agnostic contrastive speaker loss.

## Features
- **Noise-Robust Voice Conversion**: Utilizes a dual-branch reference encoding module and noise-agnostic contrastive speaker loss to maintain high-quality voice conversion in noisy environments.
- **One-shot Voice Conversion**: Achieves timbre conversion using only one reference speech sample.
- **Speaker Representation Learning**: Explores the potential of the reference encoder as a self-supervised speaker encoder.

## Installation Requirement

Set up your environment as in Amphion README (you'll need a conda environment, and we recommend using Linux).

### Prepare Hubert Model

Humbert checkpoint and kmeans can be downloaded [here](https://github.com/facebookresearch/fairseq/tree/main/examples/hubert).
Set the downloded model path at `egs/vc/Noro/exp_config_base.json`.


## Usage

### Download pretrained weights
You need to download our pretrained weights from [Google Drive](https://drive.google.com/drive/folders/1NPzSIuSKO8o87g5ImNzpw_BgbhsZaxNg?usp=drive_link).

### Inference
1. Configure inference parameters:
Modify the pretrained checkpoint path, source voice path and reference voice path at `egs/vc/Noro/noro_inference.sh` file.
Currently it's at line 35.
```
checkpoint_path="path/to/checkpoint/model.safetensors"
output_dir="path/to/output/directory"
source_path="path/to/source/audio.wav"
reference_path="path/to/reference/audio.wav"
```
2. Start inference:
```bash
bash path/to/Amphion/egs/vc/noro_inference.sh
```

3. You got the reconstructed mel spectrum saved to the output direction.
Then use the [BigVGAN](https://github.com/NVIDIA/BigVGAN) to construct the wav file.

## Training from Scratch

### Data Preparation

We use the LibriLight dataset for training and evaluation. You can download it using the following commands:
```bash
wget https://dl.fbaipublicfiles.com/librilight/data/large.tar
wget https://dl.fbaipublicfiles.com/librilight/data/medium.tar
wget https://dl.fbaipublicfiles.com/librilight/data/small.tar
```

### Training the Model with Clean Reference Voice

Configure training parameters:
Our configuration file for training clean Noro model is at "egs/vc/Noro/exp_config_clean.json", and Nosiy Noro model at "egs/vc/Noro/exp_config_noisy.json".

To train your model, you need to modify the `dataset` variable in the json configurations.
Currently it's at line 40, you should modify the "data_dir" to your dataset's root directory.
```
"directory_list": [
"path/to/your/training_data_directory1",
"path/to/your/training_data_directory2",
"path/to/your/training_data_directory3"
],
```
If you want to train for the noisy noro model, you also need to set the direction path for the noisy data at "egs/vc/Noro/exp_config_noisy.json".
```
"noise_dir": "path/to/your/noise/train/directory",
"test_noise_dir": "path/to/your/noise/test/directory"
```
You can change other experiment settings in the config flies such as the learning rate, optimizer and the dataset.
**Set smaller batch_size if you are out of memory😢😢**
I used max_tokens = 3200000 to successfully run on a single card, if you'r out of memory, try smaller.
```json
"max_tokens": 3200000
```
### Resume from existing checkpoint
Our framework supports resuming from existing checkpoint.
If this is a new experiment, use the following command:
```
CUDA_VISIBLE_DEVICES=$gpu accelerate launch --main_process_port 26667 --mixed_precision fp16 \
"${work_dir}/bins/vc/train.py" \
--config $exp_config \
--exp_name $exp_name \
--log_level debug
```
To resume training or fine-tune from a checkpoint, use the following command:
Ensure the options `--resume`, `--resume_type resume`, and `--checkpoint_path` are set.

### Run the command to Train model
Start clean training:
```bash
bash path/to/Amphion/egs/vc/noro_train_clean.sh
```


Start noisy training:
```bash
bash path/to/Amphion/egs/vc/noro_train_noisy.sh
```



61 changes: 61 additions & 0 deletions egs/vc/Noro/exp_config_base.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
{
"base_config": "config/noro.json",
"model_type": "VC",
"dataset": [
"mls"
],
"sample_rate": 16000,
"n_fft": 1024,
"n_mel": 80,
"hop_size": 200,
"win_size": 800,
"fmin": 0,
"fmax": 8000,
"preprocess": {
"kmeans_model_path": "path/to/kmeans_model",
"hubert_model_path": "path/to/hubert_model",
"sample_rate": 16000,
"hop_size": 200,
"f0_min": 50,
"f0_max": 500,
"frame_period": 12.5
},
"model": {
"reference_encoder": {
"encoder_layer": 6,
"encoder_hidden": 512,
"encoder_head": 8,
"conv_filter_size": 2048,
"conv_kernel_size": 9,
"encoder_dropout": 0.2,
"use_skip_connection": false,
"use_new_ffn": true,
"ref_in_dim": 80,
"ref_out_dim": 512,
"use_query_emb": true,
"num_query_emb": 32
},
"diffusion": {
"beta_min": 0.05,
"beta_max": 20,
"sigma": 1.0,
"noise_factor": 1.0,
"ode_solve_method": "euler",
"diff_model_type": "WaveNet",
"diff_wavenet":{
"input_size": 80,
"hidden_size": 512,
"out_size": 80,
"num_layers": 47,
"cross_attn_per_layer": 3,
"dilation_cycle": 2,
"attn_head": 8,
"drop_out": 0.2
}
},
"vc_feature": {
"content_feature_dim": 768,
"hidden_dim": 512
}
}
}
38 changes: 38 additions & 0 deletions egs/vc/Noro/exp_config_clean.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
{
"base_config": "egs/vc/Noro/exp_config_base.json",
"dataset": [
"mls"
],
// Specify the output root path to save model checkpoints and logs
"log_dir": "path/to/your/checkpoint/directory",
"train": {
// New trainer and Accelerator
"gradient_accumulation_step": 1,
"tracker": ["tensorboard"],
"max_epoch": 10,
"save_checkpoint_stride": [1000],
"keep_last": [20],
"run_eval": [true],
"dataloader": {
"num_worker": 64,
"pin_memory": true
},
"adam": {
"lr": 5e-5
},
"use_dynamic_batchsize": true,
"max_tokens": 3200000,
"max_sentences": 64,
"lr_warmup_steps": 5000,
"lr_scheduler": "cosine",
"num_train_steps": 800000
},
"trans_exp": {
"directory_list": [
"path/to/your/training_data_directory1",
"path/to/your/training_data_directory2",
"path/to/your/training_data_directory3"
],
"use_ref_noise": false
}
}
Loading

0 comments on commit f7cb4b4

Please sign in to comment.