Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate linting workflow to ruff check/format #26

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 0 additions & 13 deletions .flake8

This file was deleted.

33 changes: 22 additions & 11 deletions .github/workflows/test_formatting.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,34 @@ on: [push, pull_request]

jobs:
code-black:
name: CI
name: Linting & Formatting
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v2
- name: Set up Python 3.6
uses: actions/setup-python@v2

- name: Install a specific version of uv
uses: astral-sh/setup-uv@v5
with:
python-version: 3.6
python-version: 3.13
enable-cache: false

- name: Install ruff
# Keep version in sync with pre-commit and requirements.txt
run: pip install ruff==0.9.1

- name: Install Black and flake8
run: pip install black==20.8b1 flake8
- name: Run Black
run: python -m black --config=pyproject.toml --check asteroid_filterbanks tests
- name: Run ruff format
run: |
ruff format --diff \
asteroid_filterbanks \
tests

- name: Link with flake8
- name: Run ruff check
# Exit on important linting errors and warn about others.
run: |
python -m flake8 asteroid tests --show-source --statistics --select=F6,F7,F82,F52
python -m flake8 --config .flake8 --exit-zero asteroid tests --statistics
ruff check --diff --select=F6,F7,F82,F52 \
asteroid_filterbanks \
tests
ruff check --statistics --exit-zero \
asteroid_filterbanks \
tests
10 changes: 6 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@ repos:
- id: trailing-whitespace
- id: check-yaml

- repo: https://github.com/psf/black
rev: 20.8b1
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.9.1
hooks:
- id: black
types: [python]
- id: ruff
types_or: [ python, pyi ]
- id: ruff-format
types_or: [ python, pyi ]
12 changes: 6 additions & 6 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ installing asteroid-filterbanks everytime you change something.
To do that, install asteroid-filterbanks in develop mode either with pip
```pip install -e .[tests]``` or with python ```python setup.py develop```.

