From cb15e382cb8670439070562a3879ef415b32d209 Mon Sep 17 00:00:00 2001 From: Ryan Date: Mon, 13 Jan 2025 20:12:08 +0800 Subject: [PATCH] [Hackathon 7th No.55] Add `audiotools` to `PaddleSpeech` (#3900) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add AudioSignal && util * fix codestyle * add basemodel && decorator * add util && add quality * add acc && data && transforms * add utils * fix dir * add *.py; wo unitest * add unitest * fix codestyle * fix cuda error * add readme && __all__ * add 2 file test * change download dir * fix CI download path * add tar -zxvf * change requirements path * add audiotools path * fix place error * fix paddle2.5 verion Q * FFTConv1d -> FFTConv1D * FFTConv1d -> FFTConv1D * mv unfold * add _unfold1d 2 loudness * fix stupid device variable * bias -> bias_attr * () -> [] * fix .to() * rm ✅ * fix exp * deepcopy -> clone * fix dim error * fix slice && tensor.to * fix paddle2.5 index bug * git rm std * rm comment && ✅ * rm some useless comment * add __all__ * fix codestyle * fix soundfile.info error * fix sth * add License * fix cycle import * Adapt to paddle3.0 && update readme * fix License * fix License * rm duplicate requirements * fix trasform problems * rm disp * Update test_transforms.py * change path * rm notebook && add audio path * rm import * add comment * fix cycle import && rm TYPE_CHECKING * rm IPython * rm sth useless * rm uesless deps * Update requirements.txt --- audio/audiotools/README.md | 68 + audio/audiotools/__init__.py | 25 + audio/audiotools/core/__init__.py | 28 + audio/audiotools/core/_julius.py | 666 +++++++ audio/audiotools/core/audio_signal.py | 1750 +++++++++++++++++ audio/audiotools/core/display.py | 195 ++ audio/audiotools/core/dsp.py | 467 +++++ audio/audiotools/core/effects.py | 539 +++++ audio/audiotools/core/ffmpeg.py | 119 ++ audio/audiotools/core/loudness.py | 387 ++++ audio/audiotools/core/util.py | 921 +++++++++ audio/audiotools/data/__init__.py | 16 + audio/audiotools/data/datasets.py | 548 ++++++ audio/audiotools/data/preprocess.py | 87 + audio/audiotools/data/transforms.py | 1182 +++++++++++ audio/audiotools/metrics/__init__.py | 17 + audio/audiotools/metrics/quality.py | 74 + audio/audiotools/ml/__init__.py | 16 + audio/audiotools/ml/accelerator.py | 199 ++ audio/audiotools/ml/basemodel.py | 272 +++ audio/audiotools/ml/decorators.py | 446 +++++ audio/audiotools/post.py | 88 + audio/audiotools/requirements.txt | 6 + .../audiotools/core/test_audio_signal.py | 615 ++++++ audio/tests/audiotools/core/test_bands.py | 54 + audio/tests/audiotools/core/test_display.py | 51 + audio/tests/audiotools/core/test_dsp.py | 181 ++ audio/tests/audiotools/core/test_effects.py | 321 +++ audio/tests/audiotools/core/test_fftconv.py | 85 + audio/tests/audiotools/core/test_grad.py | 172 ++ audio/tests/audiotools/core/test_highpass.py | 104 + audio/tests/audiotools/core/test_loudness.py | 274 +++ audio/tests/audiotools/core/test_lowpass.py | 109 + audio/tests/audiotools/core/test_util.py | 157 ++ audio/tests/audiotools/data/test_datasets.py | 208 ++ .../tests/audiotools/data/test_preprocess.py | 33 + .../tests/audiotools/data/test_transforms.py | 453 +++++ audio/tests/audiotools/ml/test_decorators.py | 110 ++ audio/tests/audiotools/ml/test_model.py | 89 + audio/tests/audiotools/test_audiotools.sh | 7 + audio/tests/audiotools/test_post.py | 30 + tests/unit/ci.sh | 7 + 42 files changed, 11176 insertions(+) create mode 100644 audio/audiotools/README.md create mode 100644 audio/audiotools/__init__.py create mode 100644 audio/audiotools/core/__init__.py create mode 100644 audio/audiotools/core/_julius.py create mode 100644 audio/audiotools/core/audio_signal.py create mode 100644 audio/audiotools/core/display.py create mode 100644 audio/audiotools/core/dsp.py create mode 100644 audio/audiotools/core/effects.py create mode 100644 audio/audiotools/core/ffmpeg.py create mode 100644 audio/audiotools/core/loudness.py create mode 100644 audio/audiotools/core/util.py create mode 100644 audio/audiotools/data/__init__.py create mode 100644 audio/audiotools/data/datasets.py create mode 100644 audio/audiotools/data/preprocess.py create mode 100644 audio/audiotools/data/transforms.py create mode 100644 audio/audiotools/metrics/__init__.py create mode 100644 audio/audiotools/metrics/quality.py create mode 100644 audio/audiotools/ml/__init__.py create mode 100644 audio/audiotools/ml/accelerator.py create mode 100644 audio/audiotools/ml/basemodel.py create mode 100644 audio/audiotools/ml/decorators.py create mode 100644 audio/audiotools/post.py create mode 100644 audio/audiotools/requirements.txt create mode 100644 audio/tests/audiotools/core/test_audio_signal.py create mode 100644 audio/tests/audiotools/core/test_bands.py create mode 100644 audio/tests/audiotools/core/test_display.py create mode 100644 audio/tests/audiotools/core/test_dsp.py create mode 100644 audio/tests/audiotools/core/test_effects.py create mode 100644 audio/tests/audiotools/core/test_fftconv.py create mode 100644 audio/tests/audiotools/core/test_grad.py create mode 100644 audio/tests/audiotools/core/test_highpass.py create mode 100644 audio/tests/audiotools/core/test_loudness.py create mode 100644 audio/tests/audiotools/core/test_lowpass.py create mode 100644 audio/tests/audiotools/core/test_util.py create mode 100644 audio/tests/audiotools/data/test_datasets.py create mode 100644 audio/tests/audiotools/data/test_preprocess.py create mode 100644 audio/tests/audiotools/data/test_transforms.py create mode 100644 audio/tests/audiotools/ml/test_decorators.py create mode 100644 audio/tests/audiotools/ml/test_model.py create mode 100644 audio/tests/audiotools/test_audiotools.sh create mode 100644 audio/tests/audiotools/test_post.py diff --git a/audio/audiotools/README.md b/audio/audiotools/README.md new file mode 100644 index 00000000000..a0eac367585 --- /dev/null +++ b/audio/audiotools/README.md @@ -0,0 +1,68 @@ +Audiotools is a comprehensive toolkit designed for audio processing and analysis, providing robust solutions for audio signal processing, data management, model training, and evaluation. + +### Directory Structure + +``` +. +├── audiotools +│ ├── README.md +│ ├── __init__.py +│ ├── core +│ │ ├── __init__.py +│ │ ├── _julius.py +│ │ ├── audio_signal.py +│ │ ├── display.py +│ │ ├── dsp.py +│ │ ├── effects.py +│ │ ├── ffmpeg.py +│ │ ├── loudness.py +│ │ └── util.py +│ ├── data +│ │ ├── __init__.py +│ │ ├── datasets.py +│ │ ├── preprocess.py +│ │ └── transforms.py +│ ├── metrics +│ │ ├── __init__.py +│ │ └── quality.py +│ ├── ml +│ │ ├── __init__.py +│ │ ├── accelerator.py +│ │ ├── basemodel.py +│ │ └── decorators.py +│ ├── requirements.txt +│ └── post.py +├── tests +│ └── audiotools +│ ├── core +│ │ ├── test_audio_signal.py +│ │ ├── test_bands.py +│ │ ├── test_display.py +│ │ ├── test_dsp.py +│ │ ├── test_effects.py +│ │ ├── test_fftconv.py +│ │ ├── test_grad.py +│ │ ├── test_highpass.py +│ │ ├── test_loudness.py +│ │ ├── test_lowpass.py +│ │ └── test_util.py +│ ├── data +│ │ ├── test_datasets.py +│ │ ├── test_preprocess.py +│ │ └── test_transforms.py +│ ├── ml +│ │ ├── test_decorators.py +│ │ └── test_model.py +│ └── test_post.py + +``` + +- **core**: Contains the core class AudioSignal, which is responsible for the fundamental representation and manipulation of audio signals. + +- **data**: Primarily dedicated to storing and processing datasets, including classes and functions for data preprocessing, ensuring efficient loading and transformation of audio data. + +- **metrics**: Implements functions for various audio evaluation metrics, enabling precise assessment of the performance of audio models and processing algorithms. + +- **ml**: Comprises classes and methods related to model training, supporting the construction, training, and optimization of machine learning models in the context of audio. + +This project aims to provide developers and researchers with an efficient and flexible framework to foster innovation and exploration across various domains of audio technology. diff --git a/audio/audiotools/__init__.py b/audio/audiotools/__init__.py new file mode 100644 index 00000000000..e8a201e85f0 --- /dev/null +++ b/audio/audiotools/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from . import metrics +from . import ml +from . import post +from .core import AudioSignal +from .core import highpass_filter +from .core import highpass_filters +from .core import Meter +from .core import STFTParams +from .core import util +from .data import datasets +from .data import preprocess +from .data import transforms diff --git a/audio/audiotools/core/__init__.py b/audio/audiotools/core/__init__.py new file mode 100644 index 00000000000..609d6a34a5d --- /dev/null +++ b/audio/audiotools/core/__init__.py @@ -0,0 +1,28 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from . import util +from ._julius import fft_conv1d +from ._julius import FFTConv1D +from ._julius import highpass_filter +from ._julius import highpass_filters +from ._julius import lowpass_filter +from ._julius import LowPassFilter +from ._julius import LowPassFilters +from ._julius import pure_tone +from ._julius import resample_frac +from ._julius import split_bands +from ._julius import SplitBands +from .audio_signal import AudioSignal +from .audio_signal import STFTParams +from .loudness import Meter diff --git a/audio/audiotools/core/_julius.py b/audio/audiotools/core/_julius.py new file mode 100644 index 00000000000..aef51f98ff8 --- /dev/null +++ b/audio/audiotools/core/_julius.py @@ -0,0 +1,666 @@ +# MIT License, Copyright (c) 2020 Alexandre Défossez. +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Modified from julius(https://github.com/adefossez/julius/tree/main/julius) +""" +Implementation of a FFT based 1D convolution in PaddlePaddle. +While FFT is used in some cases for small kernel sizes, it is not the default for long ones, e.g. 512. +This module implements efficient FFT based convolutions for such cases. A typical +application is for evaluating FIR filters with a long receptive field, typically +evaluated with a stride of 1. +""" +import inspect +import math +import sys +import typing +from typing import Optional +from typing import Sequence + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from paddlespeech.t2s.modules import fft_conv1d +from paddlespeech.t2s.modules import FFTConv1D +from paddlespeech.utils import satisfy_paddle_version + +__all__ = [ + 'highpass_filter', 'highpass_filters', 'lowpass_filter', 'LowPassFilter', + 'LowPassFilters', 'pure_tone', 'resample_frac', 'split_bands', 'SplitBands' +] + + +def simple_repr(obj, attrs: Optional[Sequence[str]]=None, overrides: dict={}): + """ + Return a simple representation string for `obj`. + If `attrs` is not None, it should be a list of attributes to include. + """ + params = inspect.signature(obj.__class__).parameters + attrs_repr = [] + if attrs is None: + attrs = list(params.keys()) + for attr in attrs: + display = False + if attr in overrides: + value = overrides[attr] + elif hasattr(obj, attr): + value = getattr(obj, attr) + else: + continue + if attr in params: + param = params[attr] + if param.default is inspect._empty or value != param.default: # type: ignore + display = True + else: + display = True + + if display: + attrs_repr.append(f"{attr}={value}") + return f"{obj.__class__.__name__}({','.join(attrs_repr)})" + + +def sinc(x: paddle.Tensor): + """ + Implementation of sinc, i.e. sin(x) / x + + __Warning__: the input is not multiplied by `pi`! + """ + if satisfy_paddle_version("3.0"): + return paddle.sinc(x) + + return paddle.where( + x == 0, + paddle.to_tensor(1.0, dtype=x.dtype, place=x.place), + paddle.sin(x) / x, ) + + +class ResampleFrac(paddle.nn.Layer): + """ + Resampling from the sample rate `old_sr` to `new_sr`. + """ + + def __init__(self, + old_sr: int, + new_sr: int, + zeros: int=24, + rolloff: float=0.945): + """ + Args: + old_sr (int): sample rate of the input signal x. + new_sr (int): sample rate of the output. + zeros (int): number of zero crossing to keep in the sinc filter. + rolloff (float): use a lowpass filter that is `rolloff * new_sr / 2`, + to ensure sufficient margin due to the imperfection of the FIR filter used. + Lowering this value will reduce anti-aliasing, but will reduce some of the + highest frequencies. + + Shape: + + - Input: `[*, T]` + - Output: `[*, T']` with `T' = int(new_sr * T / old_sr)` + + + .. caution:: + After dividing `old_sr` and `new_sr` by their GCD, both should be small + for this implementation to be fast. + + >>> import paddle + >>> resample = ResampleFrac(4, 5) + >>> x = paddle.randn([1000]) + >>> print(len(resample(x))) + 1250 + """ + super().__init__() + if not isinstance(old_sr, int) or not isinstance(new_sr, int): + raise ValueError("old_sr and new_sr should be integers") + gcd = math.gcd(old_sr, new_sr) + self.old_sr = old_sr // gcd + self.new_sr = new_sr // gcd + self.zeros = zeros + self.rolloff = rolloff + + self._init_kernels() + + def _init_kernels(self): + if self.old_sr == self.new_sr: + return + + kernels = [] + sr = min(self.new_sr, self.old_sr) + sr *= self.rolloff + + self._width = math.ceil(self.zeros * self.old_sr / sr) + idx = paddle.arange( + -self._width, self._width + self.old_sr, dtype="float32") + for i in range(self.new_sr): + t = (-i / self.new_sr + idx / self.old_sr) * sr + t = paddle.clip(t, -self.zeros, self.zeros) + t *= math.pi + window = paddle.cos(t / self.zeros / 2)**2 + kernel = sinc(t) * window + # Renormalize kernel to ensure a constant signal is preserved. + kernel = kernel / kernel.sum() + kernels.append(kernel) + + _kernel = paddle.stack(kernels).reshape([self.new_sr, 1, -1]) + self.kernel = self.create_parameter( + shape=_kernel.shape, + dtype=_kernel.dtype, ) + self.kernel.set_value(_kernel) + + def forward( + self, + x: paddle.Tensor, + output_length: Optional[int]=None, + full: bool=False, ): + """ + Resample x. + Args: + x (Tensor): signal to resample, time should be the last dimension + output_length (None or int): This can be set to the desired output length + (last dimension). Allowed values are between 0 and + ceil(length * new_sr / old_sr). When None (default) is specified, the + floored output length will be used. In order to select the largest possible + size, use the `full` argument. + full (bool): return the longest possible output from the input. This can be useful + if you chain resampling operations, and want to give the `output_length` only + for the last one, while passing `full=True` to all the other ones. + """ + if self.old_sr == self.new_sr: + return x + shape = x.shape + _dtype = x.dtype + length = x.shape[-1] + x = x.reshape([-1, length]) + x = F.pad( + x.unsqueeze(1), + [self._width, self._width + self.old_sr], + mode="replicate", + data_format="NCL", ).astype(self.kernel.dtype) + ys = F.conv1d(x, self.kernel, stride=self.old_sr, data_format="NCL") + y = ys.transpose( + [0, 2, 1]).reshape(list(shape[:-1]) + [-1]).astype(_dtype) + + float_output_length = paddle.to_tensor( + self.new_sr * length / self.old_sr, dtype="float32") + max_output_length = paddle.ceil(float_output_length).astype("int64") + default_output_length = paddle.floor(float_output_length).astype( + "int64") + + if output_length is None: + applied_output_length = (max_output_length + if full else default_output_length) + elif output_length < 0 or output_length > max_output_length: + raise ValueError( + f"output_length must be between 0 and {max_output_length.numpy()}" + ) + else: + applied_output_length = paddle.to_tensor( + output_length, dtype="int64") + if full: + raise ValueError( + "You cannot pass both full=True and output_length") + return y[..., :applied_output_length] + + def __repr__(self): + return simple_repr(self) + + +def resample_frac( + x: paddle.Tensor, + old_sr: int, + new_sr: int, + zeros: int=24, + rolloff: float=0.945, + output_length: Optional[int]=None, + full: bool=False, ): + """ + Functional version of `ResampleFrac`, refer to its documentation for more information. + + ..warning:: + If you call repeatidly this functions with the same sample rates, then the + resampling kernel will be recomputed everytime. For best performance, you should use + and cache an instance of `ResampleFrac`. + """ + return ResampleFrac(old_sr, new_sr, zeros, rolloff)(x, output_length, full) + + +def pad_to(tensor: paddle.Tensor, + target_length: int, + mode: str="constant", + value: float=0.0): + """ + Pad the given tensor to the given length, with 0s on the right. + """ + return F.pad( + tensor, (0, target_length - tensor.shape[-1]), + mode=mode, + value=value, + data_format="NCL") + + +def pure_tone(freq: float, sr: float=128, dur: float=4, device=None): + """ + Return a pure tone, i.e. cosine. + + Args: + freq (float): frequency (in Hz) + sr (float): sample rate (in Hz) + dur (float): duration (in seconds) + """ + time = paddle.arange(int(sr * dur), dtype="float32") / sr + return paddle.cos(2 * math.pi * freq * time) + + +class LowPassFilters(nn.Layer): + """ + Bank of low pass filters. + """ + + def __init__(self, + cutoffs: Sequence[float], + stride: int=1, + pad: bool=True, + zeros: float=8, + fft: Optional[bool]=None, + dtype="float32"): + super().__init__() + self.cutoffs = list(cutoffs) + if min(self.cutoffs) < 0: + raise ValueError("Minimum cutoff must be larger than zero.") + if max(self.cutoffs) > 0.5: + raise ValueError("A cutoff above 0.5 does not make sense.") + self.stride = stride + self.pad = pad + self.zeros = zeros + self.half_size = int(zeros / min([c for c in self.cutoffs if c > 0]) / + 2) + if fft is None: + fft = self.half_size > 32 + self.fft = fft + + # Create filters + window = paddle.audio.functional.get_window( + "hann", 2 * self.half_size + 1, fftbins=False, dtype=dtype) + time = paddle.arange( + -self.half_size, self.half_size + 1, dtype="float32") + filters = [] + for cutoff in cutoffs: + if cutoff == 0: + filter_ = paddle.zeros_like(time) + else: + filter_ = 2 * cutoff * window * sinc(2 * cutoff * math.pi * + time) + # Normalize filter + filter_ /= paddle.sum(filter_) + filters.append(filter_) + filters = paddle.stack(filters)[:, None] + self.filters = self.create_parameter( + shape=filters.shape, + default_initializer=nn.initializer.Constant(value=0.0), + dtype="float32", + is_bias=False, + attr=paddle.ParamAttr(trainable=False), ) + self.filters.set_value(filters) + + def forward(self, _input): + shape = list(_input.shape) + _input = _input.reshape([-1, 1, shape[-1]]) + if self.pad: + _input = F.pad( + _input, (self.half_size, self.half_size), + mode="replicate", + data_format="NCL") + if self.fft: + out = fft_conv1d(_input, self.filters, stride=self.stride) + else: + out = F.conv1d(_input, self.filters, stride=self.stride) + + shape.insert(0, len(self.cutoffs)) + shape[-1] = out.shape[-1] + return out.transpose([1, 0, 2]).reshape(shape) + + +class LowPassFilter(nn.Layer): + """ + Same as `LowPassFilters` but applies a single low pass filter. + """ + + def __init__(self, + cutoff: float, + stride: int=1, + pad: bool=True, + zeros: float=8, + fft: Optional[bool]=None): + super().__init__() + self._lowpasses = LowPassFilters([cutoff], stride, pad, zeros, fft) + + @property + def cutoff(self): + return self._lowpasses.cutoffs[0] + + @property + def stride(self): + return self._lowpasses.stride + + @property + def pad(self): + return self._lowpasses.pad + + @property + def zeros(self): + return self._lowpasses.zeros + + @property + def fft(self): + return self._lowpasses.fft + + def forward(self, _input): + return self._lowpasses(_input)[0] + + +def lowpass_filters( + _input: paddle.Tensor, + cutoffs: Sequence[float], + stride: int=1, + pad: bool=True, + zeros: float=8, + fft: Optional[bool]=None, ): + """ + Functional version of `LowPassFilters`, refer to this class for more information. + """ + return LowPassFilters(cutoffs, stride, pad, zeros, fft)(_input) + + +def lowpass_filter(_input: paddle.Tensor, + cutoff: float, + stride: int=1, + pad: bool=True, + zeros: float=8, + fft: Optional[bool]=None): + """ + Same as `lowpass_filters` but with a single cutoff frequency. + Output will not have a dimension inserted in the front. + """ + return lowpass_filters(_input, [cutoff], stride, pad, zeros, fft)[0] + + +class HighPassFilters(paddle.nn.Layer): + """ + Bank of high pass filters. See `julius.lowpass.LowPassFilters` for more + details on the implementation. + + Args: + cutoffs (list[float]): list of cutoff frequencies, in [0, 0.5] expressed as `f/f_s` where + f_s is the samplerate and `f` is the cutoff frequency. + The upper limit is 0.5, because a signal sampled at `f_s` contains only + frequencies under `f_s / 2`. + stride (int): how much to decimate the output. Probably not a good idea + to do so with a high pass filters though... + pad (bool): if True, appropriately pad the _input with zero over the edge. If `stride=1`, + the output will have the same length as the _input. + zeros (float): Number of zero crossings to keep. + Controls the receptive field of the Finite Impulse Response filter. + For filters with low cutoff frequency, e.g. 40Hz at 44.1kHz, + it is a bad idea to set this to a high value. + This is likely appropriate for most use. Lower values + will result in a faster filter, but with a slower attenuation around the + cutoff frequency. + fft (bool or None): if True, uses `julius.fftconv` rather than PyTorch convolutions. + If False, uses PyTorch convolutions. If None, either one will be chosen automatically + depending on the effective filter size. + + + ..warning:: + All the filters will use the same filter size, aligned on the lowest + frequency provided. If you combine a lot of filters with very diverse frequencies, it might + be more efficient to split them over multiple modules with similar frequencies. + + Shape: + + - Input: `[*, T]` + - Output: `[F, *, T']`, with `T'=T` if `pad` is True and `stride` is 1, and + `F` is the numer of cutoff frequencies. + + >>> highpass = HighPassFilters([1/4]) + >>> x = paddle.randn([4, 12, 21, 1024]) + >>> list(highpass(x).shape) + [1, 4, 12, 21, 1024] + """ + + def __init__(self, + cutoffs: Sequence[float], + stride: int=1, + pad: bool=True, + zeros: float=8, + fft: Optional[bool]=None): + super().__init__() + self._lowpasses = LowPassFilters(cutoffs, stride, pad, zeros, fft) + + @property + def cutoffs(self): + return self._lowpasses.cutoffs + + @property + def stride(self): + return self._lowpasses.stride + + @property + def pad(self): + return self._lowpasses.pad + + @property + def zeros(self): + return self._lowpasses.zeros + + @property + def fft(self): + return self._lowpasses.fft + + def forward(self, _input): + lows = self._lowpasses(_input) + + # We need to extract the right portion of the _input in case + # pad is False or stride > 1 + if self.pad: + start, end = 0, _input.shape[-1] + else: + start = self._lowpasses.half_size + end = -start + _input = _input[..., start:end:self.stride] + highs = _input - lows + return highs + + +class HighPassFilter(paddle.nn.Layer): + """ + Same as `HighPassFilters` but applies a single high pass filter. + + Shape: + + - Input: `[*, T]` + - Output: `[*, T']`, with `T'=T` if `pad` is True and `stride` is 1. + + >>> highpass = HighPassFilter(1/4, stride=1) + >>> x = paddle.randn([4, 124]) + >>> list(highpass(x).shape) + [4, 124] + """ + + def __init__(self, + cutoff: float, + stride: int=1, + pad: bool=True, + zeros: float=8, + fft: Optional[bool]=None): + super().__init__() + self._highpasses = HighPassFilters([cutoff], stride, pad, zeros, fft) + + @property + def cutoff(self): + return self._highpasses.cutoffs[0] + + @property + def stride(self): + return self._highpasses.stride + + @property + def pad(self): + return self._highpasses.pad + + @property + def zeros(self): + return self._highpasses.zeros + + @property + def fft(self): + return self._highpasses.fft + + def forward(self, _input): + return self._highpasses(_input)[0] + + +def highpass_filters( + _input: paddle.Tensor, + cutoffs: Sequence[float], + stride: int=1, + pad: bool=True, + zeros: float=8, + fft: Optional[bool]=None, ): + """ + Functional version of `HighPassFilters`, refer to this class for more information. + """ + return HighPassFilters(cutoffs, stride, pad, zeros, fft)(_input) + + +def highpass_filter(_input: paddle.Tensor, + cutoff: float, + stride: int=1, + pad: bool=True, + zeros: float=8, + fft: Optional[bool]=None): + """ + Functional version of `HighPassFilter`, refer to this class for more information. + Output will not have a dimension inserted in the front. + """ + return highpass_filters(_input, [cutoff], stride, pad, zeros, fft)[0] + + +class SplitBands(paddle.nn.Layer): + """ + Decomposes a signal over the given frequency bands in the waveform domain using + a cascade of low pass filters as implemented by `julius.lowpass.LowPassFilters`. + You can either specify explicitly the frequency cutoffs, or just the number of bands, + in which case the frequency cutoffs will be spread out evenly in mel scale. + + Args: + sample_rate (float): Sample rate of the input signal in Hz. + n_bands (int or None): number of bands, when not giving them explicitly with `cutoffs`. + In that case, the cutoff frequencies will be evenly spaced in mel-space. + cutoffs (list[float] or None): list of frequency cutoffs in Hz. + pad (bool): if True, appropriately pad the input with zero over the edge. If `stride=1`, + the output will have the same length as the input. + zeros (float): Number of zero crossings to keep. See `LowPassFilters` for more informations. + fft (bool or None): See `LowPassFilters` for more info. + + ..note:: + The sum of all the bands will always be the input signal. + + ..warning:: + Unlike `julius.lowpass.LowPassFilters`, the cutoffs frequencies must be provided in Hz along + with the sample rate. + + Shape: + + - Input: `[*, T]` + - Output: `[B, *, T']`, with `T'=T` if `pad` is True. + If `n_bands` was provided, `B = n_bands` otherwise `B = len(cutoffs) + 1` + + >>> bands = SplitBands(sample_rate=128, n_bands=10) + >>> x = paddle.randn(shape=[6, 4, 1024]) + >>> list(bands(x).shape) + [10, 6, 4, 1024] + """ + + def __init__( + self, + sample_rate: float, + n_bands: Optional[int]=None, + cutoffs: Optional[Sequence[float]]=None, + pad: bool=True, + zeros: float=8, + fft: Optional[bool]=None, ): + super().__init__() + if (cutoffs is None) + (n_bands is None) != 1: + raise ValueError( + "You must provide either n_bands, or cutoffs, but not both.") + + self.sample_rate = sample_rate + self.n_bands = n_bands + self._cutoffs = list(cutoffs) if cutoffs is not None else None + self.pad = pad + self.zeros = zeros + self.fft = fft + + if cutoffs is None: + if n_bands is None: + raise ValueError("You must provide one of n_bands or cutoffs.") + if not n_bands >= 1: + raise ValueError( + f"n_bands must be greater than one (got {n_bands})") + cutoffs = paddle.audio.functional.mel_frequencies( + n_bands + 1, 0, sample_rate / 2)[1:-1] + else: + if max(cutoffs) > 0.5 * sample_rate: + raise ValueError( + "A cutoff above sample_rate/2 does not make sense.") + if len(cutoffs) > 0: + self.lowpass = LowPassFilters( + [c / sample_rate for c in cutoffs], + pad=pad, + zeros=zeros, + fft=fft) + else: + self.lowpass = None # type: ignore + + def forward(self, input): + if self.lowpass is None: + return input[None] + lows = self.lowpass(input) + low = lows[0] + bands = [low] + for low_and_band in lows[1:]: + # Get a bandpass filter by subtracting lowpasses + band = low_and_band - low + bands.append(band) + low = low_and_band + # Last band is whatever is left in the signal + bands.append(input - low) + return paddle.stack(bands) + + @property + def cutoffs(self): + if self._cutoffs is not None: + return self._cutoffs + elif self.lowpass is not None: + return [c * self.sample_rate for c in self.lowpass.cutoffs] + else: + return [] + + +def split_bands( + signal: paddle.Tensor, + sample_rate: float, + n_bands: Optional[int]=None, + cutoffs: Optional[Sequence[float]]=None, + pad: bool=True, + zeros: float=8, + fft: Optional[bool]=None, ): + """ + Functional version of `SplitBands`, refer to this class for more information. + + >>> x = paddle.randn(shape=[6, 4, 1024]) + >>> list(split_bands(x, sample_rate=64, cutoffs=[12, 24]).shape) + [3, 6, 4, 1024] + """ + return SplitBands(sample_rate, n_bands, cutoffs, pad, zeros, fft)(signal) diff --git a/audio/audiotools/core/audio_signal.py b/audio/audiotools/core/audio_signal.py new file mode 100644 index 00000000000..74e8cac67c6 --- /dev/null +++ b/audio/audiotools/core/audio_signal.py @@ -0,0 +1,1750 @@ +# MIT License, Copyright (c) 2023-Present, Descript. +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/audiotools/core/audio_signal.py) +import copy +import functools +import hashlib +import math +import pathlib +import tempfile +import typing +import warnings +from collections import namedtuple +from pathlib import Path +from typing import Optional + +import librosa +import numpy as np +import paddle +import soundfile + +from . import util +from ._julius import resample_frac +from .display import DisplayMixin +from .dsp import DSPMixin +from .effects import EffectMixin +from .effects import ImpulseResponseMixin +from .ffmpeg import FFMPEGMixin +from .loudness import LoudnessMixin + +__all__ = ['STFTParams', 'AudioSignal'] + + +def create_dct(n_mfcc: int, n_mels: int, norm: Optional[str]) -> paddle.Tensor: + r"""Create a DCT transformation matrix with shape (``n_mels``, ``n_mfcc``), + normalized depending on norm. + + Args: + n_mfcc (int): Number of mfc coefficients to retain + n_mels (int): Number of mel filterbanks + norm (str or None): Norm to use (either "ortho" or None) + + Returns: + paddle.Tensor: The transformation matrix, to be right-multiplied to + row-wise data of size (``n_mels``, ``n_mfcc``). + """ + + if norm is not None and norm != "ortho": + raise ValueError('norm must be either "ortho" or None') + + # http://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II + n = paddle.arange(float(n_mels)) + k = paddle.arange(float(n_mfcc)).unsqueeze([1]) + dct = paddle.cos(math.pi / float(n_mels) * (n + 0.5) * + k) # size (n_mfcc, n_mels) + + if norm is None: + dct *= 2.0 + else: + dct[0] *= 1.0 / math.sqrt(2.0) + dct *= math.sqrt(2.0 / float(n_mels)) + return dct.transpose([1, 0]) + + +STFTParams = namedtuple( + "STFTParams", + [ + "window_length", + "hop_length", + "window_type", + "match_stride", + "padding_type", + ], ) +""" +STFTParams object is a container that holds STFT parameters - window_length, +hop_length, and window_type. Not all parameters need to be specified. Ones that +are not specified will be inferred by the AudioSignal parameters. + +Parameters +---------- +window_length : int, optional + Window length of STFT, by default ``0.032 * self.sample_rate``. +hop_length : int, optional + Hop length of STFT, by default ``window_length // 4``. +window_type : str, optional + Type of window to use, by default ``sqrt\_hann``. +match_stride : bool, optional + Whether to match the stride of convolutional layers, by default False +padding_type : str, optional + Type of padding to use, by default 'reflect' +""" +STFTParams.__new__.__defaults__ = (None, None, None, None, None) + + +class AudioSignal( + EffectMixin, + LoudnessMixin, + ImpulseResponseMixin, + DSPMixin, + DisplayMixin, + FFMPEGMixin, ): + """This is the core object of this library. Audio is always + loaded into an AudioSignal, which then enables all the features + of this library, including audio augmentations, I/O, playback, + and more. + + The structure of this object is that the base functionality + is defined in ``core/audio_signal.py``, while extensions to + that functionality are defined in the other ``core/*.py`` + files. For example, all the display-based functionality + (e.g. plot spectrograms, waveforms, write to tensorboard) + are in ``core/display.py``. + + Parameters + ---------- + audio_path_or_array : typing.Union[paddle.Tensor, str, Path, np.ndarray] + Object to create AudioSignal from. Can be a tensor, numpy array, + or a path to a file. The file is always reshaped to + sample_rate : int, optional + Sample rate of the audio. If different from underlying file, resampling is + performed. If passing in an array or tensor, this must be defined, + by default None + stft_params : STFTParams, optional + Parameters of STFT to use. , by default None + offset : float, optional + Offset in seconds to read from file, by default 0 + duration : float, optional + Duration in seconds to read from file, by default None + device : str, optional + Device to load audio onto, by default None + + Examples + -------- + Loading an AudioSignal from an array, at a sample rate of + 44100. + + >>> signal = AudioSignal(paddle.randn([5*44100]), 44100) + + Note, the signal is reshaped to have a batch size, and one + audio channel: + + >>> print(signal.shape) + (1, 1, 44100) + + You can treat AudioSignals like tensors, and many of the same + functions you might use on tensors are defined for AudioSignals + as well: + + >>> signal.to("cuda") + >>> signal.cuda() + >>> signal.clone() + >>> signal.detach() + + Indexing AudioSignals returns an AudioSignal: + + >>> signal[..., 3*44100:4*44100] + + The above signal is 1 second long, and is also an AudioSignal. + """ + + def __init__( + self, + audio_path_or_array: typing.Union[paddle.Tensor, str, Path, + np.ndarray], + sample_rate: int=None, + stft_params: STFTParams=None, + offset: float=0, + duration: float=None, + device: str=None, ): + # + audio_path = None + audio_array = None + + if isinstance(audio_path_or_array, str): + audio_path = audio_path_or_array + elif isinstance(audio_path_or_array, pathlib.Path): + audio_path = audio_path_or_array + elif isinstance(audio_path_or_array, np.ndarray): + audio_array = audio_path_or_array + elif paddle.is_tensor(audio_path_or_array): + audio_array = audio_path_or_array + else: + raise ValueError("audio_path_or_array must be either a Path, " + "string, numpy array, or paddle Tensor!") + + self.path_to_file = None + + self.audio_data = None + self.sources = None # List of AudioSignal objects. + self.stft_data = None + if audio_path is not None: + self.load_from_file( + audio_path, offset=offset, duration=duration, device=device) + elif audio_array is not None: + assert sample_rate is not None, "Must set sample rate!" + self.load_from_array(audio_array, sample_rate, device=device) + + self.window = None + self.stft_params = stft_params + + self.metadata = { + "offset": offset, + "duration": duration, + } + + @property + def path_to_input_file( + self, ): + """ + Path to input file, if it exists. + Alias to ``path_to_file`` for backwards compatibility + """ + return self.path_to_file + + @classmethod + def excerpt( + cls, + audio_path: typing.Union[str, Path], + offset: float=None, + duration: float=None, + state: typing.Union[np.random.RandomState, int]=None, + **kwargs, ): + """Randomly draw an excerpt of ``duration`` seconds from an + audio file specified at ``audio_path``, between ``offset`` seconds + and end of file. ``state`` can be used to seed the random draw. + + Parameters + ---------- + audio_path : typing.Union[str, Path] + Path to audio file to grab excerpt from. + offset : float, optional + Lower bound for the start time, in seconds drawn from + the file, by default None. + duration : float, optional + Duration of excerpt, in seconds, by default None + state : typing.Union[np.random.RandomState, int], optional + RandomState or seed of random state, by default None + + Returns + ------- + AudioSignal + AudioSignal containing excerpt. + + Examples + -------- + >>> signal = AudioSignal.excerpt("path/to/audio", duration=5) + """ + info = util.info(audio_path) + total_duration = info.duration + + state = util.random_state(state) + lower_bound = 0 if offset is None else offset + upper_bound = max(total_duration - duration, 0) + offset = state.uniform(lower_bound, upper_bound) + + signal = cls(audio_path, offset=offset, duration=duration, **kwargs) + signal.metadata["offset"] = offset + signal.metadata["duration"] = duration + + return signal + + @classmethod + def salient_excerpt( + cls, + audio_path: typing.Union[str, Path], + loudness_cutoff: float=None, + num_tries: int=8, + state: typing.Union[np.random.RandomState, int]=None, + **kwargs, ): + """Similar to AudioSignal.excerpt, except it extracts excerpts only + if they are above a specified loudness threshold, which is computed via + a fast LUFS routine. + + Parameters + ---------- + audio_path : typing.Union[str, Path] + Path to audio file to grab excerpt from. + loudness_cutoff : float, optional + Loudness threshold in dB. Typical values are ``-40, -60``, + etc, by default None + num_tries : int, optional + Number of tries to grab an excerpt above the threshold + before giving up, by default 8. + state : typing.Union[np.random.RandomState, int], optional + RandomState or seed of random state, by default None + kwargs : dict + Keyword arguments to AudioSignal.excerpt + + Returns + ------- + AudioSignal + AudioSignal containing excerpt. + + + .. warning:: + if ``num_tries`` is set to None, ``salient_excerpt`` may try forever, which can + result in an infinite loop if ``audio_path`` does not have + any loud enough excerpts. + + Examples + -------- + >>> signal = AudioSignal.salient_excerpt( + "path/to/audio", + loudness_cutoff=-40, + duration=5 + ) + """ + state = util.random_state(state) + if loudness_cutoff is None: + excerpt = cls.excerpt(audio_path, state=state, **kwargs) + else: + loudness = -np.inf + num_try = 0 + while loudness <= loudness_cutoff: + excerpt = cls.excerpt(audio_path, state=state, **kwargs) + loudness = excerpt.loudness() + num_try += 1 + if num_tries is not None and num_try >= num_tries: + break + return excerpt + + @classmethod + def zeros( + cls, + duration: float, + sample_rate: int, + num_channels: int=1, + batch_size: int=1, + **kwargs, ): + """Helper function create an AudioSignal of all zeros. + + Parameters + ---------- + duration : float + Duration of AudioSignal + sample_rate : int + Sample rate of AudioSignal + num_channels : int, optional + Number of channels, by default 1 + batch_size : int, optional + Batch size, by default 1 + + Returns + ------- + AudioSignal + AudioSignal containing all zeros. + + Examples + -------- + Generate 5 seconds of all zeros at a sample rate of 44100. + + >>> signal = AudioSignal.zeros(5.0, 44100) + """ + n_samples = int(duration * sample_rate) + return cls( + paddle.zeros([batch_size, num_channels, n_samples]), + sample_rate, + **kwargs, ) + + @classmethod + def wave( + cls, + frequency: float, + duration: float, + sample_rate: int, + num_channels: int=1, + shape: str="sine", + **kwargs, ): + """ + Generate a waveform of a given frequency and shape. + + Parameters + ---------- + frequency : float + Frequency of the waveform + duration : float + Duration of the waveform + sample_rate : int + Sample rate of the waveform + num_channels : int, optional + Number of channels, by default 1 + shape : str, optional + Shape of the waveform, by default "saw" + One of "sawtooth", "square", "sine", "triangle" + kwargs : dict + Keyword arguments to AudioSignal + """ + n_samples = int(duration * sample_rate) + t = np.linspace(0, duration, n_samples) + if shape == "sawtooth": + from scipy.signal import sawtooth + + wave_data = sawtooth(2 * np.pi * frequency * t, 0.5) + elif shape == "square": + from scipy.signal import square + + wave_data = square(2 * np.pi * frequency * t) + elif shape == "sine": + wave_data = np.sin(2 * np.pi * frequency * t) + elif shape == "triangle": + from scipy.signal import sawtooth + + # frequency is doubled by the abs call, so omit the 2 in 2pi + wave_data = sawtooth(np.pi * frequency * t, 0.5) + wave_data = -np.abs(wave_data) * 2 + 1 + else: + raise ValueError(f"Invalid shape {shape}") + + wave_data = paddle.to_tensor(wave_data, dtype=paddle.float32) + wave_data = wave_data[None, None].expand([1, num_channels, -1]) + return cls(wave_data, sample_rate, **kwargs) + + @classmethod + def batch( + cls, + audio_signals: list, + pad_signals: bool=False, + truncate_signals: bool=False, + resample: bool=False, + dim: int=0, ): + """Creates a batched AudioSignal from a list of AudioSignals. + + Parameters + ---------- + audio_signals : list[AudioSignal] + List of AudioSignal objects + pad_signals : bool, optional + Whether to pad signals to length of the maximum length + AudioSignal in the list, by default False + truncate_signals : bool, optional + Whether to truncate signals to length of shortest length + AudioSignal in the list, by default False + resample : bool, optional + Whether to resample AudioSignal to the sample rate of + the first AudioSignal in the list, by default False + dim : int, optional + Dimension along which to batch the signals. + + Returns + ------- + AudioSignal + Batched AudioSignal. + + Raises + ------ + RuntimeError + If not all AudioSignals are the same sample rate, and + ``resample=False``, an error is raised. + RuntimeError + If not all AudioSignals are the same the length, and + both ``pad_signals=False`` and ``truncate_signals=False``, + an error is raised. + + Examples + -------- + Batching a bunch of random signals: + + >>> signal_list = [AudioSignal(paddle.randn([44100]), 44100) for _ in range(10)] + >>> signal = AudioSignal.batch(signal_list) + >>> print(signal.shape) + (10, 1, 44100) + + """ + signal_lengths = [x.signal_length for x in audio_signals] + sample_rates = [x.sample_rate for x in audio_signals] + + if len(set(sample_rates)) != 1: + if resample: + for x in audio_signals: + x.resample(sample_rates[0]) + else: + raise RuntimeError( + f"Not all signals had the same sample rate! Got {sample_rates}. " + f"All signals must have the same sample rate, or resample must be True. " + ) + + if len(set(signal_lengths)) != 1: + if pad_signals: + max_length = max(signal_lengths) + for x in audio_signals: + pad_len = max_length - x.signal_length + x.zero_pad(0, pad_len) + elif truncate_signals: + min_length = min(signal_lengths) + for x in audio_signals: + x.truncate_samples(min_length) + else: + raise RuntimeError( + f"Not all signals had the same length! Got {signal_lengths}. " + f"All signals must be the same length, or pad_signals/truncate_signals " + f"must be True. ") + # Concatenate along the specified dimension (default 0) + audio_data = paddle.concat( + [x.audio_data for x in audio_signals], axis=dim) + audio_paths = [x.path_to_file for x in audio_signals] + + batched_signal = cls( + audio_data, + sample_rate=audio_signals[0].sample_rate, ) + batched_signal.path_to_file = audio_paths + return batched_signal + + # I/O + def load_from_file( + self, + audio_path: typing.Union[str, Path], + offset: float, + duration: float, + device: str="cpu", ): + """Loads data from file. Used internally when AudioSignal + is instantiated with a path to a file. + + Parameters + ---------- + audio_path : typing.Union[str, Path] + Path to file + offset : float + Offset in seconds + duration : float + Duration in seconds + device : str, optional + Device to put AudioSignal on, by default "cpu" + + Returns + ------- + AudioSignal + AudioSignal loaded from file + """ + # need `ffmpeg` + data, sample_rate = librosa.load( + audio_path, + offset=offset, + duration=duration, + sr=None, + mono=False, ) + data = util.ensure_tensor(data) + if data.shape[-1] == 0: + raise RuntimeError( + f"Audio file {audio_path} with offset {offset} and duration {duration} is empty!" + ) + + if data.ndim < 2: + data = data.unsqueeze(0) + if data.ndim < 3: + data = data.unsqueeze(0) + self.audio_data = data + + self.original_signal_length = self.signal_length + + self.sample_rate = sample_rate + self.path_to_file = audio_path + return self.to(device) + + def load_from_array( + self, + audio_array: typing.Union[paddle.Tensor, np.ndarray], + sample_rate: int, + device: str="cpu", ): + """Loads data from array, reshaping it to be exactly 3 + dimensions. Used internally when AudioSignal is called + with a tensor or an array. + + Parameters + ---------- + audio_array : typing.Union[paddle.Tensor, np.ndarray] + Array/tensor of audio of samples. + sample_rate : int + Sample rate of audio + device : str, optional + Device to move audio onto, by default "cpu" + + Returns + ------- + AudioSignal + AudioSignal loaded from array + """ + audio_data = util.ensure_tensor(audio_array) + + if str(audio_data.dtype) == paddle.float64: + audio_data = audio_data.astype("float32") + + if audio_data.ndim < 2: + audio_data = audio_data.unsqueeze(0) + if audio_data.ndim < 3: + audio_data = audio_data.unsqueeze(0) + self.audio_data = audio_data + + self.original_signal_length = self.signal_length + + self.sample_rate = sample_rate + + return self + + def write(self, audio_path: typing.Union[str, Path]): + """Writes audio to a file. Only writes the audio + that is in the very first item of the batch. To write other items + in the batch, index the signal along the batch dimension + before writing. After writing, the signal's ``path_to_file`` + attribute is updated to the new path. + + Parameters + ---------- + audio_path : typing.Union[str, Path] + Path to write audio to. + + Returns + ------- + AudioSignal + Returns original AudioSignal, so you can use this in a fluent + interface. + + Examples + -------- + Creating and writing a signal to disk: + + >>> signal = AudioSignal(paddle.randn([10, 1, 44100]), 44100) + >>> signal.write("/tmp/out.wav") + + Writing a different element of the batch: + + >>> signal[5].write("/tmp/out.wav") + + Using this in a fluent interface: + + >>> signal.write("/tmp/original.wav").low_pass(4000).write("/tmp/lowpass.wav") + + """ + if self.audio_data[0].abs().max() > 1: + warnings.warn("Audio amplitude > 1 clipped when saving") + soundfile.write( + str(audio_path), self.audio_data[0].numpy().T, self.sample_rate) + + self.path_to_file = audio_path + return self + + def deepcopy(self): + """Copies the signal and all of its attributes. + + Returns + ------- + AudioSignal + Deep copy of the audio signal. + """ + return copy.deepcopy(self) + + def copy(self): + """Shallow copy of signal. + + Returns + ------- + AudioSignal + Shallow copy of the audio signal. + """ + return copy.copy(self) + + def clone(self): + """Clones all tensors contained in the AudioSignal, + and returns a copy of the signal with everything + cloned. Useful when using AudioSignal within autograd + computation graphs. + + Relevant attributes are the stft data, the audio data, + and the loudness of the file. + + Returns + ------- + AudioSignal + Clone of AudioSignal. + """ + clone = type(self)( + self.audio_data.clone(), + self.sample_rate, + stft_params=self.stft_params, ) + if self.stft_data is not None: + clone.stft_data = self.stft_data.clone() + if self._loudness is not None: + clone._loudness = self._loudness.clone() + clone.path_to_file = copy.deepcopy(self.path_to_file) + clone.metadata = copy.deepcopy(self.metadata) + return clone + + def detach(self): + """Detaches tensors contained in AudioSignal. + + Relevant attributes are the stft data, the audio data, + and the loudness of the file. + + Returns + ------- + AudioSignal + Same signal, but with all tensors detached. + """ + if self._loudness is not None: + self._loudness = self._loudness.detach() + if self.stft_data is not None: + self.stft_data = self.stft_data.detach() + + self.audio_data = self.audio_data.detach() + return self + + def hash(self): + """Writes the audio data to a temporary file, and then + hashes it using hashlib. Useful for creating a file + name based on the audio content. + + Returns + ------- + str + Hash of audio data. + + Examples + -------- + Creating a signal, and writing it to a unique file name: + + >>> signal = AudioSignal(paddle.randn([44100]), 44100) + >>> hash = signal.hash() + >>> signal.write(f"{hash}.wav") + + """ + with tempfile.NamedTemporaryFile(suffix=".wav") as f: + self.write(f.name) + h = hashlib.sha256() + b = bytearray(128 * 1024) + mv = memoryview(b) + with open(f.name, "rb", buffering=0) as f: + for n in iter(lambda: f.readinto(mv), 0): + h.update(mv[:n]) + file_hash = h.hexdigest() + return file_hash + + # Signal operations + def to_mono(self): + """Converts audio data to mono audio, by taking the mean + along the channels dimension. + + Returns + ------- + AudioSignal + AudioSignal with mean of channels. + """ + self.audio_data = self.audio_data.mean(1, keepdim=True) + return self + + def resample(self, sample_rate: int): + """Resamples the audio, using sinc interpolation. This works on both + cpu and gpu, and is much faster on gpu. + + Parameters + ---------- + sample_rate : int + Sample rate to resample to. + + Returns + ------- + AudioSignal + Resampled AudioSignal + """ + if sample_rate == self.sample_rate: + return self + self.audio_data = resample_frac(self.audio_data, self.sample_rate, + sample_rate) + self.sample_rate = sample_rate + return self + + # Tensor operations + def to(self, device: str): + """Moves all tensors contained in signal to the specified device. + + Parameters + ---------- + device : str + Device to move AudioSignal onto. Typical values are + "gpu", "cpu", or "gpu:x" to specify the nth gpu. + + Returns + ------- + AudioSignal + AudioSignal with all tensors moved to specified device. + """ + if self._loudness is not None: + self._loudness = util.move_to_device(self._loudness, device) + if self.stft_data is not None: + self.stft_data = util.move_to_device(self.stft_data, device) + if self.audio_data is not None: + self.audio_data = util.move_to_device(self.audio_data, device) + return self + + def float(self): + """Calls ``.float()`` on ``self.audio_data``. + + Returns + ------- + AudioSignal + """ + self.audio_data = self.audio_data.astype("float32") + return self + + def cpu(self): + """Moves AudioSignal to cpu. + + Returns + ------- + AudioSignal + """ + return self.to("cpu") + + def cuda(self): + """Moves AudioSignal to cuda. + + Returns + ------- + AudioSignal + """ + return self.to("gpu") + + def numpy(self): + """Detaches ``self.audio_data``, moves to cpu, and converts to numpy. + + Returns + ------- + np.ndarray + Audio data as a numpy array. + """ + return self.audio_data.detach().cpu().numpy() + + def zero_pad(self, before: int, after: int): + """Zero pads the audio_data tensor before and after. + + Parameters + ---------- + before : int + How many zeros to prepend to audio. + after : int + How many zeros to append to audio. + + Returns + ------- + AudioSignal + AudioSignal with padding applied. + """ + self.audio_data = paddle.nn.functional.pad( + self.audio_data, (before, after), data_format="NCL") + return self + + def zero_pad_to(self, length: int, mode: str="after"): + """Pad with zeros to a specified length, either before or after + the audio data. + + Parameters + ---------- + length : int + Length to pad to + mode : str, optional + Whether to prepend or append zeros to signal, by default "after" + + Returns + ------- + AudioSignal + AudioSignal with padding applied. + """ + if mode == "before": + self.zero_pad(max(length - self.signal_length, 0), 0) + elif mode == "after": + self.zero_pad(0, max(length - self.signal_length, 0)) + return self + + def trim(self, before: int, after: int): + """Trims the audio_data tensor before and after. + + Parameters + ---------- + before : int + How many samples to trim from beginning. + after : int + How many samples to trim from end. + + Returns + ------- + AudioSignal + AudioSignal with trimming applied. + """ + if after == 0: + self.audio_data = self.audio_data[..., before:] + else: + self.audio_data = self.audio_data[..., before:-after] + return self + + def truncate_samples(self, length_in_samples: int): + """Truncate signal to specified length. + + Parameters + ---------- + length_in_samples : int + Truncate to this many samples. + + Returns + ------- + AudioSignal + AudioSignal with truncation applied. + """ + self.audio_data = self.audio_data[..., :length_in_samples] + return self + + @property + def device(self): + """Get device that AudioSignal is on. + + Returns + ------- + paddle.device + Device that AudioSignal is on. + """ + if self.audio_data is not None: + device = self.audio_data.place + elif self.stft_data is not None: + device = self.stft_data.place + return device + + # Properties + @property + def audio_data(self): + """Returns the audio data tensor in the object. + + Audio data is always of the shape + (batch_size, num_channels, num_samples). If value has less + than 3 dims (e.g. is (num_channels, num_samples)), then it will + be reshaped to (1, num_channels, num_samples) - a batch size of 1. + + Parameters + ---------- + data : typing.Union[paddle.Tensor, np.ndarray] + Audio data to set. + + Returns + ------- + paddle.Tensor + Audio samples. + """ + return self._audio_data + + @audio_data.setter + def audio_data(self, data: typing.Union[paddle.Tensor, np.ndarray]): + if data is not None: + assert paddle.is_tensor(data), "audio_data should be paddle.Tensor" + assert data.ndim == 3, "audio_data should be 3-dim (B, C, T)" + self._audio_data = data + # Old loudness value not guaranteed to be right, reset it. + self._loudness = None + return + + # alias for audio_data + samples = audio_data + + @property + def stft_data(self): + """Returns the STFT data inside the signal. Shape is + (batch, channels, frequencies, time). + + Returns + ------- + paddle.Tensor + Complex spectrogram data. + """ + return self._stft_data + + @stft_data.setter + def stft_data(self, data: typing.Union[paddle.Tensor, np.ndarray]): + if data is not None: + assert paddle.is_tensor(data) and paddle.is_complex(data) + if self.stft_data is not None and self.stft_data.shape != data.shape: + warnings.warn("stft_data changed shape") + self._stft_data = data + return + + @property + def batch_size(self): + """Batch size of audio signal. + + Returns + ------- + int + Batch size of signal. + """ + return self.audio_data.shape[0] + + @property + def signal_length(self): + """Length of audio signal. + + Returns + ------- + int + Length of signal in samples. + """ + return self.audio_data.shape[-1] + + # alias for signal_length + length = signal_length + + @property + def shape(self): + """Shape of audio data. + + Returns + ------- + tuple + Shape of audio data. + """ + return self.audio_data.shape + + @property + def signal_duration(self): + """Length of audio signal in seconds. + + Returns + ------- + float + Length of signal in seconds. + """ + return self.signal_length / self.sample_rate + + # alias for signal_duration + duration = signal_duration + + @property + def num_channels(self): + """Number of audio channels. + + Returns + ------- + int + Number of audio channels. + """ + return self.audio_data.shape[1] + + # STFT + @staticmethod + @functools.lru_cache(None) + def get_window(window_type: str, window_length: int, device: str=None): + """Wrapper around scipy.signal.get_window so one can also get the + popular sqrt-hann window. This function caches for efficiency + using functools.lru\_cache. + + Parameters + ---------- + window_type : str + Type of window to get + window_length : int + Length of the window + device : str + Device to put window onto. + + Returns + ------- + paddle.Tensor + Window returned by scipy.signal.get_window, as a tensor. + """ + from scipy import signal + + if window_type == "average": + window = np.ones(window_length) / window_length + elif window_type == "sqrt_hann": + window = np.sqrt(signal.get_window("hann", window_length)) + else: + window = signal.get_window(window_type, window_length) + window = paddle.to_tensor(window).astype("float32") + return window + + @property + def stft_params(self): + """Returns STFTParams object, which can be re-used to other + AudioSignals. + + This property can be set as well. If values are not defined in STFTParams, + they are inferred automatically from the signal properties. The default is to use + 32ms windows, with 8ms hop length, and the square root of the hann window. + + Returns + ------- + STFTParams + STFT parameters for the AudioSignal. + + Examples + -------- + >>> stft_params = STFTParams(128, 32) + >>> signal1 = AudioSignal(paddle.randn([44100]), 44100, stft_params=stft_params) + >>> signal2 = AudioSignal(paddle.randn([44100]), 44100, stft_params=signal1.stft_params) + >>> signal1.stft_params = STFTParams() # Defaults + """ + return self._stft_params + + @stft_params.setter + def stft_params(self, value: STFTParams): + # + default_win_len = int(2**(np.ceil(np.log2(0.032 * self.sample_rate)))) + default_hop_len = default_win_len // 4 + default_win_type = "hann" + default_match_stride = False + default_padding_type = "reflect" + + default_stft_params = STFTParams( + window_length=default_win_len, + hop_length=default_hop_len, + window_type=default_win_type, + match_stride=default_match_stride, + padding_type=default_padding_type, )._asdict() + + value = value._asdict() if value else default_stft_params + + for key in default_stft_params: + if value[key] is None: + value[key] = default_stft_params[key] + + self._stft_params = STFTParams(**value) + self.stft_data = None + + def compute_stft_padding(self, + window_length: int, + hop_length: int, + match_stride: bool): + """Compute how the STFT should be padded, based on match\_stride. + + Parameters + ---------- + window_length : int + Window length of STFT. + hop_length : int + Hop length of STFT. + match_stride : bool + Whether or not to match stride, making the STFT have the same alignment as + convolutional layers. + + Returns + ------- + tuple + Amount to pad on either side of audio. + """ + length = self.signal_length + + if match_stride: + assert hop_length == window_length // 4, "For match_stride, hop must equal n_fft // 4" + right_pad = math.ceil(length / hop_length) * hop_length - length + pad = (window_length - hop_length) // 2 + else: + right_pad = 0 + pad = 0 + + return right_pad, pad + + def stft( + self, + window_length: int=None, + hop_length: int=None, + window_type: str=None, + match_stride: bool=None, + padding_type: str=None, ): + """Computes the short-time Fourier transform of the audio data, + with specified STFT parameters. + + Parameters + ---------- + window_length : int, optional + Window length of STFT, by default ``0.032 * self.sample_rate``. + hop_length : int, optional + Hop length of STFT, by default ``window_length // 4``. + window_type : str, optional + Type of window to use, by default ``sqrt\_hann``. + match_stride : bool, optional + Whether to match the stride of convolutional layers, by default False + padding_type : str, optional + Type of padding to use, by default 'reflect' + + Returns + ------- + paddle.Tensor + STFT of audio data. + + Examples + -------- + Compute the STFT of an AudioSignal: + + >>> signal = AudioSignal(paddle.randn([44100]), 44100) + >>> signal.stft() + + Vary the window and hop length: + + >>> stft_params_list = [STFTParams(128, 32), STFTParams(512, 128)] + >>> for stft_params in stft_params_list: + >>> signal.stft_params = stft_params + >>> signal.stft() + + """ + window_length = self.stft_params.window_length if window_length is None else int( + window_length) + hop_length = self.stft_params.hop_length if hop_length is None else int( + hop_length) + window_type = self.stft_params.window_type if window_type is None else window_type + match_stride = self.stft_params.match_stride if match_stride is None else match_stride + padding_type = self.stft_params.padding_type if padding_type is None else padding_type + + window = self.get_window(window_type, window_length) + + audio_data = self.audio_data + right_pad, pad = self.compute_stft_padding(window_length, hop_length, + match_stride) + audio_data = paddle.nn.functional.pad( + x=audio_data, + pad=[pad, pad + right_pad], + mode="reflect", + data_format="NCL", ) + stft_data = paddle.signal.stft( + audio_data.reshape([-1, audio_data.shape[-1]]).astype("float32"), + n_fft=window_length, + hop_length=hop_length, + window=window, + # return_complex=True, + center=True, ) + _, nf, nt = stft_data.shape + stft_data = stft_data.reshape( + [self.batch_size, self.num_channels, nf, nt]) + + if match_stride: + # Drop first two and last two frames, which are added + # because of padding. Now num_frames * hop_length = num_samples. + stft_data = stft_data[..., 2:-2] + self.stft_data = stft_data + + return stft_data + + def istft( + self, + window_length: int=None, + hop_length: int=None, + window_type: str=None, + match_stride: bool=None, + length: int=None, ): + """Computes inverse STFT and sets it to audio\_data. + + Parameters + ---------- + window_length : int, optional + Window length of STFT, by default ``0.032 * self.sample_rate``. + hop_length : int, optional + Hop length of STFT, by default ``window_length // 4``. + window_type : str, optional + Type of window to use, by default ``sqrt\_hann``. + match_stride : bool, optional + Whether to match the stride of convolutional layers, by default False + length : int, optional + Original length of signal, by default None + + Returns + ------- + AudioSignal + AudioSignal with istft applied. + + Raises + ------ + RuntimeError + Raises an error if stft was not called prior to istft on the signal, + or if stft_data is not set. + """ + if self.stft_data is None: + raise RuntimeError("Cannot do inverse STFT without self.stft_data!") + + window_length = self.stft_params.window_length if window_length is None else int( + window_length) + hop_length = self.stft_params.hop_length if hop_length is None else int( + hop_length) + window_type = self.stft_params.window_type if window_type is None else window_type + match_stride = self.stft_params.match_stride if match_stride is None else match_stride + + window = self.get_window(window_type, window_length, + self.stft_data.place) + + nb, nch, nf, nt = self.stft_data.shape + stft_data = self.stft_data.reshape([nb * nch, nf, nt]) + right_pad, pad = self.compute_stft_padding(window_length, hop_length, + match_stride) + + if length is None: + length = self.original_signal_length + length = length + 2 * pad + right_pad + + if match_stride: + # Zero-pad the STFT on either side, putting back the frames that were + # dropped in stft(). + stft_data = paddle.nn.functional.pad( + stft_data, pad=(2, 2), data_format="NCL") + + audio_data = paddle.signal.istft( + stft_data, + n_fft=window_length, + hop_length=hop_length, + window=window, + length=length, + center=True, ) + audio_data = audio_data.reshape([nb, nch, -1]) + if match_stride: + audio_data = audio_data[..., pad:-(pad + right_pad)] + self.audio_data = audio_data + + return self + + @staticmethod + @functools.lru_cache(None) + def get_mel_filters(sr: int, + n_fft: int, + n_mels: int, + fmin: float=0.0, + fmax: float=None): + """Create a Filterbank matrix to combine FFT bins into Mel-frequency bins. + + Parameters + ---------- + sr : int + Sample rate of audio + n_fft : int + Number of FFT bins + n_mels : int + Number of mels + fmin : float, optional + Lowest frequency, in Hz, by default 0.0 + fmax : float, optional + Highest frequency, by default None + + Returns + ------- + np.ndarray [shape=(n_mels, 1 + n_fft/2)] + Mel transform matrix + """ + from librosa.filters import mel as librosa_mel_fn + + return librosa_mel_fn( + sr=sr, + n_fft=n_fft, + n_mels=n_mels, + fmin=fmin, + fmax=fmax, ) + + def mel_spectrogram( + self, + n_mels: int=80, + mel_fmin: float=0.0, + mel_fmax: float=None, + **kwargs, ): + """Computes a Mel spectrogram. + + Parameters + ---------- + n_mels : int, optional + Number of mels, by default 80 + mel_fmin : float, optional + Lowest frequency, in Hz, by default 0.0 + mel_fmax : float, optional + Highest frequency, by default None + kwargs : dict, optional + Keyword arguments to self.stft(). + + Returns + ------- + paddle.Tensor [shape=(batch, channels, mels, time)] + Mel spectrogram. + """ + # from paddle.audio.compliance.librosa import melspectrogram + # # from ..compliance.librosa import melspectrogram + # return melspectrogram( + # x=self.audio_data, + # sr=self.sample_rate, + # window_size: int=512, + # hop_length: int=320, + # n_mels: int=64, + # fmin: float=50.0, + # fmax: Optional[float]=None, + # window: str='hann', + # center: bool=True, + # pad_mode: str='reflect', + # power: float=2.0, + # to_db: bool=True, + # ref: float=1.0, + # amin: float=1e-10, + # top_db: Optional[float]=None + # ) + + stft = self.stft(**kwargs) + magnitude = paddle.abs(stft) + + nf = magnitude.shape[2] + mel_basis = self.get_mel_filters( + sr=self.sample_rate, + n_fft=2 * (nf - 1), + n_mels=n_mels, + fmin=mel_fmin, + fmax=mel_fmax, ) + mel_basis = paddle.to_tensor(mel_basis) + + mel_spectrogram = magnitude.transpose([0, 1, 3, 2]) @ mel_basis.T + mel_spectrogram = mel_spectrogram.transpose([0, 1, 3, 2]) + return mel_spectrogram + + @staticmethod + @functools.lru_cache(None) + def get_dct(n_mfcc: int, n_mels: int, norm: str="ortho", device: str=None): + """Create a discrete cosine transform (DCT) transformation matrix with shape (``n_mels``, ``n_mfcc``), + it can be normalized depending on norm. For more information about dct: + http://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II + + Parameters + ---------- + n_mfcc : int + Number of mfccs + n_mels : int + Number of mels + norm : str + Use "ortho" to get a orthogonal matrix or None, by default "ortho" + device : str, optional + Device to load the transformation matrix on, by default None + + Returns + ------- + paddle.Tensor [shape=(n_mels, n_mfcc)] T + The dct transformation matrix. + """ + + return create_dct(n_mfcc, n_mels, norm) + + def mfcc( + self, + n_mfcc: int=40, + n_mels: int=80, + log_offset: float=1e-6, + **kwargs, ): + """Computes mel-frequency cepstral coefficients (MFCCs). + + Parameters + ---------- + n_mfcc : int, optional + Number of mels, by default 40 + n_mels : int, optional + Number of mels, by default 80 + log_offset: float, optional + Small value to prevent numerical issues when trying to compute log(0), by default 1e-6 + kwargs : dict, optional + Keyword arguments to self.mel_spectrogram(), note that some of them will be used for self.stft() + + Returns + ------- + paddle.Tensor [shape=(batch, channels, mfccs, time)] + MFCCs. + """ + + # from paddle.audio.compliance.librosa import mfcc + # return mfcc(self.audio_data, self.sample_rate, n_mfcc=n_mfcc, n_mels=n_mels) + + mel_spectrogram = self.mel_spectrogram(n_mels, **kwargs) + mel_spectrogram = paddle.log(mel_spectrogram + log_offset) + dct_mat = self.get_dct(n_mfcc, n_mels, "ortho", self.device) + + mfcc = mel_spectrogram.transpose([0, 1, 3, 2]) @ dct_mat + mfcc = mfcc.transpose([0, 1, 3, 2]) + return mfcc + + @property + def magnitude(self): + """Computes and returns the absolute value of the STFT, which + is the magnitude. This value can also be set to some tensor. + When set, ``self.stft_data`` is manipulated so that its magnitude + matches what this is set to, and modulated by the phase. + + Returns + ------- + paddle.Tensor + Magnitude of STFT. + + Examples + -------- + >>> signal = AudioSignal(paddle.randn([44100]), 44100) + >>> magnitude = signal.magnitude # Computes stft if not computed + >>> magnitude[magnitude < magnitude.mean()] = 0 + >>> signal.magnitude = magnitude + >>> signal.istft() + """ + if self.stft_data is None: + self.stft() + return paddle.abs(self.stft_data) + + @magnitude.setter + def magnitude(self, value): + self.stft_data = value * util.exp_compat(1j * self.phase) + return + + def log_magnitude(self, + ref_value: float=1.0, + amin: float=1e-5, + top_db: float=80.0): + """Computes the log-magnitude of the spectrogram. + + Parameters + ---------- + ref_value : float, optional + The magnitude is scaled relative to ``ref``: ``20 * log10(S / ref)``. + Zeros in the output correspond to positions where ``S == ref``, + by default 1.0 + amin : float, optional + Minimum threshold for ``S`` and ``ref``, by default 1e-5 + top_db : float, optional + Threshold the output at ``top_db`` below the peak: + ``max(10 * log10(S/ref)) - top_db``, by default -80.0 + + Returns + ------- + paddle.Tensor + Log-magnitude spectrogram + """ + magnitude = self.magnitude + + amin = amin**2 + log_spec = 10.0 * paddle.log10(magnitude.pow(2).clip(min=amin)) + if paddle.is_tensor(ref_value): + ref_value = ref_value.item() + log_spec -= 10.0 * np.log10(np.maximum(amin, ref_value)) + + if top_db is not None: + log_spec = paddle.maximum(log_spec, log_spec.max() - top_db) + return log_spec + + @property + def phase(self): + """Computes and returns the phase of the STFT. + This value can also be set to some tensor. + When set, ``self.stft_data`` is manipulated so that its phase + matches what this is set to, we original magnitudeith th. + + Returns + ------- + paddle.Tensor + Phase of STFT. + + Examples + -------- + >>> signal = AudioSignal(paddle.randn([44100]), 44100) + >>> phase = signal.phase # Computes stft if not computed + >>> phase[phase < phase.mean()] = 0 + >>> signal.phase = phase + >>> signal.istft() + """ + if self.stft_data is None: + self.stft() + return paddle.angle(self.stft_data) + + @phase.setter + def phase(self, value): + # + self.stft_data = self.magnitude * util.exp_compat(1j * value) + return + + # Operator overloading + def __add__(self, other): + new_signal = self.clone() + new_signal.audio_data += util._get_value(other) + return new_signal + + def __iadd__(self, other): + self.audio_data += util._get_value(other) + return self + + def __radd__(self, other): + return self + other + + def __sub__(self, other): + new_signal = self.clone() + new_signal.audio_data -= util._get_value(other) + return new_signal + + def __isub__(self, other): + self.audio_data -= util._get_value(other) + return self + + def __mul__(self, other): + new_signal = self.clone() + new_signal.audio_data *= util._get_value(other) + return new_signal + + def __imul__(self, other): + self.audio_data *= util._get_value(other) + return self + + def __rmul__(self, other): + return self * other + + # Representation + def _info(self): + # + dur = f"{self.signal_duration:0.3f}" if self.signal_duration else "[unknown]" + info = { + "duration": + f"{dur} seconds", + "batch_size": + self.batch_size, + "path": + self.path_to_file if self.path_to_file else "path unknown", + "sample_rate": + self.sample_rate, + "num_channels": (self.num_channels + if self.num_channels else "[unknown]"), + "audio_data.shape": + self.audio_data.shape, + "stft_params": + self.stft_params, + "device": + self.device, + } + + return info + + def markdown(self): + """Produces a markdown representation of AudioSignal, in a markdown table. + + Returns + ------- + str + Markdown representation of AudioSignal. + + Examples + -------- + >>> signal = AudioSignal(paddle.randn([44100]), 44100) + >>> print(signal.markdown()) + | Key | Value + |---|--- + | duration | 1.000 seconds | + | batch_size | 1 | + | path | path unknown | + | sample_rate | 44100 | + | num_channels | 1 | + | audio_data.shape | paddle.Size([1, 1, 44100]) | + | stft_params | STFTParams(window_length=2048, hop_length=512, window_type='sqrt_hann', match_stride=False) | + | device | cpu | + """ + info = self._info() + + FORMAT = "| Key | Value \n" "|---|--- \n" + for k, v in info.items(): + row = f"| {k} | {v} |\n" + FORMAT += row + return FORMAT + + def __str__(self): + info = self._info() + + desc = "" + for k, v in info.items(): + desc += f"{k}: {v}\n" + return desc + + def __rich__(self): + from rich.table import Table + + info = self._info() + + table = Table(title=f"{self.__class__.__name__}") + table.add_column("Key", style="green") + table.add_column("Value", style="cyan") + + for k, v in info.items(): + table.add_row(k, str(v)) + return table + + # Comparison + def __eq__(self, other): + for k, v in list(self.__dict__.items()): + if paddle.is_tensor(v): + + if paddle.is_complex(v): + if not np.allclose( + v.cpu().numpy(), + other.__dict__[k].cpu().numpy(), + atol=1e-6): + max_error = (v - other.__dict__[k]).abs().max() + print(f"Max abs error for {k}: {max_error}") + return False + else: + if not paddle.allclose(v, other.__dict__[k], atol=1e-6): + max_error = (v - other.__dict__[k]).abs().max() + print(f"Max abs error for {k}: {max_error}") + return False + return True + + # Indexing + def __getitem__(self, key): + if paddle.is_tensor(key) and key.ndim == 0 and key.item() is True: + assert self.batch_size == 1 + audio_data = self.audio_data + _loudness = self._loudness + stft_data = self.stft_data + + elif isinstance(key, (bool, int, list, slice, tuple)) or ( + paddle.is_tensor(key) and key.ndim <= 1): + # Indexing only on the batch dimension. + # Then let's copy over relevant stuff. + # Future work: make this work for time-indexing + # as well, using the hop length. + audio_data = self.audio_data[key] + _loudness = self._loudness[ + key] if self._loudness is not None else None + # stft_data = self.stft_data[ + # key] if self.stft_data is not None else None + stft_data = util.bool_index_compat( + self.stft_data, key) if self.stft_data is not None else None + + sources = None + + copy = type(self)( + audio_data, self.sample_rate, stft_params=self.stft_params) + copy._loudness = _loudness + copy._stft_data = stft_data + copy.sources = sources + + return copy + + def __setitem__(self, key, value): + if not isinstance(value, type(self)): + self.audio_data[key] = value + return + + if paddle.is_tensor(key) and key.ndim == 0 and key.item() is True: + assert self.batch_size == 1 + self.audio_data = value.audio_data + self._loudness = value._loudness + self.stft_data = value.stft_data + return + + elif isinstance(key, (bool, int, list, slice, tuple)) or ( + paddle.is_tensor(key) and key.ndim <= 1): + if self.audio_data is not None and value.audio_data is not None: + self.audio_data[key] = value.audio_data + if self._loudness is not None and value._loudness is not None: + if paddle.is_tensor(key) and key.dtype == paddle.bool: + # FOR Paddle BOOL Index + _key_no_bool = paddle.nonzero(key).flatten() + self._loudness[_key_no_bool] = value._loudness + else: + self._loudness[key] = value._loudness + if self.stft_data is not None and value.stft_data is not None: + # self.stft_data[key] = value.stft_data + self.stft_data = util.bool_setitem_compat(self.stft_data, key, + value.stft_data) + return + + def __ne__(self, other): + return not self == other diff --git a/audio/audiotools/core/display.py b/audio/audiotools/core/display.py new file mode 100644 index 00000000000..66f0c641e0f --- /dev/null +++ b/audio/audiotools/core/display.py @@ -0,0 +1,195 @@ +# MIT License, Copyright (c) 2023-Present, Descript. +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/audiotools/core/display.py) +import inspect +import typing +from functools import wraps + +from . import util + + +def format_figure(func): + """Decorator for formatting figures produced by the code below. + See :py:func:`audiotools.core.util.format_figure` for more. + + Parameters + ---------- + func : Callable + Plotting function that is decorated by this function. + + """ + + @wraps(func) + def wrapper(*args, **kwargs): + f_keys = inspect.signature(util.format_figure).parameters.keys() + f_kwargs = {} + for k, v in list(kwargs.items()): + if k in f_keys: + kwargs.pop(k) + f_kwargs[k] = v + func(*args, **kwargs) + util.format_figure(**f_kwargs) + + return wrapper + + +class DisplayMixin: + @format_figure + def specshow( + self, + preemphasis: bool=False, + x_axis: str="time", + y_axis: str="linear", + n_mels: int=128, + **kwargs, ): + """Displays a spectrogram, using ``librosa.display.specshow``. + + Parameters + ---------- + preemphasis : bool, optional + Whether or not to apply preemphasis, which makes high + frequency detail easier to see, by default False + x_axis : str, optional + How to label the x axis, by default "time" + y_axis : str, optional + How to label the y axis, by default "linear" + n_mels : int, optional + If displaying a mel spectrogram with ``y_axis = "mel"``, + this controls the number of mels, by default 128. + kwargs : dict, optional + Keyword arguments to :py:func:`audiotools.core.util.format_figure`. + """ + import librosa + import librosa.display + + # Always re-compute the STFT data before showing it, in case + # it changed. + signal = self.clone() + signal.stft_data = None + + if preemphasis: + signal.preemphasis() + + ref = signal.magnitude.max() + log_mag = signal.log_magnitude(ref_value=ref) + + if y_axis == "mel": + log_mag = 20 * signal.mel_spectrogram(n_mels).clip(1e-5).log10() + log_mag -= log_mag.max() + + librosa.display.specshow( + log_mag.numpy()[0].mean(axis=0), + x_axis=x_axis, + y_axis=y_axis, + sr=signal.sample_rate, + **kwargs, ) + + @format_figure + def waveplot(self, x_axis: str="time", **kwargs): + """Displays a waveform plot, using ``librosa.display.waveshow``. + + Parameters + ---------- + x_axis : str, optional + How to label the x axis, by default "time" + kwargs : dict, optional + Keyword arguments to :py:func:`audiotools.core.util.format_figure`. + """ + import librosa + import librosa.display + + audio_data = self.audio_data[0].mean(axis=0) + audio_data = audio_data.cpu().numpy() + + plot_fn = "waveshow" if hasattr(librosa.display, + "waveshow") else "waveplot" + wave_plot_fn = getattr(librosa.display, plot_fn) + wave_plot_fn(audio_data, x_axis=x_axis, sr=self.sample_rate, **kwargs) + + @format_figure + def wavespec(self, x_axis: str="time", **kwargs): + """Displays a waveform plot, using ``librosa.display.waveshow``. + + Parameters + ---------- + x_axis : str, optional + How to label the x axis, by default "time" + kwargs : dict, optional + Keyword arguments to :py:func:`audiotools.core.display.DisplayMixin.specshow`. + """ + import matplotlib.pyplot as plt + from matplotlib.gridspec import GridSpec + + gs = GridSpec(6, 1) + plt.subplot(gs[0, :]) + self.waveplot(x_axis=x_axis) + plt.subplot(gs[1:, :]) + self.specshow(x_axis=x_axis, **kwargs) + + def write_audio_to_tb( + self, + tag: str, + writer, + step: int=None, + plot_fn: typing.Union[typing.Callable, str]="specshow", + **kwargs, ): + """Writes a signal and its spectrogram to Tensorboard. Will show up + under the Audio and Images tab in Tensorboard. + + Parameters + ---------- + tag : str + Tag to write signal to (e.g. ``clean/sample_0.wav``). The image will be + written to the corresponding ``.png`` file (e.g. ``clean/sample_0.png``). + writer : SummaryWriter + A SummaryWriter object from PyTorch library. + step : int, optional + The step to write the signal to, by default None + plot_fn : typing.Union[typing.Callable, str], optional + How to create the image. Set to ``None`` to avoid plotting, by default "specshow" + kwargs : dict, optional + Keyword arguments to :py:func:`audiotools.core.display.DisplayMixin.specshow` or + whatever ``plot_fn`` is set to. + """ + import matplotlib.pyplot as plt + + audio_data = self.audio_data[0, 0].detach().cpu().numpy() + sample_rate = self.sample_rate + writer.add_audio(tag, audio_data, step, sample_rate) + + if plot_fn is not None: + if isinstance(plot_fn, str): + plot_fn = getattr(self, plot_fn) + fig = plt.figure() + plt.clf() + plot_fn(**kwargs) + writer.add_figure(tag.replace("wav", "png"), fig, step) + + def save_image( + self, + image_path: str, + plot_fn: typing.Union[typing.Callable, str]="specshow", + **kwargs, ): + """Save AudioSignal spectrogram (or whatever ``plot_fn`` is set to) to + a specified file. + + Parameters + ---------- + image_path : str + Where to save the file to. + plot_fn : typing.Union[typing.Callable, str], optional + How to create the image. Set to ``None`` to avoid plotting, by default "specshow" + kwargs : dict, optional + Keyword arguments to :py:func:`audiotools.core.display.DisplayMixin.specshow` or + whatever ``plot_fn`` is set to. + """ + import matplotlib.pyplot as plt + + if isinstance(plot_fn, str): + plot_fn = getattr(self, plot_fn) + + plt.clf() + plot_fn(**kwargs) + plt.savefig(image_path, bbox_inches="tight", pad_inches=0) + plt.close() diff --git a/audio/audiotools/core/dsp.py b/audio/audiotools/core/dsp.py new file mode 100644 index 00000000000..d62e1e6dc39 --- /dev/null +++ b/audio/audiotools/core/dsp.py @@ -0,0 +1,467 @@ +# MIT License, Copyright (c) 2023-Present, Descript. +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/audiotools/core/dsp.py) +import typing + +import numpy as np +import paddle + +from . import _julius +from . import util + + +def _unfold(x, kernel_sizes, strides): + # https://github.com/PaddlePaddle/Paddle/pull/70102 + + if 1 == kernel_sizes[0]: + x_zeros = paddle.zeros_like(x) + x = paddle.concat([x, x_zeros], axis=2) + + kernel_sizes = [2, kernel_sizes[1]] + strides = list(strides) + + unfolded = paddle.nn.functional.unfold( + x, + kernel_sizes=kernel_sizes, + strides=strides, ) + if 2 == kernel_sizes[0]: + unfolded = unfolded[:, :kernel_sizes[1]] + return unfolded + + +def _fold(x, output_sizes, kernel_sizes, strides): + # https://github.com/PaddlePaddle/Paddle/pull/70102 + + if 1 == output_sizes[0] and 1 == kernel_sizes[0]: + x_zeros = paddle.zeros_like(x) + x = paddle.concat([x, x_zeros], axis=1) + + output_sizes = (2, output_sizes[1]) + kernel_sizes = (2, kernel_sizes[1]) + + fold = paddle.nn.functional.fold( + x, + output_sizes=output_sizes, + kernel_sizes=kernel_sizes, + strides=strides, ) + if 2 == kernel_sizes[0]: + fold = fold[:, :, :1] + return fold + + +class DSPMixin: + _original_batch_size = None + _original_num_channels = None + _padded_signal_length = None + + def _preprocess_signal_for_windowing(self, window_duration, hop_duration): + self._original_batch_size = self.batch_size + self._original_num_channels = self.num_channels + + window_length = int(window_duration * self.sample_rate) + hop_length = int(hop_duration * self.sample_rate) + + if window_length % hop_length != 0: + factor = window_length // hop_length + window_length = factor * hop_length + + self.zero_pad(hop_length, hop_length) + self._padded_signal_length = self.signal_length + + return window_length, hop_length + + def windows(self, + window_duration: float, + hop_duration: float, + preprocess: bool=True): + """Generator which yields windows of specified duration from signal with a specified + hop length. + + Parameters + ---------- + window_duration : float + Duration of every window in seconds. + hop_duration : float + Hop between windows in seconds. + preprocess : bool, optional + Whether to preprocess the signal, so that the first sample is in + the middle of the first window, by default True + + Yields + ------ + AudioSignal + Each window is returned as an AudioSignal. + """ + if preprocess: + window_length, hop_length = self._preprocess_signal_for_windowing( + window_duration, hop_duration) + + self.audio_data = self.audio_data.reshape([-1, 1, self.signal_length]) + + for b in range(self.batch_size): + i = 0 + start_idx = i * hop_length + while True: + start_idx = i * hop_length + i += 1 + end_idx = start_idx + window_length + if end_idx > self.signal_length: + break + yield self[b, ..., start_idx:end_idx] + + def collect_windows(self, + window_duration: float, + hop_duration: float, + preprocess: bool=True): + """Reshapes signal into windows of specified duration from signal with a specified + hop length. Window are placed along the batch dimension. Use with + :py:func:`audiotools.core.dsp.DSPMixin.overlap_and_add` to reconstruct the + original signal. + + Parameters + ---------- + window_duration : float + Duration of every window in seconds. + hop_duration : float + Hop between windows in seconds. + preprocess : bool, optional + Whether to preprocess the signal, so that the first sample is in + the middle of the first window, by default True + + Returns + ------- + AudioSignal + AudioSignal unfolded with shape ``(nb * nch * num_windows, 1, window_length)`` + """ + if preprocess: + window_length, hop_length = self._preprocess_signal_for_windowing( + window_duration, hop_duration) + + # self.audio_data: (nb, nch, nt). + # unfolded = paddle.nn.functional.unfold( + # self.audio_data.reshape([-1, 1, 1, self.signal_length]), + # kernel_sizes=(1, window_length), + # strides=(1, hop_length), + # ) + unfolded = _unfold( + self.audio_data.reshape([-1, 1, 1, self.signal_length]), + kernel_sizes=(1, window_length), + strides=(1, hop_length), ) + # unfolded: (nb * nch, window_length, num_windows). + # -> (nb * nch * num_windows, 1, window_length) + unfolded = unfolded.transpose([0, 2, 1]).reshape([-1, 1, window_length]) + self.audio_data = unfolded + return self + + def overlap_and_add(self, hop_duration: float): + """Function which takes a list of windows and overlap adds them into a + signal the same length as ``audio_signal``. + + Parameters + ---------- + hop_duration : float + How much to shift for each window + (overlap is window_duration - hop_duration) in seconds. + + Returns + ------- + AudioSignal + overlap-and-added signal. + """ + hop_length = int(hop_duration * self.sample_rate) + window_length = self.signal_length + + nb, nch = self._original_batch_size, self._original_num_channels + + unfolded = self.audio_data.reshape( + [nb * nch, -1, window_length]).transpose([0, 2, 1]) + # folded = paddle.nn.functional.fold( + # unfolded, + # output_sizes=(1, self._padded_signal_length), + # kernel_sizes=(1, window_length), + # strides=(1, hop_length), + # ) + folded = _fold( + unfolded, + output_sizes=(1, self._padded_signal_length), + kernel_sizes=(1, window_length), + strides=(1, hop_length), ) + + norm = paddle.ones_like(unfolded) + # norm = paddle.nn.functional.fold( + # norm, + # output_sizes=(1, self._padded_signal_length), + # kernel_sizes=(1, window_length), + # strides=(1, hop_length), + # ) + norm = _fold( + norm, + output_sizes=(1, self._padded_signal_length), + kernel_sizes=(1, window_length), + strides=(1, hop_length), ) + + folded = folded / norm + + folded = folded.reshape([nb, nch, -1]) + self.audio_data = folded + self.trim(hop_length, hop_length) + return self + + def low_pass(self, + cutoffs: typing.Union[paddle.Tensor, np.ndarray, float], + zeros: int=51): + """Low-passes the signal in-place. Each item in the batch + can have a different low-pass cutoff, if the input + to this signal is an array or tensor. If a float, all + items are given the same low-pass filter. + + Parameters + ---------- + cutoffs : typing.Union[paddle.Tensor, np.ndarray, float] + Cutoff in Hz of low-pass filter. + zeros : int, optional + Number of taps to use in low-pass filter, by default 51 + + Returns + ------- + AudioSignal + Low-passed AudioSignal. + """ + cutoffs = util.ensure_tensor(cutoffs, 2, self.batch_size) + cutoffs = cutoffs / self.sample_rate + filtered = paddle.empty_like(self.audio_data) + + for i, cutoff in enumerate(cutoffs): + lp_filter = _julius.LowPassFilter(cutoff.cpu(), zeros=zeros) + filtered[i] = lp_filter(self.audio_data[i]) + + self.audio_data = filtered + self.stft_data = None + return self + + def high_pass(self, + cutoffs: typing.Union[paddle.Tensor, np.ndarray, float], + zeros: int=51): + """High-passes the signal in-place. Each item in the batch + can have a different high-pass cutoff, if the input + to this signal is an array or tensor. If a float, all + items are given the same high-pass filter. + + Parameters + ---------- + cutoffs : typing.Union[paddle.Tensor, np.ndarray, float] + Cutoff in Hz of high-pass filter. + zeros : int, optional + Number of taps to use in high-pass filter, by default 51 + + Returns + ------- + AudioSignal + High-passed AudioSignal. + """ + cutoffs = util.ensure_tensor(cutoffs, 2, self.batch_size) + cutoffs = cutoffs / self.sample_rate + filtered = paddle.empty_like(self.audio_data) + + for i, cutoff in enumerate(cutoffs): + hp_filter = _julius.HighPassFilter(cutoff.cpu(), zeros=zeros) + filtered[i] = hp_filter(self.audio_data[i]) + + self.audio_data = filtered + self.stft_data = None + return self + + def mask_frequencies( + self, + fmin_hz: typing.Union[paddle.Tensor, np.ndarray, float], + fmax_hz: typing.Union[paddle.Tensor, np.ndarray, float], + val: float=0.0, ): + """Masks frequencies between ``fmin_hz`` and ``fmax_hz``, and fills them + with the value specified by ``val``. Useful for implementing SpecAug. + The min and max can be different for every item in the batch. + + Parameters + ---------- + fmin_hz : typing.Union[paddle.Tensor, np.ndarray, float] + Lower end of band to mask out. + fmax_hz : typing.Union[paddle.Tensor, np.ndarray, float] + Upper end of band to mask out. + val : float, optional + Value to fill in, by default 0.0 + + Returns + ------- + AudioSignal + Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the + masked audio data. + """ + # SpecAug + mag, phase = self.magnitude, self.phase + fmin_hz = util.ensure_tensor( + fmin_hz, + ndim=mag.ndim, ) + fmax_hz = util.ensure_tensor( + fmax_hz, + ndim=mag.ndim, ) + assert paddle.all(fmin_hz < fmax_hz) + + # build mask + nbins = mag.shape[-2] + bins_hz = paddle.linspace( + 0, + self.sample_rate / 2, + nbins, ) + bins_hz = bins_hz[None, None, :, None].tile( + [self.batch_size, 1, 1, mag.shape[-1]]) + + fmin_hz, fmax_hz = fmin_hz.astype(bins_hz.dtype), fmax_hz.astype( + bins_hz.dtype) + mask = (fmin_hz <= bins_hz) & (bins_hz < fmax_hz) + + mag = paddle.where(mask, paddle.full_like(mag, val), mag) + phase = paddle.where(mask, paddle.full_like(phase, val), phase) + self.stft_data = mag * util.exp_compat(1j * phase) + return self + + def mask_timesteps( + self, + tmin_s: typing.Union[paddle.Tensor, np.ndarray, float], + tmax_s: typing.Union[paddle.Tensor, np.ndarray, float], + val: float=0.0, ): + """Masks timesteps between ``tmin_s`` and ``tmax_s``, and fills them + with the value specified by ``val``. Useful for implementing SpecAug. + The min and max can be different for every item in the batch. + + Parameters + ---------- + tmin_s : typing.Union[paddle.Tensor, np.ndarray, float] + Lower end of timesteps to mask out. + tmax_s : typing.Union[paddle.Tensor, np.ndarray, float] + Upper end of timesteps to mask out. + val : float, optional + Value to fill in, by default 0.0 + + Returns + ------- + AudioSignal + Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the + masked audio data. + """ + # SpecAug + mag, phase = self.magnitude, self.phase + tmin_s = util.ensure_tensor(tmin_s, ndim=mag.ndim) + tmax_s = util.ensure_tensor(tmax_s, ndim=mag.ndim) + + assert paddle.all(tmin_s < tmax_s) + + # build mask + nt = mag.shape[-1] + bins_t = paddle.linspace( + 0, + self.signal_duration, + nt, ) + bins_t = bins_t[None, None, None, :].tile( + [self.batch_size, 1, mag.shape[-2], 1]) + mask = (tmin_s <= bins_t) & (bins_t < tmax_s) + + # mag = mag.masked_fill(mask, val) + # phase = phase.masked_fill(mask, val) + mag = paddle.where(mask, paddle.full_like(mag, val), mag) + phase = paddle.where(mask, paddle.full_like(phase, val), phase) + + self.stft_data = mag * util.exp_compat(1j * phase) + return self + + def mask_low_magnitudes( + self, + db_cutoff: typing.Union[paddle.Tensor, np.ndarray, float], + val: float=0.0): + """Mask away magnitudes below a specified threshold, which + can be different for every item in the batch. + + Parameters + ---------- + db_cutoff : typing.Union[paddle.Tensor, np.ndarray, float] + Decibel value for which things below it will be masked away. + val : float, optional + Value to fill in for masked portions, by default 0.0 + + Returns + ------- + AudioSignal + Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the + masked audio data. + """ + mag = self.magnitude + log_mag = self.log_magnitude() + + db_cutoff = util.ensure_tensor(db_cutoff, ndim=mag.ndim) + db_cutoff = db_cutoff.astype(log_mag.dtype) + mask = log_mag < db_cutoff + # mag = mag.masked_fill(mask, val) + mag = paddle.where(mask, mag, val * paddle.ones_like(mag)) + + self.magnitude = mag + return self + + def shift_phase(self, + shift: typing.Union[paddle.Tensor, np.ndarray, float]): + """Shifts the phase by a constant value. + + Parameters + ---------- + shift : typing.Union[paddle.Tensor, np.ndarray, float] + What to shift the phase by. + + Returns + ------- + AudioSignal + Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the + masked audio data. + """ + shift = util.ensure_tensor(shift, ndim=self.phase.ndim) + shift = shift.astype(self.phase.dtype) + self.phase = self.phase + shift + return self + + def corrupt_phase(self, + scale: typing.Union[paddle.Tensor, np.ndarray, float]): + """Corrupts the phase randomly by some scaled value. + + Parameters + ---------- + scale : typing.Union[paddle.Tensor, np.ndarray, float] + Standard deviation of noise to add to the phase. + + Returns + ------- + AudioSignal + Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the + masked audio data. + """ + scale = util.ensure_tensor(scale, ndim=self.phase.ndim) + self.phase = self.phase + scale * paddle.randn( + shape=self.phase.shape, dtype=self.phase.dtype) + return self + + def preemphasis(self, coef: float=0.85): + """Applies pre-emphasis to audio signal. + + Parameters + ---------- + coef : float, optional + How much pre-emphasis to apply, lower values do less. 0 does nothing. + by default 0.85 + + Returns + ------- + AudioSignal + Pre-emphasized signal. + """ + kernel = paddle.to_tensor([1, -coef, 0]).reshape([1, 1, -1]) + x = self.audio_data.reshape([-1, 1, self.signal_length]) + x = paddle.nn.functional.conv1d( + x.astype(kernel.dtype), kernel, padding=1) + self.audio_data = x.reshape(self.audio_data.shape) + return self diff --git a/audio/audiotools/core/effects.py b/audio/audiotools/core/effects.py new file mode 100644 index 00000000000..2938e0c32f2 --- /dev/null +++ b/audio/audiotools/core/effects.py @@ -0,0 +1,539 @@ +# MIT License, Copyright (c) 2023-Present, Descript. +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/audiotools/core/effects.py) +import typing + +import numpy as np +import paddle + +from . import util +from ._julius import SplitBands + + +class EffectMixin: + GAIN_FACTOR = np.log(10) / 20 + """Gain factor for converting between amplitude and decibels.""" + CODEC_PRESETS = { + "8-bit": { + "format": "wav", + "encoding": "ULAW", + "bits_per_sample": 8 + }, + "GSM-FR": { + "format": "gsm" + }, + "MP3": { + "format": "mp3", + "compression": -9 + }, + "Vorbis": { + "format": "vorbis", + "compression": -1 + }, + "Ogg": { + "format": "ogg", + "compression": -1, + }, + "Amr-nb": { + "format": "amr-nb" + }, + } + """Presets for applying codecs via torchaudio.""" + + def mix( + self, + other, + snr: typing.Union[paddle.Tensor, np.ndarray, float]=10, + other_eq: typing.Union[paddle.Tensor, np.ndarray]=None, ): + """Mixes noise with signal at specified + signal-to-noise ratio. Optionally, the + other signal can be equalized in-place. + + + Parameters + ---------- + other : AudioSignal + AudioSignal object to mix with. + snr : typing.Union[paddle.Tensor, np.ndarray, float], optional + Signal to noise ratio, by default 10 + other_eq : typing.Union[paddle.Tensor, np.ndarray], optional + EQ curve to apply to other signal, if any, by default None + + Returns + ------- + AudioSignal + In-place modification of AudioSignal. + """ + snr = util.ensure_tensor(snr) + + pad_len = max(0, self.signal_length - other.signal_length) + other.zero_pad(0, pad_len) + other.truncate_samples(self.signal_length) + if other_eq is not None: + other = other.equalizer(other_eq) + + tgt_loudness = self.loudness() - snr + other = other.normalize(tgt_loudness) + + self.audio_data = self.audio_data + other.audio_data + return self + + def convolve(self, other, start_at_max: bool=True): + """Convolves self with other. + This function uses FFTs to do the convolution. + + Parameters + ---------- + other : AudioSignal + Signal to convolve with. + start_at_max : bool, optional + Whether to start at the max value of other signal, to + avoid inducing delays, by default True + + Returns + ------- + AudioSignal + Convolved signal, in-place. + """ + from . import AudioSignal + + pad_len = self.signal_length - other.signal_length + + if pad_len > 0: + other.zero_pad(0, pad_len) + else: + other.truncate_samples(self.signal_length) + + if start_at_max: + # Use roll to rotate over the max for every item + # so that the impulse responses don't induce any + # delay. + idx = paddle.argmax(paddle.abs(other.audio_data), axis=-1) + irs = paddle.zeros_like(other.audio_data) + for i in range(other.batch_size): + irs[i] = paddle.roll( + other.audio_data[i], shifts=-idx[i].item(), axis=-1) + other = AudioSignal(irs, other.sample_rate) + + delta = paddle.zeros_like(other.audio_data) + delta[..., 0] = 1 + + length = self.signal_length + delta_fft = paddle.fft.rfft(delta, n=length) + other_fft = paddle.fft.rfft(other.audio_data, n=length) + self_fft = paddle.fft.rfft(self.audio_data, n=length) + + convolved_fft = other_fft * self_fft + convolved_audio = paddle.fft.irfft(convolved_fft, n=length) + + delta_convolved_fft = other_fft * delta_fft + delta_audio = paddle.fft.irfft(delta_convolved_fft, n=length) + + # Use the delta to rescale the audio exactly as needed. + delta_max = paddle.max(paddle.abs(delta_audio), axis=-1, keepdim=True) + scale = 1 / paddle.clip(delta_max, min=1e-5) + convolved_audio = convolved_audio * scale + + self.audio_data = convolved_audio + + return self + + def apply_ir( + self, + ir, + drr: typing.Union[paddle.Tensor, np.ndarray, float]=None, + ir_eq: typing.Union[paddle.Tensor, np.ndarray]=None, + use_original_phase: bool=False, ): + """Applies an impulse response to the signal. If ` is`ir_eq`` + is specified, the impulse response is equalized before + it is applied, using the given curve. + + Parameters + ---------- + ir : AudioSignal + Impulse response to convolve with. + drr : typing.Union[paddle.Tensor, np.ndarray, float], optional + Direct-to-reverberant ratio that impulse response will be + altered to, if specified, by default None + ir_eq : typing.Union[paddle.Tensor, np.ndarray], optional + Equalization that will be applied to impulse response + if specified, by default None + use_original_phase : bool, optional + Whether to use the original phase, instead of the convolved + phase, by default False + + Returns + ------- + AudioSignal + Signal with impulse response applied to it + """ + if ir_eq is not None: + ir = ir.equalizer(ir_eq) + if drr is not None: + ir = ir.alter_drr(drr) + + # Save the peak before + max_spk = self.audio_data.abs().max(axis=-1, keepdim=True) + + # Augment the impulse response to simulate microphone effects + # and with varying direct-to-reverberant ratio. + phase = self.phase + self.convolve(ir) + + # Use the input phase + if use_original_phase: + self.stft() + self.stft_data = self.magnitude * util.exp_compat(1j * phase) + self.istft() + + # Rescale to the input's amplitude + max_transformed = self.audio_data.abs().max(axis=-1, keepdim=True) + scale_factor = max_spk.clip(1e-8) / max_transformed.clip(1e-8) + self = self * scale_factor + + return self + + def ensure_max_of_audio(self, _max: float=1.0): + """Ensures that ``abs(audio_data) <= max``. + + Parameters + ---------- + max : float, optional + Max absolute value of signal, by default 1.0 + + Returns + ------- + AudioSignal + Signal with values scaled between -max and max. + """ + peak = self.audio_data.abs().max(axis=-1, keepdim=True) + peak_gain = paddle.ones_like(peak) + # peak_gain[peak > _max] = _max / peak[peak > _max] + peak_gain = paddle.where(peak > _max, _max / peak, peak_gain) + self.audio_data = self.audio_data * peak_gain + return self + + def normalize(self, + db: typing.Union[paddle.Tensor, np.ndarray, float]=-24.0): + """Normalizes the signal's volume to the specified db, in LUFS. + This is GPU-compatible, making for very fast loudness normalization. + + Parameters + ---------- + db : typing.Union[paddle.Tensor, np.ndarray, float], optional + Loudness to normalize to, by default -24.0 + + Returns + ------- + AudioSignal + Normalized audio signal. + """ + db = util.ensure_tensor(db) + ref_db = self.loudness() + gain = db.astype(ref_db.dtype) - ref_db + gain = util.exp_compat(gain * self.GAIN_FACTOR) + + self.audio_data = self.audio_data * gain[:, None, None] + return self + + def volume_change(self, db: typing.Union[paddle.Tensor, np.ndarray, float]): + """Change volume of signal by some amount, in dB. + + Parameters + ---------- + db : typing.Union[paddle.Tensor, np.ndarray, float] + Amount to change volume by. + + Returns + ------- + AudioSignal + Signal at new volume. + """ + db = util.ensure_tensor(db, ndim=1) + gain = util.exp_compat(db * self.GAIN_FACTOR) + self.audio_data = self.audio_data * gain[:, None, None] + return self + + def mel_filterbank(self, n_bands: int): + """Breaks signal into mel bands. + + Parameters + ---------- + n_bands : int + Number of mel bands to use. + + Returns + ------- + paddle.Tensor + Mel-filtered bands, with last axis being the band index. + """ + filterbank = SplitBands(self.sample_rate, n_bands) + filtered = filterbank(self.audio_data) + return filtered.transpose([1, 2, 3, 0]) + + def equalizer(self, db: typing.Union[paddle.Tensor, np.ndarray]): + """Applies a mel-spaced equalizer to the audio signal. + + Parameters + ---------- + db : typing.Union[paddle.Tensor, np.ndarray] + EQ curve to apply. + + Returns + ------- + AudioSignal + AudioSignal with equalization applied. + """ + db = util.ensure_tensor(db) + n_bands = db.shape[-1] + fbank = self.mel_filterbank(n_bands) + + # If there's a batch dimension, make sure it's the same. + if db.ndim == 2: + if db.shape[0] != 1: + assert db.shape[0] == fbank.shape[0] + else: + db = db.unsqueeze(0) + + weights = (10**db).astype("float32") + fbank = fbank * weights[:, None, None, :] + eq_audio_data = fbank.sum(-1) + self.audio_data = eq_audio_data + return self + + def clip_distortion( + self, + clip_percentile: typing.Union[paddle.Tensor, np.ndarray, float]): + """Clips the signal at a given percentile. The higher it is, + the lower the threshold for clipping. + + Parameters + ---------- + clip_percentile : typing.Union[paddle.Tensor, np.ndarray, float] + Values are between 0.0 to 1.0. Typical values are 0.1 or below. + + Returns + ------- + AudioSignal + Audio signal with clipped audio data. + """ + clip_percentile = util.ensure_tensor(clip_percentile, ndim=1) + clip_percentile = clip_percentile.cpu().numpy() + min_thresh = paddle.quantile( + self.audio_data, (clip_percentile / 2).tolist(), axis=-1)[None] + max_thresh = paddle.quantile( + self.audio_data, (1 - clip_percentile / 2).tolist(), axis=-1)[None] + + nc = self.audio_data.shape[1] + min_thresh = min_thresh[:, :nc, :] + max_thresh = max_thresh[:, :nc, :] + + self.audio_data = self.audio_data.clip(min_thresh, max_thresh) + + return self + + def quantization(self, + quantization_channels: typing.Union[paddle.Tensor, + np.ndarray, int]): + """Applies quantization to the input waveform. + + Parameters + ---------- + quantization_channels : typing.Union[paddle.Tensor, np.ndarray, int] + Number of evenly spaced quantization channels to quantize + to. + + Returns + ------- + AudioSignal + Quantized AudioSignal. + """ + quantization_channels = util.ensure_tensor( + quantization_channels, ndim=3) + + x = self.audio_data + quantization_channels = quantization_channels.astype(x.dtype) + x = (x + 1) / 2 + x = x * quantization_channels + x = x.floor() + x = x / quantization_channels + x = 2 * x - 1 + + residual = (self.audio_data - x).detach() + self.audio_data = self.audio_data - residual + return self + + def mulaw_quantization(self, + quantization_channels: typing.Union[ + paddle.Tensor, np.ndarray, int]): + """Applies mu-law quantization to the input waveform. + + Parameters + ---------- + quantization_channels : typing.Union[paddle.Tensor, np.ndarray, int] + Number of mu-law spaced quantization channels to quantize + to. + + Returns + ------- + AudioSignal + Quantized AudioSignal. + """ + mu = quantization_channels - 1.0 + mu = util.ensure_tensor(mu, ndim=3) + + x = self.audio_data + + # quantize + x = paddle.sign(x) * paddle.log1p(mu * paddle.abs(x)) / paddle.log1p(mu) + x = ((x + 1) / 2 * mu + 0.5).astype("int64") + + # unquantize + x = (x.astype(mu.dtype) / mu) * 2 - 1.0 + x = paddle.sign(x) * ( + util.exp_compat(paddle.abs(x) * paddle.log1p(mu)) - 1.0) / mu + + residual = (self.audio_data - x).detach() + self.audio_data = self.audio_data - residual + return self + + def __matmul__(self, other): + return self.convolve(other) + + +class ImpulseResponseMixin: + """These functions are generally only used with AudioSignals that are derived + from impulse responses, not other sources like music or speech. These methods + are used to replicate the data augmentation described in [1]. + + 1. Bryan, Nicholas J. "Impulse response data augmentation and deep + neural networks for blind room acoustic parameter estimation." + ICASSP 2020-2020 IEEE International Conference on Acoustics, + Speech and Signal Processing (ICASSP). IEEE, 2020. + """ + + def decompose_ir(self): + """Decomposes an impulse response into early and late + field responses. + """ + # Equations 1 and 2 + # ----------------- + # Breaking up into early + # response + late field response. + + td = paddle.argmax(self.audio_data, axis=-1, keepdim=True) + t0 = int(self.sample_rate * 0.0025) + + idx = paddle.arange(self.audio_data.shape[-1])[None, None, :] + idx = idx.expand([self.batch_size, -1, -1]) + early_idx = (idx >= td - t0) * (idx <= td + t0) + + early_response = paddle.zeros_like(self.audio_data) + + # early_response[early_idx] = self.audio_data[early_idx] + early_response = paddle.where(early_idx, self.audio_data, + early_response) + + late_idx = ~early_idx + late_field = paddle.zeros_like(self.audio_data) + # late_field[late_idx] = self.audio_data[late_idx] + late_field = paddle.where(late_idx, self.audio_data, late_field) + + # Equation 4 + # ---------- + # Decompose early response into windowed + # direct path and windowed residual. + + window = paddle.zeros_like(self.audio_data) + window_idx = paddle.nonzero(early_idx) + for idx in range(self.batch_size): + # window_idx = early_idx[idx, 0] + + # ----- Just for this ----- + # window[idx, ..., window_idx] = self.get_window("hann", window_idx.sum().item()) + # indices = paddle.nonzero(window_idx).reshape( + # [-1]) # shape: [num_true], dtype: int64 + indices = window_idx[window_idx[:, 0] == idx][:, -1] + + temp_window = self.get_window("hann", indices.shape[0]) + + window_slice = window[idx, 0] + updated_window_slice = paddle.scatter( + window_slice, index=indices, updates=temp_window) + + window[idx, 0] = updated_window_slice + # ----- Just for that ----- + + return early_response, late_field, window + + def measure_drr(self): + """Measures the direct-to-reverberant ratio of the impulse + response. + + Returns + ------- + float + Direct-to-reverberant ratio + """ + early_response, late_field, _ = self.decompose_ir() + num = (early_response**2).sum(axis=-1) + den = (late_field**2).sum(axis=-1) + drr = 10 * paddle.log10(num / den) + return drr + + @staticmethod + def solve_alpha(early_response, late_field, wd, target_drr): + """Used to solve for the alpha value, which is used + to alter the drr. + """ + # Equation 5 + # ---------- + # Apply the good ol' quadratic formula. + + wd_sq = wd**2 + wd_sq_1 = (1 - wd)**2 + e_sq = early_response**2 + l_sq = late_field**2 + a = (wd_sq * e_sq).sum(axis=-1) + b = (2 * (1 - wd) * wd * e_sq).sum(axis=-1) + c = (wd_sq_1 * e_sq).sum(axis=-1) - paddle.pow(10 * paddle.ones_like( + target_drr, dtype="float32"), target_drr.cast("float32") / + 10) * l_sq.sum(axis=-1) + + expr = ((b**2) - 4 * a * c).sqrt() + alpha = paddle.maximum( + (-b - expr) / (2 * a), + (-b + expr) / (2 * a), ) + return alpha + + def alter_drr(self, drr: typing.Union[paddle.Tensor, np.ndarray, float]): + """Alters the direct-to-reverberant ratio of the impulse response. + + Parameters + ---------- + drr : typing.Union[paddle.Tensor, np.ndarray, float] + Direct-to-reverberant ratio that impulse response will be + altered to, if specified, by default None + + Returns + ------- + AudioSignal + Altered impulse response. + """ + drr = util.ensure_tensor( + drr, 2, self.batch_size + ) # Assuming util.ensure_tensor is adapted or equivalent exists + + early_response, late_field, window = self.decompose_ir() + alpha = self.solve_alpha(early_response, late_field, window, drr) + min_alpha = late_field.abs().max(axis=-1)[0] / early_response.abs().max( + axis=-1)[0] + alpha = paddle.maximum(alpha, min_alpha)[..., None] + + aug_ir_data = alpha * window * early_response + ( + (1 - window) * early_response) + late_field + self.audio_data = aug_ir_data + self.ensure_max_of_audio( + ) # Assuming ensure_max_of_audio is a method defined elsewhere + return self diff --git a/audio/audiotools/core/ffmpeg.py b/audio/audiotools/core/ffmpeg.py new file mode 100644 index 00000000000..64a74e51d5c --- /dev/null +++ b/audio/audiotools/core/ffmpeg.py @@ -0,0 +1,119 @@ +# MIT License, Copyright (c) 2023-Present, Descript. +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/audiotools/core/ffmpeg.py) +import json +import shlex +import subprocess +import tempfile +from pathlib import Path +from typing import Tuple + +import ffmpy +import numpy as np +import paddle + + +def r128stats(filepath: str, quiet: bool): + """Takes a path to an audio file, returns a dict with the loudness + stats computed by the ffmpeg ebur128 filter. + + Parameters + ---------- + filepath : str + Path to compute loudness stats on. + quiet : bool + Whether to show FFMPEG output during computation. + + Returns + ------- + dict + Dictionary containing loudness stats. + """ + ffargs = [ + "ffmpeg", + "-nostats", + "-i", + filepath, + "-filter_complex", + "ebur128", + "-f", + "null", + "-", + ] + if quiet: + ffargs += ["-hide_banner"] + proc = subprocess.Popen( + ffargs, stderr=subprocess.PIPE, universal_newlines=True) + stats = proc.communicate()[1] + summary_index = stats.rfind("Summary:") + + summary_list = stats[summary_index:].split() + i_lufs = float(summary_list[summary_list.index("I:") + 1]) + i_thresh = float(summary_list[summary_list.index("I:") + 4]) + lra = float(summary_list[summary_list.index("LRA:") + 1]) + lra_thresh = float(summary_list[summary_list.index("LRA:") + 4]) + lra_low = float(summary_list[summary_list.index("low:") + 1]) + lra_high = float(summary_list[summary_list.index("high:") + 1]) + stats_dict = { + "I": i_lufs, + "I Threshold": i_thresh, + "LRA": lra, + "LRA Threshold": lra_thresh, + "LRA Low": lra_low, + "LRA High": lra_high, + } + + return stats_dict + + +def ffprobe_offset_and_codec(path: str) -> Tuple[float, str]: + """Given a path to a file, returns the start time offset and codec of + the first audio stream. + """ + ff = ffmpy.FFprobe( + inputs={path: None}, + global_options="-show_entries format=start_time:stream=duration,start_time,codec_type,codec_name,start_pts,time_base -of json -v quiet", + ) + streams = json.loads(ff.run(stdout=subprocess.PIPE)[0])["streams"] + seconds_offset = 0.0 + codec = None + + # Get the offset and codec of the first audio stream we find + # and return its start time, if it has one. + for stream in streams: + if stream["codec_type"] == "audio": + seconds_offset = stream.get("start_time", 0.0) + codec = stream.get("codec_name") + break + return float(seconds_offset), codec + + +class FFMPEGMixin: + _loudness = None + + def ffmpeg_loudness(self, quiet: bool=True): + """Computes loudness of audio file using FFMPEG. + + Parameters + ---------- + quiet : bool, optional + Whether to show FFMPEG output during computation, + by default True + + Returns + ------- + paddle.Tensor + Loudness of every item in the batch, computed via + FFMPEG. + """ + loudness = [] + + with tempfile.NamedTemporaryFile(suffix=".wav") as f: + for i in range(self.batch_size): + self[i].write(f.name) + loudness_stats = r128stats(f.name, quiet=quiet) + loudness.append(loudness_stats["I"]) + + self._loudness = paddle.to_tensor(np.array(loudness)).astype("float32") + return self.loudness() diff --git a/audio/audiotools/core/loudness.py b/audio/audiotools/core/loudness.py new file mode 100644 index 00000000000..cde5e81fe3a --- /dev/null +++ b/audio/audiotools/core/loudness.py @@ -0,0 +1,387 @@ +# MIT License, Copyright (c) 2023-Present, Descript. +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/audiotools/core/loudness.py) +import copy +import math +import typing + +import numpy as np +import paddle +import paddle.nn.functional as F +import scipy + +from . import _julius + + +def _unfold1d(x, kernel_size, stride): + # https://github.com/PaddlePaddle/Paddle/pull/70102 + """1D only unfolding similar to the one from Paddlepaddle. + + Given an _input tensor of size `[*, T]` this will return + a tensor `[*, F, K]` with `K` the kernel size, and `F` the number + of frames. The i-th frame is a view onto `i * stride: i * stride + kernel_size`. + This will automatically pad the _input to cover at least once all entries in `_input`. + + Args: + _input (Tensor): tensor for which to return the frames. + kernel_size (int): size of each frame. + stride (int): stride between each frame. + + Shape: + + - Inputs: `_input` is `[*, T]` + - Output: `[*, F, kernel_size]` with `F = 1 + ceil((T - kernel_size) / stride)` + """ + + if 3 != x.dim(): + raise NotImplementedError + + N, C, length = x.shape + x = x.reshape([N * C, 1, length]) + + n_frames = math.ceil((max(length, kernel_size) - kernel_size) / stride) + 1 + tgt_length = (n_frames - 1) * stride + kernel_size + x = F.pad(x, (0, tgt_length - length), data_format="NCL") + + x = x.unsqueeze(-1) + + unfolded = paddle.nn.functional.unfold( + x, + kernel_sizes=[kernel_size, 1], + strides=[stride, 1], ) + + unfolded = unfolded.transpose([0, 2, 1]) + unfolded = unfolded.reshape([N, C, *unfolded.shape[1:]]) + return unfolded + + +class Meter(paddle.nn.Layer): + """Tensorized version of pyloudnorm.Meter. Works with batched audio tensors. + + Parameters + ---------- + rate : int + Sample rate of audio. + filter_class : str, optional + Class of weighting filter used. + K-weighting' (default), 'Fenton/Lee 1' + 'Fenton/Lee 2', 'Dash et al.' + by default "K-weighting" + block_size : float, optional + Gating block size in seconds, by default 0.400 + zeros : int, optional + Number of zeros to use in FIR approximation of + IIR filters, by default 512 + use_fir : bool, optional + Whether to use FIR approximation or exact IIR formulation. + If computing on GPU, ``use_fir=True`` will be used, as its + much faster, by default False + """ + + def __init__( + self, + rate: int, + filter_class: str="K-weighting", + block_size: float=0.400, + zeros: int=512, + use_fir: bool=False, ): + super().__init__() + + self.rate = rate + self.filter_class = filter_class + self.block_size = block_size + self.use_fir = use_fir + + G = paddle.to_tensor( + np.array([1.0, 1.0, 1.0, 1.41, 1.41]), stop_gradient=True) + self.register_buffer("G", G) + + # Compute impulse responses so that filtering is fast via + # a convolution at runtime, on GPU, unlike lfilter. + impulse = np.zeros((zeros, )) + impulse[..., 0] = 1.0 + + firs = np.zeros((len(self._filters), 1, zeros)) + # passband_gain = torch.zeros(len(self._filters)) + passband_gain = paddle.zeros([len(self._filters)], dtype="float32") + + for i, (_, filter_stage) in enumerate(self._filters.items()): + firs[i] = scipy.signal.lfilter(filter_stage.b, filter_stage.a, + impulse) + passband_gain[i] = filter_stage.passband_gain + + firs = paddle.to_tensor( + firs[..., ::-1].copy(), dtype="float32", stop_gradient=True) + + self.register_buffer("firs", firs) + self.register_buffer("passband_gain", passband_gain) + + def apply_filter_gpu(self, data: paddle.Tensor): + """Performs FIR approximation of loudness computation. + + Parameters + ---------- + data : paddle.Tensor + Audio data of shape (nb, nch, nt). + + Returns + ------- + paddle.Tensor + Filtered audio data. + """ + # Data is of shape (nb, nch, nt) + # Reshape to (nb*nch, 1, nt) + nb, nt, nch = data.shape + data = data.transpose([0, 2, 1]) + data = data.reshape([nb * nch, 1, nt]) + + # Apply padding + pad_length = self.firs.shape[-1] + + # Apply filtering in sequence + for i in range(self.firs.shape[0]): + data = F.pad(data, (pad_length, pad_length), data_format="NCL") + data = _julius.fft_conv1d(data, self.firs[i, None, ...]) + data = self.passband_gain[i] * data + data = data[..., 1:nt + 1] + + data = data.transpose([0, 2, 1]) + data = data[:, :nt, :] + return data + + @staticmethod + def scipy_lfilter(waveform, a_coeffs, b_coeffs, clamp: bool=True): + # 使用 scipy.signal.lfilter 进行滤波(处理三维数据) + output = np.zeros_like(waveform) + for batch_idx in range(waveform.shape[0]): + for channel_idx in range(waveform.shape[2]): + output[batch_idx, :, channel_idx] = scipy.signal.lfilter( + b_coeffs, a_coeffs, waveform[batch_idx, :, channel_idx]) + return output + + def apply_filter_cpu(self, data: paddle.Tensor): + """Performs IIR formulation of loudness computation. + + Parameters + ---------- + data : paddle.Tensor + Audio data of shape (nb, nch, nt). + + Returns + ------- + paddle.Tensor + Filtered audio data. + """ + _data = data.cpu().numpy().copy() + for _, filter_stage in self._filters.items(): + passband_gain = filter_stage.passband_gain + + a_coeffs = filter_stage.a + b_coeffs = filter_stage.b + + filtered = self.scipy_lfilter(_data, a_coeffs, b_coeffs) + _data[:] = passband_gain * filtered + data = paddle.to_tensor(_data) + return data + + def apply_filter(self, data: paddle.Tensor): + """Applies filter on either CPU or GPU, depending + on if the audio is on GPU or is on CPU, or if + ``self.use_fir`` is True. + + Parameters + ---------- + data : paddle.Tensor + Audio data of shape (nb, nch, nt). + + Returns + ------- + paddle.Tensor + Filtered audio data. + """ + # if data.place.is_gpu_place() or self.use_fir: + # data = self.apply_filter_gpu(data) + # else: + # data = self.apply_filter_cpu(data) + data = self.apply_filter_cpu(data) + return data + + def forward(self, data: paddle.Tensor): + """Computes integrated loudness of data. + + Parameters + ---------- + data : paddle.Tensor + Audio data of shape (nb, nch, nt). + + Returns + ------- + paddle.Tensor + Filtered audio data. + """ + return self.integrated_loudness(data) + + def _unfold(self, input_data): + T_g = self.block_size + overlap = 0.75 # overlap of 75% of the block duration + step = 1.0 - overlap # step size by percentage + + kernel_size = int(T_g * self.rate) + stride = int(T_g * self.rate * step) + unfolded = _unfold1d( + input_data.transpose([0, 2, 1]), kernel_size, stride) + unfolded = unfolded.transpose([0, 1, 3, 2]) + + return unfolded + + def integrated_loudness(self, data: paddle.Tensor): + """Computes integrated loudness of data. + + Parameters + ---------- + data : paddle.Tensor + Audio data of shape (nb, nch, nt). + + Returns + ------- + paddle.Tensor + Filtered audio data. + """ + if not paddle.is_tensor(data): + data = paddle.to_tensor(data, dtype="float32") + else: + data = data.astype("float32") + + input_data = data.clone() + # Data always has a batch and channel dimension. + # Is of shape (nb, nt, nch) + if input_data.ndim < 2: + input_data = input_data.unsqueeze(-1) + if input_data.ndim < 3: + input_data = input_data.unsqueeze(0) + + nb, nt, nch = input_data.shape + + # Apply frequency weighting filters - account + # for the acoustic respose of the head and auditory system + input_data = self.apply_filter(input_data) + + G = self.G # channel gains + T_g = self.block_size # 400 ms gating block standard + Gamma_a = -70.0 # -70 LKFS = absolute loudness threshold + + unfolded = self._unfold(input_data) + + z = (1.0 / (T_g * self.rate)) * unfolded.square().sum(2) + l = -0.691 + 10.0 * paddle.log10( + (G[None, :nch, None] * z).sum(1, keepdim=True)) + l = l.expand_as(z) + + # find gating block indices above absolute threshold + z_avg_gated = z + z_avg_gated[l <= Gamma_a] = 0 + masked = l > Gamma_a + z_avg_gated = z_avg_gated.sum(2) / masked.sum(2).astype("float32") + + # calculate the relative threshold value (see eq. 6) + Gamma_r = -0.691 + 10.0 * paddle.log10( + (z_avg_gated * G[None, :nch]).sum(-1)) - 10.0 + Gamma_r = Gamma_r[:, None, None] + Gamma_r = Gamma_r.expand([nb, nch, l.shape[-1]]) + + # find gating block indices above relative and absolute thresholds (end of eq. 7) + z_avg_gated = z + z_avg_gated[l <= Gamma_a] = 0 + z_avg_gated[l <= Gamma_r] = 0 + masked = (l > Gamma_a) * (l > Gamma_r) + z_avg_gated = z_avg_gated.sum(2) / (masked.sum(2) + 10e-6) + + # TODO Currently, paddle has a segmentation fault bug in this section of the code + # z_avg_gated = paddle.nan_to_num(z_avg_gated) + # z_avg_gated = paddle.where( + # paddle.isnan(z_avg_gated), + # paddle.zeros_like(z_avg_gated), z_avg_gated) + z_avg_gated[z_avg_gated == float("inf")] = float( + np.finfo(np.float32).max) + z_avg_gated[z_avg_gated == -float("inf")] = float( + np.finfo(np.float32).min) + + LUFS = -0.691 + 10.0 * paddle.log10( + (G[None, :nch] * z_avg_gated).sum(1)) + return LUFS.astype("float32") + + @property + def filter_class(self): + return self._filter_class + + @filter_class.setter + def filter_class(self, value): + from pyloudnorm import Meter + + meter = Meter(self.rate) + meter.filter_class = value + self._filter_class = value + self._filters = meter._filters + + +class LoudnessMixin: + _loudness = None + MIN_LOUDNESS = -70 + """Minimum loudness possible.""" + + def loudness(self, + filter_class: str="K-weighting", + block_size: float=0.400, + **kwargs): + """Calculates loudness using an implementation of ITU-R BS.1770-4. + Allows control over gating block size and frequency weighting filters for + additional control. Measure the integrated gated loudness of a signal. + + API is derived from PyLoudnorm, but this implementation is ported to PyTorch + and is tensorized across batches. When on GPU, an FIR approximation of the IIR + filters is used to compute loudness for speed. + + Uses the weighting filters and block size defined by the meter + the integrated loudness is measured based upon the gating algorithm + defined in the ITU-R BS.1770-4 specification. + + Parameters + ---------- + filter_class : str, optional + Class of weighting filter used. + K-weighting' (default), 'Fenton/Lee 1' + 'Fenton/Lee 2', 'Dash et al.' + by default "K-weighting" + block_size : float, optional + Gating block size in seconds, by default 0.400 + kwargs : dict, optional + Keyword arguments to :py:func:`audiotools.core.loudness.Meter`. + + Returns + ------- + paddle.Tensor + Loudness of audio data. + """ + if self._loudness is not None: + return self._loudness # .to(self.device) + original_length = self.signal_length + if self.signal_duration < 0.5: + pad_len = int((0.5 - self.signal_duration) * self.sample_rate) + self.zero_pad(0, pad_len) + + # create BS.1770 meter + meter = Meter( + self.sample_rate, + filter_class=filter_class, + block_size=block_size, + **kwargs) + # meter = meter.to(self.device) + # measure loudness + loudness = meter.integrated_loudness( + self.audio_data.transpose([0, 2, 1])) + self.truncate_samples(original_length) + min_loudness = paddle.ones_like(loudness) * self.MIN_LOUDNESS + self._loudness = paddle.maximum(loudness, min_loudness) + + return self._loudness # .to(self.device) diff --git a/audio/audiotools/core/util.py b/audio/audiotools/core/util.py new file mode 100644 index 00000000000..0bbcf46a504 --- /dev/null +++ b/audio/audiotools/core/util.py @@ -0,0 +1,921 @@ +# MIT License, Copyright (c) 2023-Present, Descript. +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/audiotools/core/util.py) +import collections +import csv +import glob +import math +import numbers +import os +import random +import typing +from contextlib import contextmanager +from dataclasses import dataclass +from pathlib import Path +from typing import Any +from typing import Callable +from typing import Dict +from typing import Iterable +from typing import List +from typing import NamedTuple +from typing import Optional +from typing import Tuple +from typing import Type +from typing import Union + +import ffmpeg +import librosa +import numpy as np +import paddle +import soundfile +from flatten_dict import flatten +from flatten_dict import unflatten + +from .audio_signal import AudioSignal +from paddlespeech.utils import satisfy_paddle_version +from paddlespeech.vector.training.seeding import seed_everything + +__all__ = [ + "exp_compat", + "bool_index_compat", + "bool_setitem_compat", + "Info", + "info", + "ensure_tensor", + "random_state", + "find_audio", + "read_sources", + "choose_from_list_of_lists", + "chdir", + "move_to_device", + "prepare_batch", + "sample_from_dist", + "format_figure", + "default_collate", + "collate", + "hz_to_bin", + "generate_chord_dataset", +] + + +def exp_compat(x): + """ + Compute the exponential of the input tensor `x`. + + This function is designed to handle compatibility issues with PaddlePaddle versions below 2.6, + which do not support the `exp` operation for complex tensors. In such cases, the computation + is offloaded to NumPy. + + Args: + x (paddle.Tensor): The input tensor for which to compute the exponential. + + Returns: + paddle.Tensor: The result of the exponential operation, as a PaddlePaddle tensor. + + Notes: + - If the PaddlePaddle version is 2.6 or above, the function uses `paddle.exp` directly. + - For versions below 2.6, the tensor is first converted to a NumPy array, the exponential + is computed using `np.exp`, and the result is then converted back to a PaddlePaddle tensor. + """ + if satisfy_paddle_version("2.6"): + return paddle.exp(x) + else: + x_np = x.cpu().numpy() + return paddle.to_tensor(np.exp(x_np)) + + +def bool_index_compat(x, mask): + """ + Perform boolean indexing on the input tensor `x` using the provided `mask`. + + This function ensures compatibility with PaddlePaddle versions below 2.6, where boolean indexing + may not be fully supported. For older versions, the operation is performed using NumPy. + + Args: + x (paddle.Tensor): The input tensor to be indexed. + mask (paddle.Tensor or int): The boolean mask or integer index used for indexing. + + Returns: + paddle.Tensor: The result of the boolean indexing operation, as a PaddlePaddle tensor. + + Notes: + - If the PaddlePaddle version is 2.6 or above, or if `mask` is an integer, the function uses + Paddle's native indexing directly. + - For versions below 2.6, the tensor and mask are converted to NumPy arrays, the indexing + operation is performed using NumPy, and the result is converted back to a PaddlePaddle tensor. + """ + if satisfy_paddle_version("2.6") or isinstance(mask, (int, list, slice)): + return x[mask] + else: + x_np = x.cpu().numpy()[mask.cpu().numpy()] + return paddle.to_tensor(x_np) + + +def bool_setitem_compat(x, mask, y): + """ + Perform boolean assignment on the input tensor `x` using the provided `mask` and values `y`. + + This function ensures compatibility with PaddlePaddle versions below 2.6, where boolean assignment + may not be fully supported. For older versions, the operation is performed using NumPy. + + Args: + x (paddle.Tensor): The input tensor to be modified. + mask (paddle.Tensor): The boolean mask used for assignment. + y (paddle.Tensor): The values to assign to the selected elements of `x`. + + Returns: + paddle.Tensor: The modified tensor after the assignment operation. + + Notes: + - If the PaddlePaddle version is 2.6 or above, the function uses Paddle's native assignment directly. + - For versions below 2.6, the tensor, mask, and values are converted to NumPy arrays, the assignment + operation is performed using NumPy, and the result is converted back to a PaddlePaddle tensor. + """ + if satisfy_paddle_version("2.6"): + + x[mask] = y + return x + else: + x_np = x.cpu().numpy() + x_np[mask.cpu().numpy()] = y.cpu().numpy() + + return paddle.to_tensor(x_np) + + +@dataclass +class Info: + + sample_rate: float + num_frames: int + + @property + def duration(self) -> float: + return self.num_frames / self.sample_rate + + +def info_ffmpeg(audio_path: str): + """ + Parameters + ---------- + audio_path : str + Path to audio file. + """ + probe = ffmpeg.probe(audio_path) + audio_streams = [ + stream for stream in probe['streams'] if stream['codec_type'] == 'audio' + ] + if not audio_streams: + raise ValueError("No audio stream found in the file.") + audio_stream = audio_streams[0] + + sample_rate = int(audio_stream['sample_rate']) + duration = float(audio_stream['duration']) + + num_frames = int(duration * sample_rate) + + info = Info(sample_rate=sample_rate, num_frames=num_frames) + return info + + +def info(audio_path: str): + """ + + Parameters + ---------- + audio_path : str + Path to audio file. + """ + try: + info = soundfile.info(str(audio_path)) + info = Info(sample_rate=info.samplerate, num_frames=info.frames) + except: + info = info_ffmpeg(str(audio_path)) + + return info + + +def ensure_tensor( + x: typing.Union[np.ndarray, paddle.Tensor, float, int], + ndim: int=None, + batch_size: int=None, ): + """Ensures that the input ``x`` is a tensor of specified + dimensions and batch size. + + Parameters + ---------- + x : typing.Union[np.ndarray, paddle.Tensor, float, int] + Data that will become a tensor on its way out. + ndim : int, optional + How many dimensions should be in the output, by default None + batch_size : int, optional + The batch size of the output, by default None + + Returns + ------- + paddle.Tensor + Modified version of ``x`` as a tensor. + """ + if not paddle.is_tensor(x): + x = paddle.to_tensor(x) + if ndim is not None: + assert x.ndim <= ndim + while x.ndim < ndim: + x = x.unsqueeze(-1) + if batch_size is not None: + if x.shape[0] != batch_size: + shape = list(x.shape) + shape[0] = batch_size + x = paddle.expand(x, shape) + return x + + +def _get_value(other): + # + from . import AudioSignal + + if isinstance(other, AudioSignal): + return other.audio_data + return other + + +def random_state(seed: typing.Union[int, np.random.RandomState]): + """ + Turn seed into a np.random.RandomState instance. + + Parameters + ---------- + seed : typing.Union[int, np.random.RandomState] or None + If seed is None, return the RandomState singleton used by np.random. + If seed is an int, return a new RandomState instance seeded with seed. + If seed is already a RandomState instance, return it. + Otherwise raise ValueError. + + Returns + ------- + np.random.RandomState + Random state object. + + Raises + ------ + ValueError + If seed is not valid, an error is thrown. + """ + if seed is None or seed is np.random: + return np.random.mtrand._rand + elif isinstance(seed, (numbers.Integral, np.integer, int)): + return np.random.RandomState(seed) + elif isinstance(seed, np.random.RandomState): + return seed + else: + raise ValueError("%r cannot be used to seed a numpy.random.RandomState" + " instance" % seed) + + +@contextmanager +def _close_temp_files(tmpfiles: list): + """Utility function for creating a context and closing all temporary files + once the context is exited. For correct functionality, all temporary file + handles created inside the context must be appended to the ```tmpfiles``` + list. + + This function is taken wholesale from Scaper. + + Parameters + ---------- + tmpfiles : list + List of temporary file handles + """ + + def _close(): + for t in tmpfiles: + try: + t.close() + os.unlink(t.name) + except: + pass + + try: + yield + except: + _close() + raise + _close() + + +AUDIO_EXTENSIONS = [".wav", ".flac", ".mp3"] + + +def find_audio(folder: str, ext: List[str]=AUDIO_EXTENSIONS): + """Finds all audio files in a directory recursively. + Returns a list. + + Parameters + ---------- + folder : str + Folder to look for audio files in, recursively. + ext : List[str], optional + Extensions to look for without the ., by default + ``['.wav', '.flac', '.mp3', '.mp4']``. + """ + folder = Path(folder) + # Take care of case where user has passed in an audio file directly + # into one of the calling functions. + if str(folder).endswith(tuple(ext)): + # if, however, there's a glob in the path, we need to + # return the glob, not the file. + if "*" in str(folder): + return glob.glob(str(folder), recursive=("**" in str(folder))) + else: + return [folder] + + files = [] + for x in ext: + files += folder.glob(f"**/*{x}") + return files + + +def read_sources( + sources: List[str], + remove_empty: bool=True, + relative_path: str="", + ext: List[str]=AUDIO_EXTENSIONS, ): + """Reads audio sources that can either be folders + full of audio files, or CSV files that contain paths + to audio files. CSV files that adhere to the expected + format can be generated by + :py:func:`audiotools.data.preprocess.create_csv`. + + Parameters + ---------- + sources : List[str] + List of audio sources to be converted into a + list of lists of audio files. + remove_empty : bool, optional + Whether or not to remove rows with an empty "path" + from each CSV file, by default True. + + Returns + ------- + list + List of lists of rows of CSV files. + """ + files = [] + relative_path = Path(relative_path) + for source in sources: + source = str(source) + _files = [] + if source.endswith(".csv"): + with open(source, "r") as f: + reader = csv.DictReader(f) + for x in reader: + if remove_empty and x["path"] == "": + continue + if x["path"] != "": + x["path"] = str(relative_path / x["path"]) + _files.append(x) + else: + for x in find_audio(source, ext=ext): + x = str(relative_path / x) + _files.append({"path": x}) + files.append(sorted(_files, key=lambda x: x["path"])) + return files + + +def choose_from_list_of_lists(state: np.random.RandomState, + list_of_lists: list, + p: float=None): + """Choose a single item from a list of lists. + + Parameters + ---------- + state : np.random.RandomState + Random state to use when choosing an item. + list_of_lists : list + A list of lists from which items will be drawn. + p : float, optional + Probabilities of each list, by default None + + Returns + ------- + typing.Any + An item from the list of lists. + """ + source_idx = state.choice(list(range(len(list_of_lists))), p=p) + item_idx = state.randint(len(list_of_lists[source_idx])) + return list_of_lists[source_idx][item_idx], source_idx, item_idx + + +@contextmanager +def chdir(newdir: typing.Union[Path, str]): + """ + Context manager for switching directories to run a + function. Useful for when you want to use relative + paths to different runs. + + Parameters + ---------- + newdir : typing.Union[Path, str] + Directory to switch to. + """ + curdir = os.getcwd() + try: + os.chdir(newdir) + yield + finally: + os.chdir(curdir) + + +def move_to_device(data, device): + if device is None or device == "": + return data + elif device == 'cpu': + return paddle.to_tensor(data, place=paddle.CPUPlace()) + elif device in ('gpu', 'cuda'): + return paddle.to_tensor(data, place=paddle.CUDAPlace()) + else: + device = device.replace("cuda", "gpu") if "cuda" in device else device + return data.to(device) + + +def prepare_batch(batch: typing.Union[dict, list, paddle.Tensor], + device: str="cpu"): + """Moves items in a batch (typically generated by a DataLoader as a list + or a dict) to the specified device. This works even if dictionaries + are nested. + + Parameters + ---------- + batch : typing.Union[dict, list, paddle.Tensor] + Batch, typically generated by a dataloader, that will be moved to + the device. + device : str, optional + Device to move batch to, by default "cpu" + + Returns + ------- + typing.Union[dict, list, paddle.Tensor] + Batch with all values moved to the specified device. + """ + device = device.replace("cuda", "gpu") + if isinstance(batch, dict): + batch = flatten(batch) + for key, val in batch.items(): + try: + # batch[key] = val.to(device) + batch[key] = move_to_device(val, device) + except: + pass + batch = unflatten(batch) + elif paddle.is_tensor(batch): + # batch = batch.to(device) + batch = move_to_device(batch, device) + elif isinstance(batch, list): + for i in range(len(batch)): + try: + batch[i] = batch[i].to(device) + except: + pass + return batch + + +def sample_from_dist(dist_tuple: tuple, state: np.random.RandomState=None): + """Samples from a distribution defined by a tuple. The first + item in the tuple is the distribution type, and the rest of the + items are arguments to that distribution. The distribution function + is gotten from the ``np.random.RandomState`` object. + + Parameters + ---------- + dist_tuple : tuple + Distribution tuple + state : np.random.RandomState, optional + Random state, or seed to use, by default None + + Returns + ------- + typing.Union[float, int, str] + Draw from the distribution. + + Examples + -------- + Sample from a uniform distribution: + + >>> dist_tuple = ("uniform", 0, 1) + >>> sample_from_dist(dist_tuple) + + Sample from a constant distribution: + + >>> dist_tuple = ("const", 0) + >>> sample_from_dist(dist_tuple) + + Sample from a normal distribution: + + >>> dist_tuple = ("normal", 0, 0.5) + >>> sample_from_dist(dist_tuple) + + """ + if dist_tuple[0] == "const": + return dist_tuple[1] + state = random_state(state) + dist_fn = getattr(state, dist_tuple[0]) + return dist_fn(*dist_tuple[1:]) + + +BASE_SIZE = 864 +DEFAULT_FIG_SIZE = (9, 3) + + +def format_figure( + fig_size: tuple=None, + title: str=None, + fig=None, + format_axes: bool=True, + format: bool=True, + font_color: str="white", ): + """Prettifies the spectrogram and waveform plots. A title + can be inset into the top right corner, and the axes can be + inset into the figure, allowing the data to take up the entire + image. Used in + + - :py:func:`audiotools.core.display.DisplayMixin.specshow` + - :py:func:`audiotools.core.display.DisplayMixin.waveplot` + - :py:func:`audiotools.core.display.DisplayMixin.wavespec` + + Parameters + ---------- + fig_size : tuple, optional + Size of figure, by default (9, 3) + title : str, optional + Title to inset in top right, by default None + fig : matplotlib.figure.Figure, optional + Figure object, if None ``plt.gcf()`` will be used, by default None + format_axes : bool, optional + Format the axes to be inside the figure, by default True + format : bool, optional + This formatting can be skipped entirely by passing ``format=False`` + to any of the plotting functions that use this formater, by default True + font_color : str, optional + Color of font of axes, by default "white" + """ + import matplotlib + import matplotlib.pyplot as plt + + if fig_size is None: + fig_size = DEFAULT_FIG_SIZE + if not format: + return + if fig is None: + fig = plt.gcf() + fig.set_size_inches(*fig_size) + axs = fig.axes + + pixels = (fig.get_size_inches() * fig.dpi)[0] + font_scale = pixels / BASE_SIZE + + if format_axes: + axs = fig.axes + + for ax in axs: + ymin, _ = ax.get_ylim() + xmin, _ = ax.get_xlim() + + ticks = ax.get_yticks() + for t in ticks[2:-1]: + t = axs[0].annotate( + f"{(t / 1000):2.1f}k", + xy=(xmin, t), + xycoords="data", + xytext=(5, -5), + textcoords="offset points", + ha="left", + va="top", + color=font_color, + fontsize=12 * font_scale, + alpha=0.75, ) + + ticks = ax.get_xticks()[2:] + for t in ticks[:-1]: + t = axs[0].annotate( + f"{t:2.1f}s", + xy=(t, ymin), + xycoords="data", + xytext=(5, 5), + textcoords="offset points", + ha="center", + va="bottom", + color=font_color, + fontsize=12 * font_scale, + alpha=0.75, ) + + ax.margins(0, 0) + ax.set_axis_off() + ax.xaxis.set_major_locator(plt.NullLocator()) + ax.yaxis.set_major_locator(plt.NullLocator()) + + plt.subplots_adjust( + top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) + + if title is not None: + t = axs[0].annotate( + title, + xy=(1, 1), + xycoords="axes fraction", + fontsize=20 * font_scale, + xytext=(-5, -5), + textcoords="offset points", + ha="right", + va="top", + color="white", ) + t.set_bbox(dict(facecolor="black", alpha=0.5, edgecolor="black")) + + +_default_collate_err_msg_format = ( + "default_collate: batch must contain tensors, numpy arrays, numbers, " + "dicts or lists; found {}") + + +def collate_tensor_fn( + batch, + *, + collate_fn_map: Optional[Dict[Union[type, Tuple[type, ...]], + Callable]]=None, ): + out = paddle.stack(batch, axis=0) + return out + + +def collate_float_fn( + batch, + *, + collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], + Callable]]=None, ): + return paddle.to_tensor(batch, dtype=paddle.float64) + + +def collate_int_fn( + batch, + *, + collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], + Callable]]=None, ): + return paddle.to_tensor(batch) + + +def collate_str_fn( + batch, + *, + collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], + Callable]]=None, ): + return batch + + +default_collate_fn_map: Dict[Union[Type, Tuple[Type, ...]], Callable] = { + paddle.Tensor: collate_tensor_fn +} +default_collate_fn_map[float] = collate_float_fn +default_collate_fn_map[int] = collate_int_fn +default_collate_fn_map[str] = collate_str_fn +default_collate_fn_map[bytes] = collate_str_fn + + +def default_collate(batch, + *, + collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], + Callable]]=None): + r""" + General collate function that handles collection type of element within each batch. + + The function also opens function registry to deal with specific element types. `default_collate_fn_map` + provides default collate functions for tensors, numpy arrays, numbers and strings. + + Args: + batch: a single batch to be collated + collate_fn_map: Optional dictionary mapping from element type to the corresponding collate function. + If the element type isn't present in this dictionary, + this function will go through each key of the dictionary in the insertion order to + invoke the corresponding collate function if the element type is a subclass of the key. + Note: + Each collate function requires a positional argument for batch and a keyword argument + for the dictionary of collate functions as `collate_fn_map`. + """ + elem = batch[0] + elem_type = type(elem) + + if collate_fn_map is not None: + if elem_type in collate_fn_map: + return collate_fn_map[elem_type]( + batch, collate_fn_map=collate_fn_map) + + for collate_type in collate_fn_map: + if isinstance(elem, collate_type): + return collate_fn_map[collate_type]( + batch, collate_fn_map=collate_fn_map) + + if isinstance(elem, collections.abc.Mapping): + try: + return elem_type({ + key: default_collate( + [d[key] for d in batch], collate_fn_map=collate_fn_map) + for key in elem + }) + except TypeError: + # The mapping type may not support `__init__(iterable)`. + return { + key: default_collate( + [d[key] for d in batch], collate_fn_map=collate_fn_map) + for key in elem + } + elif isinstance(elem, tuple) and hasattr(elem, "_fields"): # namedtuple + return elem_type(*(default_collate( + samples, collate_fn_map=collate_fn_map) for samples in zip(*batch))) + elif isinstance(elem, collections.abc.Sequence): + # check to make sure that the elements in batch have consistent size + it = iter(batch) + elem_size = len(next(it)) + if not all(len(elem) == elem_size for elem in it): + raise RuntimeError( + "each element in list of batch should be of equal size") + transposed = list(zip( + *batch)) # It may be accessed twice, so we use a list. + + if isinstance(elem, tuple): + return [ + default_collate(samples, collate_fn_map=collate_fn_map) + for samples in transposed + ] # Backwards compatibility. + else: + try: + return elem_type([ + default_collate(samples, collate_fn_map=collate_fn_map) + for samples in transposed + ]) + except TypeError: + # The sequence type may not support `__init__(iterable)` (e.g., `range`). + return [ + default_collate(samples, collate_fn_map=collate_fn_map) + for samples in transposed + ] + + raise TypeError(_default_collate_err_msg_format.format(elem_type)) + + +def collate(list_of_dicts: list, n_splits: int=None): + """Collates a list of dictionaries (e.g. as returned by a + dataloader) into a dictionary with batched values. This routine + uses the default torch collate function for everything + except AudioSignal objects, which are handled by the + :py:func:`audiotools.core.audio_signal.AudioSignal.batch` + function. + + This function takes n_splits to enable splitting a batch + into multiple sub-batches for the purposes of gradient accumulation, + etc. + + Parameters + ---------- + list_of_dicts : list + List of dictionaries to be collated. + n_splits : int + Number of splits to make when creating the batches (split into + sub-batches). Useful for things like gradient accumulation. + + Returns + ------- + dict + Dictionary containing batched data. + """ + + batches = [] + list_len = len(list_of_dicts) + + return_list = False if n_splits is None else True + n_splits = 1 if n_splits is None else n_splits + n_items = int(math.ceil(list_len / n_splits)) + + for i in range(0, list_len, n_items): + # Flatten the dictionaries to avoid recursion. + list_of_dicts_ = [flatten(d) for d in list_of_dicts[i:i + n_items]] + dict_of_lists = { + k: [dic[k] for dic in list_of_dicts_] + for k in list_of_dicts_[0] + } + + batch = {} + for k, v in dict_of_lists.items(): + if isinstance(v, list): + if all(isinstance(s, AudioSignal) for s in v): + batch[k] = AudioSignal.batch(v, pad_signals=True) + else: + batch[k] = default_collate( + v, collate_fn_map=default_collate_fn_map) + batches.append(unflatten(batch)) + + batches = batches[0] if not return_list else batches + return batches + + +def hz_to_bin(hz: paddle.Tensor, n_fft: int, sample_rate: int): + """Closest frequency bin given a frequency, number + of bins, and a sampling rate. + + Parameters + ---------- + hz : paddle.Tensor + Tensor of frequencies in Hz. + n_fft : int + Number of FFT bins. + sample_rate : int + Sample rate of audio. + + Returns + ------- + paddle.Tensor + Closest bins to the data. + """ + shape = hz.shape + hz = hz.reshape([-1]) + freqs = paddle.linspace(0, sample_rate / 2, 2 + n_fft // 2) + hz = paddle.clip(hz, max=sample_rate / 2).astype(freqs.dtype) + + closest = (hz[None, :] - freqs[:, None]).abs() + closest_bins = closest.argmin(axis=0) + + return closest_bins.reshape(shape) + + +def generate_chord_dataset( + max_voices: int=8, + sample_rate: int=44100, + num_items: int=5, + duration: float=1.0, + min_note: str="C2", + max_note: str="C6", + output_dir: Path="chords", ): + """ + Generates a toy multitrack dataset of chords, synthesized from sine waves. + + + Parameters + ---------- + max_voices : int, optional + Maximum number of voices in a chord, by default 8 + sample_rate : int, optional + Sample rate of audio, by default 44100 + num_items : int, optional + Number of items to generate, by default 5 + duration : float, optional + Duration of each item, by default 1.0 + min_note : str, optional + Minimum note in the dataset, by default "C2" + max_note : str, optional + Maximum note in the dataset, by default "C6" + output_dir : Path, optional + Directory to save the dataset, by default "chords" + + """ + import librosa + from . import AudioSignal + from ..data.preprocess import create_csv + + min_midi = librosa.note_to_midi(min_note) + max_midi = librosa.note_to_midi(max_note) + + tracks = [] + for idx in range(num_items): + track = {} + # figure out how many voices to put in this track + num_voices = random.randint(1, max_voices) + for voice_idx in range(num_voices): + # choose some random params + midinote = random.randint(min_midi, max_midi) + dur = random.uniform(0.85 * duration, duration) + + sig = AudioSignal.wave( + frequency=librosa.midi_to_hz(midinote), + duration=dur, + sample_rate=sample_rate, + shape="sine", ) + track[f"voice_{voice_idx}"] = sig + tracks.append(track) + + # save the tracks to disk + output_dir = Path(output_dir) + output_dir.mkdir(exist_ok=True) + for idx, track in enumerate(tracks): + track_dir = output_dir / f"track_{idx}" + track_dir.mkdir(exist_ok=True) + for voice_name, sig in track.items(): + sig.write(track_dir / f"{voice_name}.wav") + + all_voices = list(set([k for track in tracks for k in track.keys()])) + voice_lists = {voice: [] for voice in all_voices} + for track in tracks: + for voice_name in all_voices: + if voice_name in track: + voice_lists[voice_name].append(track[voice_name].path_to_file) + else: + voice_lists[voice_name].append("") + + for voice_name, paths in voice_lists.items(): + create_csv(paths, output_dir / f"{voice_name}.csv", loudness=True) + + return output_dir diff --git a/audio/audiotools/data/__init__.py b/audio/audiotools/data/__init__.py new file mode 100644 index 00000000000..3f170c64a12 --- /dev/null +++ b/audio/audiotools/data/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from . import datasets +from . import preprocess +from . import transforms diff --git a/audio/audiotools/data/datasets.py b/audio/audiotools/data/datasets.py new file mode 100644 index 00000000000..37daaef053f --- /dev/null +++ b/audio/audiotools/data/datasets.py @@ -0,0 +1,548 @@ +# MIT License, Copyright (c) 2023-Present, Descript. +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/audiotools/data/datasets.py) +from pathlib import Path +from typing import Callable +from typing import Dict +from typing import List +from typing import Union + +import numpy as np +import paddle +from paddle.io import DistributedBatchSampler +from paddle.io import SequenceSampler + +from ..core import AudioSignal +from ..core import util + +__all__ = [ + "AudioLoader", "AudioDataset", "ConcatDataset", + "ResumableDistributedSampler", "ResumableSequentialSampler" +] + + +class AudioLoader: + """Loads audio endlessly from a list of audio sources + containing paths to audio files. Audio sources can be + folders full of audio files (which are found via file + extension) or by providing a CSV file which contains paths + to audio files. + + Parameters + ---------- + sources : List[str], optional + Sources containing folders, or CSVs with + paths to audio files, by default None + weights : List[float], optional + Weights to sample audio files from each source, by default None + relative_path : str, optional + Path audio should be loaded relative to, by default "" + transform : Callable, optional + Transform to instantiate alongside audio sample, + by default None + ext : List[str] + List of extensions to find audio within each source by. Can + also be a file name (e.g. "vocals.wav"). by default + ``['.wav', '.flac', '.mp3', '.mp4']``. + shuffle: bool + Whether to shuffle the files within the dataloader. Defaults to True. + shuffle_state: int + State to use to seed the shuffle of the files. + """ + + def __init__( + self, + sources: List[str]=None, + weights: List[float]=None, + transform: Callable=None, + relative_path: str="", + ext: List[str]=util.AUDIO_EXTENSIONS, + shuffle: bool=True, + shuffle_state: int=0, ): + self.audio_lists = util.read_sources( + sources, relative_path=relative_path, ext=ext) + + self.audio_indices = [(src_idx, item_idx) + for src_idx, src in enumerate(self.audio_lists) + for item_idx in range(len(src))] + if shuffle: + state = util.random_state(shuffle_state) + state.shuffle(self.audio_indices) + + self.sources = sources + self.weights = weights + self.transform = transform + + def __call__( + self, + state, + sample_rate: int, + duration: float, + loudness_cutoff: float=-40, + num_channels: int=1, + offset: float=None, + source_idx: int=None, + item_idx: int=None, + global_idx: int=None, ): + if source_idx is not None and item_idx is not None: + try: + audio_info = self.audio_lists[source_idx][item_idx] + except: + audio_info = {"path": "none"} + elif global_idx is not None: + source_idx, item_idx = self.audio_indices[global_idx % + len(self.audio_indices)] + audio_info = self.audio_lists[source_idx][item_idx] + else: + audio_info, source_idx, item_idx = util.choose_from_list_of_lists( + state, self.audio_lists, p=self.weights) + + path = audio_info["path"] + signal = AudioSignal.zeros(duration, sample_rate, num_channels) + + if path != "none": + if offset is None: + signal = AudioSignal.salient_excerpt( + path, + duration=duration, + state=state, + loudness_cutoff=loudness_cutoff, ) + else: + signal = AudioSignal( + path, + offset=offset, + duration=duration, ) + + if num_channels == 1: + signal = signal.to_mono() + signal = signal.resample(sample_rate) + + if signal.duration < duration: + signal = signal.zero_pad_to(int(duration * sample_rate)) + + for k, v in audio_info.items(): + signal.metadata[k] = v + + item = { + "signal": signal, + "source_idx": source_idx, + "item_idx": item_idx, + "source": str(self.sources[source_idx]), + "path": str(path), + } + if self.transform is not None: + item["transform_args"] = self.transform.instantiate( + state, signal=signal) + return item + + +def default_matcher(x, y): + return Path(x).parent == Path(y).parent + + +def align_lists(lists, matcher: Callable=default_matcher): + longest_list = lists[np.argmax([len(l) for l in lists])] + for i, x in enumerate(longest_list): + for l in lists: + if i >= len(l): + l.append({"path": "none"}) + elif not matcher(l[i]["path"], x["path"]): + l.insert(i, {"path": "none"}) + return lists + + +class AudioDataset: + """Loads audio from multiple loaders (with associated transforms) + for a specified number of samples. Excerpts are drawn randomly + of the specified duration, above a specified loudness threshold + and are resampled on the fly to the desired sample rate + (if it is different from the audio source sample rate). + + This takes either a single AudioLoader object, + a dictionary of AudioLoader objects, or a dictionary of AudioLoader + objects. Each AudioLoader is called by the dataset, and the + result is placed in the output dictionary. A transform can also be + specified for the entire dataset, rather than for each specific + loader. This transform can be applied to the output of all the + loaders if desired. + + AudioLoader objects can be specified as aligned, which means the + loaders correspond to multitrack audio (e.g. a vocals, bass, + drums, and other loader for multitrack music mixtures). + + + Parameters + ---------- + loaders : Union[AudioLoader, List[AudioLoader], Dict[str, AudioLoader]] + AudioLoaders to sample audio from. + sample_rate : int + Desired sample rate. + n_examples : int, optional + Number of examples (length of dataset), by default 1000 + duration : float, optional + Duration of audio samples, by default 0.5 + loudness_cutoff : float, optional + Loudness cutoff threshold for audio samples, by default -40 + num_channels : int, optional + Number of channels in output audio, by default 1 + transform : Callable, optional + Transform to instantiate alongside each dataset item, by default None + aligned : bool, optional + Whether the loaders should be sampled in an aligned manner (e.g. same + offset, duration, and matched file name), by default False + shuffle_loaders : bool, optional + Whether to shuffle the loaders before sampling from them, by default False + matcher : Callable + How to match files from adjacent audio lists (e.g. for a multitrack audio loader), + by default uses the parent directory of each file. + without_replacement : bool + Whether to choose files with or without replacement, by default True. + + + Examples + -------- + >>> from audio.audiotools.data.datasets import AudioLoader + >>> from audio.audiotools.data.datasets import AudioDataset + >>> from audio.audiotools import transforms as tfm + >>> import numpy as np + >>> + >>> loaders = [ + >>> AudioLoader( + >>> sources=[f"tests/audiotools/audio/spk"], + >>> transform=tfm.Equalizer(), + >>> ext=["wav"], + >>> ) + >>> for i in range(5) + >>> ] + >>> + >>> dataset = AudioDataset( + >>> loaders = loaders, + >>> sample_rate = 44100, + >>> duration = 1.0, + >>> transform = tfm.RescaleAudio(), + >>> ) + >>> + >>> item = dataset[np.random.randint(len(dataset))] + >>> + >>> for i in range(len(loaders)): + >>> item[i]["signal"] = loaders[i].transform( + >>> item[i]["signal"], **item[i]["transform_args"] + >>> ) + >>> item[i]["signal"].widget(i) + >>> + >>> mix = sum([item[i]["signal"] for i in range(len(loaders))]) + >>> mix = dataset.transform(mix, **item["transform_args"]) + >>> mix.widget("mix") + + Below is an example of how one could load MUSDB multitrack data: + + >>> from audio import audiotools as at + >>> from pathlib import Path + >>> from audio.audiotools import transforms as tfm + >>> import numpy as np + >>> import torch + >>> + >>> def build_dataset( + >>> sample_rate: int = 44100, + >>> duration: float = 5.0, + >>> musdb_path: str = "~/.data/musdb/", + >>> ): + >>> musdb_path = Path(musdb_path).expanduser() + >>> loaders = { + >>> src: at.datasets.AudioLoader( + >>> sources=[musdb_path], + >>> transform=tfm.Compose( + >>> tfm.VolumeNorm(("uniform", -20, -10)), + >>> tfm.Silence(prob=0.1), + >>> ), + >>> ext=[f"{src}.wav"], + >>> ) + >>> for src in ["vocals", "bass", "drums", "other"] + >>> } + >>> + >>> dataset = at.datasets.AudioDataset( + >>> loaders=loaders, + >>> sample_rate=sample_rate, + >>> duration=duration, + >>> num_channels=1, + >>> aligned=True, + >>> transform=tfm.RescaleAudio(), + >>> shuffle_loaders=True, + >>> ) + >>> return dataset, list(loaders.keys()) + >>> + >>> train_data, sources = build_dataset() + >>> dataloader = torch.utils.data.DataLoader( + >>> train_data, + >>> batch_size=16, + >>> num_workers=0, + >>> collate_fn=train_data.collate, + >>> ) + >>> batch = next(iter(dataloader)) + >>> + >>> for k in sources: + >>> src = batch[k] + >>> src["transformed"] = train_data.loaders[k].transform( + >>> src["signal"].clone(), **src["transform_args"] + >>> ) + >>> + >>> mixture = sum(batch[k]["transformed"] for k in sources) + >>> mixture = train_data.transform(mixture, **batch["transform_args"]) + >>> + >>> # Say a model takes the mix and gives back (n_batch, n_src, n_time). + >>> # Construct the targets: + >>> targets = at.AudioSignal.batch([batch[k]["transformed"] for k in sources], dim=1) + + Similarly, here's example code for loading Slakh data: + + >>> from audio import audiotools as at + >>> from pathlib import Path + >>> from audio.audiotools import transforms as tfm + >>> import numpy as np + >>> import torch + >>> import glob + >>> + >>> def build_dataset( + >>> sample_rate: int = 16000, + >>> duration: float = 10.0, + >>> slakh_path: str = "~/.data/slakh/", + >>> ): + >>> slakh_path = Path(slakh_path).expanduser() + >>> + >>> # Find the max number of sources in Slakh + >>> src_names = [x.name for x in list(slakh_path.glob("**/*.wav")) if "S" in str(x.name)] + >>> n_sources = len(list(set(src_names))) + >>> + >>> loaders = { + >>> f"S{i:02d}": at.datasets.AudioLoader( + >>> sources=[slakh_path], + >>> transform=tfm.Compose( + >>> tfm.VolumeNorm(("uniform", -20, -10)), + >>> tfm.Silence(prob=0.1), + >>> ), + >>> ext=[f"S{i:02d}.wav"], + >>> ) + >>> for i in range(n_sources) + >>> } + >>> dataset = at.datasets.AudioDataset( + >>> loaders=loaders, + >>> sample_rate=sample_rate, + >>> duration=duration, + >>> num_channels=1, + >>> aligned=True, + >>> transform=tfm.RescaleAudio(), + >>> shuffle_loaders=False, + >>> ) + >>> + >>> return dataset, list(loaders.keys()) + >>> + >>> train_data, sources = build_dataset() + >>> dataloader = torch.utils.data.DataLoader( + >>> train_data, + >>> batch_size=16, + >>> num_workers=0, + >>> collate_fn=train_data.collate, + >>> ) + >>> batch = next(iter(dataloader)) + >>> + >>> for k in sources: + >>> src = batch[k] + >>> src["transformed"] = train_data.loaders[k].transform( + >>> src["signal"].clone(), **src["transform_args"] + >>> ) + >>> + >>> mixture = sum(batch[k]["transformed"] for k in sources) + >>> mixture = train_data.transform(mixture, **batch["transform_args"]) + + """ + + def __init__( + self, + loaders: Union[AudioLoader, List[AudioLoader], Dict[str, + AudioLoader]], + sample_rate: int, + n_examples: int=1000, + duration: float=0.5, + offset: float=None, + loudness_cutoff: float=-40, + num_channels: int=1, + transform: Callable=None, + aligned: bool=False, + shuffle_loaders: bool=False, + matcher: Callable=default_matcher, + without_replacement: bool=True, ): + # Internally we convert loaders to a dictionary + if isinstance(loaders, list): + loaders = {i: l for i, l in enumerate(loaders)} + elif isinstance(loaders, AudioLoader): + loaders = {0: loaders} + + self.loaders = loaders + self.loudness_cutoff = loudness_cutoff + self.num_channels = num_channels + + self.length = n_examples + self.transform = transform + self.sample_rate = sample_rate + self.duration = duration + self.offset = offset + self.aligned = aligned + self.shuffle_loaders = shuffle_loaders + self.without_replacement = without_replacement + + if aligned: + loaders_list = list(loaders.values()) + for i in range(len(loaders_list[0].audio_lists)): + input_lists = [l.audio_lists[i] for l in loaders_list] + # Alignment happens in-place + align_lists(input_lists, matcher) + + def __getitem__(self, idx): + state = util.random_state(idx) + offset = None if self.offset is None else self.offset + item = {} + + keys = list(self.loaders.keys()) + if self.shuffle_loaders: + state.shuffle(keys) + + loader_kwargs = { + "state": state, + "sample_rate": self.sample_rate, + "duration": self.duration, + "loudness_cutoff": self.loudness_cutoff, + "num_channels": self.num_channels, + "global_idx": idx if self.without_replacement else None, + } + + # Draw item from first loader + loader = self.loaders[keys[0]] + item[keys[0]] = loader(**loader_kwargs) + + for key in keys[1:]: + loader = self.loaders[key] + if self.aligned: + # Path mapper takes the current loader + everything + # returned by the first loader. + offset = item[keys[0]]["signal"].metadata["offset"] + loader_kwargs.update({ + "offset": offset, + "source_idx": item[keys[0]]["source_idx"], + "item_idx": item[keys[0]]["item_idx"], + }) + item[key] = loader(**loader_kwargs) + + # Sort dictionary back into original order + keys = list(self.loaders.keys()) + item = {k: item[k] for k in keys} + + item["idx"] = idx + if self.transform is not None: + item["transform_args"] = self.transform.instantiate( + state=state, signal=item[keys[0]]["signal"]) + + # If there's only one loader, pop it up + # to the main dictionary, instead of keeping it + # nested. + if len(keys) == 1: + item.update(item.pop(keys[0])) + + return item + + def __len__(self): + return self.length + + @staticmethod + def collate(list_of_dicts: Union[list, dict], n_splits: int=None): + """Collates items drawn from this dataset. Uses + :py:func:`audiotools.core.util.collate`. + + Parameters + ---------- + list_of_dicts : typing.Union[list, dict] + Data drawn from each item. + n_splits : int + Number of splits to make when creating the batches (split into + sub-batches). Useful for things like gradient accumulation. + + Returns + ------- + dict + Dictionary of batched data. + """ + return util.collate(list_of_dicts, n_splits=n_splits) + + +class ConcatDataset(AudioDataset): + # + def __init__(self, datasets: list): + self.datasets = datasets + + def __len__(self): + return sum([len(d) for d in self.datasets]) + + def __getitem__(self, idx): + dataset = self.datasets[idx % len(self.datasets)] + return dataset[idx // len(self.datasets)] + + +class ResumableDistributedSampler(DistributedBatchSampler): + """Distributed sampler that can be resumed from a given start index.""" + + def __init__(self, + dataset, + batch_size, + start_idx: int=None, + num_replicas=None, + rank=None, + shuffle=False, + drop_last=False): + super().__init__( + dataset=dataset, + batch_size=batch_size, + num_replicas=num_replicas, + rank=rank, + shuffle=shuffle, + drop_last=drop_last, ) + # Start index, allows to resume an experiment at the index it was + if start_idx is not None: + self.start_idx = start_idx // self.num_replicas + else: + self.start_idx = 0 + # 重新计算样本总数,因为 DistributedBatchSampler 的 __len__ 方法是基于 shuffle 后的样本总数计算的 + self.total_size = len(self.dataset) if not shuffle else len( + self.indices) + + def __iter__(self): + # 由于 Paddle 的 DistributedBatchSampler 直接返回 batch,我们需要将其展开为单个索引 + indices_iter = iter(super().__iter__()) + # 跳过前面的 start_idx 个 batch + for _ in range(self.start_idx): + next(indices_iter) + + current_idx = 0 + while True: + batch_indices = next(indices_iter, None) + if batch_indices is None: + break + for idx in batch_indices: + if current_idx >= self.start_idx * self.batch_size: # 调整判断条件,确保从 start_idx 开始 + yield idx + current_idx += 1 + self.start_idx = 0 # set the index back to 0 so for the next epoch + + +class ResumableSequentialSampler(SequenceSampler): + """Sequential sampler that can be resumed from a given start index.""" + + def __init__(self, dataset, start_idx: int=None, **kwargs): + super().__init__(dataset, **kwargs) + # Start index, allows to resume an experiment at the index it was + self.start_idx = start_idx if start_idx is not None else 0 + + def __iter__(self): + for i, idx in enumerate(super().__iter__()): + if i >= self.start_idx: + yield idx + self.start_idx = 0 # set the index back to 0 so for the next epoch diff --git a/audio/audiotools/data/preprocess.py b/audio/audiotools/data/preprocess.py new file mode 100644 index 00000000000..1f609c00b36 --- /dev/null +++ b/audio/audiotools/data/preprocess.py @@ -0,0 +1,87 @@ +# MIT License, Copyright (c) 2023-Present, Descript. +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/audiotools/data/preprocess.py) +import csv +import os +from pathlib import Path + +from tqdm import tqdm + +from ..core import AudioSignal + + +def create_csv(audio_files: list, + output_csv: Path, + loudness: bool=False, + data_path: str=None): + """Converts a folder of audio files to a CSV file. If ``loudness = True``, + the output of this function will create a CSV file that looks something + like: + + .. csv-table:: + :header: path,loudness + + daps/produced/f1_script1_produced.wav,-16.299999237060547 + daps/produced/f1_script2_produced.wav,-16.600000381469727 + daps/produced/f1_script3_produced.wav,-17.299999237060547 + daps/produced/f1_script4_produced.wav,-16.100000381469727 + daps/produced/f1_script5_produced.wav,-16.700000762939453 + daps/produced/f3_script1_produced.wav,-16.5 + + .. note:: + The paths above are written relative to the ``data_path`` argument + which defaults to the environment variable ``PATH_TO_DATA`` if + it isn't passed to this function, and defaults to the empty string + if that environment variable is not set. + + You can produce a CSV file from a directory of audio files via: + + >>> from audio import audiotools + >>> directory = ... + >>> audio_files = audiotools.util.find_audio(directory) + >>> output_path = "train.csv" + >>> audiotools.data.preprocess.create_csv( + >>> audio_files, output_csv, loudness=True + >>> ) + + Note that you can create empty rows in the CSV file by passing an empty + string or None in the ``audio_files`` list. This is useful if you want to + sync multiple CSV files in a multitrack setting. The loudness of these + empty rows will be set to -inf. + + Parameters + ---------- + audio_files : list + List of audio files. + output_csv : Path + Output CSV, with each row containing the relative path of every file + to ``data_path``, if specified (defaults to None). + loudness : bool + Compute loudness of entire file and store alongside path. + """ + + info = [] + pbar = tqdm(audio_files) + for af in pbar: + af = Path(af) + pbar.set_description(f"Processing {af.name}") + _info = {} + if af.name == "": + _info["path"] = "" + if loudness: + _info["loudness"] = -float("inf") + else: + _info["path"] = af.relative_to( + data_path) if data_path is not None else af + if loudness: + _info["loudness"] = AudioSignal(af).ffmpeg_loudness().item() + + info.append(_info) + + with open(output_csv, "w") as f: + writer = csv.DictWriter(f, fieldnames=list(info[0].keys())) + writer.writeheader() + + for item in info: + writer.writerow(item) diff --git a/audio/audiotools/data/transforms.py b/audio/audiotools/data/transforms.py new file mode 100644 index 00000000000..fd742a78cc9 --- /dev/null +++ b/audio/audiotools/data/transforms.py @@ -0,0 +1,1182 @@ +# MIT License, Copyright (c) 2023-Present, Descript. +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/audiotools/data/transforms.py) +import copy +from contextlib import contextmanager +from inspect import signature +from typing import List + +import numpy as np +import paddle +from flatten_dict import flatten +from flatten_dict import unflatten +from numpy.random import RandomState + +from .. import ml +from ..core import AudioSignal +from ..core import util +from .datasets import AudioLoader +from paddlespeech.utils import satisfy_paddle_version + +__all__ = [ + "Identity", + "SpectralTransform", + "Compose", + "Choose", + "Repeat", + "RepeatUpTo", + "ClippingDistortion", + "Equalizer", + "BackgroundNoise", + "RoomImpulseResponse", + "VolumeNorm", + "GlobalVolumeNorm", + "Silence", + "LowPass", + "HighPass", + "FrequencyMask", + "TimeMask", + "Smoothing", + "FrequencyNoise", +] + + +class BaseTransform: + """This is the base class for all transforms that are implemented + in this library. Transforms have two main operations: ``transform`` + and ``instantiate``. + + ``instantiate`` sets the parameters randomly + from distribution tuples for each parameter. For example, for the + ``BackgroundNoise`` transform, the signal-to-noise ratio (``snr``) + is chosen randomly by instantiate. By default, it chosen uniformly + between 10.0 and 30.0 (the tuple is set to ``("uniform", 10.0, 30.0)``). + + ``transform`` applies the transform using the instantiated parameters. + A simple example is as follows: + + >>> seed = 0 + >>> signal = ... + >>> transform = transforms.NoiseFloor(db = ("uniform", -50.0, -30.0)) + >>> kwargs = transform.instantiate() + >>> output = transform(signal.clone(), **kwargs) + + By breaking apart the instantiation of parameters from the actual audio + processing of the transform, we can make things more reproducible, while + also applying the transform on batches of data efficiently on GPU, + rather than on individual audio samples. + + .. note:: + We call ``signal.clone()`` for the input to the ``transform`` function + because signals are modified in-place! If you don't clone the signal, + you will lose the original data. + + Parameters + ---------- + keys : list, optional + Keys that the transform looks for when + calling ``self.transform``, by default []. In general this is + set automatically, and you won't need to manipulate this argument. + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + + Examples + -------- + + >>> seed = 0 + >>> + >>> audio_path = "tests/audio/spk/f10_script4_produced.wav" + >>> signal = AudioSignal(audio_path, offset=10, duration=2) + >>> transform = tfm.Compose( + >>> [ + >>> tfm.RoomImpulseResponse(sources=["tests/audio/irs.csv"]), + >>> tfm.BackgroundNoise(sources=["tests/audio/noises.csv"]), + >>> ], + >>> ) + >>> + >>> kwargs = transform.instantiate(seed, signal) + >>> output = transform(signal, **kwargs) + + """ + + def __init__(self, keys: list=[], name: str=None, prob: float=1.0): + # Get keys from the _transform signature. + tfm_keys = list(signature(self._transform).parameters.keys()) + + # Filter out signal and kwargs keys. + ignore_keys = ["signal", "kwargs"] + tfm_keys = [k for k in tfm_keys if k not in ignore_keys] + + # Combine keys specified by the child class, the keys found in + # _transform signature, and the mask key. + self.keys = keys + tfm_keys + ["mask"] + + self.prob = prob + + if name is None: + name = self.__class__.__name__ + self.name = name + + def _prepare(self, batch: dict): + sub_batch = batch[self.name] + + for k in self.keys: + assert k in sub_batch.keys(), f"{k} not in batch" + + return sub_batch + + def _transform(self, signal): + return signal + + def _instantiate(self, state: RandomState, signal: AudioSignal=None): + return {} + + @staticmethod + def apply_mask(batch: dict, mask: paddle.Tensor): + """Applies a mask to the batch. + + Parameters + ---------- + batch : dict + Batch whose values will be masked in the ``transform`` pass. + mask : paddle.Tensor + Mask to apply to batch. + + Returns + ------- + dict + A dictionary that contains values only where ``mask = True``. + """ + # masked_batch = {k: v[mask] for k, v in flatten(batch).items()} + masked_batch = {} + for k, v in flatten(batch).items(): + # `v` may be `Tensor` or `AudioSignal` + if 0 == len(v.shape) and 0 == mask.dim(): + if mask: # 0d 的 True + masked_batch[k] = v.unsqueeze(0) + else: + masked_batch[k] = paddle.to_tensor([], dtype=v.dtype) + else: + if not satisfy_paddle_version('2.6'): + if 0 == mask.dim() and bool(mask) and paddle.is_tensor(v): + masked_batch[k] = v.unsqueeze(0) + else: + masked_batch[k] = v[mask] + else: + masked_batch[k] = v[mask] + return unflatten(masked_batch) + + def transform(self, signal: AudioSignal, **kwargs): + """Apply the transform to the audio signal, + with given keyword arguments. + + Parameters + ---------- + signal : AudioSignal + Signal that will be modified by the transforms in-place. + kwargs: dict + Keyword arguments to the specific transforms ``self._transform`` + function. + + Returns + ------- + AudioSignal + Transformed AudioSignal. + + Examples + -------- + + >>> for seed in range(10): + >>> kwargs = transform.instantiate(seed, signal) + >>> output = transform(signal.clone(), **kwargs) + + """ + tfm_kwargs = self._prepare(kwargs) + mask = tfm_kwargs["mask"] + + if paddle.any(mask): + tfm_kwargs = self.apply_mask(tfm_kwargs, mask) + tfm_kwargs = {k: v for k, v in tfm_kwargs.items() if k != "mask"} + signal[mask] = self._transform(signal[mask], **tfm_kwargs) + + return signal + + def __call__(self, *args, **kwargs): + return self.transform(*args, **kwargs) + + def instantiate( + self, + state: RandomState=None, + signal: AudioSignal=None, ): + """Instantiates parameters for the transform. + + Parameters + ---------- + state : RandomState, optional + _description_, by default None + signal : AudioSignal, optional + _description_, by default None + + Returns + ------- + dict + Dictionary containing instantiated arguments for every keyword + argument to ``self._transform``. + + Examples + -------- + + >>> for seed in range(10): + >>> kwargs = transform.instantiate(seed, signal) + >>> output = transform(signal.clone(), **kwargs) + + """ + state = util.random_state(state) + + # Not all instantiates need the signal. Check if signal + # is needed before passing it in, so that the end-user + # doesn't need to have variables they're not using flowing + # into their function. + needs_signal = "signal" in set( + signature(self._instantiate).parameters.keys()) + kwargs = {} + if needs_signal: + kwargs = {"signal": signal} + + # Instantiate the parameters for the transform. + params = self._instantiate(state, **kwargs) + for k in list(params.keys()): + v = params[k] + if isinstance(v, (AudioSignal, paddle.Tensor, dict)): + params[k] = v + else: + params[k] = paddle.to_tensor(v) + mask = state.rand() <= self.prob + params[f"mask"] = paddle.to_tensor(mask) + + # Put the params into a nested dictionary that will be + # used later when calling the transform. This is to avoid + # collisions in the dictionary. + params = {self.name: params} + + return params + + def batch_instantiate( + self, + states: list=None, + signal: AudioSignal=None, ): + """Instantiates arguments for every item in a batch, + given a list of states. Each state in the list + corresponds to one item in the batch. + + Parameters + ---------- + states : list, optional + List of states, by default None + signal : AudioSignal, optional + AudioSignal to pass to the ``self.instantiate`` section + if it is needed for this transform, by default None + + Returns + ------- + dict + Collated dictionary of arguments. + + Examples + -------- + + >>> batch_size = 4 + >>> signal = AudioSignal(audio_path, offset=10, duration=2) + >>> signal_batch = AudioSignal.batch([signal.clone() for _ in range(batch_size)]) + >>> + >>> states = [seed + idx for idx in list(range(batch_size))] + >>> kwargs = transform.batch_instantiate(states, signal_batch) + >>> batch_output = transform(signal_batch, **kwargs) + """ + kwargs = [] + for state in states: + kwargs.append(self.instantiate(state, signal)) + kwargs = util.collate(kwargs) + return kwargs + + +class Identity(BaseTransform): + """This transform just returns the original signal.""" + + pass + + +class SpectralTransform(BaseTransform): + """Spectral transforms require STFT data to exist, since manipulations + of the STFT require the spectrogram. This just calls ``stft`` before + the transform is called, and calls ``istft`` after the transform is + called so that the audio data is written to after the spectral + manipulation. + """ + + def transform(self, signal, **kwargs): + signal.stft() + super().transform(signal, **kwargs) + signal.istft() + return signal + + +class Compose(BaseTransform): + """Compose applies transforms in sequence, one after the other. The + transforms are passed in as positional arguments or as a list like so: + + >>> transform = tfm.Compose( + >>> [ + >>> tfm.RoomImpulseResponse(sources=["tests/audio/irs.csv"]), + >>> tfm.BackgroundNoise(sources=["tests/audio/noises.csv"]), + >>> ], + >>> ) + + This will convolve the signal with a room impulse response, and then + add background noise to the signal. Instantiate instantiates + all the parameters for every transform in the transform list so the + interface for using the Compose transform is the same as everything + else: + + >>> kwargs = transform.instantiate() + >>> output = transform(signal.clone(), **kwargs) + + Under the hood, the transform maps each transform to a unique name + under the hood of the form ``{position}.{name}``, where ``position`` + is the index of the transform in the list. ``Compose`` can nest + within other ``Compose`` transforms, like so: + + >>> preprocess = transforms.Compose( + >>> tfm.GlobalVolumeNorm(), + >>> tfm.CrossTalk(), + >>> name="preprocess", + >>> ) + >>> augment = transforms.Compose( + >>> tfm.RoomImpulseResponse(), + >>> tfm.BackgroundNoise(), + >>> name="augment", + >>> ) + >>> postprocess = transforms.Compose( + >>> tfm.VolumeChange(), + >>> tfm.RescaleAudio(), + >>> tfm.ShiftPhase(), + >>> name="postprocess", + >>> ) + >>> transform = transforms.Compose(preprocess, augment, postprocess), + + This defines 3 composed transforms, and then composes them in sequence + with one another. + + Parameters + ---------- + *transforms : list + List of transforms to apply + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + """ + + def __init__(self, *transforms: list, name: str=None, prob: float=1.0): + if isinstance(transforms[0], list): + transforms = transforms[0] + + for i, tfm in enumerate(transforms): + tfm.name = f"{i}.{tfm.name}" + + keys = [tfm.name for tfm in transforms] + super().__init__(keys=keys, name=name, prob=prob) + + self.transforms = transforms + self.transforms_to_apply = keys + + @contextmanager + def filter(self, *names: list): + """This can be used to skip transforms entirely when applying + the sequence of transforms to a signal. For example, take + the following transforms with the names ``preprocess, augment, postprocess``. + + >>> preprocess = transforms.Compose( + >>> tfm.GlobalVolumeNorm(), + >>> tfm.CrossTalk(), + >>> name="preprocess", + >>> ) + >>> augment = transforms.Compose( + >>> tfm.RoomImpulseResponse(), + >>> tfm.BackgroundNoise(), + >>> name="augment", + >>> ) + >>> postprocess = transforms.Compose( + >>> tfm.VolumeChange(), + >>> tfm.RescaleAudio(), + >>> tfm.ShiftPhase(), + >>> name="postprocess", + >>> ) + >>> transform = transforms.Compose(preprocess, augment, postprocess) + + If we wanted to apply all 3 to a signal, we do: + + >>> kwargs = transform.instantiate() + >>> output = transform(signal.clone(), **kwargs) + + But if we only wanted to apply the ``preprocess`` and ``postprocess`` + transforms to the signal, we do: + + >>> with transform_fn.filter("preprocess", "postprocess"): + >>> output = transform(signal.clone(), **kwargs) + + Parameters + ---------- + *names : list + List of transforms, identified by name, to apply to signal. + """ + old_transforms = self.transforms_to_apply + self.transforms_to_apply = names + yield + self.transforms_to_apply = old_transforms + + def _transform(self, signal, **kwargs): + for transform in self.transforms: + if any([x in transform.name for x in self.transforms_to_apply]): + signal = transform(signal, **kwargs) + return signal + + def _instantiate(self, state: RandomState, signal: AudioSignal=None): + parameters = {} + for transform in self.transforms: + parameters.update(transform.instantiate(state, signal=signal)) + return parameters + + def __getitem__(self, idx): + return self.transforms[idx] + + def __len__(self): + return len(self.transforms) + + def __iter__(self): + for transform in self.transforms: + yield transform + + +class Choose(Compose): + """Choose logic is the same as :py:func:`audiotools.data.transforms.Compose`, + but instead of applying all the transforms in sequence, it applies just a single transform, + which is chosen for each item in the batch. + + Parameters + ---------- + *transforms : list + List of transforms to apply + weights : list + Probability of choosing any specific transform. + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + + Examples + -------- + + >>> transforms.Choose(tfm.LowPass(), tfm.HighPass()) + """ + + def __init__( + self, + *transforms: list, + weights: list=None, + name: str=None, + prob: float=1.0, ): + super().__init__(*transforms, name=name, prob=prob) + + if weights is None: + _len = len(self.transforms) + weights = [1 / _len for _ in range(_len)] + self.weights = np.array(weights) + + def _instantiate(self, state: RandomState, signal: AudioSignal=None): + kwargs = super()._instantiate(state, signal) + tfm_idx = list(range(len(self.transforms))) + tfm_idx = state.choice(tfm_idx, p=self.weights) + one_hot = [] + for i, t in enumerate(self.transforms): + mask = kwargs[t.name]["mask"] + if mask.item(): + kwargs[t.name]["mask"] = paddle.to_tensor(i == tfm_idx) + one_hot.append(kwargs[t.name]["mask"]) + kwargs["one_hot"] = one_hot + return kwargs + + +class Repeat(Compose): + """Repeatedly applies a given transform ``n_repeat`` times." + + Parameters + ---------- + transform : BaseTransform + Transform to repeat. + n_repeat : int, optional + Number of times to repeat transform, by default 1 + """ + + def __init__( + self, + transform, + n_repeat: int=1, + name: str=None, + prob: float=1.0, ): + transforms = [copy.copy(transform) for _ in range(n_repeat)] + super().__init__(transforms, name=name, prob=prob) + + self.n_repeat = n_repeat + + +class RepeatUpTo(Choose): + """Repeatedly applies a given transform up to ``max_repeat`` times." + + Parameters + ---------- + transform : BaseTransform + Transform to repeat. + max_repeat : int, optional + Max number of times to repeat transform, by default 1 + weights : list + Probability of choosing any specific number up to ``max_repeat``. + """ + + def __init__( + self, + transform, + max_repeat: int=5, + weights: list=None, + name: str=None, + prob: float=1.0, ): + transforms = [] + for n in range(1, max_repeat): + transforms.append(Repeat(transform, n_repeat=n)) + super().__init__(transforms, name=name, prob=prob, weights=weights) + + self.max_repeat = max_repeat + + +class ClippingDistortion(BaseTransform): + """Adds clipping distortion to signal. Corresponds + to :py:func:`audiotools.core.effects.EffectMixin.clip_distortion`. + + Parameters + ---------- + perc : tuple, optional + Clipping percentile. Values are between 0.0 to 1.0. + Typical values are 0.1 or below, by default ("uniform", 0.0, 0.1) + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + """ + + def __init__( + self, + perc: tuple=("uniform", 0.0, 0.1), + name: str=None, + prob: float=1.0, ): + super().__init__(name=name, prob=prob) + + self.perc = perc + + def _instantiate(self, state: RandomState): + return {"perc": util.sample_from_dist(self.perc, state)} + + def _transform(self, signal, perc): + return signal.clip_distortion(perc) + + +class Equalizer(BaseTransform): + """Applies an equalization curve to the audio signal. Corresponds + to :py:func:`audiotools.core.effects.EffectMixin.equalizer`. + + Parameters + ---------- + eq_amount : tuple, optional + The maximum dB cut to apply to the audio in any band, + by default ("const", 1.0 dB) + n_bands : int, optional + Number of bands in EQ, by default 6 + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + """ + + def __init__( + self, + eq_amount: tuple=("const", 1.0), + n_bands: int=6, + name: str=None, + prob: float=1.0, ): + super().__init__(name=name, prob=prob) + + self.eq_amount = eq_amount + self.n_bands = n_bands + + def _instantiate(self, state: RandomState): + eq_amount = util.sample_from_dist(self.eq_amount, state) + eq = -eq_amount * state.rand(self.n_bands) + return {"eq": eq} + + def _transform(self, signal, eq): + return signal.equalizer(eq) + + +class BackgroundNoise(BaseTransform): + """Adds background noise from audio specified by a set of CSV files. + A valid CSV file looks like, and is typically generated by + :py:func:`audiotools.data.preprocess.create_csv`: + + .. csv-table:: + :header: path + + room_tone/m6_script2_clean.wav + room_tone/m6_script2_cleanraw.wav + room_tone/m6_script2_ipad_balcony1.wav + room_tone/m6_script2_ipad_bedroom1.wav + room_tone/m6_script2_ipad_confroom1.wav + room_tone/m6_script2_ipad_confroom2.wav + room_tone/m6_script2_ipad_livingroom1.wav + room_tone/m6_script2_ipad_office1.wav + + .. note:: + All paths are relative to an environment variable called ``PATH_TO_DATA``, + so that CSV files are portable across machines where data may be + located in different places. + + This transform calls :py:func:`audiotools.core.effects.EffectMixin.mix` + and :py:func:`audiotools.core.effects.EffectMixin.equalizer` under the + hood. + + Parameters + ---------- + snr : tuple, optional + Signal-to-noise ratio, by default ("uniform", 10.0, 30.0) + sources : List[str], optional + Sources containing folders, or CSVs with paths to audio files, + by default None + weights : List[float], optional + Weights to sample audio files from each source, by default None + eq_amount : tuple, optional + Amount of equalization to apply, by default ("const", 1.0) + n_bands : int, optional + Number of bands in equalizer, by default 3 + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + loudness_cutoff : float, optional + Loudness cutoff when loading from audio files, by default None + """ + + def __init__( + self, + snr: tuple=("uniform", 10.0, 30.0), + sources: List[str]=None, + weights: List[float]=None, + eq_amount: tuple=("const", 1.0), + n_bands: int=3, + name: str=None, + prob: float=1.0, + loudness_cutoff: float=None, ): + super().__init__(name=name, prob=prob) + + self.snr = snr + self.eq_amount = eq_amount + self.n_bands = n_bands + self.loader = AudioLoader(sources, weights) + self.loudness_cutoff = loudness_cutoff + + def _instantiate(self, state: RandomState, signal: AudioSignal): + eq_amount = util.sample_from_dist(self.eq_amount, state) + eq = -eq_amount * state.rand(self.n_bands) + snr = util.sample_from_dist(self.snr, state) + + bg_signal = self.loader( + state, + signal.sample_rate, + duration=signal.signal_duration, + loudness_cutoff=self.loudness_cutoff, + num_channels=signal.num_channels, )["signal"] + + return {"eq": eq, "bg_signal": bg_signal, "snr": snr} + + def _transform(self, signal, bg_signal, snr, eq): + # Clone bg_signal so that transform can be repeatedly applied + # to different signals with the same effect. + return signal.mix(bg_signal.clone(), snr, eq) + + +class RoomImpulseResponse(BaseTransform): + """Convolves signal with a room impulse response, at a specified + direct-to-reverberant ratio, with equalization applied. Room impulse + response data is drawn from a CSV file that was produced via + :py:func:`audiotools.data.preprocess.create_csv`. + + This transform calls :py:func:`audiotools.core.effects.EffectMixin.apply_ir` + under the hood. + + Parameters + ---------- + drr : tuple, optional + _description_, by default ("uniform", 0.0, 30.0) + sources : List[str], optional + Sources containing folders, or CSVs with paths to audio files, + by default None + weights : List[float], optional + Weights to sample audio files from each source, by default None + eq_amount : tuple, optional + Amount of equalization to apply, by default ("const", 1.0) + n_bands : int, optional + Number of bands in equalizer, by default 6 + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + use_original_phase : bool, optional + Whether or not to use the original phase, by default False + offset : float, optional + Offset from each impulse response file to use, by default 0.0 + duration : float, optional + Duration of each impulse response, by default 1.0 + """ + + def __init__( + self, + drr: tuple=("uniform", 0.0, 30.0), + sources: List[str]=None, + weights: List[float]=None, + eq_amount: tuple=("const", 1.0), + n_bands: int=6, + name: str=None, + prob: float=1.0, + use_original_phase: bool=False, + offset: float=0.0, + duration: float=1.0, ): + super().__init__(name=name, prob=prob) + + self.drr = drr + self.eq_amount = eq_amount + self.n_bands = n_bands + self.use_original_phase = use_original_phase + + self.loader = AudioLoader(sources, weights) + self.offset = offset + self.duration = duration + + def _instantiate(self, state: RandomState, signal: AudioSignal=None): + eq_amount = util.sample_from_dist(self.eq_amount, state) + eq = -eq_amount * state.rand(self.n_bands) + drr = util.sample_from_dist(self.drr, state) + + ir_signal = self.loader( + state, + signal.sample_rate, + offset=self.offset, + duration=self.duration, + loudness_cutoff=None, + num_channels=signal.num_channels, )["signal"] + ir_signal.zero_pad_to(signal.sample_rate) + + return {"eq": eq, "ir_signal": ir_signal, "drr": drr} + + def _transform(self, signal, ir_signal, drr, eq): + # Clone ir_signal so that transform can be repeatedly applied + # to different signals with the same effect. + return signal.apply_ir( + ir_signal.clone(), + drr, + eq, + use_original_phase=self.use_original_phase) + + +class VolumeNorm(BaseTransform): + """Normalizes the volume of the excerpt to a specified decibel. + + Uses :py:func:`audiotools.core.effects.EffectMixin.normalize`. + + Parameters + ---------- + db : tuple, optional + dB to normalize signal to, by default ("const", -24) + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + """ + + def __init__( + self, + db: tuple=("const", -24), + name: str=None, + prob: float=1.0, ): + super().__init__(name=name, prob=prob) + + self.db = db + + def _instantiate(self, state: RandomState): + return {"db": util.sample_from_dist(self.db, state)} + + def _transform(self, signal, db): + return signal.normalize(db) + + +class GlobalVolumeNorm(BaseTransform): + """Similar to :py:func:`audiotools.data.transforms.VolumeNorm`, this + transform also normalizes the volume of a signal, but it uses + the volume of the entire audio file the loaded excerpt comes from, + rather than the volume of just the excerpt. The volume of the + entire audio file is expected in ``signal.metadata["loudness"]``. + If loading audio from a CSV generated by :py:func:`audiotools.data.preprocess.create_csv` + with ``loudness = True``, like the following: + + .. csv-table:: + :header: path,loudness + + daps/produced/f1_script1_produced.wav,-16.299999237060547 + daps/produced/f1_script2_produced.wav,-16.600000381469727 + daps/produced/f1_script3_produced.wav,-17.299999237060547 + daps/produced/f1_script4_produced.wav,-16.100000381469727 + daps/produced/f1_script5_produced.wav,-16.700000762939453 + daps/produced/f3_script1_produced.wav,-16.5 + + The ``AudioLoader`` will automatically load the loudness column into + the metadata of the signal. + + Uses :py:func:`audiotools.core.effects.EffectMixin.volume_change`. + + Parameters + ---------- + db : tuple, optional + dB to normalize signal to, by default ("const", -24) + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + """ + + def __init__( + self, + db: tuple=("const", -24), + name: str=None, + prob: float=1.0, ): + super().__init__(name=name, prob=prob) + + self.db = db + + def _instantiate(self, state: RandomState, signal: AudioSignal): + if "loudness" not in signal.metadata: + db_change = 0.0 + elif float(signal.metadata["loudness"]) == float("-inf"): + db_change = 0.0 + else: + db = util.sample_from_dist(self.db, state) + db_change = db - float(signal.metadata["loudness"]) + + return {"db": db_change} + + def _transform(self, signal, db): + return signal.volume_change(db) + + +class Silence(BaseTransform): + """Zeros out the signal with some probability. + + Parameters + ---------- + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 0.1 + """ + + def __init__(self, name: str=None, prob: float=0.1): + super().__init__(name=name, prob=prob) + + def _transform(self, signal): + _loudness = signal._loudness + signal = AudioSignal( + paddle.zeros_like(signal.audio_data), + sample_rate=signal.sample_rate, + stft_params=signal.stft_params, ) + # So that the amound of noise added is as if it wasn't silenced. + # TODO: improve this hack + signal._loudness = _loudness + + return signal + + +class LowPass(BaseTransform): + """Applies a LowPass filter. + + Uses :py:func:`audiotools.core.dsp.DSPMixin.low_pass`. + + Parameters + ---------- + cutoff : tuple, optional + Cutoff frequency distribution, + by default ``("choice", [4000, 8000, 16000])`` + zeros : int, optional + Number of zero-crossings in filter, argument to + ``julius.LowPassFilters``, by default 51 + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + """ + + def __init__( + self, + cutoff: tuple=("choice", [4000, 8000, 16000]), + zeros: int=51, + name: str=None, + prob: float=1, ): + super().__init__(name=name, prob=prob) + + self.cutoff = cutoff + self.zeros = zeros + + def _instantiate(self, state: RandomState): + return {"cutoff": util.sample_from_dist(self.cutoff, state)} + + def _transform(self, signal, cutoff): + return signal.low_pass(cutoff, zeros=self.zeros) + + +class HighPass(BaseTransform): + """Applies a HighPass filter. + + Uses :py:func:`audiotools.core.dsp.DSPMixin.high_pass`. + + Parameters + ---------- + cutoff : tuple, optional + Cutoff frequency distribution, + by default ``("choice", [50, 100, 250, 500, 1000])`` + zeros : int, optional + Number of zero-crossings in filter, argument to + ``julius.LowPassFilters``, by default 51 + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + """ + + def __init__( + self, + cutoff: tuple=("choice", [50, 100, 250, 500, 1000]), + zeros: int=51, + name: str=None, + prob: float=1, ): + super().__init__(name=name, prob=prob) + + self.cutoff = cutoff + self.zeros = zeros + + def _instantiate(self, state: RandomState): + return {"cutoff": util.sample_from_dist(self.cutoff, state)} + + def _transform(self, signal, cutoff): + return signal.high_pass(cutoff, zeros=self.zeros) + + +class FrequencyMask(SpectralTransform): + """Masks a band of frequencies at a center frequency + from the audio. + + Uses :py:func:`audiotools.core.dsp.DSPMixin.mask_frequencies`. + + Parameters + ---------- + f_center : tuple, optional + Center frequency between 0.0 and 1.0 (Nyquist), by default ("uniform", 0.0, 1.0) + f_width : tuple, optional + Width of zero'd out band, by default ("const", 0.1) + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + """ + + def __init__( + self, + f_center: tuple=("uniform", 0.0, 1.0), + f_width: tuple=("const", 0.1), + name: str=None, + prob: float=1, ): + super().__init__(name=name, prob=prob) + self.f_center = f_center + self.f_width = f_width + + def _instantiate(self, state: RandomState, signal: AudioSignal): + f_center = util.sample_from_dist(self.f_center, state) + f_width = util.sample_from_dist(self.f_width, state) + + fmin = max(f_center - (f_width / 2), 0.0) + fmax = min(f_center + (f_width / 2), 1.0) + + fmin_hz = (signal.sample_rate / 2) * fmin + fmax_hz = (signal.sample_rate / 2) * fmax + + return {"fmin_hz": fmin_hz, "fmax_hz": fmax_hz} + + def _transform(self, signal, fmin_hz: float, fmax_hz: float): + return signal.mask_frequencies(fmin_hz=fmin_hz, fmax_hz=fmax_hz) + + +class TimeMask(SpectralTransform): + """Masks out contiguous time-steps from signal. + + Uses :py:func:`audiotools.core.dsp.DSPMixin.mask_timesteps`. + + Parameters + ---------- + t_center : tuple, optional + Center time in terms of 0.0 and 1.0 (duration of signal), + by default ("uniform", 0.0, 1.0) + t_width : tuple, optional + Width of dropped out portion, by default ("const", 0.025) + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + """ + + def __init__( + self, + t_center: tuple=("uniform", 0.0, 1.0), + t_width: tuple=("const", 0.025), + name: str=None, + prob: float=1, ): + super().__init__(name=name, prob=prob) + self.t_center = t_center + self.t_width = t_width + + def _instantiate(self, state: RandomState, signal: AudioSignal): + t_center = util.sample_from_dist(self.t_center, state) + t_width = util.sample_from_dist(self.t_width, state) + + tmin = max(t_center - (t_width / 2), 0.0) + tmax = min(t_center + (t_width / 2), 1.0) + + tmin_s = signal.signal_duration * tmin + tmax_s = signal.signal_duration * tmax + return {"tmin_s": tmin_s, "tmax_s": tmax_s} + + def _transform(self, signal, tmin_s: float, tmax_s: float): + return signal.mask_timesteps(tmin_s=tmin_s, tmax_s=tmax_s) + + +class Smoothing(BaseTransform): + """Convolves the signal with a smoothing window. + + Uses :py:func:`audiotools.core.effects.EffectMixin.convolve`. + + Parameters + ---------- + window_type : tuple, optional + Type of window to use, by default ("const", "average") + window_length : tuple, optional + Length of smoothing window, by + default ("choice", [8, 16, 32, 64, 128, 256, 512]) + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + """ + + def __init__( + self, + window_type: tuple=("const", "average"), + window_length: tuple=("choice", [8, 16, 32, 64, 128, 256, 512]), + name: str=None, + prob: float=1, ): + super().__init__(name=name, prob=prob) + self.window_type = window_type + self.window_length = window_length + + def _instantiate(self, state: RandomState, signal: AudioSignal=None): + window_type = util.sample_from_dist(self.window_type, state) + window_length = util.sample_from_dist(self.window_length, state) + window = signal.get_window( + window_type=window_type, window_length=window_length, device="cpu") + return {"window": AudioSignal(window, signal.sample_rate)} + + def _transform(self, signal, window): + sscale = signal.audio_data.abs().max(axis=-1, keepdim=True) + sscale[sscale == 0.0] = 1.0 + + out = signal.convolve(window) + + oscale = out.audio_data.abs().max(axis=-1, keepdim=True) + oscale[oscale == 0.0] = 1.0 + + out = out * (sscale / oscale) + return out + + +class FrequencyNoise(FrequencyMask): + """Similar to :py:func:`audiotools.data.transforms.FrequencyMask`, but + replaces with noise instead of zeros. + + Parameters + ---------- + f_center : tuple, optional + Center frequency between 0.0 and 1.0 (Nyquist), by default ("uniform", 0.0, 1.0) + f_width : tuple, optional + Width of zero'd out band, by default ("const", 0.1) + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + """ + + def __init__( + self, + f_center: tuple=("uniform", 0.0, 1.0), + f_width: tuple=("const", 0.1), + name: str=None, + prob: float=1, ): + super().__init__( + f_center=f_center, f_width=f_width, name=name, prob=prob) + + def _transform(self, signal, fmin_hz: float, fmax_hz: float): + signal = signal.mask_frequencies(fmin_hz=fmin_hz, fmax_hz=fmax_hz) + mag, phase = signal.magnitude, signal.phase + + mag_r, phase_r = paddle.randn( + shape=mag.shape, dtype=mag.dtype), paddle.randn( + shape=phase.shape, dtype=phase.dtype) + mask = (mag == 0.0) * (phase == 0.0) + + # mag[mask] = mag_r[mask] + # phase[mask] = phase_r[mask] + mag = paddle.where(mask, mag_r, mag) + phase = paddle.where(mask, phase_r, phase) + + signal.magnitude = mag + signal.phase = phase + return signal diff --git a/audio/audiotools/metrics/__init__.py b/audio/audiotools/metrics/__init__.py new file mode 100644 index 00000000000..f14ee082276 --- /dev/null +++ b/audio/audiotools/metrics/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Functions for comparing AudioSignal objects to one another. +""" +from . import quality diff --git a/audio/audiotools/metrics/quality.py b/audio/audiotools/metrics/quality.py new file mode 100644 index 00000000000..63f72709910 --- /dev/null +++ b/audio/audiotools/metrics/quality.py @@ -0,0 +1,74 @@ +# MIT License, Copyright (c) 2023-Present, Descript. +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/audiotools/metrics/quality.py) +import os + +import numpy as np +import paddle + +from ..core import AudioSignal + + +def visqol( + estimates: AudioSignal, + references: AudioSignal, + mode: str="audio", ): + """ViSQOL score. + + Parameters + ---------- + estimates : AudioSignal + Degraded AudioSignal + references : AudioSignal + Reference AudioSignal + mode : str, optional + 'audio' or 'speech', by default 'audio' + + Returns + ------- + Tensor[float] + ViSQOL score (MOS-LQO) + """ + try: + from pyvisqol import visqol_lib_py + from pyvisqol.pb2 import visqol_config_pb2 + from pyvisqol.pb2 import similarity_result_pb2 + except ImportError: + from visqol import visqol_lib_py + from visqol.pb2 import visqol_config_pb2 + from visqol.pb2 import similarity_result_pb2 + + config = visqol_config_pb2.VisqolConfig() + if mode == "audio": + target_sr = 48000 + config.options.use_speech_scoring = False + svr_model_path = "libsvm_nu_svr_model.txt" + elif mode == "speech": + target_sr = 16000 + config.options.use_speech_scoring = True + svr_model_path = "lattice_tcditugenmeetpackhref_ls2_nl60_lr12_bs2048_learn.005_ep2400_train1_7_raw.tflite" + else: + raise ValueError(f"Unrecognized mode: {mode}") + config.audio.sample_rate = target_sr + config.options.svr_model_path = os.path.join( + os.path.dirname(visqol_lib_py.__file__), "model", svr_model_path) + + api = visqol_lib_py.VisqolApi() + api.Create(config) + + estimates = estimates.clone().to_mono().resample(target_sr) + references = references.clone().to_mono().resample(target_sr) + + visqols = [] + for i in range(estimates.batch_size): + _visqol = api.Measure( + references.audio_data[i, 0].detach().cpu().numpy().astype(float), + estimates.audio_data[i, 0].detach().cpu().numpy().astype(float), ) + visqols.append(_visqol.moslqo) + return paddle.to_tensor(np.array(visqols)) + + +if __name__ == "__main__": + signal = AudioSignal(paddle.randn([44100]), 44100) + print(visqol(signal, signal)) diff --git a/audio/audiotools/ml/__init__.py b/audio/audiotools/ml/__init__.py new file mode 100644 index 00000000000..4a3b29fac30 --- /dev/null +++ b/audio/audiotools/ml/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from . import decorators +from .accelerator import Accelerator +from .basemodel import BaseModel diff --git a/audio/audiotools/ml/accelerator.py b/audio/audiotools/ml/accelerator.py new file mode 100644 index 00000000000..74cb3331b5e --- /dev/null +++ b/audio/audiotools/ml/accelerator.py @@ -0,0 +1,199 @@ +# MIT License, Copyright (c) 2023-Present, Descript. +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/audiotools/ml/accelerator.py) +import os +import typing + +import paddle +import paddle.distributed as dist +from paddle.io import DataLoader +from paddle.io import DistributedBatchSampler +from paddle.io import SequenceSampler + + +class ResumableDistributedSampler(DistributedBatchSampler): + """Distributed sampler that can be resumed from a given start index.""" + + def __init__(self, dataset, start_idx: int=None, **kwargs): + super().__init__(dataset, **kwargs) + # Start index, allows to resume an experiment at the index it was + self.start_idx = start_idx // self.num_replicas if start_idx is not None else 0 + + def __iter__(self): + for i, idx in enumerate(super().__iter__()): + if i >= self.start_idx: + yield idx + self.start_idx = 0 # set the index back to 0 so for the next epoch + + +class ResumableSequentialSampler(SequenceSampler): + """Sequential sampler that can be resumed from a given start index.""" + + def __init__(self, dataset, start_idx: int=None, **kwargs): + super().__init__(dataset, **kwargs) + # Start index, allows to resume an experiment at the index it was + self.start_idx = start_idx if start_idx is not None else 0 + + def __iter__(self): + for i, idx in enumerate(super().__iter__()): + if i >= self.start_idx: + yield idx + self.start_idx = 0 # set the index back to 0 so for the next epoch + + +class Accelerator: + """This class is used to prepare models and dataloaders for + usage with DDP or DP. Use the functions prepare_model, prepare_dataloader to + prepare the respective objects. In the case of models, they are moved to + the appropriate GPU. In the case of + dataloaders, a sampler is created and the dataloader is initialized with + that sampler. + + If the world size is 1, prepare_model and prepare_dataloader are + no-ops. If the environment variable ``PADDLE_TRAINER_ID`` is not set, then the + script was launched without ``paddle.distributed.launch``, and ``DataParallel`` + will be used instead of ``DistributedDataParallel`` (not recommended), if + the world size (number of GPUs) is greater than 1. + + Parameters + ---------- + amp : bool, optional + Whether or not to enable automatic mixed precision, by default False + (Note: This is a placeholder as PaddlePaddle doesn't have native support for AMP as of now) + """ + + def __init__(self, amp: bool=False): + trainer_id = os.getenv("PADDLE_TRAINER_ID", None) + self.world_size = paddle.distributed.get_world_size() + + self.use_ddp = self.world_size > 1 and trainer_id is not None + self.use_dp = self.world_size > 1 and trainer_id is None + self.device = "cpu" if self.world_size == 0 else "cuda" + + if self.use_ddp: + trainer_id = int(trainer_id) + dist.init_parallel_env() + + self.local_rank = 0 if trainer_id is None else int(trainer_id) + self.amp = amp + + class DummyScaler: + def __init__(self): + pass + + def step(self, optimizer): + optimizer.step() + + def scale(self, loss): + return loss + + def unscale_(self, optimizer): + return optimizer + + def update(self): + pass + + self.scaler = paddle.amp.GradScaler() if self.amp else DummyScaler() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + pass + + def prepare_model(self, model: paddle.nn.Layer, **kwargs): + """Prepares model for DDP or DP. The model is moved to + the device of the correct rank. + + Parameters + ---------- + model : paddle.nn.Layer + Model that is converted for DDP or DP. + + Returns + ------- + paddle.nn.Layer + Wrapped model, or original model if DDP and DP are turned off. + """ + if self.use_ddp: + model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(model) + model = paddle.DataParallel(model, **kwargs) + elif self.use_dp: + model = paddle.DataParallel(model, **kwargs) + return model + + def autocast(self, *args, **kwargs): + return paddle.amp.auto_cast(self.amp, *args, **kwargs) + + def backward(self, loss: paddle.Tensor): + """Backwards pass. + + Parameters + ---------- + loss : paddle.Tensor + Loss value. + """ + scaled = self.scaler.scale(loss) # scale the loss + scaled.backward() + + def step(self, optimizer: paddle.optimizer.Optimizer): + """Steps the optimizer. + + Parameters + ---------- + optimizer : paddle.optimizer.Optimizer + Optimizer to step forward. + """ + self.scaler.step(optimizer) + + def update(self): + # https://www.paddlepaddle.org.cn/documentation/docs/zh/2.6/api/paddle/amp/GradScaler_cn.html#step-optimizer + self.scaler.update() + + def prepare_dataloader(self, + dataset: typing.Iterable, + start_idx: int=None, + **kwargs): + """Wraps a dataset with a DataLoader, using the correct sampler if DDP is + enabled. + + Parameters + ---------- + dataset : typing.Iterable + Dataset to build Dataloader around. + start_idx : int, optional + Start index of sampler, useful if resuming from some epoch, + by default None + + Returns + ------- + DataLoader + Wrapped DataLoader. + """ + + if self.use_ddp: + sampler = ResumableDistributedSampler( + dataset, + start_idx, + batch_size=kwargs.get("batch_size", 1), + shuffle=kwargs.get("shuffle", True), + drop_last=kwargs.get("drop_last", False), + num_replicas=self.world_size, + rank=self.local_rank, ) + if "num_workers" in kwargs: + kwargs["num_workers"] = max(kwargs["num_workers"] // + self.world_size, 1) + else: + sampler = ResumableSequentialSampler(dataset, start_idx) + + dataloader = DataLoader( + dataset, + batch_sampler=sampler if self.use_ddp else None, + sampler=sampler if not self.use_ddp else None, + **kwargs, ) + return dataloader + + @staticmethod + def unwrap(model): + return model diff --git a/audio/audiotools/ml/basemodel.py b/audio/audiotools/ml/basemodel.py new file mode 100644 index 00000000000..97c31ff7a7c --- /dev/null +++ b/audio/audiotools/ml/basemodel.py @@ -0,0 +1,272 @@ +# MIT License, Copyright (c) 2023-Present, Descript. +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/audiotools/ml/layers/base.py) +import inspect +import shutil +import tempfile +import typing +from pathlib import Path + +import paddle +from paddle import nn + + +class BaseModel(nn.Layer): + """This is a class that adds useful save/load functionality to a + ``paddle.nn.Layer`` object. ``BaseModel`` objects can be saved + as ``package`` easily, making them super easy to port between + machines without requiring a ton of dependencies. Files can also be + saved as just weights, in the standard way. + + >>> class Model(ml.BaseModel): + >>> def __init__(self, arg1: float = 1.0): + >>> super().__init__() + >>> self.arg1 = arg1 + >>> self.linear = nn.Linear(1, 1) + >>> + >>> def forward(self, x): + >>> return self.linear(x) + >>> + >>> model1 = Model() + >>> + >>> with tempfile.NamedTemporaryFile(suffix=".pth") as f: + >>> model1.save( + >>> f.name, + >>> ) + >>> model2 = Model.load(f.name) + >>> out2 = seed_and_run(model2, x) + >>> assert paddle.allclose(out1, out2) + >>> + >>> model1.save(f.name, package=True) + >>> model2 = Model.load(f.name) + >>> model2.save(f.name, package=False) + >>> model3 = Model.load(f.name) + >>> out3 = seed_and_run(model3, x) + >>> + >>> with tempfile.TemporaryDirectory() as d: + >>> model1.save_to_folder(d, {"data": 1.0}) + >>> Model.load_from_folder(d) + + """ + + def save( + self, + path: str, + metadata: dict=None, + package: bool=False, + intern: list=[], + extern: list=[], + mock: list=[], ): + """Saves the model, either as a package, or just as + weights, alongside some specified metadata. + + Parameters + ---------- + path : str + Path to save model to. + metadata : dict, optional + Any metadata to save alongside the model, + by default None + package : bool, optional + Whether to use ``package`` to save the model in + a format that is portable, by default True + intern : list, optional + List of additional libraries that are internal + to the model, used with package, by default [] + extern : list, optional + List of additional libraries that are external to + the model, used with package, by default [] + mock : list, optional + List of libraries to mock, used with package, + by default [] + + Returns + ------- + str + Path to saved model. + """ + sig = inspect.signature(self.__class__) + args = {} + + for key, val in sig.parameters.items(): + arg_val = val.default + if arg_val is not inspect.Parameter.empty: + args[key] = arg_val + + # Look up attibutes in self, and if any of them are in args, + # overwrite them in args. + for attribute in dir(self): + if attribute in args: + args[attribute] = getattr(self, attribute) + + metadata = {} if metadata is None else metadata + metadata["kwargs"] = args + if not hasattr(self, "metadata"): + self.metadata = {} + self.metadata.update(metadata) + + if not package: + state_dict = {"state_dict": self.state_dict(), "metadata": metadata} + paddle.save(state_dict, str(path)) + else: + self._save_package(path, intern=intern, extern=extern, mock=mock) + + return path + + @property + def device(self): + """Gets the device the model is on by looking at the device of + the first parameter. May not be valid if model is split across + multiple devices. + """ + return list(self.parameters())[0].place + + @classmethod + def load( + cls, + location: str, + *args, + package_name: str=None, + strict: bool=False, + **kwargs, ): + """Load model from a path. Tries first to load as a package, and if + that fails, tries to load as weights. The arguments to the class are + specified inside the model weights file. + + Parameters + ---------- + location : str + Path to file. + package_name : str, optional + Name of package, by default ``cls.__name__``. + strict : bool, optional + Ignore unmatched keys, by default False + kwargs : dict + Additional keyword arguments to the model instantiation, if + not loading from package. + + Returns + ------- + BaseModel + A model that inherits from BaseModel. + """ + try: + model = cls._load_package(location, package_name=package_name) + except: + model_dict = paddle.load(location) + metadata = model_dict["metadata"] + metadata["kwargs"].update(kwargs) + + sig = inspect.signature(cls) + class_keys = list(sig.parameters.keys()) + for k in list(metadata["kwargs"].keys()): + if k not in class_keys: + metadata["kwargs"].pop(k) + + model = cls(*args, **metadata["kwargs"]) + model.set_state_dict(model_dict["state_dict"]) + model.metadata = metadata + + return model + + def _save_package(self, path, intern=[], extern=[], mock=[], **kwargs): + raise NotImplementedError("Currently Paddle does not support packaging") + + @classmethod + def _load_package(cls, path, package_name=None): + raise NotImplementedError("Currently Paddle does not support packaging") + + def save_to_folder( + self, + folder: typing.Union[str, Path], + extra_data: dict=None, + package: bool=False, ): + """Dumps a model into a folder, as both a package + and as weights, as well as anything specified in + ``extra_data``. ``extra_data`` is a dictionary of other + pickleable files, with the keys being the paths + to save them in. The model is saved under a subfolder + specified by the name of the class (e.g. ``folder/generator/[package, weights].pth`` + if the model name was ``Generator``). + + >>> with tempfile.TemporaryDirectory() as d: + >>> extra_data = { + >>> "optimizer.pth": optimizer.state_dict() + >>> } + >>> model.save_to_folder(d, extra_data) + >>> Model.load_from_folder(d) + + Parameters + ---------- + folder : typing.Union[str, Path] + _description_ + extra_data : dict, optional + _description_, by default None + + Returns + ------- + str + Path to folder + """ + extra_data = {} if extra_data is None else extra_data + model_name = type(self).__name__.lower() + target_base = Path(f"{folder}/{model_name}/") + target_base.mkdir(exist_ok=True, parents=True) + + if package: + package_path = target_base / f"package.pth" + self.save(package_path) + + weights_path = target_base / f"weights.pth" + self.save(weights_path, package=False) + + for path, obj in extra_data.items(): + paddle.save(obj, str(target_base / path)) + + return target_base + + @classmethod + def load_from_folder( + cls, + folder: typing.Union[str, Path], + package: bool=False, + strict: bool=False, + **kwargs, ): + """Loads the model from a folder generated by + :py:func:`audiotools.ml.layers.base.BaseModel.save_to_folder`. + Like that function, this one looks for a subfolder that has + the name of the class (e.g. ``folder/generator/[package, weights].pth`` if the + model name was ``Generator``). + + Parameters + ---------- + folder : typing.Union[str, Path] + _description_ + package : bool, optional + Whether to use ``package`` to load the model, + loading the model from ``package.pth``. + strict : bool, optional + Ignore unmatched keys, by default False + + Returns + ------- + tuple + tuple of model and extra data as saved by + :py:func:`audiotools.ml.layers.base.BaseModel.save_to_folder`. + """ + folder = Path(folder) / cls.__name__.lower() + model_pth = "package.pth" if package else "weights.pth" + model_pth = folder / model_pth + + model = cls.load(str(model_pth)) + extra_data = {} + excluded = ["package.pth", "weights.pth"] + files = [ + x for x in folder.glob("*") + if x.is_file() and x.name not in excluded + ] + for f in files: + extra_data[f.name] = paddle.load(str(f), **kwargs) + + return model, extra_data diff --git a/audio/audiotools/ml/decorators.py b/audio/audiotools/ml/decorators.py new file mode 100644 index 00000000000..787be87c577 --- /dev/null +++ b/audio/audiotools/ml/decorators.py @@ -0,0 +1,446 @@ +# MIT License, Copyright (c) 2023-Present, Descript. +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/audiotools/ml/decorators.py) +import math +import os +import time +from collections import defaultdict +from functools import wraps + +import paddle +import paddle.distributed as dist +from rich import box +from rich.console import Console +from rich.console import Group +from rich.live import Live +from rich.markdown import Markdown +from rich.padding import Padding +from rich.panel import Panel +from rich.progress import BarColumn +from rich.progress import Progress +from rich.progress import SpinnerColumn +from rich.progress import TimeElapsedColumn +from rich.progress import TimeRemainingColumn +from rich.rule import Rule +from rich.table import Table +from visualdl import LogWriter + + +# This is here so that the history can be pickled. +def default_list(): + return [] + + +class Mean: + """Keeps track of the running mean, along with the latest + value. + """ + + def __init__(self): + self.reset() + + def __call__(self): + mean = self.total / max(self.count, 1) + return mean + + def reset(self): + self.count = 0 + self.total = 0 + + def update(self, val): + if math.isfinite(val): + self.count += 1 + self.total += val + + +def when(condition): + """Runs a function only when the condition is met. The condition is + a function that is run. + + Parameters + ---------- + condition : Callable + Function to run to check whether or not to run the decorated + function. + + Example + ------- + Checkpoint only runs every 100 iterations, and only if the + local rank is 0. + + >>> i = 0 + >>> rank = 0 + >>> + >>> @when(lambda: i % 100 == 0 and rank == 0) + >>> def checkpoint(): + >>> print("Saving to /runs/exp1") + >>> + >>> for i in range(1000): + >>> checkpoint() + + """ + + def decorator(fn): + @wraps(fn) + def decorated(*args, **kwargs): + if condition(): + return fn(*args, **kwargs) + + return decorated + + return decorator + + +def timer(prefix: str="time"): + """Adds execution time to the output dictionary of the decorated + function. The function decorated by this must output a dictionary. + The key added will follow the form "[prefix]/[name_of_function]" + + Parameters + ---------- + prefix : str, optional + The key added will follow the form "[prefix]/[name_of_function]", + by default "time". + """ + + def decorator(fn): + @wraps(fn) + def decorated(*args, **kwargs): + s = time.perf_counter() + output = fn(*args, **kwargs) + assert isinstance(output, dict) + e = time.perf_counter() + output[f"{prefix}/{fn.__name__}"] = e - s + return output + + return decorated + + return decorator + + +class Tracker: + """ + A tracker class that helps to monitor the progress of training and logging the metrics. + + Attributes + ---------- + metrics : dict + A dictionary containing the metrics for each label. + history : dict + A dictionary containing the history of metrics for each label. + writer : LogWriter + A LogWriter object for logging the metrics. + rank : int + The rank of the current process. + step : int + The current step of the training. + tasks : dict + A dictionary containing the progress bars and tables for each label. + pbar : Progress + A progress bar object for displaying the progress. + consoles : list + A list of console objects for logging. + live : Live + A Live object for updating the display live. + + Methods + ------- + print(msg: str) + Prints the given message to all consoles. + update(label: str, fn_name: str) + Updates the progress bar and table for the given label. + done(label: str, title: str) + Resets the progress bar and table for the given label and prints the final result. + track(label: str, length: int, completed: int = 0, op: dist.ReduceOp = dist.ReduceOp.AVG, ddp_active: bool = "LOCAL_RANK" in os.environ) + A decorator for tracking the progress and metrics of a function. + log(label: str, value_type: str = "value", history: bool = True) + A decorator for logging the metrics of a function. + is_best(label: str, key: str) -> bool + Checks if the latest value of the given key in the label is the best so far. + state_dict() -> dict + Returns a dictionary containing the state of the tracker. + load_state_dict(state_dict: dict) -> Tracker + Loads the state of the tracker from the given state dictionary. + """ + + def __init__( + self, + writer: LogWriter=None, + log_file: str=None, + rank: int=0, + console_width: int=100, + step: int=0, ): + """ + Initializes the Tracker object. + + Parameters + ---------- + writer : LogWriter, optional + A LogWriter object for logging the metrics, by default None. + log_file : str, optional + The path to the log file, by default None. + rank : int, optional + The rank of the current process, by default 0. + console_width : int, optional + The width of the console, by default 100. + step : int, optional + The current step of the training, by default 0. + """ + self.metrics = {} + self.history = {} + self.writer = writer + self.rank = rank + self.step = step + + # Create progress bars etc. + self.tasks = {} + self.pbar = Progress( + SpinnerColumn(), + "[progress.description]{task.description}", + "{task.completed}/{task.total}", + BarColumn(), + TimeElapsedColumn(), + "/", + TimeRemainingColumn(), ) + self.consoles = [Console(width=console_width)] + self.live = Live(console=self.consoles[0], refresh_per_second=10) + if log_file is not None: + self.consoles.append( + Console(width=console_width, file=open(log_file, "a"))) + + def print(self, msg): + """ + Prints the given message to all consoles. + + Parameters + ---------- + msg : str + The message to be printed. + """ + if self.rank == 0: + for c in self.consoles: + c.log(msg) + + def update(self, label, fn_name): + """ + Updates the progress bar and table for the given label. + + Parameters + ---------- + label : str + The label of the progress bar and table to be updated. + fn_name : str + The name of the function associated with the label. + """ + if self.rank == 0: + self.pbar.advance(self.tasks[label]["pbar"]) + + # Create table + table = Table(title=label, expand=True, box=box.MINIMAL) + table.add_column("key", style="cyan") + table.add_column("value", style="bright_blue") + table.add_column("mean", style="bright_green") + + keys = self.metrics[label]["value"].keys() + for k in keys: + value = self.metrics[label]["value"][k] + mean = self.metrics[label]["mean"][k]() + table.add_row(k, f"{value:10.6f}", f"{mean:10.6f}") + + self.tasks[label]["table"] = table + tables = [t["table"] for t in self.tasks.values()] + group = Group(*tables, self.pbar) + self.live.update( + Group( + Padding("", (0, 0)), + Rule(f"[italic]{fn_name}()", style="white"), + Padding("", (0, 0)), + Panel.fit( + group, + padding=(0, 5), + title="[b]Progress", + border_style="blue", ), )) + + def done(self, label: str, title: str): + """ + Resets the progress bar and table for the given label and prints the final result. + + Parameters + ---------- + label : str + The label of the progress bar and table to be reset. + title : str + The title to be displayed when printing the final result. + """ + for label in self.metrics: + for v in self.metrics[label]["mean"].values(): + v.reset() + + if self.rank == 0: + self.pbar.reset(self.tasks[label]["pbar"]) + tables = [t["table"] for t in self.tasks.values()] + group = Group(Markdown(f"# {title}"), *tables, self.pbar) + self.print(group) + + def track( + self, + label: str, + length: int, + completed: int=0, + op: dist.ReduceOp=dist.ReduceOp.AVG, + ddp_active: bool="LOCAL_RANK" in os.environ, ): + """ + A decorator for tracking the progress and metrics of a function. + + Parameters + ---------- + label : str + The label to be associated with the progress and metrics. + length : int + The total number of iterations to be completed. + completed : int, optional + The number of iterations already completed, by default 0. + op : dist.ReduceOp, optional + The reduce operation to be used, by default dist.ReduceOp.AVG. + ddp_active : bool, optional + Whether the DistributedDataParallel is active, by default "LOCAL_RANK" in os.environ. + """ + self.tasks[label] = { + "pbar": + self.pbar.add_task( + f"[white]Iteration ({label})", + total=length, + completed=completed), + "table": + Table(), + } + self.metrics[label] = { + "value": defaultdict(), + "mean": defaultdict(lambda: Mean()), + } + + def decorator(fn): + @wraps(fn) + def decorated(*args, **kwargs): + output = fn(*args, **kwargs) + if not isinstance(output, dict): + self.update(label, fn.__name__) + return output + # Collect across all DDP processes + scalar_keys = [] + for k, v in output.items(): + if isinstance(v, (int, float)): + v = paddle.to_tensor([v]) + if not paddle.is_tensor(v): + continue + if ddp_active and v.is_cuda: + dist.all_reduce(v, op=op) + output[k] = v.detach() + if paddle.numel(v) == 1: + scalar_keys.append(k) + output[k] = v.item() + + # Save the outputs to tracker + for k, v in output.items(): + if k not in scalar_keys: + continue + self.metrics[label]["value"][k] = v + # Update the running mean + self.metrics[label]["mean"][k].update(v) + + self.update(label, fn.__name__) + return output + + return decorated + + return decorator + + def log(self, label: str, value_type: str="value", history: bool=True): + """ + A decorator for logging the metrics of a function. + + Parameters + ---------- + label : str + The label to be associated with the logging. + value_type : str, optional + The type of value to be logged, by default "value". + history : bool, optional + Whether to save the history of the metrics, by default True. + """ + assert value_type in ["mean", "value"] + if history: + if label not in self.history: + self.history[label] = defaultdict(default_list) + + def decorator(fn): + @wraps(fn) + def decorated(*args, **kwargs): + output = fn(*args, **kwargs) + if self.rank == 0: + nonlocal value_type, label + metrics = self.metrics[label][value_type] + for k, v in metrics.items(): + v = v() if isinstance(v, Mean) else v + if self.writer is not None: + self.writer.add_scalar( + tag=f"{k}/{label}", value=v, step=self.step) + if label in self.history: + self.history[label][k].append(v) + + if label in self.history: + self.history[label]["step"].append(self.step) + + return output + + return decorated + + return decorator + + def is_best(self, label, key): + """ + Checks if the latest value of the given key in the label is the best so far. + + Parameters + ---------- + label : str + The label of the metrics to be checked. + key : str + The key of the metric to be checked. + + Returns + ------- + bool + True if the latest value is the best so far, otherwise False. + """ + return self.history[label][key][-1] == min(self.history[label][key]) + + def state_dict(self): + """ + Returns a dictionary containing the state of the tracker. + + Returns + ------- + dict + A dictionary containing the history and step of the tracker. + """ + return {"history": self.history, "step": self.step} + + def load_state_dict(self, state_dict): + """ + Loads the state of the tracker from the given state dictionary. + + Parameters + ---------- + state_dict : dict + A dictionary containing the history and step of the tracker. + + Returns + ------- + Tracker + The tracker object with the loaded state. + """ + self.history = state_dict["history"] + self.step = state_dict["step"] + return self diff --git a/audio/audiotools/post.py b/audio/audiotools/post.py new file mode 100644 index 00000000000..f5ec208ed85 --- /dev/null +++ b/audio/audiotools/post.py @@ -0,0 +1,88 @@ +# MIT License, Copyright (c) 2023-Present, Descript. +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/audiotools/post.py) +import typing + +import paddle + +from audio.audiotools.core import AudioSignal + + +def audio_table( + audio_dict: dict, + first_column: str=None, + format_fn: typing.Callable=None, + **kwargs, ): + """Embeds an audio table into HTML, or as the output cell + in a notebook. + + Parameters + ---------- + audio_dict : dict + Dictionary of data to embed. + first_column : str, optional + The label for the first column of the table, by default None + format_fn : typing.Callable, optional + How to format the data, by default None + + Returns + ------- + str + Table as a string + + Examples + -------- + + >>> audio_dict = {} + >>> for i in range(signal_batch.batch_size): + >>> audio_dict[i] = { + >>> "input": signal_batch[i], + >>> "output": output_batch[i] + >>> } + >>> audiotools.post.audio_zip(audio_dict) + + """ + + output = [] + columns = None + + def _default_format_fn(label, x, **kwargs): + if paddle.is_tensor(x): + x = x.tolist() + + if x is None: + return "." + elif isinstance(x, AudioSignal): + return x.embed(display=False, return_html=True, **kwargs) + else: + return str(x) + + if format_fn is None: + format_fn = _default_format_fn + + if first_column is None: + first_column = "." + + for k, v in audio_dict.items(): + if not isinstance(v, dict): + v = {"Audio": v} + + v_keys = list(v.keys()) + if columns is None: + columns = [first_column] + v_keys + output.append(" | ".join(columns)) + + layout = "|---" + len(v_keys) * "|:-:" + output.append(layout) + + formatted_audio = [] + for col in columns[1:]: + formatted_audio.append(format_fn(col, v[col], **kwargs)) + + row = f"| {k} | " + row += " | ".join(formatted_audio) + output.append(row) + + output = "\n" + "\n".join(output) + return output diff --git a/audio/audiotools/requirements.txt b/audio/audiotools/requirements.txt new file mode 100644 index 00000000000..57e22855999 --- /dev/null +++ b/audio/audiotools/requirements.txt @@ -0,0 +1,6 @@ +ffmpeg-python +ffmpy +flatten_dict +pyloudnorm +pytest +rich diff --git a/audio/tests/audiotools/core/test_audio_signal.py b/audio/tests/audiotools/core/test_audio_signal.py new file mode 100644 index 00000000000..ede3d9ec76c --- /dev/null +++ b/audio/tests/audiotools/core/test_audio_signal.py @@ -0,0 +1,615 @@ +# MIT License, Copyright (c) 2023-Present, Descript. +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/tests/core/test_audio_signal.py) +import pathlib +import sys +import tempfile + +import librosa +import numpy as np +import paddle +import pytest +import rich + +from audio import audiotools +from audio.audiotools import AudioSignal +from audio.audiotools import util + + +def test_io(): + audio_path = "./audio/spk/f10_script4_produced.wav" + signal = AudioSignal(pathlib.Path(audio_path)) + + with tempfile.NamedTemporaryFile(suffix=".wav") as f: + signal.write(f.name) + signal_from_file = AudioSignal(f.name) + + mp3_signal = AudioSignal(audio_path.replace("wav", "mp3")) + print(mp3_signal) + + assert signal == signal_from_file + print(signal) + print(signal.markdown()) + + mp3_signal = AudioSignal.excerpt( + audio_path.replace("wav", "mp3"), offset=5, duration=5) + assert mp3_signal.signal_duration == 5.0 + assert mp3_signal.duration == 5.0 + assert mp3_signal.length == mp3_signal.signal_length + + rich.print(signal) + + array = np.random.randn(2, 16000) + signal = AudioSignal(array, sample_rate=16000) + assert np.allclose(signal.numpy(), array) + + signal = AudioSignal(array, 44100) + assert signal.sample_rate == 44100 + signal.shape + + with pytest.raises(ValueError): + signal = AudioSignal(5, sample_rate=16000) + + signal = AudioSignal(audio_path, offset=10, duration=10) + assert np.allclose(signal.signal_duration, 10.0) + assert np.allclose(signal.duration, 10.0) + + signal = AudioSignal.excerpt(audio_path, offset=5, duration=5) + assert signal.signal_duration == 5.0 + assert signal.duration == 5.0 + + assert "offset" in signal.metadata + assert "duration" in signal.metadata + + signal = AudioSignal(paddle.randn([1000]), 44100) + assert signal.audio_data.ndim == 3 + assert paddle.all(signal.samples == signal.audio_data) + + audio_path = "./audio/spk/f10_script4_produced.wav" + assert AudioSignal(audio_path).hash() == AudioSignal(audio_path).hash() + assert AudioSignal(audio_path).hash() != AudioSignal(audio_path).normalize( + -20).hash() + + with pytest.raises(RuntimeError): + AudioSignal(audio_path, offset=100000, duration=3) + + +def test_copy_and_clone(): + audio_path = "./audio/spk/f10_script4_produced.wav" + signal = AudioSignal(audio_path) + signal.stft() + signal.loudness() + + copied = signal.copy() + deep_copied = signal.deepcopy() + cloned = signal.clone() + + for a in ["audio_data", "stft_data", "_loudness"]: + a1 = getattr(signal, a) + a2 = getattr(cloned, a) + a3 = getattr(copied, a) + a4 = getattr(deep_copied, a) + + assert id(a1) != id(a2) + assert id(a1) == id(a3) + assert id(a1) != id(a4) + + assert np.allclose(a1, a2) + assert np.allclose(a1, a3) + assert np.allclose(a1, a4) + + for a in ["path_to_file", "metadata"]: + a1 = getattr(signal, a) + a2 = getattr(cloned, a) + a3 = getattr(copied, a) + a4 = getattr(deep_copied, a) + + assert id(a1) == id(a2) if isinstance(a1, str) else id(a1) != id(a2) + assert id(a1) == id(a3) + assert id(a1) == id(a4) if isinstance(a1, str) else id(a1) != id(a2) + + # for clone, id should differ if path is list, and should differ always for metadata + # if path is string, id should remain same... + + assert signal.original_signal_length == copied.original_signal_length + assert signal.original_signal_length == deep_copied.original_signal_length + assert signal.original_signal_length == cloned.original_signal_length + + signal = signal.detach() + + +@pytest.mark.parametrize("loudness_cutoff", [-np.inf, -160, -80, -40, -20]) +def test_salient_excerpt(loudness_cutoff): + MAP = {-np.inf: 0.0, -160: 0.0, -80: 0.001, -40: 0.01, -20: 0.1} + with tempfile.NamedTemporaryFile(suffix=".wav") as f: + sr = 44100 + signal = AudioSignal(paddle.zeros([sr * 60]), sr) + + signal[..., sr * 20:sr * 21] = MAP[loudness_cutoff] * paddle.randn( + [44100]) + + signal.write(f.name) + signal = AudioSignal.salient_excerpt( + f.name, loudness_cutoff=loudness_cutoff, duration=1, num_tries=None) + + assert "offset" in signal.metadata + assert "duration" in signal.metadata + assert signal.loudness() >= loudness_cutoff + + signal = AudioSignal.salient_excerpt( + f.name, loudness_cutoff=np.inf, duration=1, num_tries=10) + signal = AudioSignal.salient_excerpt( + f.name, + loudness_cutoff=None, + duration=1, ) + + +def test_arithmetic(): + def _make_signals(): + array = np.random.randn(2, 16000) + sig1 = AudioSignal(array, sample_rate=16000) + + array = np.random.randn(2, 16000) + sig2 = AudioSignal(array, sample_rate=16000) + return sig1, sig2 + + # Addition (with a copy) + sig1, sig2 = _make_signals() + sig3 = sig1 + sig2 + assert paddle.allclose(sig3.audio_data, sig1.audio_data + sig2.audio_data) + + # Addition (rmul) + sig1, _ = _make_signals() + sig3 = 5.0 + sig1 + assert paddle.allclose(sig3.audio_data, sig1.audio_data + 5.0) + + # In place addition + sig3, sig2 = _make_signals() + sig1 = sig3.deepcopy() + sig3 += sig2 + assert paddle.allclose(sig3.audio_data, sig1.audio_data + sig2.audio_data) + + # Subtraction (with a copy) + sig1, sig2 = _make_signals() + sig3 = sig1 - sig2 + assert paddle.allclose(sig3.audio_data, sig1.audio_data - sig2.audio_data) + + # In place subtraction + sig3, sig2 = _make_signals() + sig1 = sig3.deepcopy() + sig3 -= sig2 + assert paddle.allclose(sig3.audio_data, sig1.audio_data - sig2.audio_data) + + # Multiplication (element-wise) + sig1, sig2 = _make_signals() + sig3 = sig1 * sig2 + assert paddle.allclose(sig3.audio_data, sig1.audio_data * sig2.audio_data) + + # Multiplication (gain) + sig1, _ = _make_signals() + sig3 = sig1 * 5.0 + assert paddle.allclose(sig3.audio_data, sig1.audio_data * 5.0) + + # Multiplication (rmul) + sig1, _ = _make_signals() + sig3 = 5.0 * sig1 + assert paddle.allclose(sig3.audio_data, sig1.audio_data * 5.0) + + # Multiplication (in-place) + sig3, sig2 = _make_signals() + sig1 = sig3.deepcopy() + sig3 *= sig2 + assert paddle.allclose(sig3.audio_data, sig1.audio_data * sig2.audio_data) + + +def test_equality(): + array = np.random.randn(2, 16000) + sig1 = AudioSignal(array, sample_rate=16000) + sig2 = AudioSignal(array, sample_rate=16000) + + assert sig1 == sig2 + + array = np.random.randn(2, 16000) + sig3 = AudioSignal(array, sample_rate=16000) + + assert sig1 != sig3 + + assert not np.allclose(sig1.numpy(), sig3.numpy()) + + +def test_indexing(): + array = np.random.randn(4, 2, 16000) + sig1 = AudioSignal(array, sample_rate=16000) + + assert np.allclose(sig1[0].audio_data, array[0]) + assert np.allclose(sig1[0, :, 8000].audio_data, array[0, :, 8000]) + + # Test with the associated STFT data. + array = np.random.randn(4, 2, 16000) + sig1 = AudioSignal(array, sample_rate=16000) + sig1.loudness() + sig1.stft() + + indexed = sig1[0] + + assert np.allclose(indexed.audio_data, array[0]) + assert np.allclose(indexed.stft_data, sig1.stft_data[0]) + assert np.allclose(indexed._loudness, sig1._loudness[0]) + + indexed = sig1[0:2] + + assert np.allclose(indexed.audio_data, array[0:2]) + assert np.allclose(indexed.stft_data, sig1.stft_data[0:2]) + assert np.allclose(indexed._loudness, sig1._loudness[0:2]) + + # Test using a boolean tensor to index batch + mask = paddle.to_tensor([True, False, True, False]) + indexed = sig1[mask] + + assert np.allclose(indexed.audio_data, sig1.audio_data[mask]) + # assert np.allclose(indexed.stft_data, sig1.stft_data[mask]) + assert np.allclose(indexed.stft_data, + util.bool_index_compat(sig1.stft_data, mask)) + assert np.allclose(indexed._loudness, sig1._loudness[mask]) + + # Set parts of signal using tensor + other_array = paddle.to_tensor(np.random.randn(4, 2, 16000)) + sig1 = AudioSignal(array, sample_rate=16000) + sig1[0, :, 6000:8000] = other_array[0, :, 6000:8000] + + assert np.allclose(sig1[0, :, 6000:8000].audio_data, + other_array[0, :, 6000:8000]) + + # Set parts of signal using AudioSignal + sig2 = AudioSignal(other_array, sample_rate=16000) + + sig1 = AudioSignal(array, sample_rate=16000) + sig1[0, :, 6000:8000] = sig2[0, :, 6000:8000] + + assert np.allclose(sig1[0, :, 6000:8000].audio_data, + sig2[0, :, 6000:8000].audio_data) + + # Check that loudnesses and stft_data get set as well, if only the batch + # dim is indexed. + sig2 = AudioSignal(other_array, sample_rate=16000) + sig2.stft() + sig2.loudness() + + sig1 = AudioSignal(array, sample_rate=16000) + sig1.stft() + sig1.loudness() + + # Test using a boolean tensor to index batch + mask = paddle.to_tensor([True, False, True, False]) + sig1[mask] = sig2[mask] + + for k in ["stft_data", "audio_data", "_loudness"]: + a1 = getattr(sig1, k) + a2 = getattr(sig2, k) + + # assert np.allclose(a1[mask], a2[mask]) + assert np.allclose( + util.bool_index_compat(a1, mask), util.bool_index_compat(a2, mask)) + + +def test_zeros(): + x = AudioSignal.zeros(0.5, 44100) + assert x.signal_duration == 0.5 + assert x.duration == 0.5 + assert x.sample_rate == 44100 + + +@pytest.mark.parametrize("shape", + ["sine", "square", "sawtooth", "triangle", "beep"]) +def test_waves(shape: str): + # error case + if shape == "beep": + with pytest.raises(ValueError): + AudioSignal.wave(440, 0.5, 44100, shape=shape) + + return + + x = AudioSignal.wave(440, 0.5, 44100, shape=shape) + assert x.duration == 0.5 + assert x.sample_rate == 44100 + + # test the default shape arg + x = AudioSignal.wave(440, 0.5, 44100) + assert x.duration == 0.5 + assert x.sample_rate == 44100 + + +def test_zero_pad(): + array = np.random.randn(4, 2, 16000) + sig1 = AudioSignal(array, sample_rate=16000) + + sig1.zero_pad(100, 100) + zeros = paddle.zeros([4, 2, 100], dtype="float64") + assert paddle.allclose(sig1.audio_data[..., :100], zeros) + assert paddle.allclose(sig1.audio_data[..., -100:], zeros) + + +def test_zero_pad_to(): + array = np.random.randn(4, 2, 16000) + sig1 = AudioSignal(array, sample_rate=16000) + + sig1.zero_pad_to(16100) + zeros = paddle.zeros([4, 2, 100], dtype="float64") + assert paddle.allclose(sig1.audio_data[..., -100:], zeros) + assert sig1.signal_length == 16100 + + sig1 = AudioSignal(array, sample_rate=16000) + sig1.zero_pad_to(15000) + assert sig1.signal_length == 16000 + + sig1 = AudioSignal(array, sample_rate=16000) + sig1.zero_pad_to(16100, mode="before") + zeros = paddle.zeros([4, 2, 100], dtype="float64") + assert paddle.allclose(sig1.audio_data[..., :100], zeros) + assert sig1.signal_length == 16100 + + sig1 = AudioSignal(array, sample_rate=16000) + sig1.zero_pad_to(15000, mode="before") + assert sig1.signal_length == 16000 + + +def test_truncate(): + array = np.random.randn(4, 2, 16000) + sig1 = AudioSignal(array, sample_rate=16000) + + sig1.truncate_samples(100) + assert sig1.signal_length == 100 + assert np.allclose(sig1.audio_data, array[..., :100]) + + +def test_trim(): + array = np.random.randn(4, 2, 16000) + sig1 = AudioSignal(array, sample_rate=16000) + + sig1.trim(100, 100) + assert sig1.signal_length == 16000 - 200 + assert np.allclose(sig1.audio_data, array[..., 100:-100]) + + array = np.random.randn(4, 2, 16000) + sig1 = AudioSignal(array, sample_rate=16000) + sig1.trim(0, 0) + assert np.allclose(sig1.audio_data, array) + + +def test_to_from_ops(): + audio_path = "./audio/spk/f10_script4_produced.wav" + signal = AudioSignal(audio_path) + signal.stft() + signal.loudness() + signal = signal.to("cpu") + + assert str(signal.audio_data.place) == "Place(cpu)" + assert isinstance(signal.numpy(), np.ndarray) + + signal.cpu() + # signal.cuda() + signal.float() + + +def test_device(): + audio_path = "./audio/spk/f10_script4_produced.wav" + signal = AudioSignal(audio_path) + signal.to("cpu") + + assert str(signal.device) == "Place(cpu)" + + +@pytest.mark.parametrize("window_length", [2048, 512]) +@pytest.mark.parametrize("hop_length", [512, 128]) +@pytest.mark.parametrize("window_type", ["sqrt_hann", "hann", None]) +def test_stft(window_length, hop_length, window_type): + if hop_length >= window_length: + hop_length = window_length // 2 + audio_path = "./audio/spk/f10_script4_produced.wav" + stft_params = audiotools.STFTParams( + window_length=window_length, + hop_length=hop_length, + window_type=window_type) + for _stft_params in [None, stft_params]: + signal = AudioSignal(audio_path, duration=10, stft_params=_stft_params) + with pytest.raises(RuntimeError): + signal.istft() + + stft_data = signal.stft() + + # assert paddle.allclose(signal.stft_data, stft_data) + assert np.allclose(signal.stft_data.cpu().numpy(), + stft_data.cpu().numpy()) + copied_signal = signal.deepcopy() + copied_signal.stft() + copied_signal = copied_signal.istft() + + assert copied_signal == signal + + mag = signal.magnitude + phase = signal.phase + + recon_stft = mag * util.exp_compat(1j * phase) + # assert paddle.allclose(recon_stft, signal.stft_data) + assert np.allclose(recon_stft.cpu().numpy(), + signal.stft_data.cpu().numpy()) + + signal.stft_data = None + mag = signal.magnitude + signal.stft_data = None + phase = signal.phase + + recon_stft = mag * util.exp_compat(1j * phase) + # assert paddle.allclose(recon_stft, signal.stft_data) + assert np.allclose(recon_stft.cpu().numpy(), + signal.stft_data.cpu().numpy()) + + # Test with match_stride=True, ignoring the beginning and end. + s = signal.stft_params + if s.hop_length == s.window_length // 4: + og_signal = signal.clone() + stft_data = signal.stft(match_stride=True) + recon_data = signal.istft(match_stride=True) + discard = window_length * 2 + + right_pad, _ = signal.compute_stft_padding( + s.window_length, s.hop_length, match_stride=True) + length = signal.signal_length + right_pad + assert stft_data.shape[-1] == length // s.hop_length + + assert paddle.allclose( + recon_data.audio_data[..., discard:-discard], + og_signal.audio_data[..., discard:-discard], + atol=1e-6, ) + + +def test_log_magnitude(): + audio_path = "./audio/spk/f10_script4_produced.wav" + for _ in range(10): + signal = AudioSignal.excerpt(audio_path, duration=5.0) + magnitude = signal.magnitude.numpy()[0, 0] + librosa_log_mag = librosa.amplitude_to_db(magnitude) + log_mag = signal.log_magnitude().numpy()[0, 0] + + # print(abs((log_mag - librosa_log_mag)).max()) + assert np.allclose(log_mag, librosa_log_mag, atol=10e-7) + + +@pytest.mark.parametrize("n_mels", [40, 80, 128]) +@pytest.mark.parametrize("window_length", [2048, 512]) +@pytest.mark.parametrize("hop_length", [512, 128]) +@pytest.mark.parametrize("window_type", ["sqrt_hann", "hann", None]) +def test_mel_spectrogram(n_mels, window_length, hop_length, window_type): + if hop_length >= window_length: + hop_length = window_length // 2 + audio_path = "./audio/spk/f10_script4_produced.wav" + stft_params = audiotools.STFTParams( + window_length=window_length, + hop_length=hop_length, + window_type=window_type) + for _stft_params in [None, stft_params]: + signal = AudioSignal(audio_path, duration=10, stft_params=_stft_params) + mel_spec = signal.mel_spectrogram(n_mels=n_mels) + assert mel_spec.shape[2] == n_mels + + +@pytest.mark.parametrize("n_mfcc", [20, 40]) +@pytest.mark.parametrize("n_mels", [40, 80, 128]) +@pytest.mark.parametrize("window_length", [2048, 512]) +@pytest.mark.parametrize("hop_length", [512, 128]) +def test_mfcc(n_mfcc, n_mels, window_length, hop_length): + if hop_length >= window_length: + hop_length = window_length // 2 + audio_path = "./audio/spk/f10_script4_produced.wav" + stft_params = audiotools.STFTParams( + window_length=window_length, hop_length=hop_length) + for _stft_params in [None, stft_params]: + signal = AudioSignal(audio_path, duration=10, stft_params=_stft_params) + mfcc = signal.mfcc(n_mfcc=n_mfcc, n_mels=n_mels) + assert mfcc.shape[2] == n_mfcc + + +def test_to_mono(): + array = np.random.randn(4, 2, 16000) + sr = 16000 + + signal = AudioSignal(array, sample_rate=sr) + assert signal.num_channels == 2 + + signal = signal.to_mono() + assert signal.num_channels == 1 + + +def test_float(): + array = np.random.randn(4, 1, 16000).astype("float64") + sr = 1600 + signal = AudioSignal(array, sample_rate=sr) + + signal = signal.float() + assert signal.audio_data.dtype == paddle.float32 + + +@pytest.mark.parametrize("sample_rate", [8000, 16000, 22050, 44100, 48000]) +def test_resample(sample_rate): + array = np.random.randn(4, 2, 16000) + sr = 16000 + + signal = AudioSignal(array, sample_rate=sr) + + signal = signal.resample(sample_rate) + assert signal.sample_rate == sample_rate + assert signal.signal_length == sample_rate + + +def test_batching(): + signals = [] + batch_size = 16 + + # All same length, same sample rate. + for _ in range(batch_size): + array = np.random.randn(2, 16000) + signal = AudioSignal(array, sample_rate=16000) + signals.append(signal) + + batched_signal = AudioSignal.batch(signals) + assert batched_signal.batch_size == batch_size + + signals = [] + # All different lengths, same sample rate, pad signals + for _ in range(batch_size): + L = np.random.randint(8000, 32000) + array = np.random.randn(2, L) + signal = AudioSignal(array, sample_rate=16000) + signals.append(signal) + + with pytest.raises(RuntimeError): + batched_signal = AudioSignal.batch(signals) + + signal_lengths = [x.signal_length for x in signals] + max_length = max(signal_lengths) + batched_signal = AudioSignal.batch(signals, pad_signals=True) + + assert batched_signal.signal_length == max_length + assert batched_signal.batch_size == batch_size + + signals = [] + # All different lengths, same sample rate, truncate signals + for _ in range(batch_size): + L = np.random.randint(8000, 32000) + array = np.random.randn(2, L) + signal = AudioSignal(array, sample_rate=16000) + signals.append(signal) + + with pytest.raises(RuntimeError): + batched_signal = AudioSignal.batch(signals) + + signal_lengths = [x.signal_length for x in signals] + min_length = min(signal_lengths) + batched_signal = AudioSignal.batch(signals, truncate_signals=True) + + assert batched_signal.signal_length == min_length + assert batched_signal.batch_size == batch_size + + signals = [] + # All different lengths, different sample rate, pad signals + for _ in range(batch_size): + L = np.random.randint(8000, 32000) + sr = np.random.choice([8000, 16000, 32000]) + array = np.random.randn(2, L) + signal = AudioSignal(array, sample_rate=int(sr)) + signals.append(signal) + + with pytest.raises(RuntimeError): + batched_signal = AudioSignal.batch(signals) + + signal_lengths = [x.signal_length for x in signals] + max_length = max(signal_lengths) + for i, x in enumerate(signals): + x.path_to_file = i + batched_signal = AudioSignal.batch(signals, resample=True, pad_signals=True) + + assert batched_signal.signal_length == max_length + assert batched_signal.batch_size == batch_size + assert batched_signal.path_to_file == list(range(len(signals))) + assert batched_signal.path_to_input_file == batched_signal.path_to_file diff --git a/audio/tests/audiotools/core/test_bands.py b/audio/tests/audiotools/core/test_bands.py new file mode 100644 index 00000000000..0e7a399dac9 --- /dev/null +++ b/audio/tests/audiotools/core/test_bands.py @@ -0,0 +1,54 @@ +# MIT License, Copyright (c) 2020 Alexandre Défossez. +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Modified from julius(https://github.com/adefossez/julius/blob/main/tests/test_bands.py) +import random +import sys +import unittest + +import paddle + +from audio.audiotools.core import pure_tone +from audio.audiotools.core import split_bands +from audio.audiotools.core import SplitBands + + +def delta(a, b, ref, fraction=0.9): + length = a.shape[-1] + compare_length = int(length * fraction) + offset = (length - compare_length) // 2 + a = a[..., offset:offset + length] + b = b[..., offset:offset + length] + return 100 * paddle.abs(a - b).mean() / ref.std() + + +TOLERANCE = 0.5 # Tolerance to errors as percentage of the std of the input signal + + +class _BaseTest(unittest.TestCase): + def assertSimilar(self, a, b, ref, msg=None, tol=TOLERANCE): + self.assertLessEqual(delta(a, b, ref), tol, msg) + + +class TestLowPassFilters(_BaseTest): + def setUp(self): + paddle.seed(1234) + random.seed(1234) + + def test_keep_or_kill(self): + sr = 256 + low = pure_tone(10, sr) + mid = pure_tone(40, sr) + high = pure_tone(100, sr) + + x = low + mid + high + + decomp = split_bands(x, sr, cutoffs=[20, 70]) + self.assertEqual(len(decomp), 3) + for est, gt, name in zip(decomp, [low, mid, high], + ["low", "mid", "high"]): + self.assertSimilar(est, gt, gt, name) + + +if __name__ == "__main__": + unittest.main() diff --git a/audio/tests/audiotools/core/test_display.py b/audio/tests/audiotools/core/test_display.py new file mode 100644 index 00000000000..a73b72b4292 --- /dev/null +++ b/audio/tests/audiotools/core/test_display.py @@ -0,0 +1,51 @@ +# MIT License, Copyright (c) 2023-Present, Descript. +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/tests/core/test_display.py) +import sys +from pathlib import Path + +import numpy as np +from visualdl import LogWriter + +from audio.audiotools import AudioSignal + + +def test_specshow(): + array = np.zeros((1, 16000)) + AudioSignal(array, sample_rate=16000).specshow() + AudioSignal(array, sample_rate=16000).specshow(preemphasis=True) + AudioSignal( + array, sample_rate=16000).specshow( + title="test", preemphasis=True) + AudioSignal( + array, sample_rate=16000).specshow( + format=False, preemphasis=True) + AudioSignal( + array, sample_rate=16000).specshow( + format=False, preemphasis=False, y_axis="mel") + + +def test_waveplot(): + array = np.zeros((1, 16000)) + AudioSignal(array, sample_rate=16000).waveplot() + + +def test_wavespec(): + array = np.zeros((1, 16000)) + AudioSignal(array, sample_rate=16000).wavespec() + + +def test_write_audio_to_tb(): + signal = AudioSignal("./audio/spk/f10_script4_produced.mp3", duration=5) + + Path("./scratch").mkdir(parents=True, exist_ok=True) + writer = LogWriter("./scratch/") + signal.write_audio_to_tb("tag", writer) + + +def test_save_image(): + signal = AudioSignal( + "./audio/spk/f10_script4_produced.wav", duration=10, offset=10) + Path("./scratch").mkdir(parents=True, exist_ok=True) + signal.save_image("./scratch/image.png") diff --git a/audio/tests/audiotools/core/test_dsp.py b/audio/tests/audiotools/core/test_dsp.py new file mode 100644 index 00000000000..b6db1baf70c --- /dev/null +++ b/audio/tests/audiotools/core/test_dsp.py @@ -0,0 +1,181 @@ +# MIT License, Copyright (c) 2023-Present, Descript. +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/tests/core/test_dsp.py) +import sys + +import numpy as np +import paddle +import pytest + +from audio.audiotools import AudioSignal +from audio.audiotools.core.util import sample_from_dist + + +@pytest.mark.parametrize("window_duration", [0.1, 0.25, 0.5, 1.0]) +@pytest.mark.parametrize("sample_rate", [8000, 16000, 22050, 44100]) +@pytest.mark.parametrize("duration", [0.5, 1.0, 2.0, 10.0]) +def test_overlap_add(duration, sample_rate, window_duration): + np.random.seed(0) + if duration > window_duration: + spk_signal = AudioSignal.batch([ + AudioSignal.excerpt( + "./audio/spk/f10_script4_produced.wav", duration=duration) + for _ in range(16) + ]) + spk_signal.resample(sample_rate) + + noise = paddle.randn([16, 1, int(duration * sample_rate)]) + nz_signal = AudioSignal(noise, sample_rate=sample_rate) + + def _test(signal): + hop_duration = window_duration / 2 + windowed_signal = signal.clone().collect_windows(window_duration, + hop_duration) + recombined = windowed_signal.overlap_and_add(hop_duration) + + assert recombined == signal + assert np.allclose(recombined.audio_data, signal.audio_data, 1e-3) + + _test(nz_signal) + _test(spk_signal) + + +@pytest.mark.parametrize("window_duration", [0.1, 0.25, 0.5, 1.0]) +@pytest.mark.parametrize("sample_rate", [8000, 16000, 22050, 44100]) +@pytest.mark.parametrize("duration", [0.5, 1.0, 2.0, 10.0]) +def test_inplace_overlap_add(duration, sample_rate, window_duration): + np.random.seed(0) + if duration > window_duration: + spk_signal = AudioSignal.batch([ + AudioSignal.excerpt( + "./audio/spk/f10_script4_produced.wav", duration=duration) + for _ in range(16) + ]) + spk_signal.resample(sample_rate) + + noise = paddle.randn([16, 1, int(duration * sample_rate)]) + nz_signal = AudioSignal(noise, sample_rate=sample_rate) + + def _test(signal): + hop_duration = window_duration / 2 + windowed_signal = signal.clone().collect_windows(window_duration, + hop_duration) + # Compare in-place with unfold results + for i, window in enumerate( + signal.clone().windows(window_duration, hop_duration)): + assert np.allclose(window.audio_data, + windowed_signal.audio_data[i]) + + _test(nz_signal) + _test(spk_signal) + + +def test_low_pass(): + sample_rate = 44100 + f = 440 + t = paddle.arange(0, 1, 1 / sample_rate) + sine_wave = paddle.sin(2 * np.pi * f * t) + window = AudioSignal.get_window("hann", sine_wave.shape[-1]) + sine_wave = sine_wave * window + signal = AudioSignal(sine_wave.unsqueeze(0), sample_rate=sample_rate) + out = signal.clone().low_pass(220) + assert out.audio_data.abs().max() < 1e-4 + + out = signal.clone().low_pass(880) + assert (out - signal).audio_data.abs().max() < 1e-3 + + batch = AudioSignal.batch([signal.clone(), signal.clone(), signal.clone()]) + + cutoffs = [220, 880, 220] + out = batch.clone().low_pass(cutoffs) + + assert out.audio_data[0].abs().max() < 1e-4 + assert out.audio_data[2].abs().max() < 1e-4 + assert (out - batch).audio_data[1].abs().max() < 1e-3 + + +def test_high_pass(): + sample_rate = 44100 + f = 440 + t = paddle.arange(0, 1, 1 / sample_rate) + sine_wave = paddle.sin(2 * np.pi * f * t) + window = AudioSignal.get_window("hann", sine_wave.shape[-1]) + sine_wave = sine_wave * window + signal = AudioSignal(sine_wave.unsqueeze(0), sample_rate=sample_rate) + out = signal.clone().high_pass(220) + assert (signal - out).audio_data.abs().max() < 1e-4 + + +def test_mask_frequencies(): + sample_rate = 44100 + fs = paddle.to_tensor([500.0, 2000.0, 8000.0, 32000.0])[None] + t = paddle.arange(0, 1, 1 / sample_rate)[:, None] + sine_wave = paddle.sin(2 * np.pi * t @ fs).sum(axis=-1) + sine_wave = AudioSignal(sine_wave, sample_rate) + masked_sine_wave = sine_wave.mask_frequencies(fmin_hz=1500, fmax_hz=10000) + + fs2 = paddle.to_tensor([500.0, 32000.0])[None] + sine_wave2 = paddle.sin(2 * np.pi * t @ fs).sum(axis=-1) + sine_wave2 = AudioSignal(sine_wave2, sample_rate) + + assert paddle.allclose(masked_sine_wave.audio_data, sine_wave2.audio_data) + + +def test_mask_timesteps(): + sample_rate = 44100 + f = 440 + t = paddle.linspace(0, 1, sample_rate) + sine_wave = paddle.sin(2 * np.pi * f * t) + sine_wave = AudioSignal(sine_wave, sample_rate) + + masked_sine_wave = sine_wave.mask_timesteps(tmin_s=0.25, tmax_s=0.75) + masked_sine_wave.istft() + + mask = ((0.3 < t) & (t < 0.7))[None, None] + assert paddle.allclose( + masked_sine_wave.audio_data[mask], + paddle.zeros_like(masked_sine_wave.audio_data[mask]), ) + + +def test_shift_phase(): + sample_rate = 44100 + f = 440 + t = paddle.linspace(0, 1, sample_rate) + sine_wave = paddle.sin(2 * np.pi * f * t) + sine_wave = AudioSignal(sine_wave, sample_rate) + sine_wave2 = sine_wave.clone() + + shifted_sine_wave = sine_wave.shift_phase(np.pi) + shifted_sine_wave.istft() + + sine_wave2.phase = sine_wave2.phase + np.pi + sine_wave2.istft() + + assert paddle.allclose(shifted_sine_wave.audio_data, sine_wave2.audio_data) + + +def test_corrupt_phase(): + sample_rate = 44100 + f = 440 + t = paddle.linspace(0, 1, sample_rate) + sine_wave = paddle.sin(2 * np.pi * f * t) + sine_wave = AudioSignal(sine_wave, sample_rate) + sine_wave2 = sine_wave.clone() + + shifted_sine_wave = sine_wave.corrupt_phase(scale=np.pi) + shifted_sine_wave.istft() + + assert (sine_wave2.phase - shifted_sine_wave.phase).abs().mean() > 0.0 + assert ((sine_wave2.phase - shifted_sine_wave.phase).std() / np.pi) < 1.0 + + +def test_preemphasis(): + x = AudioSignal.excerpt("./audio/spk/f10_script4_produced.wav", duration=5) + import matplotlib.pyplot as plt + + x.specshow(preemphasis=False) + + x.specshow(preemphasis=True) + + x.preemphasis() diff --git a/audio/tests/audiotools/core/test_effects.py b/audio/tests/audiotools/core/test_effects.py new file mode 100644 index 00000000000..9dba9948163 --- /dev/null +++ b/audio/tests/audiotools/core/test_effects.py @@ -0,0 +1,321 @@ +# MIT License, Copyright (c) 2023-Present, Descript. +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/tests/core/test_effects.py) +import sys + +import numpy as np +import paddle +import pytest + +from audio.audiotools import AudioSignal + + +def test_normalize(): + audio_path = "./audio/spk/f10_script4_produced.wav" + signal = AudioSignal(audio_path, offset=10, duration=10) + signal = signal.normalize() + assert np.allclose(signal.loudness(), -24, atol=1e-1) + + array = np.random.randn(1, 2, 32000) + array = array / np.abs(array).max() + + signal = AudioSignal(array, sample_rate=16000) + for db_incr in np.arange(10, 75, 5): + db = -80 + db_incr + signal = signal.normalize(db) + loudness = signal.loudness() + assert np.allclose(loudness, db, atol=1) # TODO, atol=1e-1 + + batch_size = 16 + db = -60 + paddle.linspace(10, 30, batch_size) + + array = np.random.randn(batch_size, 2, 32000) + array = array / np.abs(array).max() + signal = AudioSignal(array, sample_rate=16000) + + signal = signal.normalize(db) + assert np.allclose(signal.loudness(), db, 1e-1) + + +def test_volume_change(): + audio_path = "./audio/spk/f10_script4_produced.wav" + signal = AudioSignal(audio_path, offset=10, duration=10) + + boost = 3 + before_db = signal.loudness().clone() + signal = signal.volume_change(boost) + after_db = signal.loudness() + assert np.allclose(before_db + boost, after_db) + + signal._loudness = None + after_db = signal.loudness() + assert np.allclose(before_db + boost, after_db, 1e-1) + + +def test_mix(): + audio_path = "./audio/spk/f10_script4_produced.wav" + spk = AudioSignal(audio_path, offset=10, duration=10) + + audio_path = "./audio/nz/f5_script2_ipad_balcony1_room_tone.wav" + nz = AudioSignal(audio_path, offset=10, duration=10) + + spk.deepcopy().mix(nz, snr=-10) + snr = spk.loudness() - nz.loudness() + assert np.allclose(snr, -10, atol=1) + + # Test in batch + audio_path = "./audio/spk/f10_script4_produced.wav" + spk = AudioSignal(audio_path, offset=10, duration=10) + + audio_path = "./audio/nz/f5_script2_ipad_balcony1_room_tone.wav" + nz = AudioSignal(audio_path, offset=10, duration=10) + + batch_size = 4 + tgt_snr = paddle.linspace(-10, 10, batch_size) + + spk_batch = AudioSignal.batch([spk.deepcopy() for _ in range(batch_size)]) + nz_batch = AudioSignal.batch([nz.deepcopy() for _ in range(batch_size)]) + + spk_batch.deepcopy().mix(nz_batch, snr=tgt_snr) + snr = spk_batch.loudness() - nz_batch.loudness() + assert np.allclose(snr, tgt_snr, atol=1) + + # Test with "EQing" the other signal + db = 0 + 0 * paddle.rand([10]) + spk_batch.deepcopy().mix(nz_batch, snr=tgt_snr, other_eq=db) + snr = spk_batch.loudness() - nz_batch.loudness() + assert np.allclose(snr, tgt_snr, atol=1) + + +def test_convolve(): + np.random.seed(6) # Found a failing seed + audio_path = "./audio/spk/f10_script4_produced.wav" + spk = AudioSignal(audio_path, offset=10, duration=10) + + impulse = np.zeros((1, 16000), dtype="float32") + impulse[..., 0] = 1 + ir = AudioSignal(impulse, 16000) + batch_size = 4 + + spk_batch = AudioSignal.batch([spk.deepcopy() for _ in range(batch_size)]) + ir_batch = AudioSignal.batch( + [ + ir.deepcopy().zero_pad(np.random.randint(1000), 0) + for _ in range(batch_size) + ], + pad_signals=True, ) + + convolved = spk_batch.deepcopy().convolve(ir_batch) + assert convolved == spk_batch + + # Short duration + audio_path = "./audio/spk/f10_script4_produced.wav" + spk = AudioSignal(audio_path, offset=10, duration=0.1) + + impulse = np.zeros((1, 16000), dtype="float32") + impulse[..., 0] = 1 + ir = AudioSignal(impulse, 16000) + batch_size = 4 + + spk_batch = AudioSignal.batch([spk.deepcopy() for _ in range(batch_size)]) + ir_batch = AudioSignal.batch( + [ + ir.deepcopy().zero_pad(np.random.randint(1000), 0) + for _ in range(batch_size) + ], + pad_signals=True, ) + + convolved = spk_batch.deepcopy().convolve(ir_batch) + assert convolved == spk_batch + + +def test_pipeline(): + # An actual IR, no batching + audio_path = "./audio/spk/f10_script4_produced.wav" + spk = AudioSignal(audio_path, offset=10, duration=5) + + audio_path = "./audio/ir/h179_Bar_1txts.wav" + ir = AudioSignal(audio_path) + spk.deepcopy().convolve(ir) + + audio_path = "./audio/nz/f5_script2_ipad_balcony1_room_tone.wav" + nz = AudioSignal(audio_path, offset=10, duration=5) + + batch_size = 16 + tgt_snr = paddle.linspace(20, 30, batch_size) + + (spk @ ir).mix(nz, snr=tgt_snr) + + +@pytest.mark.parametrize("n_bands", [1, 2, 4, 8, 12, 16]) +def test_mel_filterbank(n_bands): + audio_path = "./audio/spk/f10_script4_produced.wav" + spk = AudioSignal(audio_path, offset=10, duration=1) + fbank = spk.deepcopy().mel_filterbank(n_bands) + + assert paddle.allclose(fbank.sum(-1), spk.audio_data, atol=1e-6) + + # Check if it works in batches. + spk_batch = AudioSignal.batch([ + AudioSignal.excerpt("./audio/spk/f10_script4_produced.wav", duration=2) + for _ in range(16) + ]) + fbank = spk_batch.deepcopy().mel_filterbank(n_bands) + summed = fbank.sum(-1) + assert paddle.allclose(summed, spk_batch.audio_data, atol=1e-6) + + +@pytest.mark.parametrize("n_bands", [1, 2, 4, 8, 12, 16]) +def test_equalizer(n_bands): + audio_path = "./audio/spk/f10_script4_produced.wav" + spk = AudioSignal(audio_path, offset=10, duration=10) + + db = -3 + 1 * paddle.rand([n_bands]) + spk.deepcopy().equalizer(db) + + db = -3 + 1 * np.random.rand(n_bands) + spk.deepcopy().equalizer(db) + + audio_path = "./audio/ir/h179_Bar_1txts.wav" + ir = AudioSignal(audio_path) + db = -3 + 1 * paddle.rand([n_bands]) + + spk.deepcopy().convolve(ir.equalizer(db)) + + spk_batch = AudioSignal.batch([ + AudioSignal.excerpt("./audio/spk/f10_script4_produced.wav", duration=2) + for _ in range(16) + ]) + + db = paddle.zeros([spk_batch.batch_size, n_bands]) + output = spk_batch.deepcopy().equalizer(db) + + assert output == spk_batch + + +def test_clip_distortion(): + audio_path = "./audio/spk/f10_script4_produced.wav" + spk = AudioSignal(audio_path, offset=10, duration=2) + clipped = spk.deepcopy().clip_distortion(0.05) + + spk_batch = AudioSignal.batch([ + AudioSignal.excerpt("./audio/spk/f10_script4_produced.wav", duration=2) + for _ in range(16) + ]) + percs = paddle.to_tensor(np.random.uniform(size=(16, ))).astype("float32") + clipped_batch = spk_batch.deepcopy().clip_distortion(percs) + + assert clipped.audio_data.abs().max() < 1.0 + assert clipped_batch.audio_data.abs().max() < 1.0 + + +@pytest.mark.parametrize("quant_ch", [2, 4, 8, 16, 32, 64, 128]) +def test_quantization(quant_ch): + audio_path = "./audio/spk/f10_script4_produced.wav" + spk = AudioSignal(audio_path, offset=10, duration=2) + + quantized = spk.deepcopy().quantization(quant_ch) + + # Need to round audio_data off because torch ops with straight + # through estimator are sometimes a bit off past 3 decimal places. + found_quant_ch = len(np.unique(np.around(quantized.audio_data, decimals=3))) + assert found_quant_ch <= quant_ch + + spk_batch = AudioSignal.batch([ + AudioSignal.excerpt("./audio/spk/f10_script4_produced.wav", duration=2) + for _ in range(16) + ]) + + quant_ch = np.random.choice( + [2, 4, 8, 16, 32, 64, 128], size=(16, ), replace=True) + quantized = spk_batch.deepcopy().quantization(quant_ch) + + for i, q_ch in enumerate(quant_ch): + found_quant_ch = len( + np.unique(np.around(quantized.audio_data[i], decimals=3))) + assert found_quant_ch <= q_ch + + +@pytest.mark.parametrize("quant_ch", [2, 4, 8, 16, 32, 64, 128]) +def test_mulaw_quantization(quant_ch): + audio_path = "./audio/spk/f10_script4_produced.wav" + spk = AudioSignal(audio_path, offset=10, duration=2) + + quantized = spk.deepcopy().mulaw_quantization(quant_ch) + + # Need to round audio_data off because torch ops with straight + # through estimator are sometimes a bit off past 3 decimal places. + found_quant_ch = len(np.unique(np.around(quantized.audio_data, decimals=3))) + assert found_quant_ch <= quant_ch + + spk_batch = AudioSignal.batch([ + AudioSignal.excerpt("./audio/spk/f10_script4_produced.wav", duration=2) + for _ in range(16) + ]) + + quant_ch = np.random.choice( + [2, 4, 8, 16, 32, 64, 128], size=(16, ), replace=True) + quantized = spk_batch.deepcopy().mulaw_quantization(quant_ch) + + for i, q_ch in enumerate(quant_ch): + found_quant_ch = len( + np.unique(np.around(quantized.audio_data[i], decimals=3))) + assert found_quant_ch <= q_ch + + +def test_impulse_response_augmentation(): + audio_path = "./audio/ir/h179_Bar_1txts.wav" + batch_size = 16 + ir = AudioSignal(audio_path) + ir_batch = AudioSignal.batch([ir for _ in range(batch_size)]) + early_response, late_field, window = ir_batch.decompose_ir() + + assert early_response.shape == late_field.shape + assert late_field.shape == window.shape + + drr = ir_batch.measure_drr() + + alpha = AudioSignal.solve_alpha(early_response, late_field, window, drr) + assert np.allclose(alpha, np.ones_like(alpha), 1e-5) + + target_drr = 5 + out = ir_batch.deepcopy().alter_drr(target_drr) + drr = out.measure_drr() + assert np.allclose(drr, np.ones_like(drr) * target_drr) + + target_drr = np.random.rand(batch_size).astype("float32") * 50 + altered_ir = ir_batch.deepcopy().alter_drr(target_drr) + drr = altered_ir.measure_drr() + assert np.allclose(drr.flatten(), target_drr.flatten()) + + +def test_apply_ir(): + audio_path = "./audio/spk/f10_script4_produced.wav" + ir_path = "./audio/ir/h179_Bar_1txts.wav" + + spk = AudioSignal(audio_path, offset=10, duration=2) + ir = AudioSignal(ir_path) + db = 0 + 0 * paddle.rand([10]) + output = spk.deepcopy().apply_ir(ir, drr=10, ir_eq=db) + + assert np.allclose(ir.measure_drr().flatten(), 10) + + output = spk.deepcopy().apply_ir( + ir, drr=10, ir_eq=db, use_original_phase=True) + + +def test_ensure_max_of_audio(): + spk = AudioSignal(paddle.randn([1, 1, 44100]), 44100) + + max_vals = [1.0] + [np.random.rand() for _ in range(10)] + for val in max_vals: + after = spk.deepcopy().ensure_max_of_audio(val) + assert after.audio_data.abs().max() <= val + 1e-3 + + # Make sure it does nothing to a tiny signal + spk = AudioSignal(paddle.rand([1, 1, 44100]), 44100) + spk.audio_data = spk.audio_data * 0.5 + after = spk.deepcopy().ensure_max_of_audio() + + assert paddle.allclose(after.audio_data, spk.audio_data) diff --git a/audio/tests/audiotools/core/test_fftconv.py b/audio/tests/audiotools/core/test_fftconv.py new file mode 100644 index 00000000000..c3430dae598 --- /dev/null +++ b/audio/tests/audiotools/core/test_fftconv.py @@ -0,0 +1,85 @@ +# MIT License, Copyright (c) 2020 Alexandre Défossez. +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Modified from julius(https://github.com/adefossez/julius/blob/main/tests/test_fftconv.py) +import random +import sys +import unittest + +import paddle +import paddle.nn.functional as F + +from audio.audiotools.core import fft_conv1d +from audio.audiotools.core import FFTConv1D + +TOLERANCE = 1e-4 # as relative delta in percentage + + +class _BaseTest(unittest.TestCase): + def setUp(self): + paddle.seed(1234) + random.seed(1234) + + def assertSimilar(self, a, b, msg=None, tol=TOLERANCE): + delta = 100 * paddle.norm(a - b, p=2) / paddle.norm(b, p=2) + self.assertLessEqual(delta.numpy(), tol, msg) + + def compare_paddle(self, *args, msg=None, tol=TOLERANCE, **kwargs): + y_ref = F.conv1d(*args, **kwargs) + y = fft_conv1d(*args, **kwargs) + self.assertEqual(list(y.shape), list(y_ref.shape), msg) + self.assertSimilar(y, y_ref, msg, tol) + + +class TestFFTConv1d(_BaseTest): + def test_same_as_paddle(self): + for _ in range(5): + kernel_size = random.randrange(4, 128) + batch_size = random.randrange(1, 6) + length = random.randrange(kernel_size, 1024) + chin = random.randrange(1, 12) + chout = random.randrange(1, 12) + bias = random.random() < 0.5 + if random.random() < 0.5: + padding = 0 + else: + padding = random.randrange(kernel_size // 2, 2 * kernel_size) + x = paddle.randn([batch_size, chin, length]) + w = paddle.randn([chout, chin, kernel_size]) + keys = ["length", "kernel_size", "chin", "chout", "bias"] + loc = locals() + state = {key: loc[key] for key in keys} + if bias: + bias = paddle.randn([chout]) + else: + bias = None + for stride in [1, 2, 5]: + state["stride"] = stride + self.compare_paddle( + x, w, bias, stride, padding, msg=repr(state)) + + def test_small_input(self): + x = paddle.randn([1, 5, 19]) + w = paddle.randn([10, 5, 32]) + with self.assertRaises(RuntimeError): + fft_conv1d(x, w) + + x = paddle.randn([1, 5, 19]) + w = paddle.randn([10, 5, 19]) + self.assertEqual(list(fft_conv1d(x, w).shape), [1, 10, 1]) + + def test_module(self): + x = paddle.randn([16, 4, 1024]) + mod = FFTConv1D(4, 5, 8, bias_attr=True) + mod(x) + mod = FFTConv1D(4, 5, 8, bias_attr=False) + mod(x) + + def test_dynamic_graph(self): + x = paddle.randn([16, 4, 1024]) + mod = FFTConv1D(4, 5, 8, bias_attr=True) + self.assertEqual(list(mod(x).shape), [16, 5, 1024 - 8 + 1]) + + +if __name__ == "__main__": + unittest.main() diff --git a/audio/tests/audiotools/core/test_grad.py b/audio/tests/audiotools/core/test_grad.py new file mode 100644 index 00000000000..e90320b68ae --- /dev/null +++ b/audio/tests/audiotools/core/test_grad.py @@ -0,0 +1,172 @@ +# MIT License, Copyright (c) 2023-Present, Descript. +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/tests/core/test_grad.py) +import sys +from typing import Callable + +import numpy as np +import paddle +import pytest + +from audio.audiotools import AudioSignal + + +def test_audio_grad(): + audio_path = "./audio/spk/f10_script4_produced.wav" + ir_path = "./audio/ir/h179_Bar_1txts.wav" + + def _test_audio_grad(attr: str, target=True, kwargs: dict={}): + signal = AudioSignal(audio_path) + signal.audio_data.stop_gradient = False + + assert signal.audio_data.grad is None + + # Avoid overwriting leaf tensor by cloning signal + attr = getattr(signal.clone(), attr) + result = attr(**kwargs) if isinstance(attr, Callable) else attr + + try: + if isinstance(result, AudioSignal): + # If necessary, propagate spectrogram changes to waveform + if result.stft_data is not None: + result.istft() + # if result.audio_data.dtype.is_complex: + if paddle.is_complex(result.audio_data): + result.audio_data.real.sum().backward() + else: + result.audio_data.sum().backward() + else: + # if result.dtype.is_complex: + if paddle.is_complex(result): + result.real().sum().backward() + else: + result.sum().backward() + + assert signal.audio_data.grad is not None or not target + except RuntimeError: + assert not target + + for a in [ + ["mix", True, { + "other": AudioSignal(audio_path), + "snr": 0 + }], + ["convolve", True, { + "other": AudioSignal(ir_path) + }], + [ + "apply_ir", + True, + { + "ir": AudioSignal(ir_path), + "drr": 0.1, + "ir_eq": paddle.randn([6]) + }, + ], + ["ensure_max_of_audio", True], + ["normalize", True], + ["volume_change", True, { + "db": 1 + }], + # ["pitch_shift", False, {"n_semitones": 1}], + # ["time_stretch", False, {"factor": 2}], + # ["apply_codec", False], + ["equalizer", True, { + "db": paddle.randn([6]) + }], + ["clip_distortion", True, { + "clip_percentile": 0.5 + }], + ["quantization", True, { + "quantization_channels": 8 + }], + ["mulaw_quantization", True, { + "quantization_channels": 8 + }], + ["resample", True, { + "sample_rate": 16000 + }], + ["low_pass", True, { + "cutoffs": 1000 + }], + ["high_pass", True, { + "cutoffs": 1000 + }], + ["to_mono", True], + ["zero_pad", True, { + "before": 10, + "after": 10 + }], + ["magnitude", True], + ["phase", True], + ["log_magnitude", True], + ["loudness", False], + ["stft", True], + ["clone", True], + ["mel_spectrogram", True], + ["zero_pad_to", True, { + "length": 100000 + }], + ["truncate_samples", True, { + "length_in_samples": 1000 + }], + ["corrupt_phase", True, { + "scale": 0.5 + }], + ["shift_phase", True, { + "shift": 1 + }], + ["mask_low_magnitudes", True, { + "db_cutoff": 0 + }], + ["mask_frequencies", True, { + "fmin_hz": 100, + "fmax_hz": 1000 + }], + ["mask_timesteps", True, { + "tmin_s": 0.1, + "tmax_s": 0.5 + }], + ["__add__", True, { + "other": AudioSignal(audio_path) + }], + ["__iadd__", True, { + "other": AudioSignal(audio_path) + }], + ["__radd__", True, { + "other": AudioSignal(audio_path) + }], + ["__sub__", True, { + "other": AudioSignal(audio_path) + }], + ["__isub__", True, { + "other": AudioSignal(audio_path) + }], + ["__mul__", True, { + "other": AudioSignal(audio_path) + }], + ["__imul__", True, { + "other": AudioSignal(audio_path) + }], + ["__rmul__", True, { + "other": AudioSignal(audio_path) + }], + ]: + _test_audio_grad(*a) + + +def test_batch_grad(): + audio_path = "./audio/spk/f10_script4_produced.wav" + + signal = AudioSignal(audio_path) + signal.audio_data.stop_gradient = False + + assert signal.audio_data.grad is None + + batch_size = 16 + batch = AudioSignal.batch([signal.clone() for _ in range(batch_size)]) + + batch.audio_data.sum().backward() + + assert signal.audio_data.grad is not None diff --git a/audio/tests/audiotools/core/test_highpass.py b/audio/tests/audiotools/core/test_highpass.py new file mode 100644 index 00000000000..0959474b5a7 --- /dev/null +++ b/audio/tests/audiotools/core/test_highpass.py @@ -0,0 +1,104 @@ +# MIT License, Copyright (c) 2020 Alexandre Défossez. +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Modified from julius(https://github.com/adefossez/julius/blob/main/tests/test_filters.py) +import math +import random +import sys +import unittest + +import paddle + +from audio.audiotools.core import highpass_filter +from audio.audiotools.core import highpass_filters + + +def pure_tone(freq: float, sr: float=128, dur: float=4, device=None): + """ + Return a pure tone, i.e. cosine. + + Args: + freq (float): frequency (in Hz) + sr (float): sample rate (in Hz) + dur (float): duration (in seconds) + """ + time = paddle.arange(int(sr * dur), dtype="float32") / sr + return paddle.cos(2 * math.pi * freq * time) + + +def delta(a, b, ref, fraction=0.9): + length = a.shape[-1] + compare_length = int(length * fraction) + offset = (length - compare_length) // 2 + a = a[..., offset:offset + length] + b = b[..., offset:offset + length] + # 计算绝对差值,均值,然后除以ref的标准差,乘以100 + return 100 * paddle.mean(paddle.abs(a - b)) / paddle.std(ref) + + +TOLERANCE = 1 # Tolerance to errors as percentage of the std of the input signal + + +class _BaseTest(unittest.TestCase): + def assertSimilar(self, a, b, ref, msg=None, tol=TOLERANCE): + self.assertLessEqual(delta(a, b, ref), tol, msg) + + +class TestHighPassFilters(_BaseTest): + def setUp(self): + paddle.seed(1234) + random.seed(1234) + + def test_keep_or_kill(self): + for _ in range(10): + freq = random.uniform(0.01, 0.4) + sr = 1024 + tone = pure_tone(freq * sr, sr=sr, dur=10) + + # For this test we accept 5% tolerance in amplitude, or -26dB in power. + tol = 5 + zeros = 16 + + # If cutoff frequency is under freq, output should be input + y_pass = highpass_filter(tone, 0.9 * freq, zeros=zeros) + self.assertSimilar( + y_pass, tone, tone, f"freq={freq}, pass", tol=tol) + + # If cutoff frequency is over freq, output should be zero + y_killed = highpass_filter(tone, 1.1 * freq, zeros=zeros) + self.assertSimilar( + y_killed, 0 * tone, tone, f"freq={freq}, kill", tol=tol) + + def test_fft_nofft(self): + for _ in range(10): + x = paddle.randn([1024]) + freq = random.uniform(0.01, 0.5) + y_fft = highpass_filter(x, freq, fft=True) + y_ref = highpass_filter(x, freq, fft=False) + self.assertSimilar(y_fft, y_ref, x, f"freq={freq}", tol=0.01) + + def test_constant(self): + x = paddle.ones([2048]) + for zeros in [4, 10]: + for freq in [0.01, 0.1]: + y_high = highpass_filter(x, freq, zeros=zeros) + self.assertLessEqual(y_high.abs().mean(), 1e-6, (zeros, freq)) + + def test_stride(self): + x = paddle.randn([1024]) + + y = highpass_filters(x, [0.1, 0.2], stride=1)[:, ::3] + y2 = highpass_filters(x, [0.1, 0.2], stride=3) + + self.assertEqual(y.shape, y2.shape) + self.assertSimilar(y, y2, x) + + y = highpass_filters(x, [0.1, 0.2], stride=1, pad=False)[:, ::3] + y2 = highpass_filters(x, [0.1, 0.2], stride=3, pad=False) + + self.assertEqual(y.shape, y2.shape) + self.assertSimilar(y, y2, x) + + +if __name__ == "__main__": + unittest.main() diff --git a/audio/tests/audiotools/core/test_loudness.py b/audio/tests/audiotools/core/test_loudness.py new file mode 100644 index 00000000000..a4f7cc4f3b5 --- /dev/null +++ b/audio/tests/audiotools/core/test_loudness.py @@ -0,0 +1,274 @@ +# MIT License, Copyright (c) 2023-Present, Descript. +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/tests/core/test_loudness.py) +import sys + +import numpy as np +import pyloudnorm +import soundfile as sf + +from audio.audiotools import AudioSignal +from audio.audiotools import datasets +from audio.audiotools import Meter +from audio.audiotools import transforms + +ATOL = 1e-1 + + +def test_loudness_against_pyln(): + audio_path = "./audio/spk/f10_script4_produced.wav" + signal = AudioSignal(audio_path, offset=5, duration=10) + signal_loudness = signal.loudness() + + meter = pyloudnorm.Meter( + signal.sample_rate, filter_class="K-weighting", block_size=0.4) + py_loudness = meter.integrated_loudness(signal.numpy()[0].T) + assert np.allclose(signal_loudness, py_loudness) + + +def test_loudness_short(): + audio_path = "./audio/spk/f10_script4_produced.wav" + signal = AudioSignal(audio_path, offset=10, duration=0.25) + signal_loudness = signal.loudness() + + +def test_batch_loudness(): + np.random.seed(0) + array = np.random.randn(16, 2, 16000) + array /= np.abs(array).max() + + gains = np.random.rand(array.shape[0])[:, None, None] + array = array * gains + + meter = pyloudnorm.Meter(16000) + py_loudness = [ + meter.integrated_loudness(array[i].T) for i in range(array.shape[0]) + ] + + meter = Meter(16000) + meter.filter_class + at_loudness_iso = [ + meter.integrated_loudness(array[i].T).item() + for i in range(array.shape[0]) + ] + + assert np.allclose(py_loudness, at_loudness_iso, atol=1e-1) + + signal = AudioSignal(array, sample_rate=16000) + at_loudness_batch = signal.loudness() + assert np.allclose(py_loudness, at_loudness_batch, atol=1e-1) + + +# Tests below are copied from pyloudnorm +def test_integrated_loudness(): + data, rate = sf.read("./audio/loudness/sine_1000.wav") + meter = Meter(rate) + loudness = meter(data) + + targetLoudness = -3.0523438444331137 + assert np.allclose(loudness, targetLoudness) + + +def test_rel_gate_test(): + data, rate = sf.read("./audio/loudness/1770-2_Comp_RelGateTest.wav") + meter = Meter(rate) + loudness = meter.integrated_loudness(data) + + targetLoudness = -10.0 + assert np.allclose(loudness, targetLoudness, atol=ATOL) + + +def test_abs_gate_test(): + data, rate = sf.read("./audio/loudness/1770-2_Comp_AbsGateTest.wav") + meter = Meter(rate) + loudness = meter.integrated_loudness(data) + + targetLoudness = -69.5 + assert np.allclose(loudness, targetLoudness, atol=ATOL) + + +def test_24LKFS_25Hz_2ch(): + data, rate = sf.read("./audio/loudness/1770-2_Comp_24LKFS_25Hz_2ch.wav") + meter = Meter(rate) + loudness = meter.integrated_loudness(data) + + targetLoudness = -24.0 + assert np.allclose(loudness, targetLoudness, atol=ATOL) + + +def test_24LKFS_100Hz_2ch(): + data, rate = sf.read("./audio/loudness/1770-2_Comp_24LKFS_100Hz_2ch.wav") + meter = Meter(rate) + loudness = meter.integrated_loudness(data) + + targetLoudness = -24.0 + assert np.allclose(loudness, targetLoudness, atol=ATOL) + + +def test_24LKFS_500Hz_2ch(): + data, rate = sf.read("./audio/loudness/1770-2_Comp_24LKFS_500Hz_2ch.wav") + meter = Meter(rate) + loudness = meter.integrated_loudness(data) + + targetLoudness = -24.0 + assert np.allclose(loudness, targetLoudness, atol=ATOL) + + +def test_24LKFS_1000Hz_2ch(): + data, rate = sf.read("./audio/loudness/1770-2_Comp_24LKFS_1000Hz_2ch.wav") + meter = Meter(rate) + loudness = meter.integrated_loudness(data) + + targetLoudness = -24.0 + assert np.allclose(loudness, targetLoudness, atol=ATOL) + + +def test_24LKFS_2000Hz_2ch(): + data, rate = sf.read("./audio/loudness/1770-2_Comp_24LKFS_2000Hz_2ch.wav") + meter = Meter(rate) + loudness = meter.integrated_loudness(data) + + targetLoudness = -24.0 + assert np.allclose(loudness, targetLoudness, atol=ATOL) + + +def test_24LKFS_10000Hz_2ch(): + data, rate = sf.read("./audio/loudness/1770-2_Comp_24LKFS_10000Hz_2ch.wav") + meter = Meter(rate) + loudness = meter.integrated_loudness(data) + + targetLoudness = -24.0 + assert np.allclose(loudness, targetLoudness, atol=ATOL) + + +def test_23LKFS_25Hz_2ch(): + data, rate = sf.read("./audio/loudness/1770-2_Comp_23LKFS_25Hz_2ch.wav") + meter = Meter(rate) + loudness = meter.integrated_loudness(data) + + targetLoudness = -23.0 + assert np.allclose(loudness, targetLoudness, atol=ATOL) + + +def test_23LKFS_100Hz_2ch(): + data, rate = sf.read("./audio/loudness/1770-2_Comp_23LKFS_100Hz_2ch.wav") + meter = Meter(rate) + loudness = meter.integrated_loudness(data) + + targetLoudness = -23.0 + assert np.allclose(loudness, targetLoudness, atol=ATOL) + + +def test_23LKFS_500Hz_2ch(): + data, rate = sf.read("./audio/loudness/1770-2_Comp_23LKFS_500Hz_2ch.wav") + meter = Meter(rate) + loudness = meter.integrated_loudness(data) + + targetLoudness = -23.0 + assert np.allclose(loudness, targetLoudness, atol=ATOL) + + +def test_23LKFS_1000Hz_2ch(): + data, rate = sf.read("./audio/loudness/1770-2_Comp_23LKFS_1000Hz_2ch.wav") + meter = Meter(rate) + loudness = meter.integrated_loudness(data) + + targetLoudness = -23.0 + assert np.allclose(loudness, targetLoudness, atol=ATOL) + + +def test_23LKFS_2000Hz_2ch(): + data, rate = sf.read("./audio/loudness/1770-2_Comp_23LKFS_2000Hz_2ch.wav") + meter = Meter(rate) + loudness = meter.integrated_loudness(data) + + targetLoudness = -23.0 + assert np.allclose(loudness, targetLoudness, atol=ATOL) + + +def test_23LKFS_10000Hz_2ch(): + data, rate = sf.read("./audio/loudness/1770-2_Comp_23LKFS_10000Hz_2ch.wav") + meter = Meter(rate) + loudness = meter.integrated_loudness(data) + + targetLoudness = -23.0 + assert np.allclose(loudness, targetLoudness, atol=ATOL) + + +def test_18LKFS_frequency_sweep(): + data, rate = sf.read( + "./audio/loudness/1770-2_Comp_18LKFS_FrequencySweep.wav") + meter = Meter(rate) + loudness = meter.integrated_loudness(data) + + targetLoudness = -18.0 + assert np.allclose(loudness, targetLoudness, atol=ATOL) + + +def test_conf_stereo_vinL_R_23LKFS(): + data, rate = sf.read( + "./audio/loudness/1770-2_Conf_Stereo_VinL+R-23LKFS.wav") + meter = Meter(rate) + loudness = meter.integrated_loudness(data) + + targetLoudness = -23.0 + assert np.allclose(loudness, targetLoudness, atol=ATOL) + + +def test_conf_monovoice_music_24LKFS(): + data, rate = sf.read( + "./audio/loudness/1770-2_Conf_Mono_Voice+Music-24LKFS.wav") + meter = Meter(rate) + loudness = meter.integrated_loudness(data) + + targetLoudness = -24.0 + assert np.allclose(loudness, targetLoudness, atol=ATOL) + + +def conf_monovoice_music_24LKFS(): + data, rate = sf.read( + "./audio/loudness/1770-2_Conf_Mono_Voice+Music-24LKFS.wav") + meter = Meter(rate) + loudness = meter.integrated_loudness(data) + + targetLoudness = -24.0 + assert np.allclose(loudness, targetLoudness, atol=ATOL) + + +def test_conf_monovoice_music_23LKFS(): + data, rate = sf.read( + "./audio/loudness/1770-2_Conf_Mono_Voice+Music-23LKFS.wav") + meter = Meter(rate) + loudness = meter.integrated_loudness(data) + + targetLoudness = -23.0 + assert np.allclose(loudness, targetLoudness, atol=ATOL) + + +def test_fir_accuracy(): + transform = transforms.Compose( + transforms.ClippingDistortion(prob=0.5), + transforms.LowPass(prob=0.5), + transforms.HighPass(prob=0.5), + transforms.Equalizer(prob=0.5), + prob=0.5, ) + loader = datasets.AudioLoader(sources=["./audio/spk.csv"]) + dataset = datasets.AudioDataset( + loader, + 44100, + 10, + 5.0, + transform=transform, ) + + for i in range(20): + item = dataset[i] + kwargs = item["transform_args"] + signal = item["signal"] + signal = transform(signal, **kwargs) + + signal._loudness = None + iir_db = signal.clone().loudness() + fir_db = signal.clone().loudness(use_fir=True) + + assert np.allclose(iir_db, fir_db, atol=1e-2) diff --git a/audio/tests/audiotools/core/test_lowpass.py b/audio/tests/audiotools/core/test_lowpass.py new file mode 100644 index 00000000000..5b00e757fa1 --- /dev/null +++ b/audio/tests/audiotools/core/test_lowpass.py @@ -0,0 +1,109 @@ +# MIT License, Copyright (c) 2020 Alexandre Défossez. +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Modified from julius(https://github.com/adefossez/julius/blob/main/tests/test_lowpass.py) +import math +import random +import sys +import unittest + +import numpy as np +import paddle + +from audio.audiotools.core import lowpass_filter +from audio.audiotools.core import LowPassFilter +from audio.audiotools.core import LowPassFilters +from audio.audiotools.core import resample_frac + + +def pure_tone(freq: float, sr: float=128, dur: float=4, device=None): + """ + Return a pure tone, i.e. cosine. + + Args: + freq (float): frequency (in Hz) + sr (float): sample rate (in Hz) + dur (float): duration (in seconds) + """ + time = paddle.arange(int(sr * dur), dtype="float32") / sr + return paddle.cos(2 * math.pi * freq * time) + + +def delta(a, b, ref, fraction=0.9): + length = a.shape[-1] + compare_length = int(length * fraction) + offset = (length - compare_length) // 2 + a = a[..., offset:offset + length] + b = b[..., offset:offset + length] + # 计算绝对差值,均值,然后除以ref的标准差,乘以100 + return 100 * paddle.mean(paddle.abs(a - b)) / paddle.std(ref) + + +TOLERANCE = 1 # Tolerance to errors as percentage of the std of the input signal + + +class _BaseTest(unittest.TestCase): + def assertSimilar(self, a, b, ref, msg=None, tol=TOLERANCE): + self.assertLessEqual(delta(a, b, ref), tol, msg) + + +class TestLowPassFilters(_BaseTest): + def setUp(self): + paddle.seed(1234) + random.seed(1234) + + def test_keep_or_kill(self): + for _ in range(10): + freq = random.uniform(0.01, 0.4) + sr = 1024 + tone = pure_tone(freq * sr, sr=sr, dur=10) + + # For this test we accept 5% tolerance in amplitude, or -26dB in power. + tol = 5 + zeros = 16 + + # If cutoff frequency is under freq, output should be zero + y_killed = lowpass_filter(tone, 0.9 * freq, zeros=zeros) + self.assertSimilar( + y_killed, 0 * y_killed, tone, f"freq={freq}, kill", tol=tol) + + # If cutoff frequency is under freq, output should be input + y_pass = lowpass_filter(tone, 1.1 * freq, zeros=zeros) + self.assertSimilar( + y_pass, tone, tone, f"freq={freq}, pass", tol=tol) + + def test_same_as_downsample(self): + for _ in range(10): + x = paddle.randn([2 * 3 * 4 * 100]) + x = paddle.ones_like(x) + np.random.seed(1234) + x = paddle.to_tensor( + np.random.randn(2 * 3 * 4 * 100), dtype="float32") + rolloff = 0.945 + for old_sr in [2, 3, 4]: + y_resampled = resample_frac( + x, old_sr, 1, rolloff=rolloff, zeros=16) + y_lowpass = lowpass_filter( + x, rolloff / old_sr / 2, stride=old_sr, zeros=16) + self.assertSimilar(y_resampled, y_lowpass, x, + f"old_sr={old_sr}") + + def test_fft_nofft(self): + for _ in range(10): + x = paddle.randn([1024]) + freq = random.uniform(0.01, 0.5) + y_fft = lowpass_filter(x, freq, fft=True) + y_ref = lowpass_filter(x, freq, fft=False) + self.assertSimilar(y_fft, y_ref, x, f"freq={freq}", tol=0.01) + + def test_constant(self): + x = paddle.ones([2048]) + for zeros in [4, 10]: + for freq in [0.01, 0.1]: + y_low = lowpass_filter(x, freq, zeros=zeros) + self.assertLessEqual((y_low - 1).abs().mean(), 1e-6, + (zeros, freq)) + + +if __name__ == "__main__": + unittest.main() diff --git a/audio/tests/audiotools/core/test_util.py b/audio/tests/audiotools/core/test_util.py new file mode 100644 index 00000000000..7516dce47e9 --- /dev/null +++ b/audio/tests/audiotools/core/test_util.py @@ -0,0 +1,157 @@ +# MIT License, Copyright (c) 2023-Present, Descript. +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/tests/core/test_util.py) +import os +import random +import sys +import tempfile + +import numpy as np +import paddle +import pytest + +from audio.audiotools import util +from audio.audiotools.core.audio_signal import AudioSignal +from paddlespeech.vector.training.seeding import seed_everything + + +def test_check_random_state(): + # seed is None + rng_type = type(np.random.RandomState(10)) + rng = util.random_state(None) + assert type(rng) == rng_type + + # seed is int + rng = util.random_state(10) + assert type(rng) == rng_type + + # seed is RandomState + rng_test = np.random.RandomState(10) + rng = util.random_state(rng_test) + assert type(rng) == rng_type + + # seed is none of the above : error + pytest.raises(ValueError, util.random_state, "random") + + +def test_seed(): + seed_everything(0) + paddle_result_a = paddle.randn([1]) + np_result_a = np.random.randn(1) + py_result_a = random.random() + + seed_everything(0) + paddle_result_b = paddle.randn([1]) + np_result_b = np.random.randn(1) + py_result_b = random.random() + + assert paddle_result_a == paddle_result_b + assert np_result_a == np_result_b + assert py_result_a == py_result_b + + +def test_hz_to_bin(): + hz = paddle.to_tensor(np.array([100, 200, 300]), dtype="float32") + sr = 1000 + n_fft = 2048 + + bins = util.hz_to_bin(hz, n_fft, sr) + + assert (((bins / n_fft) * sr) - hz).abs().max() < 1 + + +def test_find_audio(): + wav_files = util.find_audio("tests/", ["wav"]) + for a in wav_files: + assert "wav" in str(a) + + audio_files = util.find_audio("tests/", ["flac"]) + assert not audio_files + + # Make sure it works with single audio files + audio_files = util.find_audio("./audio/spk//f10_script4_produced.wav") + + # Make sure it works with globs + audio_files = util.find_audio("tests/**/*.wav") + assert len(audio_files) == len(wav_files) + + +def test_chdir(): + with tempfile.TemporaryDirectory(suffix="tmp") as d: + with util.chdir(d): + assert os.path.samefile(d, os.path.realpath(".")) + + +def test_prepare_batch(): + batch = {"tensor": paddle.randn([1]), "non_tensor": np.random.randn(1)} + util.prepare_batch(batch) + + batch = paddle.randn([1]) + util.prepare_batch(batch) + + batch = [paddle.randn([1]), np.random.randn(1)] + util.prepare_batch(batch) + + +def test_sample_dist(): + state = util.random_state(0) + v1 = state.uniform(0.0, 1.0) + v2 = util.sample_from_dist(("uniform", 0.0, 1.0), 0) + assert v1 == v2 + + assert util.sample_from_dist(("const", 1.0)) == 1.0 + + dist_tuple = ("choice", [8, 16, 32]) + assert util.sample_from_dist(dist_tuple) in [8, 16, 32] + + +def test_collate(): + batch_size = 16 + + def _one_item(): + return { + "signal": AudioSignal(paddle.randn([1, 1, 44100]), 44100), + "tensor": paddle.randn([1]), + "string": "Testing", + "dict": { + "nested_signal": + AudioSignal(paddle.randn([1, 1, 44100]), 44100), + }, + } + + items = [_one_item() for _ in range(batch_size)] + collated = util.collate(items) + + assert collated["signal"].batch_size == batch_size + assert collated["tensor"].shape[0] == batch_size + assert len(collated["string"]) == batch_size + assert collated["dict"]["nested_signal"].batch_size == batch_size + + # test collate with splitting (evenly) + batch_size = 16 + n_splits = 4 + + items = [_one_item() for _ in range(batch_size)] + collated = util.collate(items, n_splits=n_splits) + + for x in collated: + assert x["signal"].batch_size == batch_size // n_splits + assert x["tensor"].shape[0] == batch_size // n_splits + assert len(x["string"]) == batch_size // n_splits + assert x["dict"]["nested_signal"].batch_size == batch_size // n_splits + + # test collate with splitting (unevenly) + batch_size = 15 + n_splits = 4 + + items = [_one_item() for _ in range(batch_size)] + collated = util.collate(items, n_splits=n_splits) + + tlen = [4, 4, 4, 3] + + for x, t in zip(collated, tlen): + assert x["signal"].batch_size == t + assert x["tensor"].shape[0] == t + assert len(x["string"]) == t + assert x["dict"]["nested_signal"].batch_size == t diff --git a/audio/tests/audiotools/data/test_datasets.py b/audio/tests/audiotools/data/test_datasets.py new file mode 100644 index 00000000000..f26267ca0cd --- /dev/null +++ b/audio/tests/audiotools/data/test_datasets.py @@ -0,0 +1,208 @@ +# MIT License, Copyright (c) 2023-Present, Descript. +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/tests/data/test_datasets.py) +import sys +import tempfile +from pathlib import Path + +import numpy as np +import paddle +import pytest + +from audio import audiotools +from audio.audiotools.data import transforms as tfm + + +def test_align_lists(): + input_lists = [ + ["a/1.wav", "b/1.wav", "c/1.wav", "d/1.wav"], + ["a/2.wav", "c/2.wav"], + ["c/3.wav"], + ] + target_lists = [ + ["a/1.wav", "b/1.wav", "c/1.wav", "d/1.wav"], + ["a/2.wav", "none", "c/2.wav", "none"], + ["none", "none", "c/3.wav", "none"], + ] + + def _preprocess(lists): + output = [] + for x in lists: + output.append([]) + for y in x: + output[-1].append({"path": y}) + return output + + input_lists = _preprocess(input_lists) + target_lists = _preprocess(target_lists) + + aligned_lists = audiotools.datasets.align_lists(input_lists) + assert target_lists == aligned_lists + + +def test_audio_dataset(): + transform = tfm.Compose( + [ + tfm.VolumeNorm(), + tfm.Silence(prob=0.5), + ], ) + loader = audiotools.data.datasets.AudioLoader( + sources=["./audio/spk.csv"], + transform=transform, ) + dataset = audiotools.data.datasets.AudioDataset( + loader, + 44100, + n_examples=100, + transform=transform, ) + dataloader = paddle.io.DataLoader( + dataset, + batch_size=16, + num_workers=0, + collate_fn=dataset.collate, ) + for batch in dataloader: + kwargs = batch["transform_args"] + signal = batch["signal"] + original = signal.clone() + + signal = dataset.transform(signal, **kwargs) + original = dataset.transform(original, **kwargs) + + mask = kwargs["Compose"]["1.Silence"]["mask"] + + zeros_ = paddle.zeros_like(signal[mask].audio_data) + original_ = original[~mask].audio_data + + assert paddle.allclose(signal[mask].audio_data, zeros_) + assert paddle.allclose(signal[~mask].audio_data, original_) + + +def test_aligned_audio_dataset(): + with tempfile.TemporaryDirectory() as d: + dataset_dir = Path(d) + audiotools.util.generate_chord_dataset( + max_voices=8, num_items=3, output_dir=dataset_dir) + loaders = [ + audiotools.data.datasets.AudioLoader([dataset_dir / f"track_{i}"]) + for i in range(3) + ] + dataset = audiotools.data.datasets.AudioDataset( + loaders, 44100, n_examples=1000, aligned=True, shuffle_loaders=True) + dataloader = paddle.io.DataLoader( + dataset, + batch_size=16, + num_workers=0, + collate_fn=dataset.collate, ) + + # Make sure the voice tracks are aligned. + for batch in dataloader: + paths = [] + for i in range(len(loaders)): + _paths = [p.split("/")[-1] for p in batch[i]["path"]] + paths.append(_paths) + paths = np.array(paths) + for i in range(paths.shape[1]): + col = paths[:, i] + col = col[col != "none"] + assert np.all(col == col[0]) + + +def test_loader_without_replacement(): + with tempfile.TemporaryDirectory() as d: + dataset_dir = Path(d) + num_items = 100 + audiotools.util.generate_chord_dataset( + max_voices=1, + num_items=num_items, + output_dir=dataset_dir, + duration=0.01, ) + loader = audiotools.data.datasets.AudioLoader( + [dataset_dir], shuffle=False) + dataset = audiotools.data.datasets.AudioDataset(loader, 44100) + + for idx in range(num_items): + item = dataset[idx] + assert item["item_idx"] == idx + + +def test_loader_with_replacement(): + with tempfile.TemporaryDirectory() as d: + dataset_dir = Path(d) + num_items = 100 + audiotools.util.generate_chord_dataset( + max_voices=1, + num_items=num_items, + output_dir=dataset_dir, + duration=0.01, ) + loader = audiotools.data.datasets.AudioLoader([dataset_dir]) + dataset = audiotools.data.datasets.AudioDataset( + loader, 44100, without_replacement=False) + + for idx in range(num_items): + item = dataset[idx] + + +def test_loader_out_of_range(): + with tempfile.TemporaryDirectory() as d: + dataset_dir = Path(d) + num_items = 100 + audiotools.util.generate_chord_dataset( + max_voices=1, + num_items=num_items, + output_dir=dataset_dir, + duration=0.01, ) + loader = audiotools.data.datasets.AudioLoader([dataset_dir]) + + item = loader( + sample_rate=44100, + duration=0.01, + state=audiotools.util.random_state(0), + source_idx=0, + item_idx=101, ) + assert item["path"] == "none" + + +def test_dataset_pipeline(): + transform = tfm.Compose([ + tfm.RoomImpulseResponse(sources=["./audio/irs.csv"]), + tfm.BackgroundNoise(sources=["./audio/noises.csv"]), + ]) + loader = audiotools.data.datasets.AudioLoader(sources=["./audio/spk.csv"]) + dataset = audiotools.data.datasets.AudioDataset( + loader, + 44100, + n_examples=10, + transform=transform, ) + dataloader = paddle.io.DataLoader( + dataset, num_workers=0, batch_size=1, collate_fn=dataset.collate) + for batch in dataloader: + batch = audiotools.core.util.prepare_batch(batch, device="cpu") + kwargs = batch["transform_args"] + signal = batch["signal"] + batch = dataset.transform(signal, **kwargs) + + +class NumberDataset: + def __init__(self): + pass + + def __len__(self): + return 10 + + def __getitem__(self, idx): + return {"idx": idx} + + +def test_concat_dataset(): + d1 = NumberDataset() + d2 = NumberDataset() + d3 = NumberDataset() + + d = audiotools.datasets.ConcatDataset([d1, d2, d3]) + x = d.collate([d[i] for i in range(len(d))])["idx"].tolist() + + t = [] + for i in range(10): + t += [i, i, i] + + assert x == t diff --git a/audio/tests/audiotools/data/test_preprocess.py b/audio/tests/audiotools/data/test_preprocess.py new file mode 100644 index 00000000000..5dbb0daa4a9 --- /dev/null +++ b/audio/tests/audiotools/data/test_preprocess.py @@ -0,0 +1,33 @@ +# MIT License, Copyright (c) 2023-Present, Descript. +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/tests/data/test_preprocess.py) +import sys +import tempfile +from pathlib import Path + +import paddle + +from audio.audiotools.core.util import find_audio +from audio.audiotools.core.util import read_sources +from audio.audiotools.data import preprocess + + +def test_create_csv(): + with tempfile.NamedTemporaryFile(suffix=".csv") as f: + preprocess.create_csv( + find_audio("././audio/spk", ext=["wav"]), f.name, loudness=True) + + +def test_create_csv_with_empty_rows(): + audio_files = find_audio("././audio/spk", ext=["wav"]) + audio_files.insert(0, "") + audio_files.insert(2, "") + + with tempfile.NamedTemporaryFile(suffix=".csv") as f: + preprocess.create_csv(audio_files, f.name, loudness=True) + + audio_files = read_sources([f.name], remove_empty=True) + assert len(audio_files[0]) == 1 + audio_files = read_sources([f.name], remove_empty=False) + assert len(audio_files[0]) == 3 diff --git a/audio/tests/audiotools/data/test_transforms.py b/audio/tests/audiotools/data/test_transforms.py new file mode 100644 index 00000000000..0175f8ff36f --- /dev/null +++ b/audio/tests/audiotools/data/test_transforms.py @@ -0,0 +1,453 @@ +# MIT License, Copyright (c) 2023-Present, Descript. +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/tests/data/test_transforms.py) +import inspect +import sys +import warnings +from pathlib import Path + +import numpy as np +import paddle +import pytest + +from audio import audiotools +from audio.audiotools import AudioSignal +from audio.audiotools import util +from audio.audiotools.data import transforms as tfm +from audio.audiotools.data.datasets import AudioDataset +from paddlespeech.vector.training.seeding import seed_everything + +non_deterministic_transforms = ["TimeNoise", "FrequencyNoise"] +transforms_to_test = [] +for x in dir(tfm): + if hasattr(getattr(tfm, x), "transform"): + if x not in [ + "Compose", + "Choose", + "Repeat", + "RepeatUpTo", + # The above 4 transforms are currently excluded from testing at 1e-4 precision due to potential accuracy issues + "BackgroundNoise", + "Equalizer", + "FrequencyNoise", + "RoomImpulseResponse" + ]: + transforms_to_test.append(x) + + +def _compare_transform(transform_name, signal): + regression_data = Path(f"regression/transforms/{transform_name}.wav") + regression_data.parent.mkdir(exist_ok=True, parents=True) + + if regression_data.exists(): + regression_signal = AudioSignal(regression_data) + + assert paddle.allclose( + signal.audio_data, regression_signal.audio_data, atol=1e-4) + else: + signal.write(regression_data) + + +@pytest.mark.parametrize("transform_name", transforms_to_test) +def test_transform(transform_name): + seed = 0 + seed_everything(seed) + transform_cls = getattr(tfm, transform_name) + + kwargs = {} + if transform_name == "BackgroundNoise": + kwargs["sources"] = ["./audio/noises.csv"] + if transform_name == "RoomImpulseResponse": + kwargs["sources"] = ["./audio/irs.csv"] + if transform_name == "CrossTalk": + kwargs["sources"] = ["./audio/spk.csv"] + + audio_path = "./audio/spk/f10_script4_produced.wav" + signal = AudioSignal(audio_path, offset=10, duration=2) + signal.metadata["loudness"] = AudioSignal( + audio_path).ffmpeg_loudness().item() + transform = transform_cls(prob=1.0, **kwargs) + + kwargs = transform.instantiate(seed, signal) + for k in kwargs[transform_name]: + assert k in transform.keys + + output = transform(signal, **kwargs) + assert isinstance(output, AudioSignal) + + _compare_transform(transform_name, output) + + if transform_name in non_deterministic_transforms: + return + + # Test that if you make a batch of signals and call it, + # the first item in the batch is still the same as above. + batch_size = 4 + signal = AudioSignal(audio_path, offset=10, duration=2) + signal_batch = AudioSignal.batch( + [signal.clone() for _ in range(batch_size)]) + signal_batch.metadata["loudness"] = AudioSignal( + audio_path).ffmpeg_loudness().item() + + states = [seed + idx for idx in list(range(batch_size))] + kwargs = transform.batch_instantiate(states, signal_batch) + batch_output = transform(signal_batch, **kwargs) + + assert batch_output[0] == output + + ## Test that you can apply transform with the same args twice. + signal = AudioSignal(audio_path, offset=10, duration=2) + signal.metadata["loudness"] = AudioSignal( + audio_path).ffmpeg_loudness().item() + kwargs = transform.instantiate(seed, signal) + output_a = transform(signal.clone(), **kwargs) + output_b = transform(signal.clone(), **kwargs) + + assert output_a == output_b + + +def test_compose_basic(): + seed = 0 + + audio_path = "./audio/spk/f10_script4_produced.wav" + signal = AudioSignal(audio_path, offset=10, duration=2) + transform = tfm.Compose( + [ + tfm.RoomImpulseResponse(sources=["./audio/irs.csv"]), + tfm.BackgroundNoise(sources=["./audio/noises.csv"]), + ], ) + + kwargs = transform.instantiate(seed, signal) + output = transform(signal, **kwargs) + + # Due to precision issues with RoomImpulseResponse and BackgroundNoise used in the Compose test, + # we only perform logical testing for Compose and skip precision testing of the final output + # _compare_transform("Compose", output) + + assert isinstance(transform[0], tfm.RoomImpulseResponse) + assert isinstance(transform[1], tfm.BackgroundNoise) + assert len(transform) == 2 + + # Make sure __iter__ works + for _tfm in transform: + pass + + +class MulTransform(tfm.BaseTransform): + def __init__(self, num, name=None): + self.num = num + super().__init__(name=name, keys=["num"]) + + def _transform(self, signal, num): + + if not num.dim(): + num = num.unsqueeze(axis=0) + + signal.audio_data = signal.audio_data * num[:, None, None] + return signal + + def _instantiate(self, state): + return {"num": self.num} + + +def test_compose_with_duplicate_transforms(): + muls = [0.5, 0.25, 0.125] + transform = tfm.Compose([MulTransform(x) for x in muls]) + full_mul = np.prod(muls) + + kwargs = transform.instantiate(0) + audio_path = "./audio/spk/f10_script4_produced.wav" + signal = AudioSignal(audio_path, offset=10, duration=2) + + output = transform(signal.clone(), **kwargs) + expected_output = signal.audio_data * full_mul + + assert paddle.allclose(output.audio_data, expected_output) + + +def test_nested_compose(): + muls = [0.5, 0.25, 0.125] + transform = tfm.Compose([ + MulTransform(muls[0]), + tfm.Compose( + [MulTransform(muls[1]), tfm.Compose([MulTransform(muls[2])])]), + ]) + full_mul = np.prod(muls) + + kwargs = transform.instantiate(0) + audio_path = "./audio/spk/f10_script4_produced.wav" + signal = AudioSignal(audio_path, offset=10, duration=2) + + output = transform(signal.clone(), **kwargs) + expected_output = signal.audio_data * full_mul + + assert paddle.allclose(output.audio_data, expected_output) + + +def test_compose_filtering(): + muls = [0.5, 0.25, 0.125] + transform = tfm.Compose([MulTransform(x, name=str(x)) for x in muls]) + + kwargs = transform.instantiate(0) + audio_path = "./audio/spk/f10_script4_produced.wav" + signal = AudioSignal(audio_path, offset=10, duration=2) + + for s in range(len(muls)): + for _ in range(10): + _muls = np.random.choice(muls, size=s, replace=False).tolist() + full_mul = np.prod(_muls) + with transform.filter(*[str(x) for x in _muls]): + output = transform(signal.clone(), **kwargs) + + expected_output = signal.audio_data * full_mul + assert paddle.allclose(output.audio_data, expected_output) + + +def test_sequential_compose(): + muls = [0.5, 0.25, 0.125] + transform = tfm.Compose([ + tfm.Compose([MulTransform(muls[0])]), + tfm.Compose([MulTransform(muls[1]), MulTransform(muls[2])]), + ]) + full_mul = np.prod(muls) + + kwargs = transform.instantiate(0) + audio_path = "./audio/spk/f10_script4_produced.wav" + signal = AudioSignal(audio_path, offset=10, duration=2) + + output = transform(signal.clone(), **kwargs) + expected_output = signal.audio_data * full_mul + + assert paddle.allclose(output.audio_data, expected_output) + + +def test_choose_basic(): + seed = 0 + audio_path = "./audio/spk/f10_script4_produced.wav" + signal = AudioSignal(audio_path, offset=10, duration=2) + transform = tfm.Choose([ + tfm.RoomImpulseResponse(sources=["./audio/irs.csv"]), + tfm.BackgroundNoise(sources=["./audio/noises.csv"]), + ]) + + kwargs = transform.instantiate(seed, signal) + output = transform(signal.clone(), **kwargs) + + # Due to precision issues with RoomImpulseResponse and BackgroundNoise used in the Choose test, + # we only perform logical testing for Choose and skip precision testing of the final output + # _compare_transform("Choose", output) + + transform = tfm.Choose([ + MulTransform(0.0), + MulTransform(2.0), + ]) + targets = [signal.clone() * 0.0, signal.clone() * 2.0] + + for seed in range(10): + kwargs = transform.instantiate(seed, signal) + output = transform(signal.clone(), **kwargs) + + assert any([output == target for target in targets]) + + # Test that if you make a batch of signals and call it, + # the first item in the batch is still the same as above. + batch_size = 4 + signal = AudioSignal(audio_path, offset=10, duration=2) + signal_batch = AudioSignal.batch( + [signal.clone() for _ in range(batch_size)]) + + states = [seed + idx for idx in list(range(batch_size))] + kwargs = transform.batch_instantiate(states, signal_batch) + batch_output = transform(signal_batch, **kwargs) + + for nb in range(batch_size): + assert batch_output[nb] in targets + + +def test_choose_weighted(): + seed = 0 + audio_path = "./audio/spk/f10_script4_produced.wav" + transform = tfm.Choose( + [ + MulTransform(0.0), + MulTransform(2.0), + ], + weights=[0.0, 1.0], ) + + # Test that if you make a batch of signals and call it, + # the first item in the batch is still the same as above. + batch_size = 4 + signal = AudioSignal(audio_path, offset=10, duration=2) + signal_batch = AudioSignal.batch( + [signal.clone() for _ in range(batch_size)]) + + targets = [signal.clone() * 0.0, signal.clone() * 2.0] + + states = [seed + idx for idx in list(range(batch_size))] + kwargs = transform.batch_instantiate(states, signal_batch) + batch_output = transform(signal_batch, **kwargs) + + for nb in range(batch_size): + assert batch_output[nb] == targets[1] + + +def test_choose_with_compose(): + audio_path = "./audio/spk/f10_script4_produced.wav" + signal = AudioSignal(audio_path, offset=10, duration=2) + + transform = tfm.Choose([ + tfm.Compose([MulTransform(0.0)]), + tfm.Compose([MulTransform(2.0)]), + ]) + + targets = [signal.clone() * 0.0, signal.clone() * 2.0] + + for seed in range(10): + kwargs = transform.instantiate(seed, signal) + output = transform(signal, **kwargs) + + assert output in targets + + +def test_repeat(): + seed = 0 + audio_path = "./audio/spk/f10_script4_produced.wav" + signal = AudioSignal(audio_path, offset=10, duration=2) + + kwargs = {} + kwargs["transform"] = tfm.Compose( + tfm.FrequencyMask(), + tfm.TimeMask(), ) + kwargs["n_repeat"] = 5 + + transform = tfm.Repeat(**kwargs) + kwargs = transform.instantiate(seed, signal) + output = transform(signal.clone(), **kwargs) + + _compare_transform("Repeat", output) + + kwargs = {} + kwargs["transform"] = tfm.Compose( + tfm.FrequencyMask(), + tfm.TimeMask(), ) + kwargs["max_repeat"] = 10 + + transform = tfm.RepeatUpTo(**kwargs) + kwargs = transform.instantiate(seed, signal) + output = transform(signal.clone(), **kwargs) + + _compare_transform("RepeatUpTo", output) + + # Make sure repeat does what it says + transform = tfm.Repeat(MulTransform(0.5), n_repeat=3) + kwargs = transform.instantiate(seed, signal) + signal = AudioSignal(paddle.randn([1, 1, 100]).clip(1e-5), 44100) + output = transform(signal.clone(), **kwargs) + + scale = (output.audio_data / signal.audio_data).mean() + assert scale == (0.5**3) + + +class DummyData(paddle.io.Dataset): + def __init__(self, audio_path): + super().__init__() + + self.audio_path = audio_path + self.length = 100 + self.transform = tfm.Silence(prob=0.5) + + def __getitem__(self, idx): + state = util.random_state(idx) + signal = AudioSignal.salient_excerpt( + self.audio_path, state=state, duration=1.0).resample(44100) + + item = self.transform.instantiate(state, signal=signal) + item["signal"] = signal + + return item + + def __len__(self): + return self.length + + +def test_masking(): + dataset = DummyData("./audio/spk/f10_script4_produced.wav") + dataloader = paddle.io.DataLoader( + dataset, + batch_size=16, + num_workers=0, + collate_fn=util.collate, ) + for batch in dataloader: + signal = batch.pop("signal") + original = signal.clone() + + signal = dataset.transform(signal, **batch) + original = dataset.transform(original, **batch) + mask = batch["Silence"]["mask"] + + zeros_ = paddle.zeros_like(signal[mask].audio_data) + original_ = original[~mask].audio_data + + assert paddle.allclose(signal[mask].audio_data, zeros_) + assert paddle.allclose(original[~mask].audio_data, original_) + + +def test_nested_masking(): + transform = tfm.Compose( + [ + tfm.VolumeNorm(prob=0.5), + tfm.Silence(prob=0.9), + ], + prob=0.9, ) + + loader = audiotools.data.datasets.AudioLoader(sources=["./audio/spk.csv"]) + dataset = audiotools.data.datasets.AudioDataset( + loader, + 44100, + n_examples=100, + transform=transform, ) + dataloader = paddle.io.DataLoader( + dataset, num_workers=0, batch_size=10, collate_fn=dataset.collate) + + for batch in dataloader: + batch = util.prepare_batch(batch, device="cpu") + signal = batch["signal"] + kwargs = batch["transform_args"] + with paddle.no_grad(): + output = dataset.transform(signal, **kwargs) + + +def test_smoothing_edge_case(): + transform = tfm.Smoothing() + zeros = paddle.zeros([1, 1, 44100]) + signal = AudioSignal(zeros, 44100) + kwargs = transform.instantiate(0, signal) + output = transform(signal, **kwargs) + + assert paddle.allclose(output.audio_data, zeros) + + +def test_global_volume_norm(): + signal = AudioSignal.wave(440, 1, 44100, 1) + + # signal with -inf loudness should be unchanged + signal.metadata["loudness"] = float("-inf") + + transform = tfm.GlobalVolumeNorm(db=("const", -100)) + kwargs = transform.instantiate(0, signal) + + output = transform(signal.clone(), **kwargs) + assert paddle.allclose(output.samples, signal.samples) + + # signal without a loudness key should be unchanged + signal.metadata.pop("loudness") + kwargs = transform.instantiate(0, signal) + output = transform(signal.clone(), **kwargs) + assert paddle.allclose(output.samples, signal.samples) + + # signal with the actual loudness should be normalized + signal.metadata["loudness"] = signal.ffmpeg_loudness() + kwargs = transform.instantiate(0, signal) + output = transform(signal.clone(), **kwargs) + assert not paddle.allclose(output.samples, signal.samples) diff --git a/audio/tests/audiotools/ml/test_decorators.py b/audio/tests/audiotools/ml/test_decorators.py new file mode 100644 index 00000000000..555f3b345df --- /dev/null +++ b/audio/tests/audiotools/ml/test_decorators.py @@ -0,0 +1,110 @@ +# MIT License, Copyright (c) 2023-Present, Descript. +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/tests/ml/test_decorators.py) +import sys +import time + +import paddle +from visualdl import LogWriter + +from audio.audiotools import util +from audio.audiotools.ml.decorators import timer +from audio.audiotools.ml.decorators import Tracker +from audio.audiotools.ml.decorators import when + + +def test_all_decorators(): + rank = 0 + max_iters = 100 + + writer = LogWriter("/tmp/logs") + tracker = Tracker(writer, log_file="/tmp/log.txt") + + train_data = range(100) + val_data = range(100) + + @tracker.log("train", "value", history=False) + @tracker.track("train", max_iters, tracker.step) + @timer() + def train_loop(): + i = tracker.step + time.sleep(0.01) + return { + "loss": + util.exp_compat(paddle.to_tensor([-i / 100], dtype="float32")), + "mel": + util.exp_compat(paddle.to_tensor([-i / 100], dtype="float32")), + "stft": + util.exp_compat(paddle.to_tensor([-i / 100], dtype="float32")), + "waveform": + util.exp_compat(paddle.to_tensor([-i / 100], dtype="float32")), + "not_scalar": + paddle.arange(start=0, end=10, step=1, dtype="int64"), + } + + @tracker.track("val", len(val_data)) + @timer() + def val_loop(): + i = tracker.step + time.sleep(0.01) + return { + "loss": + util.exp_compat(paddle.to_tensor([-i / 100], dtype="float32")), + "mel": + util.exp_compat(paddle.to_tensor([-i / 100], dtype="float32")), + "stft": + util.exp_compat(paddle.to_tensor([-i / 100], dtype="float32")), + "waveform": + util.exp_compat(paddle.to_tensor([-i / 100], dtype="float32")), + "not_scalar": + paddle.arange(10, dtype="int64"), + "string": + "string", + } + + @when(lambda: tracker.step % 1000 == 0 and rank == 0) + @paddle.no_grad() + def save_samples(): + tracker.print("Saving samples to TensorBoard.") + + @when(lambda: tracker.step % 100 == 0 and rank == 0) + def checkpoint(): + save_samples() + if tracker.is_best("val", "mel"): + tracker.print("Best model so far.") + tracker.print("Saving to /runs/exp1") + tracker.done("val", f"Iteration {tracker.step}") + + @when(lambda: tracker.step % 100 == 0) + @tracker.log("val", "mean") + @paddle.no_grad() + def validate(): + for _ in range(len(val_data)): + output = val_loop() + return output + + with tracker.live: + for tracker.step in range(max_iters): + validate() + checkpoint() + train_loop() + + state_dict = tracker.state_dict() + tracker.load_state_dict(state_dict) + + # If train loop returned not a dict + @tracker.track("train", max_iters, tracker.step) + def train_loop_2(): + i = tracker.step + time.sleep(0.01) + + with tracker.live: + for tracker.step in range(max_iters): + validate() + checkpoint() + train_loop_2() + + +if __name__ == "__main__": + test_all_decorators() diff --git a/audio/tests/audiotools/ml/test_model.py b/audio/tests/audiotools/ml/test_model.py new file mode 100644 index 00000000000..5b1ac7f9dd4 --- /dev/null +++ b/audio/tests/audiotools/ml/test_model.py @@ -0,0 +1,89 @@ +# MIT License, Copyright (c) 2023-Present, Descript. +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/tests/ml/test_model.py) +import sys +import tempfile + +import paddle +from paddle import nn + +from audio.audiotools import ml +from audio.audiotools import util +from paddlespeech.vector.training.seeding import seed_everything +SEED = 0 + + +def seed_and_run(model, *args, **kwargs): + seed_everything(SEED) + return model(*args, **kwargs) + + +class Model(ml.BaseModel): + def __init__(self, arg1: float=1.0): + super().__init__() + self.arg1 = arg1 + self.linear = nn.Linear(1, 1) + + def forward(self, x): + return self.linear(x) + + +class OtherModel(ml.BaseModel): + def __init__(self): + super().__init__() + self.linear = nn.Linear(1, 1) + + def forward(self, x): + return self.linear(x) + + +def test_base_model(): + # Save and load + # ml.BaseModel.EXTERN += ["test_model"] + + x = paddle.randn([10, 1]) + model1 = Model() + + # assert str(model1.device) == 'Place(cpu)' + + out1 = seed_and_run(model1, x) + + with tempfile.NamedTemporaryFile(suffix=".pdparams") as f: + model1.save( + f.name, ) + model2 = Model.load(f.name) + out2 = seed_and_run(model2, x) + assert paddle.allclose(out1, out2) + + # test re-export + model2.save(f.name) + model3 = Model.load(f.name) + out3 = seed_and_run(model3, x) + assert paddle.allclose(out1, out3) + + # make sure legacy/save load works + model1.save(f.name, package=False) + model2 = Model.load(f.name) + out2 = seed_and_run(model2, x) + assert paddle.allclose(out1, out2) + + # make sure new way -> legacy save -> legacy load works + model1.save(f.name, package=False) + model2 = Model.load(f.name) + model2.save(f.name, package=False) + model3 = Model.load(f.name) + out3 = seed_and_run(model3, x) + + # save/load without package, but with model2 being a model + # without an argument of arg1 to its instantiation. + model1.save(f.name, package=False) + model2 = OtherModel.load(f.name) + out2 = seed_and_run(model2, x) + assert paddle.allclose(out1, out2) + + assert paddle.allclose(out1, out3) + + with tempfile.TemporaryDirectory() as d: + model1.save_to_folder(d, {"data": 1.0}) + Model.load_from_folder(d) diff --git a/audio/tests/audiotools/test_audiotools.sh b/audio/tests/audiotools/test_audiotools.sh new file mode 100644 index 00000000000..387059d5136 --- /dev/null +++ b/audio/tests/audiotools/test_audiotools.sh @@ -0,0 +1,7 @@ +python -m pip install -r ../../audiotools/requirements.txt +export PYTHONPATH=$PYTHONPATH:$(realpath ../../..) # this is root path of `PaddleSpeech` +wget https://paddlespeech.bj.bcebos.com/PaddleAudio/audio_tools/audio.tar.gz +wget https://paddlespeech.bj.bcebos.com/PaddleAudio/audio_tools/regression.tar.gz +tar -zxvf audio.tar.gz +tar -zxvf regression.tar.gz +python -m pytest \ No newline at end of file diff --git a/audio/tests/audiotools/test_post.py b/audio/tests/audiotools/test_post.py new file mode 100644 index 00000000000..def831ec26e --- /dev/null +++ b/audio/tests/audiotools/test_post.py @@ -0,0 +1,30 @@ +# MIT License, Copyright (c) 2023-Present, Descript. +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/tests/test_post.py) +import sys +from pathlib import Path + +from audio.audiotools import AudioSignal +from audio.audiotools import post +from audio.audiotools import transforms + + +def test_audio_table(): + tfm = transforms.LowPass() + + audio_dict = {} + + audio_dict["inputs"] = [ + AudioSignal.excerpt("./audio/spk/f10_script4_produced.wav", duration=5) + for _ in range(3) + ] + audio_dict["outputs"] = [] + for i in range(3): + x = audio_dict["inputs"][i] + + kwargs = tfm.instantiate() + output = tfm(x.clone(), **kwargs) + audio_dict["outputs"].append(output) + + post.audio_table(audio_dict) diff --git a/tests/unit/ci.sh b/tests/unit/ci.sh index daf40f721ee..ef21645b2a8 100644 --- a/tests/unit/ci.sh +++ b/tests/unit/ci.sh @@ -32,6 +32,13 @@ function main(){ cd ${speech_ci_path}/server/offline bash test_server_client.sh echo "End server" + + echo "Start testing audiotools" + cd ${speech_ci_path}/../../audio/tests/audiotools + bash test_audiotools.sh + echo "End testing audiotools" + + } main