diff --git a/data/ft_dataset.py b/data/ft_dataset.py index 4ac4fc0..c6b1f55 100644 --- a/data/ft_dataset.py +++ b/data/ft_dataset.py @@ -12,27 +12,26 @@ "max": 30.0, } # assume single speaker +def to_mel_fn(wave, mel_fn_args): + return mel_spectrogram(wave, **mel_fn_args) + class FT_Dataset(torch.utils.data.Dataset): - def __init__(self, - data_path, - spect_params, - sr=22050, - batch_size=1, - ): + def __init__( + self, + data_path, + spect_params, + sr=22050, + batch_size=1, + ): self.data_path = data_path - # recursively find all files in data_path self.data = [] for root, _, files in os.walk(data_path): for file in files: - if (file.endswith(".wav") or - file.endswith(".mp3") or - file.endswith(".flac") or - file.endswith(".ogg") or - file.endswith(".m4a") or - file.endswith(".opus")): + if file.endswith((".wav", ".mp3", ".flac", ".ogg", ".m4a", ".opus")): self.data.append(os.path.join(root, file)) - mel_fn_args = { + self.sr = sr + self.mel_fn_args = { "n_fft": spect_params['n_fft'], "win_size": spect_params['win_length'], "hop_size": spect_params['hop_length'], @@ -42,11 +41,8 @@ def __init__(self, "fmax": None if spect_params['fmax'] == "None" else spect_params['fmax'], "center": False } - self.to_mel = lambda x: mel_spectrogram(x, **mel_fn_args) - self.sr = sr assert len(self.data) != 0 - # if dataset length is less than batch size, repeat the dataset while len(self.data) < batch_size: self.data += self.data @@ -64,17 +60,14 @@ def __getitem__(self, idx): if len(speech) < self.sr * duration_setting["min"] or len(speech) > self.sr * duration_setting["max"]: print(f"Audio {wav_path} is too short or too long, skipping") return self.__getitem__(random.randint(0, len(self))) - return_dict = { - 'audio': speech, - 'sr': orig_sr - } - wave, orig_sr = return_dict['audio'], return_dict['sr'] if orig_sr != self.sr: - wave = librosa.resample(wave, orig_sr, self.sr) - wave = torch.from_numpy(wave).float() - mel = self.to_mel(wave.unsqueeze(0)).squeeze(0) + speech = librosa.resample(speech, orig_sr, self.sr) + + wave = torch.from_numpy(speech).float().unsqueeze(0) + mel = to_mel_fn(wave, self.mel_fn_args).squeeze(0) + + return wave.squeeze(0), mel - return wave, mel def build_ft_dataloader(data_path, spect_params, sr, batch_size=1, num_workers=0): dataset = FT_Dataset(data_path, spect_params, sr, batch_size) @@ -130,4 +123,4 @@ def collate(batch): wave, mel, wave_lengths, mel_lengths = batch print(wave.shape, mel.shape) if idx == 10: - break \ No newline at end of file + break