Skip to content

Commit

Permalink
Update code for fine-tuning
Browse files Browse the repository at this point in the history
  • Loading branch information
Plachtaa committed Nov 26, 2024
1 parent 4d54e69 commit 85df0ea
Show file tree
Hide file tree
Showing 35 changed files with 5,926 additions and 232 deletions.
441 changes: 441 additions & 0 deletions app_svc.py

Large diffs are not rendered by default.

384 changes: 384 additions & 0 deletions app_vc.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -1,32 +1,41 @@
log_dir: "./runs/run_dit_mel_seed"
log_dir: "./runs/run_dit_mel_seed_uvit_whisper_base_f0_44k"
save_freq: 1
log_interval: 10
save_interval: 1000
device: "cuda"
epochs: 1000 # number of epochs for first stage training (pre-training)
batch_size: 4
batch_size: 1
batch_length: 100 # maximum duration of audio in a batch (in seconds)
max_len: 80 # maximum number of frames
pretrained_model: ""
pretrained_encoder: ""
load_only_params: False # set to true if do not want to load epoch numbers and optimizer parameters

F0_path: "modules/JDC/bst.t7"

preprocess_params:
sr: 22050
sr: 44100
spect_params:
n_fft: 1024
win_length: 1024
hop_length: 256
n_mels: 80
n_fft: 2048
win_length: 2048
hop_length: 512
n_mels: 128
fmin: 0
fmax: "None"

model_params:
dit_type: "DiT" # uDiT or DiT
reg_loss_type: "l2" # l1 or l2
reg_loss_type: "l1" # l1 or l2

timbre_shifter:
se_db_path: "./modules/openvoice/checkpoints_v2/converter/se_db.pt"
ckpt_path: './modules/openvoice/checkpoints_v2/converter'

vocoder:
type: "bigvgan"
name: "nvidia/bigvgan_v2_44khz_128band_512x"

speech_tokenizer:
path: "checkpoints/speech_tokenizer_v1.onnx"
type: 'whisper'
name: "openai/whisper-small"

style_encoder:
dim: 192
Expand All @@ -41,31 +50,39 @@ model_params:

length_regulator:
channels: 768
is_discrete: true
content_codebook_size: 4096
in_frame_rate: 50
out_frame_rate: 80
is_discrete: false
in_channels: 768
content_codebook_size: 2048
sampling_ratios: [1, 1, 1, 1]
vector_quantize: false
n_codebooks: 1
quantizer_dropout: 0.0
f0_condition: true
n_f0_bins: 256

DiT:
hidden_dim: 768
num_heads: 12
depth: 12
depth: 17
class_dropout_prob: 0.1
block_size: 8192
in_channels: 80
in_channels: 128
style_condition: true
final_layer_type: 'wavenet'
final_layer_type: 'mlp'
target: 'mel' # mel or codec
content_dim: 768
content_codebook_size: 1024
content_type: 'discrete'
f0_condition: false
n_f0_bins: 512
f0_condition: true
n_f0_bins: 256
content_codebooks: 1
is_causal: false
long_skip_connection: true
long_skip_connection: false
zero_prompt_speech_token: false # for prompt component, do not input corresponding speech token
time_as_token: false
style_as_token: false
uvit_skip_connection: true
add_resblock_in_transformer: false

wavenet:
hidden_dim: 768
Expand All @@ -77,3 +94,5 @@ model_params:

loss_params:
base_lr: 0.0001
lambda_mel: 45
lambda_kl: 1.0
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
log_dir: "./runs"
log_dir: "./runs/run_dit_mel_seed_uvit_whisper_small_wavenet"
save_freq: 1
log_interval: 10
save_interval: 1000
Expand All @@ -25,24 +25,21 @@ model_params:
dit_type: "DiT" # uDiT or DiT
reg_loss_type: "l1" # l1 or l2

timbre_shifter:
se_db_path: "./modules/openvoice/checkpoints_v2/converter/se_db.pt"
ckpt_path: './modules/openvoice/checkpoints_v2/converter'

speech_tokenizer:
type: 'whisper'
whisper_name: "openai/whisper-small"
path: "speech_tokenizer_v1.onnx"

cosyvoice:
path: "../CosyVoice/pretrained_models/CosyVoice-300M"
name: "openai/whisper-small"

style_encoder:
dim: 192
campplus_path: "campplus_cn_common.bin"

DAC:
encoder_dim: 64
encoder_rates: [2, 5, 5, 6]
decoder_dim: 1536
decoder_rates: [ 6, 5, 5, 2 ]
sr: 24000
vocoder:
type: "bigvgan"
name: "nvidia/bigvgan_v2_22khz_80band_256x"

length_regulator:
channels: 512
Expand Down
82 changes: 82 additions & 0 deletions configs/presets/config_dit_mel_seed_uvit_xlsr_tiny.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
log_dir: "runs/run_mel_seed_uvit_xlsr_tiny"
save_freq: 1
log_interval: 10
save_interval: 500
device: "cuda"
epochs: 1000 # number of epochs for first stage training (pre-training)
batch_size: 2
batch_length: 100 # maximum duration of audio in a batch (in seconds)
max_len: 80 # maximum number of frames
pretrained_model: "DiT_uvit_tat_xlsr_ema.pth"
pretrained_encoder: ""
load_only_params: False # set to true if do not want to load epoch numbers and optimizer parameters

preprocess_params:
sr: 22050
spect_params:
n_fft: 1024
win_length: 1024
hop_length: 256
n_mels: 80
fmin: 0
fmax: 8000

