Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tests] Update python and black versions #17

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/test_formatting.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@ jobs:
steps:
- name: Checkout
uses: actions/checkout@v2
- name: Set up Python 3.6
- name: Set up Python 3.8
uses: actions/setup-python@v2
with:
python-version: 3.6
python-version: 3.8

- name: Install Black and flake8
run: pip install black==20.8b1 flake8
run: pip install black==22.3.0 flake8
- name: Run Black
run: python -m black --config=pyproject.toml --check asteroid_filterbanks tests

Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test_package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.6] #, 3.7, 3.8]
python-version: [3.8]
pytorch-version: ["1.8.0", "nightly"]

# Timeout: https://stackoverflow.com/a/59076067/4521646
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test_torch_stft.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.6] #, 3.7, 3.8]
python-version: [3.8]

# Timeout: https://stackoverflow.com/a/59076067/4521646
timeout-minutes: 10
Expand Down
1 change: 0 additions & 1 deletion asteroid_filterbanks/analytic_free_fb.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
def conj(filt):
return torch.stack([filt[:, :, :, 1], -filt[:, :, :, 0]], dim=-1)


except ImportError:
from torch import fft

Expand Down
10 changes: 5 additions & 5 deletions asteroid_filterbanks/enc_dec.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(self, n_filters, kernel_size, stride=None, sample_rate=8000.0):
self.sample_rate = sample_rate

def filters(self):
""" Abstract method for filters. """
"""Abstract method for filters."""
raise NotImplementedError

def pre_analysis(self, wav: torch.Tensor):
Expand Down Expand Up @@ -97,22 +97,22 @@ def filters(self):
return self.filterbank.filters()

def compute_filter_pinv(self, filters):
""" Computes pseudo inverse filterbank of given filters."""
"""Computes pseudo inverse filterbank of given filters."""
scale = self.filterbank.stride / self.filterbank.kernel_size
shape = filters.shape
ifilt = torch.pinverse(filters.squeeze()).transpose(-1, -2).view(shape)
# Compensate for the overlap-add.
return ifilt * scale

def get_filters(self):
""" Returns filters or pinv filters depending on `is_pinv` attribute """
"""Returns filters or pinv filters depending on `is_pinv` attribute"""
if self.is_pinv:
return self.compute_filter_pinv(self.filters())
else:
return self.filters()

def get_config(self):
""" Returns dictionary of arguments to re-instantiate the class."""
"""Returns dictionary of arguments to re-instantiate the class."""
config = {"is_pinv": self.is_pinv}
base_config = self.filterbank.get_config()
return dict(list(base_config.items()) + list(config.items()))
Expand Down Expand Up @@ -261,7 +261,7 @@ def __init__(self, filterbank, is_pinv=False, padding=0, output_padding=0):

