diff --git a/.flake8 b/.flake8 deleted file mode 100644 index 2590a18..0000000 --- a/.flake8 +++ /dev/null @@ -1,13 +0,0 @@ -[flake8] -max-line-length = 119 -exclude = docs/source,*.egg,build -select = E,W,F -verbose = 2 -# https://pep8.readthedocs.io/en/latest/intro.html#error-codes -format = pylint -ignore = - E731 # E731 - Do not assign a lambda expression, use a def - W605 # W605 - invalid escape sequence '\_'. Needed for docs - W504 # W504 - line break after binary operator - W503 # W503 - line break before binary operator, need for black - E203 # E203 - whitespace before ':'. Opposite convention enforced by black diff --git a/.github/workflows/test_formatting.yml b/.github/workflows/test_formatting.yml index 2ee7ed6..12a596b 100644 --- a/.github/workflows/test_formatting.yml +++ b/.github/workflows/test_formatting.yml @@ -3,23 +3,34 @@ on: [push, pull_request] jobs: code-black: - name: CI + name: Linting & Formatting runs-on: ubuntu-latest steps: - name: Checkout uses: actions/checkout@v2 - - name: Set up Python 3.6 - uses: actions/setup-python@v2 + + - name: Install a specific version of uv + uses: astral-sh/setup-uv@v5 with: - python-version: 3.6 + python-version: 3.13 + enable-cache: false + + - name: Install ruff + # Keep version in sync with pre-commit and requirements.txt + run: pip install ruff==0.9.1 - - name: Install Black and flake8 - run: pip install black==20.8b1 flake8 - - name: Run Black - run: python -m black --config=pyproject.toml --check asteroid_filterbanks tests + - name: Run ruff format + run: | + ruff format --diff \ + asteroid_filterbanks \ + tests - - name: Link with flake8 + - name: Run ruff check # Exit on important linting errors and warn about others. run: | - python -m flake8 asteroid tests --show-source --statistics --select=F6,F7,F82,F52 - python -m flake8 --config .flake8 --exit-zero asteroid tests --statistics + ruff check --diff --select=F6,F7,F82,F52 \ + asteroid_filterbanks \ + tests + ruff check --statistics --exit-zero \ + asteroid_filterbanks \ + tests diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 724f54f..3874632 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,8 +6,10 @@ repos: - id: trailing-whitespace - id: check-yaml - - repo: https://github.com/psf/black - rev: 20.8b1 + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.9.1 hooks: - - id: black - types: [python] + - id: ruff + types_or: [ python, pyi ] + - id: ruff-format + types_or: [ python, pyi ] diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index aa772f3..b732396 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -13,7 +13,7 @@ installing asteroid-filterbanks everytime you change something. To do that, install asteroid-filterbanks in develop mode either with pip ```pip install -e .[tests]``` or with python ```python setup.py develop```. -To avoid formatting roundtrips in PRs, Asteroid relies on [`black`](https://github.com/psf/black) +To avoid formatting roundtrips in PRs, Asteroid relies on [`ruff`](https://docs.astral.sh/ruff/) and [`pre-commit-hooks`](https://github.com/pre-commit/pre-commit-hooks) to handle formatting for us. You'll need to install `requirements.txt` and install git hooks with `pre-commit install`. @@ -26,7 +26,7 @@ git clone your_fork_url cd asteroid-filterbanks pip install -r requirements.txt pip install -e . -pre-commit install # To run black before commit +pre-commit install # To run `ruff format` before commit # Make your changes # Test them locally @@ -62,10 +62,10 @@ docstrings in the codebase for examples. ### Coding style We use [pre-commit hooks](../.pre-commit-config.yaml) to format the code using -`black`. -The code is checked for `black`- and `flake8`- compliance on every commit with -GitHub actions. Remember, continuous integration is not here to be all green, -be to help us see where to improve ! +`ruff format`. +The code is checked for compliance with `ruff check` and `ruff format` on every +commit with GitHub actions. Remember, continuous integration is not here to be +all green, be to help us see where to improve! If you have any question, [open an issue][issue] or [join the slack][slack], diff --git a/asteroid_filterbanks/analytic_free_fb.py b/asteroid_filterbanks/analytic_free_fb.py index fb8b530..38dfb6d 100644 --- a/asteroid_filterbanks/analytic_free_fb.py +++ b/asteroid_filterbanks/analytic_free_fb.py @@ -57,10 +57,7 @@ def __init__(self, n_filters, kernel_size, stride=None, sample_rate=8000.0, **kw self.cutoff = int(n_filters // 2) self.n_feats_out = 2 * self.cutoff if n_filters % 2 != 0: - print( - "If the number of filters `n_filters` is odd, the " - "output size of the layer will be `n_filters - 1`." - ) + print("If the number of filters `n_filters` is odd, the output size of the layer will be `n_filters - 1`.") self._filters = nn.Parameter(torch.ones(n_filters // 2, 1, kernel_size), requires_grad=True) for p in self.parameters(): diff --git a/asteroid_filterbanks/deprecation.py b/asteroid_filterbanks/deprecation.py index f824c37..4cbe463 100644 --- a/asteroid_filterbanks/deprecation.py +++ b/asteroid_filterbanks/deprecation.py @@ -26,9 +26,7 @@ def decorator(func): def wrapped(*args, **kwargs): from_what = "a future release" if version is None else f"asteroid v{version}" warn_message = ( - f"{func.__module__}.{func.__name__} has been deprecated " - f"and will be removed from {from_what}. " - f"{message}" + f"{func.__module__}.{func.__name__} has been deprecated and will be removed from {from_what}. {message}" ) warnings.warn(warn_message, VisibleDeprecationWarning, stacklevel=2) return func(*args, **kwargs) diff --git a/asteroid_filterbanks/enc_dec.py b/asteroid_filterbanks/enc_dec.py index 59e26da..8d91cd3 100644 --- a/asteroid_filterbanks/enc_dec.py +++ b/asteroid_filterbanks/enc_dec.py @@ -34,7 +34,7 @@ def __init__(self, n_filters, kernel_size, stride=None, sample_rate=8000.0): self.sample_rate = sample_rate def filters(self): - """ Abstract method for filters. """ + """Abstract method for filters.""" raise NotImplementedError def pre_analysis(self, wav: torch.Tensor): @@ -97,7 +97,7 @@ def filters(self): return self.filterbank.filters() def compute_filter_pinv(self, filters): - """ Computes pseudo inverse filterbank of given filters.""" + """Computes pseudo inverse filterbank of given filters.""" scale = self.filterbank.stride / self.filterbank.kernel_size shape = filters.shape ifilt = torch.pinverse(filters.squeeze()).transpose(-1, -2).view(shape) @@ -105,14 +105,14 @@ def compute_filter_pinv(self, filters): return ifilt * scale def get_filters(self): - """ Returns filters or pinv filters depending on `is_pinv` attribute """ + """Returns filters or pinv filters depending on `is_pinv` attribute""" if self.is_pinv: return self.compute_filter_pinv(self.filters()) else: return self.filters() def get_config(self): - """ Returns dictionary of arguments to re-instantiate the class.""" + """Returns dictionary of arguments to re-instantiate the class.""" config = {"is_pinv": self.is_pinv} base_config = self.filterbank.get_config() return dict(list(base_config.items()) + list(config.items())) @@ -226,9 +226,7 @@ def multishape_conv1d( return batch_packed_1d_conv(waveform, filters, stride=stride, padding=padding) -def batch_packed_1d_conv( - inp: torch.Tensor, filters: torch.Tensor, stride: int = 1, padding: int = 0 -): +def batch_packed_1d_conv(inp: torch.Tensor, filters: torch.Tensor, stride: int = 1, padding: int = 0): # Here we perform multichannel / multi-source convolution. # Output should be (batch, channels, freq, conv_time) batched_conv = F.conv1d(inp.view(-1, 1, inp.shape[-1]), filters, stride=stride, padding=padding) @@ -261,7 +259,7 @@ def __init__(self, filterbank, is_pinv=False, padding=0, output_padding=0): @classmethod def pinv_of(cls, filterbank): - """ Returns an Decoder, pseudo inverse of a filterbank or Encoder.""" + """Returns an Decoder, pseudo inverse of a filterbank or Encoder.""" if isinstance(filterbank, Filterbank): return cls(filterbank, is_pinv=True) elif isinstance(filterbank, Encoder): diff --git a/asteroid_filterbanks/griffin_lim.py b/asteroid_filterbanks/griffin_lim.py index ad7ba12..fd0feef 100644 --- a/asteroid_filterbanks/griffin_lim.py +++ b/asteroid_filterbanks/griffin_lim.py @@ -2,7 +2,7 @@ import math from typing import Optional, List -from . import Encoder, Decoder, STFTFB +from . import Decoder, STFTFB from .stft_fb import perfect_synthesis_window from . import transforms @@ -154,9 +154,7 @@ def misi( complex_specgram = transforms.from_magphase(mag_specgrams, angles) wavs = istft_dec(complex_specgram) # Make wavs sum up to the mixture - consistent_wavs = _mixture_consistency( - mixture_wav, wavs, src_weights=src_weights, dim=wav_dim - ) + consistent_wavs = _mixture_consistency(mixture_wav, wavs, src_weights=src_weights, dim=wav_dim) # Back to TF domain rebuilt = stft_enc(consistent_wavs) # Update phase estimates (with momentum). Keep the momentum here @@ -216,7 +214,7 @@ def _mixture_consistency( all_dims: List[int] = torch.arange(est_sources.ndim).tolist() all_dims.pop(dim) # Remove source axis all_dims.pop(0) # Remove batch axis - src_weights = torch.mean(est_sources ** 2, dim=all_dims, keepdim=True) + src_weights = torch.mean(est_sources**2, dim=all_dims, keepdim=True) # Make sure that the weights sum up to 1 norm_weights = torch.sum(src_weights, dim=dim, keepdim=True) + 1e-8 src_weights = src_weights / norm_weights @@ -233,7 +231,7 @@ def _mixture_consistency( raise RuntimeError( f"The size of the mixture tensor should match the " f"size of the est_sources tensor. Expected mixture" - f"tensor to have {n} or {n-1} dimension, found {m}." + f"tensor to have {n} or {n - 1} dimension, found {m}." ) # Compute remove new_sources = est_sources + src_weights * residual diff --git a/asteroid_filterbanks/melgram_fb.py b/asteroid_filterbanks/melgram_fb.py index 1178e7d..120fdb9 100644 --- a/asteroid_filterbanks/melgram_fb.py +++ b/asteroid_filterbanks/melgram_fb.py @@ -37,7 +37,6 @@ def __init__( norm="slaney", **kwargs, ): - self.n_mels = n_mels self.fmin = fmin self.fmax = fmax @@ -50,9 +49,7 @@ def __init__( sample_rate=sample_rate, **kwargs, ) - self.mel_scale = MelScale( - n_filters, sample_rate=sample_rate, n_mels=n_mels, fmin=fmin, fmax=fmax, norm=norm - ) + self.mel_scale = MelScale(n_filters, sample_rate=sample_rate, n_mels=n_mels, fmin=fmin, fmax=fmax, norm=norm) self.n_feats_out = n_mels def post_analysis(self, spec: torch.Tensor): @@ -90,9 +87,7 @@ def __init__( from librosa.filters import mel super().__init__() - fb_mat = mel( - sr=sample_rate, n_fft=n_filters, fmin=fmin, fmax=fmax, n_mels=n_mels, norm=norm - ) + fb_mat = mel(sr=sample_rate, n_fft=n_filters, fmin=fmin, fmax=fmax, n_mels=n_mels, norm=norm) self.register_buffer("fb_mat", torch.from_numpy(fb_mat).unsqueeze(0)) def forward(self, spec: torch.Tensor): diff --git a/asteroid_filterbanks/multiphase_gammatone_fb.py b/asteroid_filterbanks/multiphase_gammatone_fb.py index 81812cf..d9ce72b 100644 --- a/asteroid_filterbanks/multiphase_gammatone_fb.py +++ b/asteroid_filterbanks/multiphase_gammatone_fb.py @@ -48,9 +48,7 @@ def generate_mpgtf(samplerate_hz, len_sec, n_filters): current_center_freq_hz = center_freq_hz_min # Determine number of phase shifts per center frequency - phase_pair_count = (np.ones(n_center_freqs) * np.floor(n_filters / 2 / n_center_freqs)).astype( - int - ) + phase_pair_count = (np.ones(n_center_freqs) * np.floor(n_filters / 2 / n_center_freqs)).astype(int) remaining_phase_pairs = ((n_filters - np.sum(phase_pair_count) * 2) / 2).astype(int) if remaining_phase_pairs > 0: phase_pair_count[:remaining_phase_pairs] = phase_pair_count[:remaining_phase_pairs] + 1 @@ -70,48 +68,39 @@ def generate_mpgtf(samplerate_hz, len_sec, n_filters): index = index + 1 # Second half of filters: phase_shifts in [pi, 2*pi) - filterbank[index : index + phase_pair_count[i], :] = -filterbank[ - index - phase_pair_count[i] : index, : - ] + filterbank[index : index + phase_pair_count[i], :] = -filterbank[index - phase_pair_count[i] : index, :] # Prepare for next center frequency index = index + phase_pair_count[i] - current_center_freq_hz = erb_scale_2_freq_hz( - freq_hz_2_erb_scale(current_center_freq_hz) + 1 - ) + current_center_freq_hz = erb_scale_2_freq_hz(freq_hz_2_erb_scale(current_center_freq_hz) + 1) filterbank = normalize_filters(filterbank) return filterbank def gammatone_impulse_response(samplerate_hz, len_sec, center_freq_hz, phase_shift): - """ Generate single parametrized gammatone filter """ + """Generate single parametrized gammatone filter""" p = 2 # filter order erb = 24.7 + 0.108 * center_freq_hz # equivalent rectangular bandwidth - divisor = (np.pi * math.factorial(2 * p - 2) * np.power(2, float(-(2 * p - 2)))) / np.square( - math.factorial(p - 1) - ) + divisor = (np.pi * math.factorial(2 * p - 2) * np.power(2, float(-(2 * p - 2)))) / np.square(math.factorial(p - 1)) b = erb / divisor # bandwidth parameter a = 1.0 # amplitude. This is varied later by the normalization process. len_sample = int(np.floor(samplerate_hz * len_sec)) t = np.linspace(1.0 / samplerate_hz, len_sec, len_sample) gammatone_ir = ( - a - * np.power(t, p - 1) - * np.exp(-2 * np.pi * b * t) - * np.cos(2 * np.pi * center_freq_hz * t + phase_shift) + a * np.power(t, p - 1) * np.exp(-2 * np.pi * b * t) * np.cos(2 * np.pi * center_freq_hz * t + phase_shift) ) return gammatone_ir def erb_scale_2_freq_hz(freq_erb): - """ Convert frequency on ERB scale to frequency in Hertz """ + """Convert frequency on ERB scale to frequency in Hertz""" freq_hz = (np.exp(freq_erb / 9.265) - 1) * 24.7 * 9.265 return freq_hz def freq_hz_2_erb_scale(freq_hz): - """ Convert frequency in Hertz to frequency on ERB scale """ + """Convert frequency in Hertz to frequency on ERB scale""" freq_erb = 9.265 * np.log(1 + freq_hz / (24.7 * 9.265)) return freq_erb diff --git a/asteroid_filterbanks/param_sinc_fb.py b/asteroid_filterbanks/param_sinc_fb.py index 4374527..c9e26ed 100644 --- a/asteroid_filterbanks/param_sinc_fb.py +++ b/asteroid_filterbanks/param_sinc_fb.py @@ -55,36 +55,27 @@ def __init__( self.n_feats_out = 2 * self.cutoff self._initialize_filters() if n_filters % 2 != 0: - print( - "If the number of filters `n_filters` is odd, the " - "output size of the layer will be `n_filters - 1`." - ) + print("If the number of filters `n_filters` is odd, the output size of the layer will be `n_filters - 1`.") window_ = np.hamming(self.kernel_size)[: self.half_kernel] # Half window - n_ = ( - 2 * np.pi * (torch.arange(-self.half_kernel, 0.0).view(1, -1) / self.sample_rate) - ) # Half time vector + n_ = 2 * np.pi * (torch.arange(-self.half_kernel, 0.0).view(1, -1) / self.sample_rate) # Half time vector self.register_buffer("window_", torch.from_numpy(window_).float()) self.register_buffer("n_", n_) def _initialize_filters(self): - """ Filter Initialization along the Mel scale""" + """Filter Initialization along the Mel scale""" low_hz = 30 high_hz = self.sample_rate / 2 - (self.min_low_hz + self.min_band_hz) - mel = np.linspace( - self.to_mel(low_hz), self.to_mel(high_hz), self.n_filters // 2 + 1, dtype="float32" - ) + mel = np.linspace(self.to_mel(low_hz), self.to_mel(high_hz), self.n_filters // 2 + 1, dtype="float32") hz = self.to_hz(mel) # filters parameters (out_channels // 2, 1) self.low_hz_ = nn.Parameter(torch.from_numpy(hz[:-1]).view(-1, 1)) self.band_hz_ = nn.Parameter(torch.from_numpy(np.diff(hz)).view(-1, 1)) def filters(self): - """ Compute filters from parameters """ + """Compute filters from parameters""" low = self.min_low_hz + torch.abs(self.low_hz_) - high = torch.clamp( - low + self.min_band_hz + torch.abs(self.band_hz_), self.min_low_hz, self.sample_rate / 2 - ) + high = torch.clamp(low + self.min_band_hz + torch.abs(self.band_hz_), self.min_low_hz, self.sample_rate / 2) cos_filters = self.make_filters(low, high, filt_type="cos") sin_filters = self.make_filters(low, high, filt_type="sin") return torch.cat([cos_filters, sin_filters], dim=0) @@ -116,7 +107,7 @@ def to_hz(mel): return 700 * (10 ** (mel / 2595) - 1) def get_config(self): - """ Returns dictionary of arguments to re-instantiate the class.""" + """Returns dictionary of arguments to re-instantiate the class.""" config = { "min_low_hz": self.min_low_hz, "min_band_hz": self.min_band_hz, diff --git a/asteroid_filterbanks/pcen.py b/asteroid_filterbanks/pcen.py index e87b4d3..3388d31 100644 --- a/asteroid_filterbanks/pcen.py +++ b/asteroid_filterbanks/pcen.py @@ -1,6 +1,5 @@ from torch import nn import torch -from . import transforms from typing import Union, Optional, Tuple try: @@ -76,9 +75,7 @@ def __init__( super().__init__() if trainable is True or trainable is False: - trainable = TrainableParameters( - alpha=trainable, delta=trainable, root=trainable, smooth=trainable - ) + trainable = TrainableParameters(alpha=trainable, delta=trainable, root=trainable, smooth=trainable) self.trainable = trainable self.n_channels = n_channels @@ -89,9 +86,7 @@ def __init__( self.delta = nn.Parameter( torch.full((self.n_channels,), fill_value=delta), requires_grad=self.trainable["delta"] ) - self.root = nn.Parameter( - torch.full((self.n_channels,), fill_value=root), requires_grad=self.trainable["root"] - ) + self.root = nn.Parameter(torch.full((self.n_channels,), fill_value=root), requires_grad=self.trainable["root"]) self.floor = floor self.ema = ExponentialMovingAverage( @@ -119,7 +114,7 @@ def forward( # Equation (1) in [1] out = ( mag_spec / (self.floor + ema_smoother) ** alpha + self.delta - ) ** one_over_root - self.delta ** one_over_root + ) ** one_over_root - self.delta**one_over_root out = out.transpose(1, -1) if post_squeeze: out = out.squeeze(1) @@ -144,9 +139,9 @@ class PCEN(_PCEN): Defaults to 0.04 n_channels: Number of channels in the time frequency representation. Defaults to 1 - trainable: If True, the parameters (alpha, delta, root and smooth) are trainable. If False, the parameters are fixed. - Individual parameters can set to be fixed or trainable by passing a dictionary of booleans, with the key - matching the parameter name and the value being either True (trainable) or False (fixed). + trainable: If True, the parameters (alpha, delta, root and smooth) are trainable. If False, the parameters are + fixed. Individual parameters can set to be fixed or trainable by passing a dictionary of booleans, with the + key matching the parameter name and the value being either True (trainable) or False (fixed). i.e. ``{"alpha": False, "delta": True, "root": False, "smooth": True}`` Defaults to False per_channel_smoothing: If True, each channel has it's own smoothing coefficient. @@ -211,9 +206,9 @@ class StatefulPCEN(_PCEN): Defaults to 0.04 n_channels: Number of channels in the time frequency representation. Defaults to 1 - trainable: If True, the parameters (alpha, delta, root and smooth) are trainable. If False, the parameters are fixed. - Individual parameters can set to be fixed or trainable by passing a dictionary of booleans, with the key - matching the parameter name and the value being either True (trainable) or False (fixed). + trainable: If True, the parameters (alpha, delta, root and smooth) are trainable. If False, the parameters are + fixed. Individual parameters can set to be fixed or trainable by passing a dictionary of booleans, with the + key matching the parameter name and the value being either True (trainable) or False (fixed). i.e. ``{"alpha": False, "delta": True, "root": False, "smooth": True}`` Defaults to False per_channel_smoothing: If True, each channel has it's own smoothing coefficient. @@ -254,7 +249,8 @@ def from_pcen(cls, pcen: PCEN): def forward( self, mag_spec: torch.Tensor, initial_state: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, torch.Tensor]: - """Computes the PCEN from magnitude spectrum representation, and an optional smoothed version of the filterbank (Equation (2) in [1]). + """Computes the PCEN from magnitude spectrum representation, and an optional smoothed version of the filterbank + (Equation (2) in [1]). Args: mag_spec: tensor containing an magnitude spectrum representation. diff --git a/asteroid_filterbanks/stft_fb.py b/asteroid_filterbanks/stft_fb.py index 0f61128..98d4baf 100644 --- a/asteroid_filterbanks/stft_fb.py +++ b/asteroid_filterbanks/stft_fb.py @@ -21,9 +21,7 @@ class STFTFB(Filterbank): n_feats_out (int): Number of output filters. """ - def __init__( - self, n_filters, kernel_size, stride=None, window=None, sample_rate=8000.0, **kwargs - ): + def __init__(self, n_filters, kernel_size, stride=None, window=None, sample_rate=8000.0, **kwargs): super().__init__(n_filters, kernel_size, stride=stride, sample_rate=sample_rate) assert n_filters >= kernel_size if n_filters % 2 != 0: @@ -38,9 +36,7 @@ def __init__( window = window.data.numpy() ws = window.size if not (ws == kernel_size): - raise AssertionError( - f"Expected window of size {kernel_size}. Received {ws} instead." - ) + raise AssertionError(f"Expected window of size {kernel_size}. Received {ws} instead.") self.window = window # Create and normalize DFT filters (can be overcomplete) filters = np.fft.fft(np.eye(n_filters)) @@ -50,9 +46,7 @@ def __init__( lpad = int((n_filters - kernel_size) // 2) rpad = int(n_filters - kernel_size - lpad) indexes = list(range(lpad, n_filters - rpad)) - filters = np.vstack( - [np.real(filters[: self.cutoff, indexes]), np.imag(filters[: self.cutoff, indexes])] - ) + filters = np.vstack([np.real(filters[: self.cutoff, indexes]), np.imag(filters[: self.cutoff, indexes])]) filters[0, :] /= np.sqrt(2) filters[n_filters // 2, :] /= np.sqrt(2) @@ -92,7 +86,7 @@ def perfect_synthesis_window(analysis_window, hop_size): loop_on = (win_size - 1) // hop_size for win_idx in range(-loop_on, loop_on + 1): - shifted = np.roll(analysis_window ** 2, win_idx * hop_size) + shifted = np.roll(analysis_window**2, win_idx * hop_size) if win_idx < 0: shifted[win_idx * hop_size :] = 0 elif win_idx > 0: diff --git a/asteroid_filterbanks/torch_stft_fb.py b/asteroid_filterbanks/torch_stft_fb.py index bbd8b4b..24d9f8d 100644 --- a/asteroid_filterbanks/torch_stft_fb.py +++ b/asteroid_filterbanks/torch_stft_fb.py @@ -53,12 +53,8 @@ def __init__( **kwargs, ): if n_filters != kernel_size: - raise NotImplementedError( - "Cannot set `n_filters!=kernel_size` in TorchSTFTFB, untested." - ) - super().__init__( - n_filters, kernel_size, stride=stride, window=window, sample_rate=sample_rate, **kwargs - ) + raise NotImplementedError("Cannot set `n_filters!=kernel_size` in TorchSTFTFB, untested.") + super().__init__(n_filters, kernel_size, stride=stride, window=window, sample_rate=sample_rate, **kwargs) self.center = center self.pad_mode = pad_mode if normalized: @@ -129,15 +125,15 @@ def post_synthesis(self, wav): @script_if_tracing def _restore_freqs_an(spec, n_filters: int): - spec[..., 0, :] *= 2 ** 0.5 - spec[..., n_filters // 2, :] *= 2 ** 0.5 + spec[..., 0, :] *= 2**0.5 + spec[..., n_filters // 2, :] *= 2**0.5 return spec @script_if_tracing def _restore_freqs_syn(spec, n_filters: int): - spec[..., 0, :] /= 2 ** 0.5 - spec[..., n_filters // 2, :] /= 2 ** 0.5 + spec[..., 0, :] /= 2**0.5 + spec[..., n_filters // 2, :] /= 2**0.5 return spec @@ -161,8 +157,7 @@ def ola_with_wdiv(wav, window, kernel_size: int, stride: int, center: bool = Tru if min_mask.any() and not torch.jit.is_scripting(): # Warning instead of error. Might be trimmed afterward. warnings.warn( - f"Minimum NOLA should be above 1e-11, Found {wsq_ola.abs().min()}. " - f"Dividind only where possible.", + f"Minimum NOLA should be above 1e-11, Found {wsq_ola.abs().min()}. Dividind only where possible.", RuntimeWarning, ) wav[~min_mask] = wav[~min_mask] / wsq_ola[~min_mask] @@ -177,9 +172,7 @@ def square_ola(window: torch.Tensor, kernel_size: int, stride: int, n_frame: int @script_if_tracing -def pad_all_shapes( - x: torch.Tensor, pad_shape: Tuple[int, int], mode: str = "reflect" -) -> torch.Tensor: +def pad_all_shapes(x: torch.Tensor, pad_shape: Tuple[int, int], mode: str = "reflect") -> torch.Tensor: if x.ndim == 1: return F.pad(x[None, None], pad=pad_shape, mode=mode).squeeze(0).squeeze(0) if x.ndim == 2: diff --git a/asteroid_filterbanks/transforms.py b/asteroid_filterbanks/transforms.py index 119cf04..1e42a6d 100644 --- a/asteroid_filterbanks/transforms.py +++ b/asteroid_filterbanks/transforms.py @@ -3,7 +3,6 @@ from typing import Tuple from .scripting import script_if_tracing -from .deprecation import mark_deprecated def mul_c(inp, other, dim: int = -2): @@ -287,8 +286,7 @@ def check_torchaudio_complex(tensor): """ if not is_torchaudio_complex(tensor): raise AssertionError( - f"Tensor of shape {tensor.shape} is not Torchaudio-style complex-like" - "(expected last dimension to be == 2)" + f"Tensor of shape {tensor.shape} is not Torchaudio-style complex-like(expected last dimension to be == 2)" ) diff --git a/pyproject.toml b/pyproject.toml index 1a1b551..7e8d64b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,8 +4,30 @@ requires = [ "wheel", ] -[tool.black] -# https://github.com/psf/black -line-length = 100 -target-version = ["py36"] -exclude = "(.eggs|.git|.hg|.mypy_cache|.nox|.tox|.venv|.svn|_build|buck-out|build|dist)" +# https://docs.astral.sh/ruff/configuration/ +[tool.ruff] +line-length = 120 +target-version = "py39" # https://endoflife.date/python +exclude = [ + ".eggs", + ".git", + ".hg", + ".mypy_cache", + ".nox", + ".tox", + ".venv", + ".svn", + "_build", + "*.ipynb", + "buck-out", + "build", + "dist", +] + +[tool.ruff.lint] +select = ["E", "W", "F"] +ignore = [ + "E203", # whitespace before ':'. Opposite convention enforced by black + "E731", # Do not assign a lambda expression, use a def + "W605", # invalid escape sequence '\_'. Needed for docs +] diff --git a/requirements.txt b/requirements.txt index 718d46d..795a27e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ numpy>=1.16.4 scipy>=1.1.0 torch>=1.8.0 pre-commit -black==20.8b1 +ruff==0.9.1 librosa>=0.8.0 pytest coverage diff --git a/tests/consistency_test.py b/tests/consistency_test.py index f123e20..4d410af 100644 --- a/tests/consistency_test.py +++ b/tests/consistency_test.py @@ -28,9 +28,7 @@ def test_consistency_withweight(mix_shape, dim, n_src): src_weights_shape = mix_shape[:1] + ones[: dim - 1] + [n_src] + ones[dim - 1 :] src_weights = torch.softmax(torch.randn(src_weights_shape), dim=dim) # Apply mixture consitency - consistent_est_sources = _mixture_consistency( - mix, est_sources, src_weights=src_weights, dim=dim - ) + consistent_est_sources = _mixture_consistency(mix, est_sources, src_weights=src_weights, dim=dim) assert_allclose(mix, consistent_est_sources.sum(dim)) diff --git a/tests/filterbanks_test.py b/tests/filterbanks_test.py index 3b27d6f..534f2ca 100644 --- a/tests/filterbanks_test.py +++ b/tests/filterbanks_test.py @@ -23,7 +23,7 @@ def fb_config_list(): @pytest.mark.parametrize("fb_class", [FreeFB, AnalyticFreeFB, ParamSincFB, MultiphaseGammatoneFB]) @pytest.mark.parametrize("fb_config", fb_config_list()) def test_fb_def_and_forward_lowdim(fb_class, fb_config): - """ Test filterbank definition and encoder/decoder forward.""" + """Test filterbank definition and encoder/decoder forward.""" # Definition enc = Encoder(fb_class(**fb_config)) dec = Decoder(fb_class(**fb_config)) @@ -50,7 +50,7 @@ def test_fb_def_and_forward_lowdim(fb_class, fb_config): @pytest.mark.parametrize("fb_class", [FreeFB, AnalyticFreeFB, ParamSincFB, MultiphaseGammatoneFB]) @pytest.mark.parametrize("fb_config", fb_config_list()) def test_fb_def_and_forward_all_dims(fb_class, fb_config): - """ Test encoder/decoder on other shapes than 3D""" + """Test encoder/decoder on other shapes than 3D""" # Definition enc = Encoder(fb_class(**fb_config)) dec = Decoder(fb_class(**fb_config)) @@ -67,7 +67,7 @@ def test_fb_def_and_forward_all_dims(fb_class, fb_config): @pytest.mark.parametrize("fb_config", fb_config_list()) @pytest.mark.parametrize("ndim", [2, 3, 4]) def test_fb_forward_multichannel(fb_class, fb_config, ndim): - """ Test encoder/decoder in multichannel setting""" + """Test encoder/decoder in multichannel setting""" # Definition enc = Encoder(fb_class(**fb_config)) dec = Decoder(fb_class(**fb_config)) @@ -90,7 +90,7 @@ def test_complexfb_shapes(fb_class, n_filters, kernel_size): @pytest.mark.parametrize("kernel_size", [256, 257, 128, 129]) def test_paramsinc_shape(kernel_size): - """ ParamSincFB has odd length filters """ + """ParamSincFB has odd length filters""" fb = ParamSincFB(n_filters=200, kernel_size=kernel_size) assert fb.filters().shape[-1] == 2 * (kernel_size // 2) + 1 diff --git a/tests/pcen_test.py b/tests/pcen_test.py index dfb0c05..5a9532b 100644 --- a/tests/pcen_test.py +++ b/tests/pcen_test.py @@ -1,9 +1,8 @@ import torch -from asteroid_filterbanks import Encoder, Decoder, STFTFB, transforms +from asteroid_filterbanks import Encoder, STFTFB, transforms from asteroid_filterbanks.pcen import PCEN, StatefulPCEN, ExponentialMovingAverage from torch.testing import assert_allclose import pytest -import re @pytest.mark.parametrize("n_channels", [2, 4]) @@ -163,9 +162,7 @@ def test_stateful_pcen_jit(n_channels, batch_size, n_filters, timesteps, trainab @pytest.mark.parametrize("timesteps", [3, 10]) @pytest.mark.parametrize("per_channel_smoothing", [True, False]) @pytest.mark.parametrize("trainable", [True, False]) -def test_stateful_pcen_from_pcen( - n_channels, batch_size, n_filters, timesteps, per_channel_smoothing, trainable -): +def test_stateful_pcen_from_pcen(n_channels, batch_size, n_filters, timesteps, per_channel_smoothing, trainable): mag_spec = torch.randn(batch_size, n_channels, n_filters, timesteps) pcen = PCEN( diff --git a/tests/stft_test.py b/tests/stft_test.py index db1cf31..985aae8 100644 --- a/tests/stft_test.py +++ b/tests/stft_test.py @@ -24,7 +24,7 @@ def fb_config_list(): @pytest.mark.parametrize("fb_config", fb_config_list()) def test_stft_def(fb_config): - """ Check consistency between two calls.""" + """Check consistency between two calls.""" fb = STFTFB(**fb_config) enc = Encoder(fb) dec = Decoder(fb) @@ -60,7 +60,7 @@ def test_filter_shape(fb_config): @pytest.mark.parametrize("fb_config", fb_config_list()) def test_perfect_istft_default_parameters(fb_config): - """ Unit test perfect reconstruction with default values. """ + """Unit test perfect reconstruction with default values.""" kernel_size = fb_config["kernel_size"] enc, dec = make_enc_dec("stft", **fb_config) inp_wav = torch.randn(2, 1, 32000) @@ -70,12 +70,10 @@ def test_perfect_istft_default_parameters(fb_config): @pytest.mark.parametrize("fb_config", fb_config_list()) -@pytest.mark.parametrize( - "analysis_window_name", ["blackman", "hamming", "hann", "bartlett", "boxcar"] -) +@pytest.mark.parametrize("analysis_window_name", ["blackman", "hamming", "hann", "bartlett", "boxcar"]) @pytest.mark.parametrize("use_torch_window", [True, False]) def test_perfect_resyn_window(fb_config, analysis_window_name, use_torch_window): - """ Unit test perfect reconstruction """ + """Unit test perfect reconstruction""" kernel_size = fb_config["kernel_size"] window = get_window(analysis_window_name, kernel_size) if use_torch_window: @@ -155,4 +153,4 @@ def test_melgram_encoder(n_filters, n_mels, ndim): mel_spec = enc(wav) assert wav.shape[:-1] == mel_spec.shape[:-2] assert mel_spec.shape[-2] == n_mels - conf = melgram_fb.get_config() + melgram_fb.get_config() diff --git a/tests/torch_stft_test.py b/tests/torch_stft_test.py index 0c7bb30..29c92e9 100644 --- a/tests/torch_stft_test.py +++ b/tests/torch_stft_test.py @@ -100,7 +100,7 @@ def test_torch_stft( window = None if window is None else get_window(window, win_length, fftbins=True) if window is not None: # Cannot restore the signal without overlap and near to zero window. - if hop_ratio == 1 and (window ** 2 < 1e-11).any(): + if hop_ratio == 1 and (window**2 < 1e-11).any(): pass fb = TorchSTFTFB.from_torch_args( @@ -160,9 +160,7 @@ def test_torch_stft( # Asteroid always returns a longer signal. assert wav_back_asteroid.shape[-1] >= wav_back.shape[-1] # The unit test is done on the left part of the signal. - assert_allclose( - wav_back_asteroid[: wav_back.shape[-1]], wav_back.float(), rtol=RTOL, atol=ATOL - ) + assert_allclose(wav_back_asteroid[: wav_back.shape[-1]], wav_back.float(), rtol=RTOL, atol=ATOL) def test_raises_if_onesided_is_false(): diff --git a/tests/transforms_test.py b/tests/transforms_test.py index 1e94ea2..e8b114f 100644 --- a/tests/transforms_test.py +++ b/tests/transforms_test.py @@ -39,8 +39,8 @@ def make_encoder_from(fb_class, config): def test_mag_mask(encoder_list): - """ Assert identity mask works. """ - for (enc, fb_dim) in encoder_list: + """Assert identity mask works.""" + for enc, fb_dim in encoder_list: tf_rep = enc(torch.randn(2, 1, 8000)) # [batch, freq, time] id_mag_mask = torch.ones((1, fb_dim // 2, 1)) masked = transforms.apply_mag_mask(tf_rep, id_mag_mask, dim=1) @@ -48,8 +48,8 @@ def test_mag_mask(encoder_list): def test_reim_mask(encoder_list): - """ Assert identity mask works. """ - for (enc, fb_dim) in encoder_list: + """Assert identity mask works.""" + for enc, fb_dim in encoder_list: tf_rep = enc(torch.randn(2, 1, 8000)) # [batch, freq, time] id_reim_mask = torch.ones((1, fb_dim, 1)) masked = transforms.apply_real_mask(tf_rep, id_reim_mask, dim=1) @@ -57,18 +57,16 @@ def test_reim_mask(encoder_list): def test_comp_mask(encoder_list): - """ Assert identity mask works. """ - for (enc, fb_dim) in encoder_list: + """Assert identity mask works.""" + for enc, fb_dim in encoder_list: tf_rep = enc(torch.randn(2, 1, 8000)) # [batch, freq, time] - id_complex_mask = torch.cat( - (torch.ones((1, fb_dim // 2, 1)), torch.zeros((1, fb_dim // 2, 1))), dim=1 - ) + id_complex_mask = torch.cat((torch.ones((1, fb_dim // 2, 1)), torch.zeros((1, fb_dim // 2, 1))), dim=1) masked = transforms.apply_complex_mask(tf_rep, id_complex_mask, dim=1) assert_allclose(masked, tf_rep) def test_mag(encoder_list): - for (enc, fb_dim) in encoder_list: + for enc, fb_dim in encoder_list: tf_rep = enc(torch.randn(2, 1, 16000)) # [batch, freq, time] batch, freq, time = tf_rep.shape mag = transforms.mag(tf_rep, dim=1) @@ -76,7 +74,7 @@ def test_mag(encoder_list): def test_cat(encoder_list): - for (enc, fb_dim) in encoder_list: + for enc, fb_dim in encoder_list: tf_rep = enc(torch.randn(2, 1, 16000)) # [batch, freq, time] batch, freq, time = tf_rep.shape mag = transforms.magreim(tf_rep, dim=1) @@ -89,7 +87,7 @@ def test_cat(encoder_list): ) @pytest.mark.parametrize("dim", [0, 1, 2]) def test_to_numpy(np_torch_tuple, dim): - """ Test torch --> np conversion (right angles)""" + """Test torch --> np conversion (right angles)""" from_np, from_torch = np_torch_tuple if dim == 0: np_array = np.array(from_np) @@ -112,7 +110,7 @@ def test_to_numpy(np_torch_tuple, dim): ) @pytest.mark.parametrize("dim", [0, 1, 2]) def test_from_numpy(np_torch_tuple, dim): - """ Test np --> torch conversion (right angles)""" + """Test np --> torch conversion (right angles)""" from_np, from_torch = np_torch_tuple if dim == 0: np_array = np.array(from_np) @@ -131,7 +129,7 @@ def test_from_numpy(np_torch_tuple, dim): @pytest.mark.parametrize("dim", [0, 1, 2, 3]) def test_return_ticket_np_torch(dim): - """ Test torch --> np --> torch --> np conversion""" + """Test torch --> np --> torch --> np conversion""" max_tested_ndim = 4 # Random tensor shape tensor_shape = [random.randint(1, 10) for _ in range(max_tested_ndim)] @@ -149,7 +147,7 @@ def test_return_ticket_np_torch(dim): @pytest.mark.parametrize("dim", [0, 1, 2, 3]) def test_angle_mag_recompostion(dim): - """ Test complex --> (mag, angle) --> complex conversions""" + """Test complex --> (mag, angle) --> complex conversions""" max_tested_ndim = 4 # Random tensor shape tensor_shape = [random.randint(1, 10) for _ in range(max_tested_ndim)] @@ -164,7 +162,7 @@ def test_angle_mag_recompostion(dim): @pytest.mark.parametrize("dim", [0, 1, 2, 3]) def test_check_complex_error(dim): - """ Test error in angle """ + """Test error in angle""" not_complex = torch.randn(3, 5, 7, 9, 15) with pytest.raises(AssertionError): transforms.check_complex(not_complex, dim=dim)