Skip to content

Commit

Permalink
Run ruff format
Browse files Browse the repository at this point in the history
  • Loading branch information
mweinelt committed Jan 12, 2025
1 parent bf73ed7 commit 8d54b1a
Show file tree
Hide file tree
Showing 17 changed files with 70 additions and 132 deletions.
5 changes: 1 addition & 4 deletions asteroid_filterbanks/analytic_free_fb.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,7 @@ def __init__(self, n_filters, kernel_size, stride=None, sample_rate=8000.0, **kw
self.cutoff = int(n_filters // 2)
self.n_feats_out = 2 * self.cutoff
if n_filters % 2 != 0:
print(
"If the number of filters `n_filters` is odd, the "
"output size of the layer will be `n_filters - 1`."
)
print("If the number of filters `n_filters` is odd, the output size of the layer will be `n_filters - 1`.")

self._filters = nn.Parameter(torch.ones(n_filters // 2, 1, kernel_size), requires_grad=True)
for p in self.parameters():
Expand Down
4 changes: 1 addition & 3 deletions asteroid_filterbanks/deprecation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@ def decorator(func):
def wrapped(*args, **kwargs):
from_what = "a future release" if version is None else f"asteroid v{version}"
warn_message = (
f"{func.__module__}.{func.__name__} has been deprecated "
f"and will be removed from {from_what}. "
f"{message}"
f"{func.__module__}.{func.__name__} has been deprecated and will be removed from {from_what}. {message}"
)
warnings.warn(warn_message, VisibleDeprecationWarning, stacklevel=2)
return func(*args, **kwargs)
Expand Down
14 changes: 6 additions & 8 deletions asteroid_filterbanks/enc_dec.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(self, n_filters, kernel_size, stride=None, sample_rate=8000.0):
self.sample_rate = sample_rate

def filters(self):
""" Abstract method for filters. """
"""Abstract method for filters."""
raise NotImplementedError

def pre_analysis(self, wav: torch.Tensor):
Expand Down Expand Up @@ -97,22 +97,22 @@ def filters(self):
return self.filterbank.filters()

def compute_filter_pinv(self, filters):
""" Computes pseudo inverse filterbank of given filters."""
"""Computes pseudo inverse filterbank of given filters."""
scale = self.filterbank.stride / self.filterbank.kernel_size
shape = filters.shape
ifilt = torch.pinverse(filters.squeeze()).transpose(-1, -2).view(shape)
# Compensate for the overlap-add.
return ifilt * scale

def get_filters(self):
""" Returns filters or pinv filters depending on `is_pinv` attribute """
"""Returns filters or pinv filters depending on `is_pinv` attribute"""
if self.is_pinv:
return self.compute_filter_pinv(self.filters())
else:
return self.filters()

def get_config(self):
""" Returns dictionary of arguments to re-instantiate the class."""
"""Returns dictionary of arguments to re-instantiate the class."""
config = {"is_pinv": self.is_pinv}
base_config = self.filterbank.get_config()
return dict(list(base_config.items()) + list(config.items()))
Expand Down Expand Up @@ -226,9 +226,7 @@ def multishape_conv1d(
return batch_packed_1d_conv(waveform, filters, stride=stride, padding=padding)


def batch_packed_1d_conv(
inp: torch.Tensor, filters: torch.Tensor, stride: int = 1, padding: int = 0
):
def batch_packed_1d_conv(inp: torch.Tensor, filters: torch.Tensor, stride: int = 1, padding: int = 0):
# Here we perform multichannel / multi-source convolution.
# Output should be (batch, channels, freq, conv_time)
batched_conv = F.conv1d(inp.view(-1, 1, inp.shape[-1]), filters, stride=stride, padding=padding)
Expand Down Expand Up @@ -261,7 +259,7 @@ def __init__(self, filterbank, is_pinv=False, padding=0, output_padding=0):

@classmethod
def pinv_of(cls, filterbank):
""" Returns an Decoder, pseudo inverse of a filterbank or Encoder."""
"""Returns an Decoder, pseudo inverse of a filterbank or Encoder."""
if isinstance(filterbank, Filterbank):
return cls(filterbank, is_pinv=True)
elif isinstance(filterbank, Encoder):
Expand Down
8 changes: 3 additions & 5 deletions asteroid_filterbanks/griffin_lim.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,7 @@ def misi(
complex_specgram = transforms.from_magphase(mag_specgrams, angles)
wavs = istft_dec(complex_specgram)
# Make wavs sum up to the mixture
consistent_wavs = _mixture_consistency(
mixture_wav, wavs, src_weights=src_weights, dim=wav_dim
)
consistent_wavs = _mixture_consistency(mixture_wav, wavs, src_weights=src_weights, dim=wav_dim)
# Back to TF domain
rebuilt = stft_enc(consistent_wavs)
# Update phase estimates (with momentum). Keep the momentum here
Expand Down Expand Up @@ -216,7 +214,7 @@ def _mixture_consistency(
all_dims: List[int] = torch.arange(est_sources.ndim).tolist()
all_dims.pop(dim) # Remove source axis
all_dims.pop(0) # Remove batch axis
src_weights = torch.mean(est_sources ** 2, dim=all_dims, keepdim=True)
src_weights = torch.mean(est_sources**2, dim=all_dims, keepdim=True)
# Make sure that the weights sum up to 1
norm_weights = torch.sum(src_weights, dim=dim, keepdim=True) + 1e-8
src_weights = src_weights / norm_weights
Expand All @@ -233,7 +231,7 @@ def _mixture_consistency(
raise RuntimeError(
f"The size of the mixture tensor should match the "
f"size of the est_sources tensor. Expected mixture"
f"tensor to have {n} or {n-1} dimension, found {m}."
f"tensor to have {n} or {n - 1} dimension, found {m}."
)
# Compute remove
new_sources = est_sources + src_weights * residual
Expand Down
9 changes: 2 additions & 7 deletions asteroid_filterbanks/melgram_fb.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ def __init__(
norm="slaney",
**kwargs,
):

self.n_mels = n_mels
self.fmin = fmin
self.fmax = fmax
Expand All @@ -50,9 +49,7 @@ def __init__(
sample_rate=sample_rate,
**kwargs,
)
self.mel_scale = MelScale(
n_filters, sample_rate=sample_rate, n_mels=n_mels, fmin=fmin, fmax=fmax, norm=norm
)
self.mel_scale = MelScale(n_filters, sample_rate=sample_rate, n_mels=n_mels, fmin=fmin, fmax=fmax, norm=norm)
self.n_feats_out = n_mels

def post_analysis(self, spec: torch.Tensor):
Expand Down Expand Up @@ -90,9 +87,7 @@ def __init__(
from librosa.filters import mel

super().__init__()
fb_mat = mel(
sr=sample_rate, n_fft=n_filters, fmin=fmin, fmax=fmax, n_mels=n_mels, norm=norm
)
fb_mat = mel(sr=sample_rate, n_fft=n_filters, fmin=fmin, fmax=fmax, n_mels=n_mels, norm=norm)
self.register_buffer("fb_mat", torch.from_numpy(fb_mat).unsqueeze(0))

def forward(self, spec: torch.Tensor):
Expand Down
27 changes: 8 additions & 19 deletions asteroid_filterbanks/multiphase_gammatone_fb.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,7 @@ def generate_mpgtf(samplerate_hz, len_sec, n_filters):
current_center_freq_hz = center_freq_hz_min

# Determine number of phase shifts per center frequency
phase_pair_count = (np.ones(n_center_freqs) * np.floor(n_filters / 2 / n_center_freqs)).astype(
int
)
phase_pair_count = (np.ones(n_center_freqs) * np.floor(n_filters / 2 / n_center_freqs)).astype(int)
remaining_phase_pairs = ((n_filters - np.sum(phase_pair_count) * 2) / 2).astype(int)
if remaining_phase_pairs > 0:
phase_pair_count[:remaining_phase_pairs] = phase_pair_count[:remaining_phase_pairs] + 1
Expand All @@ -70,48 +68,39 @@ def generate_mpgtf(samplerate_hz, len_sec, n_filters):
index = index + 1

# Second half of filters: phase_shifts in [pi, 2*pi)
filterbank[index : index + phase_pair_count[i], :] = -filterbank[
index - phase_pair_count[i] : index, :
]
filterbank[index : index + phase_pair_count[i], :] = -filterbank[index - phase_pair_count[i] : index, :]

# Prepare for next center frequency
index = index + phase_pair_count[i]
current_center_freq_hz = erb_scale_2_freq_hz(
freq_hz_2_erb_scale(current_center_freq_hz) + 1
)
current_center_freq_hz = erb_scale_2_freq_hz(freq_hz_2_erb_scale(current_center_freq_hz) + 1)

filterbank = normalize_filters(filterbank)
return filterbank


def gammatone_impulse_response(samplerate_hz, len_sec, center_freq_hz, phase_shift):
""" Generate single parametrized gammatone filter """
"""Generate single parametrized gammatone filter"""
p = 2 # filter order
erb = 24.7 + 0.108 * center_freq_hz # equivalent rectangular bandwidth
divisor = (np.pi * math.factorial(2 * p - 2) * np.power(2, float(-(2 * p - 2)))) / np.square(
math.factorial(p - 1)
)
divisor = (np.pi * math.factorial(2 * p - 2) * np.power(2, float(-(2 * p - 2)))) / np.square(math.factorial(p - 1))
b = erb / divisor # bandwidth parameter
a = 1.0 # amplitude. This is varied later by the normalization process.
len_sample = int(np.floor(samplerate_hz * len_sec))
t = np.linspace(1.0 / samplerate_hz, len_sec, len_sample)
gammatone_ir = (
a
* np.power(t, p - 1)
* np.exp(-2 * np.pi * b * t)
* np.cos(2 * np.pi * center_freq_hz * t + phase_shift)
a * np.power(t, p - 1) * np.exp(-2 * np.pi * b * t) * np.cos(2 * np.pi * center_freq_hz * t + phase_shift)
)
return gammatone_ir


def erb_scale_2_freq_hz(freq_erb):
""" Convert frequency on ERB scale to frequency in Hertz """
"""Convert frequency on ERB scale to frequency in Hertz"""
freq_hz = (np.exp(freq_erb / 9.265) - 1) * 24.7 * 9.265
return freq_hz


def freq_hz_2_erb_scale(freq_hz):
""" Convert frequency in Hertz to frequency on ERB scale """
"""Convert frequency in Hertz to frequency on ERB scale"""
freq_erb = 9.265 * np.log(1 + freq_hz / (24.7 * 9.265))
return freq_erb

Expand Down
23 changes: 7 additions & 16 deletions asteroid_filterbanks/param_sinc_fb.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,36 +55,27 @@ def __init__(
self.n_feats_out = 2 * self.cutoff
self._initialize_filters()
if n_filters % 2 != 0:
print(
"If the number of filters `n_filters` is odd, the "
"output size of the layer will be `n_filters - 1`."
)
print("If the number of filters `n_filters` is odd, the output size of the layer will be `n_filters - 1`.")

window_ = np.hamming(self.kernel_size)[: self.half_kernel] # Half window
n_ = (
2 * np.pi * (torch.arange(-self.half_kernel, 0.0).view(1, -1) / self.sample_rate)
) # Half time vector
n_ = 2 * np.pi * (torch.arange(-self.half_kernel, 0.0).view(1, -1) / self.sample_rate) # Half time vector
self.register_buffer("window_", torch.from_numpy(window_).float())
self.register_buffer("n_", n_)

def _initialize_filters(self):
""" Filter Initialization along the Mel scale"""
"""Filter Initialization along the Mel scale"""
low_hz = 30
high_hz = self.sample_rate / 2 - (self.min_low_hz + self.min_band_hz)
mel = np.linspace(
self.to_mel(low_hz), self.to_mel(high_hz), self.n_filters // 2 + 1, dtype="float32"
)
mel = np.linspace(self.to_mel(low_hz), self.to_mel(high_hz), self.n_filters // 2 + 1, dtype="float32")
hz = self.to_hz(mel)
# filters parameters (out_channels // 2, 1)
self.low_hz_ = nn.Parameter(torch.from_numpy(hz[:-1]).view(-1, 1))
self.band_hz_ = nn.Parameter(torch.from_numpy(np.diff(hz)).view(-1, 1))

def filters(self):
""" Compute filters from parameters """
"""Compute filters from parameters"""
low = self.min_low_hz + torch.abs(self.low_hz_)
high = torch.clamp(
low + self.min_band_hz + torch.abs(self.band_hz_), self.min_low_hz, self.sample_rate / 2
)
high = torch.clamp(low + self.min_band_hz + torch.abs(self.band_hz_), self.min_low_hz, self.sample_rate / 2)
cos_filters = self.make_filters(low, high, filt_type="cos")
sin_filters = self.make_filters(low, high, filt_type="sin")
return torch.cat([cos_filters, sin_filters], dim=0)
Expand Down Expand Up @@ -116,7 +107,7 @@ def to_hz(mel):
return 700 * (10 ** (mel / 2595) - 1)

def get_config(self):
""" Returns dictionary of arguments to re-instantiate the class."""
"""Returns dictionary of arguments to re-instantiate the class."""
config = {
"min_low_hz": self.min_low_hz,
"min_band_hz": self.min_band_hz,
Expand Down
10 changes: 3 additions & 7 deletions asteroid_filterbanks/pcen.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,7 @@ def __init__(
super().__init__()

if trainable is True or trainable is False:
trainable = TrainableParameters(
alpha=trainable, delta=trainable, root=trainable, smooth=trainable
)
trainable = TrainableParameters(alpha=trainable, delta=trainable, root=trainable, smooth=trainable)

self.trainable = trainable
self.n_channels = n_channels
Expand All @@ -89,9 +87,7 @@ def __init__(
self.delta = nn.Parameter(
torch.full((self.n_channels,), fill_value=delta), requires_grad=self.trainable["delta"]
)
self.root = nn.Parameter(
torch.full((self.n_channels,), fill_value=root), requires_grad=self.trainable["root"]
)
self.root = nn.Parameter(torch.full((self.n_channels,), fill_value=root), requires_grad=self.trainable["root"])

self.floor = floor
self.ema = ExponentialMovingAverage(
Expand Down Expand Up @@ -119,7 +115,7 @@ def forward(
# Equation (1) in [1]
out = (
mag_spec / (self.floor + ema_smoother) ** alpha + self.delta
) ** one_over_root - self.delta ** one_over_root
) ** one_over_root - self.delta**one_over_root
out = out.transpose(1, -1)
if post_squeeze:
out = out.squeeze(1)
Expand Down
14 changes: 4 additions & 10 deletions asteroid_filterbanks/stft_fb.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@ class STFTFB(Filterbank):
n_feats_out (int): Number of output filters.
"""

def __init__(
self, n_filters, kernel_size, stride=None, window=None, sample_rate=8000.0, **kwargs
):
def __init__(self, n_filters, kernel_size, stride=None, window=None, sample_rate=8000.0, **kwargs):
super().__init__(n_filters, kernel_size, stride=stride, sample_rate=sample_rate)
assert n_filters >= kernel_size
if n_filters % 2 != 0:
Expand All @@ -38,9 +36,7 @@ def __init__(
window = window.data.numpy()
ws = window.size
if not (ws == kernel_size):
raise AssertionError(
f"Expected window of size {kernel_size}. Received {ws} instead."
)
raise AssertionError(f"Expected window of size {kernel_size}. Received {ws} instead.")
self.window = window
# Create and normalize DFT filters (can be overcomplete)
filters = np.fft.fft(np.eye(n_filters))
Expand All @@ -50,9 +46,7 @@ def __init__(
lpad = int((n_filters - kernel_size) // 2)
rpad = int(n_filters - kernel_size - lpad)
indexes = list(range(lpad, n_filters - rpad))
filters = np.vstack(
[np.real(filters[: self.cutoff, indexes]), np.imag(filters[: self.cutoff, indexes])]
)
filters = np.vstack([np.real(filters[: self.cutoff, indexes]), np.imag(filters[: self.cutoff, indexes])])

filters[0, :] /= np.sqrt(2)
filters[n_filters // 2, :] /= np.sqrt(2)
Expand Down Expand Up @@ -92,7 +86,7 @@ def perfect_synthesis_window(analysis_window, hop_size):

loop_on = (win_size - 1) // hop_size
for win_idx in range(-loop_on, loop_on + 1):
shifted = np.roll(analysis_window ** 2, win_idx * hop_size)
shifted = np.roll(analysis_window**2, win_idx * hop_size)
if win_idx < 0:
shifted[win_idx * hop_size :] = 0
elif win_idx > 0:
Expand Down
23 changes: 8 additions & 15 deletions asteroid_filterbanks/torch_stft_fb.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,8 @@ def __init__(
**kwargs,
):
if n_filters != kernel_size:
raise NotImplementedError(
"Cannot set `n_filters!=kernel_size` in TorchSTFTFB, untested."
)
super().__init__(
n_filters, kernel_size, stride=stride, window=window, sample_rate=sample_rate, **kwargs
)
raise NotImplementedError("Cannot set `n_filters!=kernel_size` in TorchSTFTFB, untested.")
super().__init__(n_filters, kernel_size, stride=stride, window=window, sample_rate=sample_rate, **kwargs)
self.center = center
self.pad_mode = pad_mode
if normalized:
Expand Down Expand Up @@ -129,15 +125,15 @@ def post_synthesis(self, wav):

@script_if_tracing
def _restore_freqs_an(spec, n_filters: int):
spec[..., 0, :] *= 2 ** 0.5
spec[..., n_filters // 2, :] *= 2 ** 0.5
spec[..., 0, :] *= 2**0.5
spec[..., n_filters // 2, :] *= 2**0.5
return spec


@script_if_tracing
def _restore_freqs_syn(spec, n_filters: int):
spec[..., 0, :] /= 2 ** 0.5
spec[..., n_filters // 2, :] /= 2 ** 0.5
spec[..., 0, :] /= 2**0.5
spec[..., n_filters // 2, :] /= 2**0.5
return spec


Expand All @@ -161,8 +157,7 @@ def ola_with_wdiv(wav, window, kernel_size: int, stride: int, center: bool = Tru
if min_mask.any() and not torch.jit.is_scripting():
# Warning instead of error. Might be trimmed afterward.
warnings.warn(
f"Minimum NOLA should be above 1e-11, Found {wsq_ola.abs().min()}. "
f"Dividind only where possible.",
f"Minimum NOLA should be above 1e-11, Found {wsq_ola.abs().min()}. Dividind only where possible.",
RuntimeWarning,
)
wav[~min_mask] = wav[~min_mask] / wsq_ola[~min_mask]
Expand All @@ -177,9 +172,7 @@ def square_ola(window: torch.Tensor, kernel_size: int, stride: int, n_frame: int


@script_if_tracing
def pad_all_shapes(
x: torch.Tensor, pad_shape: Tuple[int, int], mode: str = "reflect"
) -> torch.Tensor:
def pad_all_shapes(x: torch.Tensor, pad_shape: Tuple[int, int], mode: str = "reflect") -> torch.Tensor:
if x.ndim == 1:
return F.pad(x[None, None], pad=pad_shape, mode=mode).squeeze(0).squeeze(0)
if x.ndim == 2:
Expand Down
3 changes: 1 addition & 2 deletions asteroid_filterbanks/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,8 +287,7 @@ def check_torchaudio_complex(tensor):
"""
if not is_torchaudio_complex(tensor):
raise AssertionError(
f"Tensor of shape {tensor.shape} is not Torchaudio-style complex-like"
"(expected last dimension to be == 2)"
f"Tensor of shape {tensor.shape} is not Torchaudio-style complex-like(expected last dimension to be == 2)"
)


Expand Down
Loading

0 comments on commit 8d54b1a

Please sign in to comment.