@classmethod
def pinv_of(cls, filterbank):
""" Returns an Decoder, pseudo inverse of a filterbank or Encoder."""
"""Returns an Decoder, pseudo inverse of a filterbank or Encoder."""
if isinstance(filterbank, Filterbank):
return cls(filterbank, is_pinv=True)
elif isinstance(filterbank, Encoder):
Expand Down
2 changes: 1 addition & 1 deletion asteroid_filterbanks/griffin_lim.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def _mixture_consistency(
all_dims: List[int] = torch.arange(est_sources.ndim).tolist()
all_dims.pop(dim) # Remove source axis
all_dims.pop(0) # Remove batch axis
src_weights = torch.mean(est_sources ** 2, dim=all_dims, keepdim=True)
src_weights = torch.mean(est_sources**2, dim=all_dims, keepdim=True)
# Make sure that the weights sum up to 1
norm_weights = torch.sum(src_weights, dim=dim, keepdim=True) + 1e-8
src_weights = src_weights / norm_weights
Expand Down
6 changes: 3 additions & 3 deletions asteroid_filterbanks/multiphase_gammatone_fb.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def generate_mpgtf(samplerate_hz, len_sec, n_filters):


def gammatone_impulse_response(samplerate_hz, len_sec, center_freq_hz, phase_shift):
""" Generate single parametrized gammatone filter """
"""Generate single parametrized gammatone filter"""
p = 2 # filter order
erb = 24.7 + 0.108 * center_freq_hz # equivalent rectangular bandwidth
divisor = (np.pi * np.math.factorial(2 * p - 2) * np.power(2, float(-(2 * p - 2)))) / np.square(
Expand All @@ -104,13 +104,13 @@ def gammatone_impulse_response(samplerate_hz, len_sec, center_freq_hz, phase_shi


def erb_scale_2_freq_hz(freq_erb):
""" Convert frequency on ERB scale to frequency in Hertz """
"""Convert frequency on ERB scale to frequency in Hertz"""
freq_hz = (np.exp(freq_erb / 9.265) - 1) * 24.7 * 9.265
return freq_hz


def freq_hz_2_erb_scale(freq_hz):
""" Convert frequency in Hertz to frequency on ERB scale """
"""Convert frequency in Hertz to frequency on ERB scale"""
freq_erb = 9.265 * np.log(1 + freq_hz / (24.7 * 9.265))
return freq_erb

Expand Down
6 changes: 3 additions & 3 deletions asteroid_filterbanks/param_sinc_fb.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __init__(
self.register_buffer("n_", n_)

def _initialize_filters(self):
""" Filter Initialization along the Mel scale"""
"""Filter Initialization along the Mel scale"""
low_hz = 30
high_hz = self.sample_rate / 2 - (self.min_low_hz + self.min_band_hz)
mel = np.linspace(
Expand All @@ -80,7 +80,7 @@ def _initialize_filters(self):
self.band_hz_ = nn.Parameter(torch.from_numpy(np.diff(hz)).view(-1, 1))

def filters(self):
""" Compute filters from parameters """
"""Compute filters from parameters"""
low = self.min_low_hz + torch.abs(self.low_hz_)
high = torch.clamp(
low + self.min_band_hz + torch.abs(self.band_hz_), self.min_low_hz, self.sample_rate / 2
Expand Down Expand Up @@ -116,7 +116,7 @@ def to_hz(mel):
return 700 * (10 ** (mel / 2595) - 1)

def get_config(self):
""" Returns dictionary of arguments to re-instantiate the class."""
"""Returns dictionary of arguments to re-instantiate the class."""
config = {
"min_low_hz": self.min_low_hz,
"min_band_hz": self.min_band_hz,
Expand Down
2 changes: 1 addition & 1 deletion asteroid_filterbanks/pcen.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def forward(
# Equation (1) in [1]
out = (
mag_spec / (self.floor + ema_smoother) ** alpha + self.delta
) ** one_over_root - self.delta ** one_over_root
) ** one_over_root - self.delta**one_over_root
out = out.transpose(1, -1)
if post_squeeze:
out = out.squeeze(1)
Expand Down
2 changes: 1 addition & 1 deletion asteroid_filterbanks/stft_fb.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def perfect_synthesis_window(analysis_window, hop_size):

loop_on = (win_size - 1) // hop_size
for win_idx in range(-loop_on, loop_on + 1):
shifted = np.roll(analysis_window ** 2, win_idx * hop_size)
shifted = np.roll(analysis_window**2, win_idx * hop_size)
if win_idx < 0:
shifted[win_idx * hop_size :] = 0
elif win_idx > 0:
Expand Down
8 changes: 4 additions & 4 deletions asteroid_filterbanks/torch_stft_fb.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,15 +129,15 @@ def post_synthesis(self, wav):

@script_if_tracing
def _restore_freqs_an(spec, n_filters: int):
spec[..., 0, :] *= 2 ** 0.5
spec[..., n_filters // 2, :] *= 2 ** 0.5
spec[..., 0, :] *= 2**0.5
spec[..., n_filters // 2, :] *= 2**0.5
return spec


@script_if_tracing
def _restore_freqs_syn(spec, n_filters: int):
spec[..., 0, :] /= 2 ** 0.5
spec[..., n_filters // 2, :] /= 2 ** 0.5
spec[..., 0, :] /= 2**0.5
spec[..., n_filters // 2, :] /= 2**0.5
return spec


Expand Down
8 changes: 4 additions & 4 deletions tests/filterbanks_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def fb_config_list():
@pytest.mark.parametrize("fb_class", [FreeFB, AnalyticFreeFB, ParamSincFB, MultiphaseGammatoneFB])
@pytest.mark.parametrize("fb_config", fb_config_list())
def test_fb_def_and_forward_lowdim(fb_class, fb_config):
""" Test filterbank definition and encoder/decoder forward."""
"""Test filterbank definition and encoder/decoder forward."""
# Definition
enc = Encoder(fb_class(**fb_config))
dec = Decoder(fb_class(**fb_config))
Expand All @@ -50,7 +50,7 @@ def test_fb_def_and_forward_lowdim(fb_class, fb_config):
@pytest.mark.parametrize("fb_class", [FreeFB, AnalyticFreeFB, ParamSincFB, MultiphaseGammatoneFB])
@pytest.mark.parametrize("fb_config", fb_config_list())
def test_fb_def_and_forward_all_dims(fb_class, fb_config):
""" Test encoder/decoder on other shapes than 3D"""
"""Test encoder/decoder on other shapes than 3D"""
# Definition
enc = Encoder(fb_class(**fb_config))
dec = Decoder(fb_class(**fb_config))
Expand All @@ -67,7 +67,7 @@ def test_fb_def_and_forward_all_dims(fb_class, fb_config):
@pytest.mark.parametrize("fb_config", fb_config_list())
@pytest.mark.parametrize("ndim", [2, 3, 4])
def test_fb_forward_multichannel(fb_class, fb_config, ndim):
""" Test encoder/decoder in multichannel setting"""
"""Test encoder/decoder in multichannel setting"""
# Definition
enc = Encoder(fb_class(**fb_config))
dec = Decoder(fb_class(**fb_config))
Expand All @@ -90,7 +90,7 @@ def test_complexfb_shapes(fb_class, n_filters, kernel_size):

@pytest.mark.parametrize("kernel_size", [256, 257, 128, 129])
def test_paramsinc_shape(kernel_size):
""" ParamSincFB has odd length filters """
"""ParamSincFB has odd length filters"""
fb = ParamSincFB(n_filters=200, kernel_size=kernel_size)
assert fb.filters().shape[-1] == 2 * (kernel_size // 2) + 1

Expand Down
6 changes: 3 additions & 3 deletions tests/stft_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def fb_config_list():

@pytest.mark.parametrize("fb_config", fb_config_list())
def test_stft_def(fb_config):
""" Check consistency between two calls."""
"""Check consistency between two calls."""
fb = STFTFB(**fb_config)
enc = Encoder(fb)
dec = Decoder(fb)
Expand Down Expand Up @@ -60,7 +60,7 @@ def test_filter_shape(fb_config):

@pytest.mark.parametrize("fb_config", fb_config_list())
def test_perfect_istft_default_parameters(fb_config):
""" Unit test perfect reconstruction with default values. """
"""Unit test perfect reconstruction with default values."""
kernel_size = fb_config["kernel_size"]
enc, dec = make_enc_dec("stft", **fb_config)
inp_wav = torch.randn(2, 1, 32000)
Expand All @@ -75,7 +75,7 @@ def test_perfect_istft_default_parameters(fb_config):
)
@pytest.mark.parametrize("use_torch_window", [True, False])
def test_perfect_resyn_window(fb_config, analysis_window_name, use_torch_window):
""" Unit test perfect reconstruction """
"""Unit test perfect reconstruction"""
kernel_size = fb_config["kernel_size"]
window = get_window(analysis_window_name, kernel_size)
if use_torch_window:
Expand Down
2 changes: 1 addition & 1 deletion tests/torch_stft_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def test_torch_stft(
window = None if window is None else get_window(window, win_length, fftbins=True)
if window is not None:
# Cannot restore the signal without overlap and near to zero window.
if hop_ratio == 1 and (window ** 2 < 1e-11).any():
if hop_ratio == 1 and (window**2 < 1e-11).any():
pass

fb = TorchSTFTFB.from_torch_args(
Expand Down
16 changes: 8 additions & 8 deletions tests/transforms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def make_encoder_from(fb_class, config):


def test_mag_mask(encoder_list):
""" Assert identity mask works. """
"""Assert identity mask works."""
for (enc, fb_dim) in encoder_list:
tf_rep = enc(torch.randn(2, 1, 8000)) # [batch, freq, time]
id_mag_mask = torch.ones((1, fb_dim // 2, 1))
Expand All @@ -48,7 +48,7 @@ def test_mag_mask(encoder_list):


def test_reim_mask(encoder_list):
""" Assert identity mask works. """
"""Assert identity mask works."""
for (enc, fb_dim) in encoder_list:
tf_rep = enc(torch.randn(2, 1, 8000)) # [batch, freq, time]
id_reim_mask = torch.ones((1, fb_dim, 1))
Expand All @@ -57,7 +57,7 @@ def test_reim_mask(encoder_list):


def test_comp_mask(encoder_list):
""" Assert identity mask works. """
"""Assert identity mask works."""
for (enc, fb_dim) in encoder_list:
tf_rep = enc(torch.randn(2, 1, 8000)) # [batch, freq, time]
id_complex_mask = torch.cat(
Expand Down Expand Up @@ -89,7 +89,7 @@ def test_cat(encoder_list):
)
@pytest.mark.parametrize("dim", [0, 1, 2])
def test_to_numpy(np_torch_tuple, dim):
""" Test torch --> np conversion (right angles)"""
"""Test torch --> np conversion (right angles)"""
from_np, from_torch = np_torch_tuple
if dim == 0:
np_array = np.array(from_np)
Expand All @@ -112,7 +112,7 @@ def test_to_numpy(np_torch_tuple, dim):
)
@pytest.mark.parametrize("dim", [0, 1, 2])
def test_from_numpy(np_torch_tuple, dim):
""" Test np --> torch conversion (right angles)"""
"""Test np --> torch conversion (right angles)"""
from_np, from_torch = np_torch_tuple
if dim == 0:
np_array = np.array(from_np)
Expand All @@ -131,7 +131,7 @@ def test_from_numpy(np_torch_tuple, dim):

@pytest.mark.parametrize("dim", [0, 1, 2, 3])
def test_return_ticket_np_torch(dim):
""" Test torch --> np --> torch --> np conversion"""
"""Test torch --> np --> torch --> np conversion"""
max_tested_ndim = 4
# Random tensor shape
tensor_shape = [random.randint(1, 10) for _ in range(max_tested_ndim)]
Expand All @@ -149,7 +149,7 @@ def test_return_ticket_np_torch(dim):

@pytest.mark.parametrize("dim", [0, 1, 2, 3])
def test_angle_mag_recompostion(dim):
""" Test complex --> (mag, angle) --> complex conversions"""
"""Test complex --> (mag, angle) --> complex conversions"""
max_tested_ndim = 4
# Random tensor shape
tensor_shape = [random.randint(1, 10) for _ in range(max_tested_ndim)]
Expand All @@ -164,7 +164,7 @@ def test_angle_mag_recompostion(dim):

@pytest.mark.parametrize("dim", [0, 1, 2, 3])
def test_check_complex_error(dim):
""" Test error in angle """
"""Test error in angle"""
not_complex = torch.randn(3, 5, 7, 9, 15)
with pytest.raises(AssertionError):
transforms.check_complex(not_complex, dim=dim)
Expand Down