diff --git a/tonic/audio_augmentations.py b/tonic/audio_augmentations.py index ab490cd..a77a9ea 100644 --- a/tonic/audio_augmentations.py +++ b/tonic/audio_augmentations.py @@ -6,13 +6,12 @@ import librosa import numpy as np import torch +import torchaudio -from tonic.audio_transforms import FixLength - -# import torchaudio # from qut_noise import QUTNoise -# from torchaudio.utils import download_asset +from torchaudio.utils import download_asset +from tonic.audio_transforms import FixLength __all__ = [ "RandomTimeStretch", @@ -366,38 +365,37 @@ def __call__(self, audio: np.ndarray): # return noise_then_audio -# @dataclass -# class RIR: -# """Convolves a RIR (room impluse response, sound of hand clapping in an empty room) to the data -# sample. +@dataclass +class RIR: + """Convolves a RIR (room impluse response) to the data sample. -# Parameters: -# samplerate (float): sample rate of the sample -# caching (bool): if we are caching the DiskCached dataset will dynamically pass copy index of data item to the transform (to set aug_index). Otherwise the aug_index will be chosen randomly in every call of transform + Parameters: + samplerate (float): sample rate of the sample + rir_audio (str): path to a sample room impluse response in the .wav format + caching (bool): if we are caching the DiskCached dataset will dynamically pass copy index of data item to the transform (to set aug_index). Otherwise the aug_index will be chosen randomly in every call of transform -# Args: -# audio (np.ndarray): data sample -# Returns: -# np.ndarray: data sample convolved with RIR -# """ + Args: + audio (np.ndarray): data sample + Returns: + np.ndarray: data sample convolved with RIR + """ -# samplerate: float -# caching: bool = False + samplerate: float + rir_audio: str + caching: bool = False -# def __call__(self, audio): -# SAMPLE_RIR = download_asset( -# "tutorial-assets/Lab41-SRI-VOiCES-rm1-impulse-mc01-stu-clo-8000hz.wav" -# ) -# rir_raw, rir_sample_rate = torchaudio.load(SAMPLE_RIR) -# rir = rir_raw[:, int(rir_sample_rate * 1.01) : int(rir_sample_rate * 1.3)] -# rir = rir / torch.norm(rir, p=2) -# RIR = torch.flip(rir, [1]) + def __call__(self, audio): + SAMPLE_RIR = download_asset(self.rir_audio) + rir_raw, rir_sample_rate = torchaudio.load(SAMPLE_RIR) + rir = rir_raw[:, int(rir_sample_rate * 1.01) : int(rir_sample_rate * 1.3)] + rir = rir / torch.norm(rir, p=2) + RIR = torch.flip(rir, [1]) -# t_audio = torch.nn.functional.pad( -# torch.from_numpy(audio), (RIR.shape[1] - 1, 0) -# ) -# rir_augmented = torch.nn.functional.conv1d(t_audio[None, ...], RIR[None, ...])[ -# 0 -# ].numpy() + t_audio = torch.nn.functional.pad( + torch.from_numpy(audio), (RIR.shape[1] - 1, 0) + ) + rir_augmented = torch.nn.functional.conv1d(t_audio[None, ...], RIR[None, ...])[ + 0 + ].numpy() -# return rir_augmented + return rir_augmented