Skip to content

Commit

Permalink
rewrite tests with pytest
Browse files Browse the repository at this point in the history
  • Loading branch information
r9y9 committed Jun 28, 2024
1 parent 91b0f10 commit 9cc39e4
Show file tree
Hide file tree
Showing 13 changed files with 121 additions and 133 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,4 @@ jobs:
pysen run lint
- name: Test
run: |
nosetests --with-coverage --cover-package=nnmnkwii -v -w tests/ -a '!require_local_data,!modspec'
pytest --cov=nnmnkwii --cov-report xml -v tests/
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,10 @@ def package_files(directory):
ext_modules=ext_modules,
cmdclass=cmdclass,
install_requires=install_requires,
tests_require=["nose", "coverage"],
tests_require=["pytest", "coverage"],
extras_require={
"docs": ["numpydoc", "sphinx_rtd_theme"],
"test": ["nose", "pyworld", "librosa"],
"test": ["pytest", "pyworld", "librosa"],
"lint": [
"pysen",
"types-setuptools",
Expand Down
40 changes: 18 additions & 22 deletions tests/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
from nnmnkwii import autograd as AF
from nnmnkwii import paramgen as G
from nnmnkwii.autograd._impl.mlpg import MLPG, UnitVarianceMLPG
from nnmnkwii.autograd._impl.modspec import ModSpec
from nose.plugins.attrib import attr
from torch import nn
from torch.autograd import gradcheck

Expand Down Expand Up @@ -220,23 +218,21 @@ def test_mlpg_variance_expand():
assert np.allclose(y.data.numpy(), y_hat.data.numpy())


@attr("modspec")
def test_modspec_gradcheck():
static_dim = 12
T = 16
torch.manual_seed(1234)
n = 16
for norm in [None, "ortho"]:
inputs = (torch.rand(T, static_dim, requires_grad=True), n, norm)
assert gradcheck(ModSpec.apply, inputs, eps=1e-4, atol=1e-4)


@attr("modspec")
def test_modspec_gradcheck_large_n():
static_dim = 12
T = 16
torch.manual_seed(1234)
for n in [16, 32]:
for norm in [None, "ortho"]:
inputs = (torch.rand(T, static_dim, requires_grad=True), n, norm)
assert gradcheck(ModSpec.apply, inputs, eps=1e-4, atol=1e-4)
# def test_modspec_gradcheck():
# static_dim = 12
# T = 16
# torch.manual_seed(1234)
# n = 16
# for norm in [None, "ortho"]:
# inputs = (torch.rand(T, static_dim, requires_grad=True), n, norm)
# assert gradcheck(ModSpec.apply, inputs, eps=1e-4, atol=1e-4)


# def test_modspec_gradcheck_large_n():
# static_dim = 12
# T = 16
# torch.manual_seed(1234)
# for n in [16, 32]:
# for norm in [None, "ortho"]:
# inputs = (torch.rand(T, static_dim, requires_grad=True), n, norm)
# assert gradcheck(ModSpec.apply, inputs, eps=1e-4, atol=1e-4)
3 changes: 0 additions & 3 deletions tests/test_baseline.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from os.path import dirname, join

import numpy as np
from nose.plugins.attrib import attr
from numpy.linalg import norm
from sklearn.mixture import GaussianMixture

Expand Down Expand Up @@ -29,7 +28,6 @@ def _get_windows_set():
return windows_set


@attr("requires_bandmat")
def test_diffvc():
from nnmnkwii.baseline.gmm import MLPG

Expand Down Expand Up @@ -62,7 +60,6 @@ def test_diffvc():
assert norm(tgt_mc - mc_converted1) < norm(src_mc - mc_converted1)


@attr("requires_bandmat")
def test_gmmmap_swap():
from nnmnkwii.baseline.gmm import MLPG

Expand Down
38 changes: 17 additions & 21 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from os.path import dirname, join

import numpy as np
import pytest
from nnmnkwii.datasets import (
FileDataSource,
FileSourceDataset,
Expand All @@ -11,8 +12,6 @@
example_file_data_sources_for_acoustic_model,
example_file_data_sources_for_duration_model,
)
from nose.plugins.attrib import attr
from nose.tools import raises

DATA_DIR = join(dirname(__file__), "data")

Expand Down Expand Up @@ -45,7 +44,8 @@ def __test_outof_range(X):
print(X[0])

# Should raise IndexError
yield raises(IndexError)(__test_outof_range), X
with pytest.raises(IndexError):
__test_outof_range(X)


def test_invalid_dataset():
Expand All @@ -71,19 +71,19 @@ def __test_wrong_num_collected_files():
X = FileSourceDataset(WrongNumberOfCollectedFilesDataSource())
X[0]

yield raises(TypeError)(__test_wrong_num_args)
yield raises(ValueError)(__test_wrong_num_collected_files)
with pytest.raises(TypeError):
__test_wrong_num_args()
with pytest.raises(ValueError):
__test_wrong_num_collected_files()


@attr("pickle")
def test_asarray_tqdm():
# verbose=1 triggers tqdm progress report
for padded in [True, False]:
X, _ = _get_small_datasets(padded=padded, duration=True)
X.asarray(verbose=1)


@attr("pickle")
def test_asarray():
X, Y = _get_small_datasets(padded=False, duration=True)
lengths = [len(x) for x in X]
Expand All @@ -110,17 +110,16 @@ def __test_very_small_padded_length():
X.asarray(padded_length=1)

# Should raise `num frames exceeded`
yield raises(RuntimeError)(__test_very_small_padded_length)
with pytest.raises(RuntimeError):
__test_very_small_padded_length()


@attr("pickle")
def test_duration_sources():
X, Y = _get_small_datasets(padded=False, duration=True)
for idx, (x, y) in enumerate(zip(X, Y)):
print(idx, x.shape, y.shape)


@attr("pickle")
def test_slice():
X, _ = _get_small_datasets(padded=False)
x = X[:2]
Expand All @@ -133,14 +132,12 @@ def test_slice():
assert len(x.shape) == 3 and x.shape[0] == 2


@attr("pickle")
def test_variable_length_sequence_wise_iteration():
X, Y = _get_small_datasets(padded=False)
for idx, (x, y) in enumerate(zip(X, Y)):
print(idx, x.shape, y.shape)


@attr("pickle")
def test_fixed_length_sequence_wise_iteration():
X, Y = _get_small_datasets(padded=True)

Expand All @@ -153,7 +150,6 @@ def test_fixed_length_sequence_wise_iteration():
assert y.shape[0] == Ty


@attr("pickle")
def test_frame_wise_iteration():
X, Y = _get_small_datasets(padded=False)

Expand Down Expand Up @@ -181,7 +177,6 @@ def test_frame_wise_iteration():
pass


@attr("pickle")
def test_sequence_wise_torch_data_loader():
import torch
from torch.utils import data as data_utils
Expand Down Expand Up @@ -210,19 +205,20 @@ def __test(X, Y, batch_size):
print(idx, x.shape, y.shape)

# Test with batch_size = 1
yield __test, X, Y, 1
__test(X, Y, 1)
# Since we have variable length frames, batch size larger than 1 causes
# runtime error.
yield raises(RuntimeError)(__test), X, Y, 2
with pytest.raises(RuntimeError):
__test(X, Y, 2)

# For padded dataset, which can be reprensented by (N, T^max, D), batchsize
# can be any number.
X, Y = _get_small_datasets(padded=True)
yield __test, X, Y, 1
yield __test, X, Y, 2
__test(X, Y, 1)
__test(X, Y, 2)


@attr("pickle")
# @attr("pickle")
def test_frame_wise_torch_data_loader():
import torch
from torch.utils import data as data_utils
Expand Down Expand Up @@ -259,5 +255,5 @@ def __test(X, Y, batch_size):
assert len(x.shape) == 2
assert len(y.shape) == 2

yield __test, X, Y, 128
yield __test, X, Y, 256
__test(X, Y, 128)
__test(X, Y, 256)
17 changes: 10 additions & 7 deletions tests/test_frontend.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from os.path import dirname, join

import numpy as np
import pytest
from nnmnkwii.frontend import merlin as fe
from nnmnkwii.io import hts
from nnmnkwii.util import example_label_file, example_question_file
from nose.tools import raises

DATA_DIR = join(dirname(__file__), "data")

Expand All @@ -14,7 +14,7 @@ def test_invalid_linguistic_features():
phone_labels = hts.load(example_label_file(phone_level=True))
state_labels = hts.load(example_label_file(phone_level=False))

@raises(ValueError)
# @raises(ValueError)
def __test(labels, subphone_features, add_frame_features):
fe.linguistic_features(
labels,
Expand All @@ -24,19 +24,22 @@ def __test(labels, subphone_features, add_frame_features):
add_frame_features=add_frame_features,
)

yield __test, phone_labels, "full", True
yield __test, phone_labels, "full", False
yield __test, state_labels, "full", False
with pytest.raises(ValueError):
__test(phone_labels, "full", True)
with pytest.raises(ValueError):
__test(phone_labels, "full", False)
with pytest.raises(ValueError):
__test(state_labels, "full", False)


def test_invalid_duration_features():
phone_labels = hts.load(example_label_file(phone_level=True))

@raises(ValueError)
def __test(labels, unit_size, feature_size):
fe.duration_features(labels, unit_size=unit_size, feature_size=feature_size)

yield __test, phone_labels, None, "frame"
with pytest.raises(ValueError):
__test(phone_labels, None, "frame")


def test_silence_frame_removal_given_hts_labels():
Expand Down
18 changes: 9 additions & 9 deletions tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
import re
from os.path import dirname, join

import pytest
from nnmnkwii.frontend import merlin as fe
from nnmnkwii.io import hts
from nnmnkwii.util import example_question_file
from nose.tools import raises

try:
import pyopenjtalk # noqa
Expand Down Expand Up @@ -183,7 +183,6 @@ def test_hts_append():
labels.append(label)
assert str(test_labels) == str(labels)

@raises(ValueError)
def test_invalid_start_time():
labels = hts.HTSLabelFile()
labels.append((100000, 0, "NG"))
Expand All @@ -193,7 +192,6 @@ def test_succeeding_times():
labels.append((0, 1000000, "OK"))
labels.append((1000000, 2000000, "OK"))

@raises(ValueError)
def test_non_succeeding_times():
labels = hts.HTSLabelFile()
labels.append((0, 1000000, "OK"))
Expand All @@ -204,9 +202,11 @@ def test_non_succeeding_times_wo_strict():
labels.append((0, 1000000, "OK"), strict=False)
labels.append((1500000, 2000000, "OK"), strict=False)

test_invalid_start_time()
with pytest.raises(ValueError):
test_invalid_start_time()
test_succeeding_times()
test_non_succeeding_times()
with pytest.raises(ValueError):
test_non_succeeding_times()
test_non_succeeding_times_wo_strict()


Expand All @@ -227,20 +227,20 @@ def test_create_from_contexts():
labels2 = hts.HTSLabelFile.create_from_contexts(contexts)
assert str(labels), str(labels2)

@raises(ValueError)
def test_empty_context():
hts.HTSLabelFile.create_from_contexts("")

@raises(ValueError)
def test_empty_context2():
contexts = pyopenjtalk.extract_fullcontext("")
hts.HTSLabelFile.create_from_contexts(contexts)

test_empty_context()
with pytest.raises(ValueError):
test_empty_context()
try:
import pyopenjtalk # noqa

test_empty_context2()
with pytest.raises(ValueError):
test_empty_context2()
except ImportError:
pass

Expand Down
4 changes: 2 additions & 2 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ def __test(f):
np.testing.assert_almost_equal(f(x, y, lengths), f(x, y), decimal=5)
assert f(x, y) > 0

yield __test, metrics.melcd
yield __test, metrics.mean_squared_error
__test(metrics.melcd)
__test(metrics.mean_squared_error)


def test_f0_mse():
Expand Down
5 changes: 0 additions & 5 deletions tests/test_paramgen.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import numpy as np
from nose.plugins.attrib import attr


def _get_windows_set():
Expand Down Expand Up @@ -29,7 +28,6 @@ def _get_windows_set():
return windows_set


@attr("requires_bandmat")
def test_mlpg():
from nnmnkwii import paramgen as G

Expand Down Expand Up @@ -61,7 +59,6 @@ def test_mlpg():
assert np.allclose(generated1, generated2)


@attr("requires_bandmat")
def test_mlpg_window_full():
from nnmnkwii import paramgen as G

Expand All @@ -82,7 +79,6 @@ def full_window_mat_native(win_mats, T):
assert np.allclose(full_window_mat_native(win_mats, T), fullwin)


@attr("requires_bandmat")
def test_unit_variance_mlpg():
from nnmnkwii import paramgen as G

Expand All @@ -99,7 +95,6 @@ def test_unit_variance_mlpg():
assert np.allclose(y_hat, y)


@attr("requires_bandmat")
def test_reshape_means():
from nnmnkwii import paramgen as G

Expand Down
Loading

0 comments on commit 9cc39e4

Please sign in to comment.