Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support fs2 24k training, aligned mel setting with gan vocoder #61

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions bins/tts/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,8 @@ def preprocess(cfg, args):
dataset_types.append((cfg.preprocess.valid_file).split(".")[0])
if "test" not in dataset_types:
dataset_types.append("test")
if "eval" in dataset:
dataset_types = ["test"]
#if "eval" in dataset:
# dataset_types = ["test"]

# Dump metadata of datasets (singers, train/test durations, etc.)
cal_metadata(cfg, dataset_types)
Expand Down
14 changes: 7 additions & 7 deletions config/fs2.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
// acoustic features
"extract_audio": true,
"extract_mel": true,
"mel_extract_mode": "taco",
"mel_extract_mode": "raw",
"mel_min_max_norm": false,
"extract_pitch": true,
"extract_uv": false,
"pitch_extractor": "dio",
"extract_energy": true,
"energy_extract_mode": "from_tacotron_stft",
"energy_extract_mode": "from_mel",
"extract_duration": true,
"use_phone": false,
"pitch_norm": true,
Expand All @@ -22,17 +22,17 @@
"energy_remove_outlier": true,

// Default config
"n_mel": 80,
"n_mel": 100,
"win_size": 1024, // todo
"hop_size": 256,
"sample_rate": 22050,
"sample_rate": 24000,
"n_fft": 1024, // todo
"fmin": 0,
"fmax": 8000, // todo
"fmax": 12000, // todo
"raw_data": "raw_data",
"text_cleaners": ["english_cleaners"],
"f0_min": 71, // ~C2
"f0_max": 800, //1100, // ~C6(1100), ~G5(800)
"f0_min": 50, // ~C2
"f0_max": 1100, //1100, // ~C6(1100), ~G5(800)
"pitch_bin": 256,
"pitch_max": 1100.0,
"pitch_min": 50.0,
Expand Down
2 changes: 1 addition & 1 deletion egs/tts/FastSpeech2/exp_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"preprocess": {
// TODO: Fill in the output data path. The default value is "Amphion/data"
"processed_dir": "data",
"sample_rate": 22050,
"sample_rate": 24000,
},
"train": {
"batch_size": 16,
Expand Down
4 changes: 3 additions & 1 deletion egs/tts/FastSpeech2/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ if [ $running_stage -eq 3 ]; then
fi


# if you don't have a vocoder, you can download from https://huggingface.co/amphion/hifigan_speech_bigdata,
# then link the hifigan_speech folder to pretrained/hifigan_speech
CUDA_VISIBLE_DEVICES=$gpu accelerate launch "$work_dir"/bins/tts/inference.py \
--config $exp_config \
--acoustics_dir $infer_expt_dir \
Expand All @@ -143,7 +145,7 @@ if [ $running_stage -eq 3 ]; then
--testing_set $infer_testing_set \
--text "$infer_text" \
--log_level debug \
--vocoder_dir /mntnfs/lee_data1/chenxi/processed_data/ljspeech/model_ckpt/hifigan/checkpoints
--vocoder_dir pretrained/hifigan_speech



Expand Down
6 changes: 3 additions & 3 deletions models/tts/fastspeech2/fs2_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,9 @@ def read_duration(self):

mel = np.load(self.utt2mel_path[utt]).transpose(1, 0)
duration = np.load(self.utt2duration_path[utt])
assert mel.shape[0] == sum(
duration
), f"{utt}: mismatch length between mel {mel.shape[0]} and sum(duration) {sum(duration)}"
#assert mel.shape[0] == sum(
# duration
#), f"{utt}: mismatch length between mel {mel.shape[0]} and sum(duration) {sum(duration)}"
utt2dur[utt] = duration
return utt2dur

Expand Down
9 changes: 1 addition & 8 deletions models/tts/fastspeech2/fs2_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,17 +49,10 @@ def _build_test_dataset(self):
@staticmethod
def _parse_vocoder(vocoder_dir):
r"""Parse vocoder config"""
vocoder_dir = os.path.abspath(vocoder_dir)
ckpt_list = [ckpt for ckpt in Path(vocoder_dir).glob("*.pt")]
# last step (different from the base *int(x.stem)*)
ckpt_list.sort(
key=lambda x: int(x.stem.split("_")[-2].split("-")[-1]), reverse=True
)
ckpt_path = str(ckpt_list[0])
vocoder_cfg = load_config(
os.path.join(vocoder_dir, "args.json"), lowercase=True
)
return vocoder_cfg, ckpt_path
return vocoder_cfg, vocoder_dir

@torch.inference_mode()
def inference_for_batches(self):
Expand Down
2 changes: 1 addition & 1 deletion models/tts/naturalspeech2/ns2_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torch.utils.data import ConcatDataset, DataLoader
from models.tts.base.tts_trainer import TTSTrainer
from models.base.base_trainer import BaseTrainer
from models.base.base_sampler import VariableSampler
from models.tts.valle.valle_dataset import VariableSampler
from models.tts.naturalspeech2.ns2_dataset import NS2Dataset, NS2Collator, batch_by_size
from models.tts.naturalspeech2.ns2_loss import (
log_pitch_loss,
Expand Down
3 changes: 0 additions & 3 deletions models/vocoders/gan/gan_vocoder_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1048,13 +1048,10 @@ def _valid_step(self, data):
valid_losses.update(discriminator_losses)
valid_losses.update(generator_losses)

for item in valid_losses:
valid_losses[item] = valid_losses[item].item()
for item in valid_losses:
valid_losses[item] = valid_losses[item].item()

return total_loss.item(), valid_losses
return total_loss.item(), valid_losses

def _inference(self, eval_mel, eval_pitch=None, use_pitch=False):
"""Inference during training for test audios."""
Expand Down
5 changes: 3 additions & 2 deletions preprocessors/ljspeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,9 @@ def prepare_align(dataset, dataset_path, cfg, output_path):
wav_path = os.path.join(in_dir, "wavs", "{}.wav".format(base_name))
if os.path.exists(wav_path):
os.makedirs(os.path.join(out_dir, speaker), exist_ok=True)
wav, _ = librosa.load(wav_path, sampling_rate)
wav = wav / max(abs(wav)) * max_wav_value
wav, _ = librosa.load(wav_path, sr=sampling_rate)
wav = wav / max(abs(wav)) * max_wav_value * 0.95
# todo: let's trim silence

wavfile.write(
os.path.join(out_dir, speaker, "{}.wav".format(base_name)),
Expand Down
7 changes: 4 additions & 3 deletions processors/acoustic_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,10 +286,11 @@ def extract_utt_acoustic_features_tts(dataset_output, cfg, utt):
mel = extract_mel_features(
wav_torch.unsqueeze(0), cfg.preprocess, taco=True, _stft=_stft
)
if cfg.preprocess.extract_duration:
mel = mel[:, : sum(durations)]
else:
mel = extract_mel_features(wav_torch.unsqueeze(0), cfg.preprocess)
mel = extract_mel_features_tts(wav_torch.unsqueeze(0), cfg.preprocess)

if cfg.preprocess.extract_duration:
mel = mel[:, : sum(durations)]
save_feature(dataset_output, cfg.preprocess.mel_dir, uid, mel.cpu().numpy())

if cfg.preprocess.extract_energy:
Expand Down
20 changes: 10 additions & 10 deletions utils/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,18 @@ def load_audio_torch(wave_file, fs):
assert len(audio) > 2

# Check the audio type (for soundfile loading backbone) - float, 8bit or 16bit
if np.issubdtype(audio.dtype, np.integer):
max_mag = -np.iinfo(audio.dtype).min
else:
max_mag = max(np.amax(audio), -np.amin(audio))
max_mag = (
(2**31) + 1
if max_mag > (2**15)
else ((2**15) + 1 if max_mag > 1.01 else 1.0)
)
# if np.issubdtype(audio.dtype, np.integer):
# max_mag = -np.iinfo(audio.dtype).min
# else:
# max_mag = max(np.amax(audio), -np.amin(audio))
# max_mag = (
# (2**31) + 1
# if max_mag > (2**15)
# else ((2**15) + 1 if max_mag > 1.01 else 1.0)
# )

# Normalize the audio
audio = torch.FloatTensor(audio.astype(np.float32)) / max_mag
audio = torch.FloatTensor(audio.astype(np.float32))

if (torch.isnan(audio) | torch.isinf(audio)).any():
return [], sample_rate or fs or 48000
Expand Down
2 changes: 0 additions & 2 deletions utils/mel.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,6 @@ def extract_mel_features_tts(
spec = torch.matmul(mel_basis[str(cfg.fmax) + "_" + str(y.device)], spec)
spec = spectral_normalize_torch(spec)
spec = spec.squeeze(0)
spec = torch.matmul(mel_basis[str(cfg.fmax) + "_" + str(y.device)], spec)
spec = spectral_normalize_torch(spec)
else:
audio = torch.clip(y, -1, 1)
audio = torch.autograd.Variable(audio, requires_grad=False)
Expand Down
4 changes: 2 additions & 2 deletions utils/stft.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def window_sumsquare(
# Compute the squared window at the desired length
win_sq = get_window(window, win_length, fftbins=True)
win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2
win_sq = librosa_util.pad_center(win_sq, n_fft)
win_sq = librosa_util.pad_center(win_sq, size = n_fft)

# Fill the envelope
for i in range(n_frames):
Expand Down Expand Up @@ -243,7 +243,7 @@ def __init__(
self.sampling_rate = sampling_rate
self.stft_fn = STFT(filter_length, hop_length, win_length)
mel_basis = librosa_mel_fn(
sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax
sr=sampling_rate, n_fft=filter_length, n_mels=n_mel_channels, fmin=mel_fmin, fmax=mel_fmax
)
mel_basis = torch.from_numpy(mel_basis).float()
self.register_buffer("mel_basis", mel_basis)
Expand Down
Loading