To avoid formatting roundtrips in PRs, Asteroid relies on [`black`](https://github.com/psf/black)
To avoid formatting roundtrips in PRs, Asteroid relies on [`ruff`](https://docs.astral.sh/ruff/)
and [`pre-commit-hooks`](https://github.com/pre-commit/pre-commit-hooks) to handle formatting
for us. You'll need to install `requirements.txt` and install git hooks with
`pre-commit install`.
Expand All @@ -26,7 +26,7 @@ git clone your_fork_url
cd asteroid-filterbanks
pip install -r requirements.txt
pip install -e .
pre-commit install # To run black before commit
pre-commit install # To run `ruff format` before commit

# Make your changes
# Test them locally
Expand Down Expand Up @@ -62,10 +62,10 @@ docstrings in the codebase for examples.
### Coding style

We use [pre-commit hooks](../.pre-commit-config.yaml) to format the code using
`black`.
The code is checked for `black`- and `flake8`- compliance on every commit with
GitHub actions. Remember, continuous integration is not here to be all green,
be to help us see where to improve !
`ruff format`.
The code is checked for compliance with `ruff check` and `ruff format` on every
commit with GitHub actions. Remember, continuous integration is not here to be
all green, be to help us see where to improve!


If you have any question, [open an issue][issue] or [join the slack][slack],
Expand Down
5 changes: 1 addition & 4 deletions asteroid_filterbanks/analytic_free_fb.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,7 @@ def __init__(self, n_filters, kernel_size, stride=None, sample_rate=8000.0, **kw
self.cutoff = int(n_filters // 2)
self.n_feats_out = 2 * self.cutoff
if n_filters % 2 != 0:
print(
"If the number of filters `n_filters` is odd, the "
"output size of the layer will be `n_filters - 1`."
)
print("If the number of filters `n_filters` is odd, the output size of the layer will be `n_filters - 1`.")

self._filters = nn.Parameter(torch.ones(n_filters // 2, 1, kernel_size), requires_grad=True)
for p in self.parameters():
Expand Down
4 changes: 1 addition & 3 deletions asteroid_filterbanks/deprecation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@ def decorator(func):
def wrapped(*args, **kwargs):
from_what = "a future release" if version is None else f"asteroid v{version}"
warn_message = (
f"{func.__module__}.{func.__name__} has been deprecated "
f"and will be removed from {from_what}. "
f"{message}"
f"{func.__module__}.{func.__name__} has been deprecated and will be removed from {from_what}. {message}"
)
warnings.warn(warn_message, VisibleDeprecationWarning, stacklevel=2)
return func(*args, **kwargs)
Expand Down
14 changes: 6 additions & 8 deletions asteroid_filterbanks/enc_dec.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(self, n_filters, kernel_size, stride=None, sample_rate=8000.0):
self.sample_rate = sample_rate

def filters(self):
""" Abstract method for filters. """
"""Abstract method for filters."""
raise NotImplementedError

def pre_analysis(self, wav: torch.Tensor):
Expand Down Expand Up @@ -97,22 +97,22 @@ def filters(self):
return self.filterbank.filters()

def compute_filter_pinv(self, filters):
""" Computes pseudo inverse filterbank of given filters."""
"""Computes pseudo inverse filterbank of given filters."""
scale = self.filterbank.stride / self.filterbank.kernel_size
shape = filters.shape
ifilt = torch.pinverse(filters.squeeze()).transpose(-1, -2).view(shape)
# Compensate for the overlap-add.
return ifilt * scale

def get_filters(self):
""" Returns filters or pinv filters depending on `is_pinv` attribute """
"""Returns filters or pinv filters depending on `is_pinv` attribute"""
if self.is_pinv:
return self.compute_filter_pinv(self.filters())
else:
return self.filters()

def get_config(self):
""" Returns dictionary of arguments to re-instantiate the class."""
"""Returns dictionary of arguments to re-instantiate the class."""
config = {"is_pinv": self.is_pinv}
base_config = self.filterbank.get_config()
return dict(list(base_config.items()) + list(config.items()))
Expand Down Expand Up @@ -226,9 +226,7 @@ def multishape_conv1d(
return batch_packed_1d_conv(waveform, filters, stride=stride, padding=padding)


def batch_packed_1d_conv(
inp: torch.Tensor, filters: torch.Tensor, stride: int = 1, padding: int = 0
):
def batch_packed_1d_conv(inp: torch.Tensor, filters: torch.Tensor, stride: int = 1, padding: int = 0):
# Here we perform multichannel / multi-source convolution.
# Output should be (batch, channels, freq, conv_time)
batched_conv = F.conv1d(inp.view(-1, 1, inp.shape[-1]), filters, stride=stride, padding=padding)
Expand Down Expand Up @@ -261,7 +259,7 @@ def __init__(self, filterbank, is_pinv=False, padding=0, output_padding=0):

@classmethod
def pinv_of(cls, filterbank):
""" Returns an Decoder, pseudo inverse of a filterbank or Encoder."""
"""Returns an Decoder, pseudo inverse of a filterbank or Encoder."""
if isinstance(filterbank, Filterbank):
return cls(filterbank, is_pinv=True)
elif isinstance(filterbank, Encoder):
Expand Down
10 changes: 4 additions & 6 deletions asteroid_filterbanks/griffin_lim.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import math
from typing import Optional, List

from . import Encoder, Decoder, STFTFB
from . import Decoder, STFTFB
from .stft_fb import perfect_synthesis_window
from . import transforms

Expand Down Expand Up @@ -154,9 +154,7 @@ def misi(
complex_specgram = transforms.from_magphase(mag_specgrams, angles)
wavs = istft_dec(complex_specgram)
# Make wavs sum up to the mixture
consistent_wavs = _mixture_consistency(
mixture_wav, wavs, src_weights=src_weights, dim=wav_dim
)
consistent_wavs = _mixture_consistency(mixture_wav, wavs, src_weights=src_weights, dim=wav_dim)
# Back to TF domain
rebuilt = stft_enc(consistent_wavs)
# Update phase estimates (with momentum). Keep the momentum here
Expand Down Expand Up @@ -216,7 +214,7 @@ def _mixture_consistency(
all_dims: List[int] = torch.arange(est_sources.ndim).tolist()
all_dims.pop(dim) # Remove source axis
all_dims.pop(0) # Remove batch axis
src_weights = torch.mean(est_sources ** 2, dim=all_dims, keepdim=True)
src_weights = torch.mean(est_sources**2, dim=all_dims, keepdim=True)
# Make sure that the weights sum up to 1
norm_weights = torch.sum(src_weights, dim=dim, keepdim=True) + 1e-8
src_weights = src_weights / norm_weights
Expand All @@ -233,7 +231,7 @@ def _mixture_consistency(
raise RuntimeError(
f"The size of the mixture tensor should match the "
f"size of the est_sources tensor. Expected mixture"
f"tensor to have {n} or {n-1} dimension, found {m}."
f"tensor to have {n} or {n - 1} dimension, found {m}."
)
# Compute remove
new_sources = est_sources + src_weights * residual
Expand Down
9 changes: 2 additions & 7 deletions asteroid_filterbanks/melgram_fb.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ def __init__(
norm="slaney",
**kwargs,
):

self.n_mels = n_mels
self.fmin = fmin
self.fmax = fmax
Expand All @@ -50,9 +49,7 @@ def __init__(
sample_rate=sample_rate,
**kwargs,
)
self.mel_scale = MelScale(
n_filters, sample_rate=sample_rate, n_mels=n_mels, fmin=fmin, fmax=fmax, norm=norm
)
self.mel_scale = MelScale(n_filters, sample_rate=sample_rate, n_mels=n_mels, fmin=fmin, fmax=fmax, norm=norm)
self.n_feats_out = n_mels

def post_analysis(self, spec: torch.Tensor):
Expand Down Expand Up @@ -90,9 +87,7 @@ def __init__(
from librosa.filters import mel

super().__init__()
fb_mat = mel(
sr=sample_rate, n_fft=n_filters, fmin=fmin, fmax=fmax, n_mels=n_mels, norm=norm
)
fb_mat = mel(sr=sample_rate, n_fft=n_filters, fmin=fmin, fmax=fmax, n_mels=n_mels, norm=norm)
self.register_buffer("fb_mat", torch.from_numpy(fb_mat).unsqueeze(0))

def forward(self, spec: torch.Tensor):
Expand Down
27 changes: 8 additions & 19 deletions asteroid_filterbanks/multiphase_gammatone_fb.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,7 @@ def generate_mpgtf(samplerate_hz, len_sec, n_filters):
current_center_freq_hz = center_freq_hz_min

# Determine number of phase shifts per center frequency
phase_pair_count = (np.ones(n_center_freqs) * np.floor(n_filters / 2 / n_center_freqs)).astype(
int
)
phase_pair_count = (np.ones(n_center_freqs) * np.floor(n_filters / 2 / n_center_freqs)).astype(int)
remaining_phase_pairs = ((n_filters - np.sum(phase_pair_count) * 2) / 2).astype(int)
if remaining_phase_pairs > 0:
phase_pair_count[:remaining_phase_pairs] = phase_pair_count[:remaining_phase_pairs] + 1
Expand All @@ -70,48 +68,39 @@ def generate_mpgtf(samplerate_hz, len_sec, n_filters):
index = index + 1

# Second half of filters: phase_shifts in [pi, 2*pi)
filterbank[index : index + phase_pair_count[i], :] = -filterbank[
index - phase_pair_count[i] : index, :
]
filterbank[index : index + phase_pair_count[i], :] = -filterbank[index - phase_pair_count[i] : index, :]

# Prepare for next center frequency
index = index + phase_pair_count[i]
current_center_freq_hz = erb_scale_2_freq_hz(
freq_hz_2_erb_scale(current_center_freq_hz) + 1
)
current_center_freq_hz = erb_scale_2_freq_hz(freq_hz_2_erb_scale(current_center_freq_hz) + 1)

filterbank = normalize_filters(filterbank)
return filterbank


def gammatone_impulse_response(samplerate_hz, len_sec, center_freq_hz, phase_shift):
""" Generate single parametrized gammatone filter """
"""Generate single parametrized gammatone filter"""
p = 2 # filter order
erb = 24.7 + 0.108 * center_freq_hz # equivalent rectangular bandwidth
divisor = (np.pi * math.factorial(2 * p - 2) * np.power(2, float(-(2 * p - 2)))) / np.square(
math.factorial(p - 1)
)
divisor = (np.pi * math.factorial(2 * p - 2) * np.power(2, float(-(2 * p - 2)))) / np.square(math.factorial(p - 1))
b = erb / divisor # bandwidth parameter
a = 1.0 # amplitude. This is varied later by the normalization process.
len_sample = int(np.floor(samplerate_hz * len_sec))
t = np.linspace(1.0 / samplerate_hz, len_sec, len_sample)
gammatone_ir = (
a
* np.power(t, p - 1)
* np.exp(-2 * np.pi * b * t)
* np.cos(2 * np.pi * center_freq_hz * t + phase_shift)
a * np.power(t, p - 1) * np.exp(-2 * np.pi * b * t) * np.cos(2 * np.pi * center_freq_hz * t + phase_shift)
)
return gammatone_ir


def erb_scale_2_freq_hz(freq_erb):
""" Convert frequency on ERB scale to frequency in Hertz """
"""Convert frequency on ERB scale to frequency in Hertz"""
freq_hz = (np.exp(freq_erb / 9.265) - 1) * 24.7 * 9.265
return freq_hz


def freq_hz_2_erb_scale(freq_hz):
""" Convert frequency in Hertz to frequency on ERB scale """
"""Convert frequency in Hertz to frequency on ERB scale"""
freq_erb = 9.265 * np.log(1 + freq_hz / (24.7 * 9.265))
return freq_erb

Expand Down
23 changes: 7 additions & 16 deletions asteroid_filterbanks/param_sinc_fb.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,36 +55,27 @@ def __init__(
self.n_feats_out = 2 * self.cutoff
self._initialize_filters()
if n_filters % 2 != 0:
print(
"If the number of filters `n_filters` is odd, the "
"output size of the layer will be `n_filters - 1`."
)
print("If the number of filters `n_filters` is odd, the output size of the layer will be `n_filters - 1`.")

window_ = np.hamming(self.kernel_size)[: self.half_kernel] # Half window
n_ = (
2 * np.pi * (torch.arange(-self.half_kernel, 0.0).view(1, -1) / self.sample_rate)
) # Half time vector
n_ = 2 * np.pi * (torch.arange(-self.half_kernel, 0.0).view(1, -1) / self.sample_rate) # Half time vector
self.register_buffer("window_", torch.from_numpy(window_).float())
self.register_buffer("n_", n_)

def _initialize_filters(self):
""" Filter Initialization along the Mel scale"""
"""Filter Initialization along the Mel scale"""
low_hz = 30
high_hz = self.sample_rate / 2 - (self.min_low_hz + self.min_band_hz)
mel = np.linspace(
self.to_mel(low_hz), self.to_mel(high_hz), self.n_filters // 2 + 1, dtype="float32"
)
mel = np.linspace(self.to_mel(low_hz), self.to_mel(high_hz), self.n_filters // 2 + 1, dtype="float32")
hz = self.to_hz(mel)
# filters parameters (out_channels // 2, 1)
self.low_hz_ = nn.Parameter(torch.from_numpy(hz[:-1]).view(-1, 1))
self.band_hz_ = nn.Parameter(torch.from_numpy(np.diff(hz)).view(-1, 1))

def filters(self):
""" Compute filters from parameters """
"""Compute filters from parameters"""
low = self.min_low_hz + torch.abs(self.low_hz_)
high = torch.clamp(
low + self.min_band_hz + torch.abs(self.band_hz_), self.min_low_hz, self.sample_rate / 2
)
high = torch.clamp(low + self.min_band_hz + torch.abs(self.band_hz_), self.min_low_hz, self.sample_rate / 2)
cos_filters = self.make_filters(low, high, filt_type="cos")
sin_filters = self.make_filters(low, high, filt_type="sin")
return torch.cat([cos_filters, sin_filters], dim=0)
Expand Down Expand Up @@ -116,7 +107,7 @@ def to_hz(mel):
return 700 * (10 ** (mel / 2595) - 1)

def get_config(self):
""" Returns dictionary of arguments to re-instantiate the class."""
"""Returns dictionary of arguments to re-instantiate the class."""
config = {
"min_low_hz": self.min_low_hz,
"min_band_hz": self.min_band_hz,
Expand Down
Loading
Loading