Skip to content

Commit

Permalink
Apply fixes from ruff check
Browse files Browse the repository at this point in the history
  • Loading branch information
mweinelt committed Jan 12, 2025
1 parent 849da9e commit 0e46dee
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 13 deletions.
2 changes: 1 addition & 1 deletion asteroid_filterbanks/griffin_lim.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import math
from typing import Optional, List

from . import Encoder, Decoder, STFTFB
from . import Decoder, STFTFB
from .stft_fb import perfect_synthesis_window
from . import transforms

Expand Down
16 changes: 8 additions & 8 deletions asteroid_filterbanks/pcen.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from torch import nn
import torch
from . import transforms
from typing import Union, Optional, Tuple

try:
Expand Down Expand Up @@ -140,9 +139,9 @@ class PCEN(_PCEN):
Defaults to 0.04
n_channels: Number of channels in the time frequency representation.
Defaults to 1
trainable: If True, the parameters (alpha, delta, root and smooth) are trainable. If False, the parameters are fixed.
Individual parameters can set to be fixed or trainable by passing a dictionary of booleans, with the key
matching the parameter name and the value being either True (trainable) or False (fixed).
trainable: If True, the parameters (alpha, delta, root and smooth) are trainable. If False, the parameters are
fixed. Individual parameters can set to be fixed or trainable by passing a dictionary of booleans, with the
key matching the parameter name and the value being either True (trainable) or False (fixed).
i.e. ``{"alpha": False, "delta": True, "root": False, "smooth": True}``
Defaults to False
per_channel_smoothing: If True, each channel has it's own smoothing coefficient.
Expand Down Expand Up @@ -207,9 +206,9 @@ class StatefulPCEN(_PCEN):
Defaults to 0.04
n_channels: Number of channels in the time frequency representation.
Defaults to 1
trainable: If True, the parameters (alpha, delta, root and smooth) are trainable. If False, the parameters are fixed.
Individual parameters can set to be fixed or trainable by passing a dictionary of booleans, with the key
matching the parameter name and the value being either True (trainable) or False (fixed).
trainable: If True, the parameters (alpha, delta, root and smooth) are trainable. If False, the parameters are
fixed. Individual parameters can set to be fixed or trainable by passing a dictionary of booleans, with the
key matching the parameter name and the value being either True (trainable) or False (fixed).
i.e. ``{"alpha": False, "delta": True, "root": False, "smooth": True}``
Defaults to False
per_channel_smoothing: If True, each channel has it's own smoothing coefficient.
Expand Down Expand Up @@ -250,7 +249,8 @@ def from_pcen(cls, pcen: PCEN):
def forward(
self, mag_spec: torch.Tensor, initial_state: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Computes the PCEN from magnitude spectrum representation, and an optional smoothed version of the filterbank (Equation (2) in [1]).
"""Computes the PCEN from magnitude spectrum representation, and an optional smoothed version of the filterbank
(Equation (2) in [1]).
Args:
mag_spec: tensor containing an magnitude spectrum representation.
Expand Down
1 change: 0 additions & 1 deletion asteroid_filterbanks/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import Tuple

from .scripting import script_if_tracing
from .deprecation import mark_deprecated


def mul_c(inp, other, dim: int = -2):
Expand Down
3 changes: 1 addition & 2 deletions tests/pcen_test.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import torch
from asteroid_filterbanks import Encoder, Decoder, STFTFB, transforms
from asteroid_filterbanks import Encoder, STFTFB, transforms
from asteroid_filterbanks.pcen import PCEN, StatefulPCEN, ExponentialMovingAverage
from torch.testing import assert_allclose
import pytest
import re


@pytest.mark.parametrize("n_channels", [2, 4])
Expand Down
2 changes: 1 addition & 1 deletion tests/stft_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,4 +153,4 @@ def test_melgram_encoder(n_filters, n_mels, ndim):
mel_spec = enc(wav)
assert wav.shape[:-1] == mel_spec.shape[:-2]
assert mel_spec.shape[-2] == n_mels
conf = melgram_fb.get_config()
melgram_fb.get_config()

0 comments on commit 0e46dee

Please sign in to comment.