Skip to content

Commit

Permalink
Fix unpicklable dataset by removing local lambda
Browse files Browse the repository at this point in the history
  • Loading branch information
justinjohn0306 authored Jan 3, 2025
1 parent 5647ce2 commit 151af2a
Showing 1 changed file with 20 additions and 27 deletions.
47 changes: 20 additions & 27 deletions data/ft_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,26 @@
"max": 30.0,
}
# assume single speaker
def to_mel_fn(wave, mel_fn_args):
return mel_spectrogram(wave, **mel_fn_args)

class FT_Dataset(torch.utils.data.Dataset):
def __init__(self,
data_path,
spect_params,
sr=22050,
batch_size=1,
):
def __init__(
self,
data_path,
spect_params,
sr=22050,
batch_size=1,
):
self.data_path = data_path
# recursively find all files in data_path
self.data = []
for root, _, files in os.walk(data_path):
for file in files:
if (file.endswith(".wav") or
file.endswith(".mp3") or
file.endswith(".flac") or
file.endswith(".ogg") or
file.endswith(".m4a") or
file.endswith(".opus")):
if file.endswith((".wav", ".mp3", ".flac", ".ogg", ".m4a", ".opus")):
self.data.append(os.path.join(root, file))

mel_fn_args = {
self.sr = sr
self.mel_fn_args = {
"n_fft": spect_params['n_fft'],
"win_size": spect_params['win_length'],
"hop_size": spect_params['hop_length'],
Expand All @@ -42,11 +41,8 @@ def __init__(self,
"fmax": None if spect_params['fmax'] == "None" else spect_params['fmax'],
"center": False
}
self.to_mel = lambda x: mel_spectrogram(x, **mel_fn_args)
self.sr = sr

assert len(self.data) != 0
# if dataset length is less than batch size, repeat the dataset
while len(self.data) < batch_size:
self.data += self.data

Expand All @@ -64,17 +60,14 @@ def __getitem__(self, idx):
if len(speech) < self.sr * duration_setting["min"] or len(speech) > self.sr * duration_setting["max"]:
print(f"Audio {wav_path} is too short or too long, skipping")
return self.__getitem__(random.randint(0, len(self)))
return_dict = {
'audio': speech,
'sr': orig_sr
}
wave, orig_sr = return_dict['audio'], return_dict['sr']
if orig_sr != self.sr:
wave = librosa.resample(wave, orig_sr, self.sr)
wave = torch.from_numpy(wave).float()
mel = self.to_mel(wave.unsqueeze(0)).squeeze(0)
speech = librosa.resample(speech, orig_sr, self.sr)

wave = torch.from_numpy(speech).float().unsqueeze(0)
mel = to_mel_fn(wave, self.mel_fn_args).squeeze(0)

return wave.squeeze(0), mel

return wave, mel

def build_ft_dataloader(data_path, spect_params, sr, batch_size=1, num_workers=0):
dataset = FT_Dataset(data_path, spect_params, sr, batch_size)
Expand Down Expand Up @@ -130,4 +123,4 @@ def collate(batch):
wave, mel, wave_lengths, mel_lengths = batch
print(wave.shape, mel.shape)
if idx == 10:
break
break

0 comments on commit 151af2a

Please sign in to comment.