diff --git a/asteroid_filterbanks/griffin_lim.py b/asteroid_filterbanks/griffin_lim.py index 198f063..fd0feef 100644 --- a/asteroid_filterbanks/griffin_lim.py +++ b/asteroid_filterbanks/griffin_lim.py @@ -2,7 +2,7 @@ import math from typing import Optional, List -from . import Encoder, Decoder, STFTFB +from . import Decoder, STFTFB from .stft_fb import perfect_synthesis_window from . import transforms diff --git a/asteroid_filterbanks/pcen.py b/asteroid_filterbanks/pcen.py index 78c38cd..3388d31 100644 --- a/asteroid_filterbanks/pcen.py +++ b/asteroid_filterbanks/pcen.py @@ -1,6 +1,5 @@ from torch import nn import torch -from . import transforms from typing import Union, Optional, Tuple try: @@ -140,9 +139,9 @@ class PCEN(_PCEN): Defaults to 0.04 n_channels: Number of channels in the time frequency representation. Defaults to 1 - trainable: If True, the parameters (alpha, delta, root and smooth) are trainable. If False, the parameters are fixed. - Individual parameters can set to be fixed or trainable by passing a dictionary of booleans, with the key - matching the parameter name and the value being either True (trainable) or False (fixed). + trainable: If True, the parameters (alpha, delta, root and smooth) are trainable. If False, the parameters are + fixed. Individual parameters can set to be fixed or trainable by passing a dictionary of booleans, with the + key matching the parameter name and the value being either True (trainable) or False (fixed). i.e. ``{"alpha": False, "delta": True, "root": False, "smooth": True}`` Defaults to False per_channel_smoothing: If True, each channel has it's own smoothing coefficient. @@ -207,9 +206,9 @@ class StatefulPCEN(_PCEN): Defaults to 0.04 n_channels: Number of channels in the time frequency representation. Defaults to 1 - trainable: If True, the parameters (alpha, delta, root and smooth) are trainable. If False, the parameters are fixed. - Individual parameters can set to be fixed or trainable by passing a dictionary of booleans, with the key - matching the parameter name and the value being either True (trainable) or False (fixed). + trainable: If True, the parameters (alpha, delta, root and smooth) are trainable. If False, the parameters are + fixed. Individual parameters can set to be fixed or trainable by passing a dictionary of booleans, with the + key matching the parameter name and the value being either True (trainable) or False (fixed). i.e. ``{"alpha": False, "delta": True, "root": False, "smooth": True}`` Defaults to False per_channel_smoothing: If True, each channel has it's own smoothing coefficient. @@ -250,7 +249,8 @@ def from_pcen(cls, pcen: PCEN): def forward( self, mag_spec: torch.Tensor, initial_state: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, torch.Tensor]: - """Computes the PCEN from magnitude spectrum representation, and an optional smoothed version of the filterbank (Equation (2) in [1]). + """Computes the PCEN from magnitude spectrum representation, and an optional smoothed version of the filterbank + (Equation (2) in [1]). Args: mag_spec: tensor containing an magnitude spectrum representation. diff --git a/asteroid_filterbanks/transforms.py b/asteroid_filterbanks/transforms.py index 3d5c25a..1e42a6d 100644 --- a/asteroid_filterbanks/transforms.py +++ b/asteroid_filterbanks/transforms.py @@ -3,7 +3,6 @@ from typing import Tuple from .scripting import script_if_tracing -from .deprecation import mark_deprecated def mul_c(inp, other, dim: int = -2): diff --git a/tests/pcen_test.py b/tests/pcen_test.py index 029e248..5a9532b 100644 --- a/tests/pcen_test.py +++ b/tests/pcen_test.py @@ -1,9 +1,8 @@ import torch -from asteroid_filterbanks import Encoder, Decoder, STFTFB, transforms +from asteroid_filterbanks import Encoder, STFTFB, transforms from asteroid_filterbanks.pcen import PCEN, StatefulPCEN, ExponentialMovingAverage from torch.testing import assert_allclose import pytest -import re @pytest.mark.parametrize("n_channels", [2, 4]) diff --git a/tests/stft_test.py b/tests/stft_test.py index be6e748..985aae8 100644 --- a/tests/stft_test.py +++ b/tests/stft_test.py @@ -153,4 +153,4 @@ def test_melgram_encoder(n_filters, n_mels, ndim): mel_spec = enc(wav) assert wav.shape[:-1] == mel_spec.shape[:-2] assert mel_spec.shape[-2] == n_mels - conf = melgram_fb.get_config() + melgram_fb.get_config()