model_params:
dit_type: "DiT" # uDiT or DiT
reg_loss_type: "l1" # l1 or l2
diffusion_type: "flow"

timbre_shifter:
se_db_path: "./modules/openvoice/checkpoints_v2/converter/se_db.pt"
ckpt_path: './modules/openvoice/checkpoints_v2/converter'

vocoder:
type: "hifigan"

speech_tokenizer:
type: 'xlsr'
output_layer: 12
name: 'facebook/wav2vec2-xls-r-300m'

style_encoder:
dim: 192
campplus_path: "campplus_cn_common.bin"

length_regulator:
channels: 384
is_discrete: false
in_channels: 1024
content_codebook_size: 1024
sampling_ratios: [1, 1, 1, 1]
vector_quantize: false
n_codebooks: 2
quantizer_dropout: 0.0
f0_condition: false
n_f0_bins: 512

DiT:
hidden_dim: 384
num_heads: 6
depth: 9
class_dropout_prob: 0.1
block_size: 8192
in_channels: 80
style_condition: true
final_layer_type: 'mlp'
target: 'mel' # mel or betavae
content_dim: 384
content_codebook_size: 1024
content_type: 'discrete'
f0_condition: false
n_f0_bins: 512
content_codebooks: 1
is_causal: false
long_skip_connection: false
zero_prompt_speech_token: false # for prompt component, do not input corresponding speech token
time_as_token: true
style_as_token: true
uvit_skip_connection: true
add_resblock_in_transformer: false

loss_params:
base_lr: 0.0001
133 changes: 133 additions & 0 deletions data/ft_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import torch
import librosa
import numpy as np
import random
import os
from torch.utils.data import DataLoader
from modules.audio import mel_spectrogram


duration_setting = {
"min": 1.0,
"max": 30.0,
}
# assume single speaker
class FT_Dataset(torch.utils.data.Dataset):
def __init__(self,
data_path,
spect_params,
sr=22050,
batch_size=1,
):
self.data_path = data_path
# recursively find all files in data_path
self.data = []
for root, _, files in os.walk(data_path):
for file in files:
if (file.endswith(".wav") or
file.endswith(".mp3") or
file.endswith(".flac") or
file.endswith(".ogg") or
file.endswith(".m4a") or
file.endswith(".opus")):
self.data.append(os.path.join(root, file))

mel_fn_args = {
"n_fft": spect_params['n_fft'],
"win_size": spect_params['win_length'],
"hop_size": spect_params['hop_length'],
"num_mels": spect_params['n_mels'],
"sampling_rate": sr,
"fmin": spect_params['fmin'],
"fmax": None if spect_params['fmax'] == "None" else spect_params['fmax'],
"center": False
}
self.to_mel = lambda x: mel_spectrogram(x, **mel_fn_args)
self.sr = sr

assert len(self.data) != 0
# if dataset length is less than batch size, repeat the dataset
while len(self.data) < batch_size:
self.data += self.data

def __len__(self):
return len(self.data)

def __getitem__(self, idx):
idx = idx % len(self.data)
wav_path = self.data[idx]
try:
speech, orig_sr = librosa.load(wav_path, sr=self.sr)
except Exception as e:
print(f"Failed to load wav file with error {e}")
return self.__getitem__(random.randint(0, len(self)))
if len(speech) < self.sr * duration_setting["min"] or len(speech) > self.sr * duration_setting["max"]:
print(f"Audio {wav_path} is too short or too long, skipping")
return self.__getitem__(random.randint(0, len(self)))
return_dict = {
'audio': speech,
'sr': orig_sr
}
wave, orig_sr = return_dict['audio'], return_dict['sr']
if orig_sr != self.sr:
wave = librosa.resample(wave, orig_sr, self.sr)
wave = torch.from_numpy(wave).float()
mel = self.to_mel(wave.unsqueeze(0)).squeeze(0)

return wave, mel

def build_ft_dataloader(data_path, spect_params, sr, batch_size=1, num_workers=0):
dataset = FT_Dataset(data_path, spect_params, sr, batch_size)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
collate_fn=collate,
)
return dataloader

def collate(batch):
batch_size = len(batch)

# sort by mel length
lengths = [b[1].shape[1] for b in batch]
batch_indexes = np.argsort(lengths)[::-1]
batch = [batch[bid] for bid in batch_indexes]

nmels = batch[0][1].size(0)
max_mel_length = max([b[1].shape[1] for b in batch])
max_wave_length = max([b[0].size(0) for b in batch])

mels = torch.zeros((batch_size, nmels, max_mel_length)).float() - 10
waves = torch.zeros((batch_size, max_wave_length)).float()

mel_lengths = torch.zeros(batch_size).long()
wave_lengths = torch.zeros(batch_size).long()

for bid, (wave, mel) in enumerate(batch):
mel_size = mel.size(1)
mels[bid, :, :mel_size] = mel
waves[bid, : wave.size(0)] = wave
mel_lengths[bid] = mel_size
wave_lengths[bid] = wave.size(0)

return waves, mels, wave_lengths, mel_lengths

if __name__ == "__main__":
data_path = "./example/reference"
sr = 22050
spect_params = {
"n_fft": 1024,
"win_length": 1024,
"hop_length": 256,
"n_mels": 80,
"fmin": 0,
"fmax": 8000,
}
dataloader = build_ft_dataloader(data_path, spect_params, sr, batch_size=2, num_workers=0)
for idx, batch in enumerate(dataloader):
wave, mel, wave_lengths, mel_lengths = batch
print(wave.shape, mel.shape)
if idx == 10:
break
Loading

0 comments on commit 85df0ea

Please sign in to comment.