diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 1699a0d88..3da97175e 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -35,6 +35,7 @@ jobs: PYTHON_VERSION: '3.8' COVERAGE: 'true' PANDAS_VERSION: '*' + DASK_VERSION: '*' TEST_DOCSTRINGS: 'true' JOBLIB_VERSION: '*' CHECK_WARNINGS: 'true' @@ -43,6 +44,7 @@ jobs: PYTHON_VERSION: '3.7' INSTALL_MKL: 'true' PANDAS_VERSION: '*' + DASK_VERSION: '*' KERAS_VERSION: '*' COVERAGE: 'true' JOBLIB_VERSION: '*' @@ -51,6 +53,7 @@ jobs: DISTRIB: 'conda' PYTHON_VERSION: '3.8' PANDAS_VERSION: '*' + DASK_VERSION: '*' JOBLIB_VERSION: '*' INSTALL_MKL: 'true' TENSORFLOW_VERSION: '*' diff --git a/build_tools/azure/install.sh b/build_tools/azure/install.sh index 79c5d5814..dba7754a6 100755 --- a/build_tools/azure/install.sh +++ b/build_tools/azure/install.sh @@ -40,6 +40,10 @@ if [[ "$DISTRIB" == "conda" ]]; then TO_INSTALL="$TO_INSTALL pandas=$PANDAS_VERSION" fi + if [[ -n "$DASK_VERSION" ]]; then + TO_INSTALL="$TO_INSTALL dask=$DASK_VERSION" + fi + if [[ -n "$KERAS_VERSION" ]]; then TO_INSTALL="$TO_INSTALL keras=$KERAS_VERSION tensorflow=1" KERAS_BACKEND=tensorflow @@ -90,9 +94,10 @@ elif [[ "$DISTRIB" == "conda-pip-latest" ]]; then make_conda "python=$PYTHON_VERSION" python -m pip install -U pip python -m pip install numpy scipy joblib cython + python -m pip install pandas + python -m pip install "dask[complete]" python -m pip install scikit-learn python -m pip install pytest==$PYTEST_VERSION pytest-cov pytest-xdist - python -m pip install pandas fi if [[ "$COVERAGE" == "true" ]]; then diff --git a/conftest.py b/conftest.py index d3ff91025..72e6a23da 100644 --- a/conftest.py +++ b/conftest.py @@ -22,12 +22,12 @@ def pytest_runtest_setup(item): if (fname.endswith(os.path.join('keras', '_generator.py')) or fname.endswith('miscellaneous.rst')): try: - import keras + import keras # noqa except ImportError: pytest.skip('The keras package is not installed.') elif (fname.endswith(os.path.join('tensorflow', '_generator.py')) or fname.endswith('miscellaneous.rst')): try: - import tensorflow + import tensorflow # noqa except ImportError: pytest.skip('The tensorflow package is not installed.') diff --git a/doc/api.rst b/doc/api.rst index 07ac6413c..65bfd1b06 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -248,6 +248,6 @@ Imbalance-learn provides some fast-prototyping tools. :toctree: generated/ :template: function.rst - utils.estimator_checks.parametrize_with_checks utils.check_neighbors_object utils.check_sampling_strategy + utils.get_classes_counts diff --git a/imblearn/base.py b/imblearn/base.py index 86bb53778..4d69f5461 100644 --- a/imblearn/base.py +++ b/imblearn/base.py @@ -9,12 +9,18 @@ import numpy as np from sklearn.base import BaseEstimator -from sklearn.preprocessing import label_binarize -from sklearn.utils.multiclass import check_classification_targets +from .dask._support import is_dask_collection from .utils import check_sampling_strategy, check_target_type -from .utils._validation import ArraysTransformer -from .utils._validation import _deprecate_positional_args +from .utils._validation import ( + ArraysTransformer, + _deprecate_positional_args, + get_classes_counts, +) +from .utils.wrapper import ( + check_classification_targets, + label_binarize, +) class SamplerMixin(BaseEstimator, metaclass=ABCMeta): @@ -45,9 +51,18 @@ def fit(self, X, y): self : object Return the instance itself. """ - X, y, _ = self._check_X_y(X, y) + arrays_transformer = ArraysTransformer(X, y) + dask_collection = any([is_dask_collection(arr) for arr in (X, y)]) + if dask_collection: + X, y = arrays_transformer.to_dask_array(X, y) + + if (not dask_collection or + (dask_collection and self.validate_if_dask_collection)): + X, y, _ = self._check_X_y(X, y) + + self._classes_counts = get_classes_counts(y) self.sampling_strategy_ = check_sampling_strategy( - self.sampling_strategy, y, self._sampling_type + self.sampling_strategy, self._classes_counts, self._sampling_type ) return self @@ -72,18 +87,31 @@ def fit_resample(self, X, y): y_resampled : array-like of shape (n_samples_new,) The corresponding label of `X_resampled`. """ - check_classification_targets(y) arrays_transformer = ArraysTransformer(X, y) - X, y, binarize_y = self._check_X_y(X, y) + dask_collection = any([is_dask_collection(arr) for arr in (X, y)]) + if dask_collection: + X, y = arrays_transformer.to_dask_array(X, y) + if (not dask_collection or + (dask_collection and self.validate_if_dask_collection)): + check_classification_targets(y) + X, y, binarize_y = self._check_X_y(X, y) + else: + binarize_y = False + + self._classes_counts = get_classes_counts(y) self.sampling_strategy_ = check_sampling_strategy( - self.sampling_strategy, y, self._sampling_type + self.sampling_strategy, self._classes_counts, self._sampling_type ) output = self._fit_resample(X, y) - y_ = (label_binarize(output[1], np.unique(y)) - if binarize_y else output[1]) + if binarize_y: + y_ = label_binarize( + output[1], classes=list(self._classes_counts.keys()) + ) + else: + y_ = output[1] X_, y_ = arrays_transformer.transform(output[0], y_) return (X_, y_) if len(output) == 2 else (X_, y_, output[2]) @@ -124,8 +152,13 @@ class BaseSampler(SamplerMixin): instead. """ - def __init__(self, sampling_strategy="auto"): + def __init__( + self, + sampling_strategy="auto", + validate_if_dask_collection=False, + ): self.sampling_strategy = sampling_strategy + self.validate_if_dask_collection = validate_if_dask_collection def _check_X_y(self, X, y, accept_sparse=None): if accept_sparse is None: @@ -251,16 +284,20 @@ def fit_resample(self, X, y): X, y, accept_sparse=self.accept_sparse ) + self._classes_counts = get_classes_counts(y) self.sampling_strategy_ = check_sampling_strategy( - self.sampling_strategy, y, self._sampling_type + self.sampling_strategy, self._classes_counts, self._sampling_type ) output = self._fit_resample(X, y) if self.validate: - - y_ = (label_binarize(output[1], np.unique(y)) - if binarize_y else output[1]) + if binarize_y: + y_ = label_binarize( + output[1], classes=list(self._classes_counts.keys()) + ) + else: + y_ = output[1] X_, y_ = arrays_transformer.transform(output[0], y_) return (X_, y_) if len(output) == 2 else (X_, y_, output[2]) diff --git a/imblearn/dask/__init__.py b/imblearn/dask/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/imblearn/dask/_support.py b/imblearn/dask/_support.py new file mode 100644 index 000000000..b5239ccac --- /dev/null +++ b/imblearn/dask/_support.py @@ -0,0 +1,9 @@ +def is_dask_collection(container): + try: + # to keep dask as an optional depency, keep the statement in a + # try/except statement + from dask import is_dask_collection + + return is_dask_collection(container) + except ImportError: + return False diff --git a/imblearn/dask/preprocessing.py b/imblearn/dask/preprocessing.py new file mode 100644 index 000000000..3a79fe576 --- /dev/null +++ b/imblearn/dask/preprocessing.py @@ -0,0 +1,7 @@ +def label_binarize(y, *, classes): + import pandas as pd + from dask import dataframe + + cat_dtype = pd.CategoricalDtype(categories=classes) + y = dataframe.from_array(y).astype(cat_dtype) + return dataframe.get_dummies(y).to_dask_array() diff --git a/imblearn/dask/tests/__init__.py b/imblearn/dask/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/imblearn/dask/tests/test_utils.py b/imblearn/dask/tests/test_utils.py new file mode 100644 index 000000000..0a262a435 --- /dev/null +++ b/imblearn/dask/tests/test_utils.py @@ -0,0 +1,40 @@ +import numpy as np +import pytest + +dask = pytest.importorskip("dask") +from dask import array + +from imblearn.dask.utils import is_multilabel +from imblearn.dask.utils import type_of_target + + +@pytest.mark.parametrize( + "y, expected_result", + [ + (array.from_array(np.array([0, 1, 0, 1])), False), + (array.from_array(np.array([[1, 0], [0, 0]])), True), + (array.from_array(np.array([[1], [0], [0]])), False), + (array.from_array(np.array([[1, 0, 0]])), True), + ] +) +def test_is_multilabel(y, expected_result): + assert is_multilabel(y) is expected_result + + +@pytest.mark.parametrize( + "y, expected_type_of_target", + [ + (array.from_array(np.array([[1, 0], [0, 0]])), "multilabel-indicator"), + (array.from_array(np.array([[1, 0, 0]])), "multilabel-indicator"), + (array.from_array(np.array([[[1, 2]]])), "unknown"), + (array.from_array(np.array([[]])), "unknown"), + (array.from_array(np.array([.1, .2, 3])), "continuous"), + (array.from_array(np.array([[.1, .2, 3]])), "continuous-multioutput"), + (array.from_array(np.array([[1., .2]])), "continuous-multioutput"), + (array.from_array(np.array([1, 2])), "binary"), + (array.from_array(np.array(["a", "b"])), "binary"), + ] +) +def test_type_of_target(y, expected_type_of_target): + target_type = type_of_target(y) + assert target_type == expected_type_of_target diff --git a/imblearn/dask/utils.py b/imblearn/dask/utils.py new file mode 100644 index 000000000..814f9ce81 --- /dev/null +++ b/imblearn/dask/utils.py @@ -0,0 +1,78 @@ +import warnings + +import numpy as np +from sklearn.exceptions import DataConversionWarning +from sklearn.utils.multiclass import _is_integral_float + + +def is_multilabel(y): + if not (y.ndim == 2 and y.shape[1] > 1): + return False + + if hasattr(y, "unique"): + labels = np.asarray(y.unique()) + else: + labels = np.unique(y).compute() + + return len(labels) < 3 and ( + y.dtype.kind in 'biu' or _is_integral_float(labels) + ) + + +def type_of_target(y): + if is_multilabel(y): + return 'multilabel-indicator' + + if y.ndim > 2: + return 'unknown' + + if y.ndim == 2 and y.shape[1] == 0: + return 'unknown' # [[]] + + if y.ndim == 2 and y.shape[1] > 1: + # [[1, 2], [1, 2]] + suffix = "-multioutput" + else: + # [1, 2, 3] or [[1], [2], [3]] + suffix = "" + + # check float and contains non-integer float values + if y.dtype.kind == 'f' and np.any(y != y.astype(int)): + # [.1, .2, 3] or [[.1, .2, 3]] or [[1., .2]] and not [1., 2., 3.] + # NOTE: we don't check for infinite values + return 'continuous' + suffix + + if hasattr(y, "unique"): + labels = np.asarray(y.unique()) + else: + labels = np.unique(y).compute() + if (len((labels)) > 2) or (y.ndim >= 2 and len(y[0]) > 1): + # [1, 2, 3] or [[1., 2., 3]] or [[1, 2]] + return 'multiclass' + suffix + # [1, 2] or [["a"], ["b"]] + return 'binary' + + +def column_or_1d(y, *, warn=False): + shape = y.shape + if len(shape) == 1: + return y.ravel() + if len(shape) == 2 and shape[1] == 1: + if warn: + warnings.warn( + "A column-vector y was passed when a 1d array was expected. " + "Please change the shape of y to (n_samples, ), for example " + "using ravel().", DataConversionWarning, stacklevel=2 + ) + return y.ravel() + + raise ValueError( + f"y should be a 1d array. Got an array of shape {shape} instead." + ) + + +def check_classification_targets(y): + y_type = type_of_target(y) + if y_type not in ['binary', 'multiclass', 'multiclass-multioutput', + 'multilabel-indicator', 'multilabel-sequences']: + raise ValueError("Unknown label type: %r" % y_type) diff --git a/imblearn/datasets/_imbalance.py b/imblearn/datasets/_imbalance.py index b35d00ed2..77a2f64d3 100644 --- a/imblearn/datasets/_imbalance.py +++ b/imblearn/datasets/_imbalance.py @@ -9,7 +9,10 @@ from ..under_sampling import RandomUnderSampler from ..utils import check_sampling_strategy -from ..utils._validation import _deprecate_positional_args +from ..utils._validation import ( + _deprecate_positional_args, + get_classes_counts, +) @_deprecate_positional_args @@ -87,11 +90,11 @@ def make_imbalance( >>> print('Distribution after imbalancing: {}'.format(Counter(y_res))) Distribution after imbalancing: Counter({2: 30, 1: 20, 0: 10}) """ - target_stats = Counter(y) + target_stats = get_classes_counts(y) # restrict ratio to be a dict or a callable if isinstance(sampling_strategy, dict) or callable(sampling_strategy): sampling_strategy_ = check_sampling_strategy( - sampling_strategy, y, "under-sampling", **kwargs + sampling_strategy, target_stats, "under-sampling", **kwargs ) else: raise ValueError( diff --git a/imblearn/ensemble/_bagging.py b/imblearn/ensemble/_bagging.py index d7c509194..c7107661e 100644 --- a/imblearn/ensemble/_bagging.py +++ b/imblearn/ensemble/_bagging.py @@ -18,7 +18,10 @@ from ..utils import Substitution, check_target_type, check_sampling_strategy from ..utils._docstring import _n_jobs_docstring from ..utils._docstring import _random_state_docstring -from ..utils._validation import _deprecate_positional_args +from ..utils._validation import ( + _deprecate_positional_args, + get_classes_counts, +) @Substitution( @@ -216,11 +219,12 @@ def __init__( def _validate_y(self, y): y_encoded = super()._validate_y(y) + classes_counts = get_classes_counts(y) if isinstance(self.sampling_strategy, dict): self._sampling_strategy = { np.where(self.classes_ == key)[0][0]: value for key, value in check_sampling_strategy( - self.sampling_strategy, y, 'under-sampling', + self.sampling_strategy, classes_counts, 'under-sampling', ).items() } else: diff --git a/imblearn/ensemble/_easy_ensemble.py b/imblearn/ensemble/_easy_ensemble.py index f140120aa..4db266134 100644 --- a/imblearn/ensemble/_easy_ensemble.py +++ b/imblearn/ensemble/_easy_ensemble.py @@ -17,7 +17,10 @@ from ..utils import Substitution, check_target_type, check_sampling_strategy from ..utils._docstring import _n_jobs_docstring from ..utils._docstring import _random_state_docstring -from ..utils._validation import _deprecate_positional_args +from ..utils._validation import ( + _deprecate_positional_args, + get_classes_counts, +) from ..pipeline import Pipeline MAX_INT = np.iinfo(np.int32).max @@ -156,11 +159,14 @@ def __init__( def _validate_y(self, y): y_encoded = super()._validate_y(y) + classes_counts = get_classes_counts(y) if isinstance(self.sampling_strategy, dict): self._sampling_strategy = { np.where(self.classes_ == key)[0][0]: value for key, value in check_sampling_strategy( - self.sampling_strategy, y, 'under-sampling', + self.sampling_strategy, + classes_counts, + "under-sampling", ).items() } else: diff --git a/imblearn/ensemble/_forest.py b/imblearn/ensemble/_forest.py index 42ae9b255..5832628c8 100644 --- a/imblearn/ensemble/_forest.py +++ b/imblearn/ensemble/_forest.py @@ -33,8 +33,11 @@ from ..utils import Substitution from ..utils._docstring import _n_jobs_docstring from ..utils._docstring import _random_state_docstring -from ..utils._validation import check_sampling_strategy -from ..utils._validation import _deprecate_positional_args +from ..utils._validation import ( + check_sampling_strategy, + _deprecate_positional_args, + get_classes_counts, +) MAX_INT = np.iinfo(np.int32).max @@ -457,10 +460,11 @@ def fit(self, X, y, sample_weight=None): y_encoded = np.ascontiguousarray(y_encoded, dtype=DOUBLE) if isinstance(self.sampling_strategy, dict): + classes_counts = get_classes_counts(y) self._sampling_strategy = { np.where(self.classes_[0] == key)[0][0]: value for key, value in check_sampling_strategy( - self.sampling_strategy, y, 'under-sampling', + self.sampling_strategy, classes_counts, 'under-sampling', ).items() } else: diff --git a/imblearn/ensemble/tests/test_weight_boosting.py b/imblearn/ensemble/tests/test_weight_boosting.py index 26facce90..517f61f40 100644 --- a/imblearn/ensemble/tests/test_weight_boosting.py +++ b/imblearn/ensemble/tests/test_weight_boosting.py @@ -77,7 +77,7 @@ def test_rusboost(imbalanced_dataset, algorithm): assert rusboost.decision_function(X_test).shape[1] == len(classes) score = rusboost.score(X_test, y_test) - assert score > 0.7, "Failed with algorithm {} and score {}".format( + assert score > 0.65, "Failed with algorithm {} and score {}".format( algorithm, score ) diff --git a/imblearn/under_sampling/_prototype_selection/_random_under_sampler.py b/imblearn/under_sampling/_prototype_selection/_random_under_sampler.py index e34d4e73d..28ef02d88 100644 --- a/imblearn/under_sampling/_prototype_selection/_random_under_sampler.py +++ b/imblearn/under_sampling/_prototype_selection/_random_under_sampler.py @@ -10,15 +10,20 @@ from sklearn.utils import _safe_indexing from ..base import BaseUnderSampler +from ...dask._support import is_dask_collection from ...utils import check_target_type from ...utils import Substitution -from ...utils._docstring import _random_state_docstring +from ...utils._docstring import ( + _random_state_docstring, + _validate_if_dask_collection_docstring +) from ...utils._validation import _deprecate_positional_args @Substitution( sampling_strategy=BaseUnderSampler._sampling_strategy_docstring, random_state=_random_state_docstring, + validate_if_dask_collection=_validate_if_dask_collection_docstring, ) class RandomUnderSampler(BaseUnderSampler): """Class to perform random under-sampling. @@ -37,6 +42,8 @@ class RandomUnderSampler(BaseUnderSampler): replacement : bool, default=False Whether the sample is with or without replacement. + {validate_if_dask_collection} + Attributes ---------- sample_indices_ : ndarray of shape (n_new_samples,) @@ -73,51 +80,80 @@ class RandomUnderSampler(BaseUnderSampler): @_deprecate_positional_args def __init__( - self, *, sampling_strategy="auto", random_state=None, replacement=False + self, + *, + sampling_strategy="auto", + random_state=None, + replacement=False, + validate_if_dask_collection=False, ): - super().__init__(sampling_strategy=sampling_strategy) + super().__init__( + sampling_strategy=sampling_strategy, + validate_if_dask_collection=validate_if_dask_collection, + ) self.random_state = random_state self.replacement = replacement def _check_X_y(self, X, y): y, binarize_y = check_target_type(y, indicate_one_vs_all=True) - X, y = self._validate_data( - X, y, reset=True, accept_sparse=["csr", "csc"], dtype=None, - force_all_finite=False, - ) + if not any([is_dask_collection(arr) for arr in (X, y)]): + X, y = self._validate_data( + X, + y, + reset=True, + accept_sparse=["csr", "csc"], + dtype=None, + force_all_finite=False, + ) return X, y, binarize_y + @staticmethod + def _find_target_class_indices(y, target_class): + target_class_indices = np.flatnonzero(y == target_class) + if is_dask_collection(y): + from dask import compute + + return compute(target_class_indices)[0] + return target_class_indices + def _fit_resample(self, X, y): random_state = check_random_state(self.random_state) - idx_under = np.empty((0,), dtype=int) + idx_under = [] - for target_class in np.unique(y): + for target_class in self._classes_counts: + target_class_indices = self._find_target_class_indices( + y, target_class + ) if target_class in self.sampling_strategy_.keys(): n_samples = self.sampling_strategy_[target_class] index_target_class = random_state.choice( - range(np.count_nonzero(y == target_class)), + target_class_indices.size, size=n_samples, replace=self.replacement, ) else: index_target_class = slice(None) - idx_under = np.concatenate( - ( - idx_under, - np.flatnonzero(y == target_class)[index_target_class], - ), - axis=0, - ) + selected_indices = target_class_indices[index_target_class] + idx_under.append(selected_indices) - self.sample_indices_ = idx_under + self.sample_indices_ = np.hstack(idx_under) + self.sample_indices_.sort() - return _safe_indexing(X, idx_under), _safe_indexing(y, idx_under) + return ( + _safe_indexing(X, self.sample_indices_), + _safe_indexing(y, self.sample_indices_) + ) def _more_tags(self): return { - "X_types": ["2darray", "string"], + "X_types": [ + "2darray", + "string", + "dask-array", + "dask-dataframe" + ], "sample_indices": True, "allow_nan": True, } diff --git a/imblearn/under_sampling/_prototype_selection/tests/test_random_under_sampler.py b/imblearn/under_sampling/_prototype_selection/tests/test_random_under_sampler.py index 945d31fec..355273dc1 100644 --- a/imblearn/under_sampling/_prototype_selection/tests/test_random_under_sampler.py +++ b/imblearn/under_sampling/_prototype_selection/tests/test_random_under_sampler.py @@ -30,61 +30,27 @@ Y = np.array([1, 0, 1, 0, 1, 1, 1, 1, 0, 1]) -@pytest.mark.parametrize("as_frame", [True, False], ids=['dataframe', 'array']) -def test_rus_fit_resample(as_frame): - if as_frame: - pd = pytest.importorskip("pandas") - X_ = pd.DataFrame(X) - else: - X_ = X - rus = RandomUnderSampler(random_state=RND_SEED, replacement=True) - X_resampled, y_resampled = rus.fit_resample(X_, Y) - - X_gt = np.array( - [ - [0.92923648, 0.76103773], - [0.47104475, 0.44386323], - [0.13347175, 0.12167502], - [0.09125309, -0.85409574], - [0.12372842, 0.6536186], - [0.04352327, -0.20515826], - ] - ) - y_gt = np.array([0, 0, 0, 1, 1, 1]) - - if as_frame: - assert hasattr(X_resampled, "loc") - X_resampled = X_resampled.to_numpy() - - assert_array_equal(X_resampled, X_gt) - assert_array_equal(y_resampled, y_gt) - - -def test_rus_fit_resample_half(): - sampling_strategy = {0: 3, 1: 6} - rus = RandomUnderSampler( - sampling_strategy=sampling_strategy, - random_state=RND_SEED, - replacement=True, - ) - X_resampled, y_resampled = rus.fit_resample(X, Y) - - X_gt = np.array( - [ - [0.92923648, 0.76103773], - [0.47104475, 0.44386323], - [0.92923648, 0.76103773], - [0.15490546, 0.3130677], - [0.15490546, 0.3130677], - [0.15490546, 0.3130677], - [0.20792588, 1.49407907], - [0.15490546, 0.3130677], - [0.12372842, 0.6536186], - ] - ) - y_gt = np.array([0, 0, 0, 1, 1, 1, 1, 1, 1]) - assert_array_equal(X_resampled, X_gt) - assert_array_equal(y_resampled, y_gt) +@pytest.mark.parametrize( + "sampling_strategy, expected_counts", + [ + ("auto", {0: 3, 1: 3}), + ({0: 3, 1: 6}, {0: 3, 1: 6}), + ] +) +def test_rus_fit_resample(sampling_strategy, expected_counts): + rus = RandomUnderSampler(sampling_strategy=sampling_strategy) + X_res, y_res = rus.fit_resample(X, Y) + + # check that there is not samples from class 0 resampled as class 1 and + # vice-versa + classes = [0, 1] + for c0, c1 in (classes, classes[::-1]): + X_c0 = X[Y == c0] + X_c1 = X_res[y_res == c1] + for s0 in X_c0: + assert not np.isclose(s0, X_c1).all(axis=1).any() + + assert Counter(y_res) == expected_counts def test_multiclass_fit_resample(): diff --git a/imblearn/utils/__init__.py b/imblearn/utils/__init__.py index 4e74d2ee3..130d9f0c9 100644 --- a/imblearn/utils/__init__.py +++ b/imblearn/utils/__init__.py @@ -7,10 +7,12 @@ from ._validation import check_neighbors_object from ._validation import check_target_type from ._validation import check_sampling_strategy +from ._validation import get_classes_counts __all__ = [ "check_neighbors_object", "check_sampling_strategy", "check_target_type", + "get_classes_counts", "Substitution", ] diff --git a/imblearn/utils/_docstring.py b/imblearn/utils/_docstring.py index d03be3740..be94b1aac 100644 --- a/imblearn/utils/_docstring.py +++ b/imblearn/utils/_docstring.py @@ -41,3 +41,10 @@ def __call__(self, obj): `Glossary `_ for more details. """.rstrip() + +_validate_if_dask_collection_docstring = \ + """validate_if_dask_collection : bool, default=False + Whether or not `X` and `y` should be validated. This parameter applies + only when `X` and `y` are Dask collections where validation might be + potentially costly. + """.rstrip() diff --git a/imblearn/utils/_validation.py b/imblearn/utils/_validation.py index fdc67619e..8538d7718 100644 --- a/imblearn/utils/_validation.py +++ b/imblearn/utils/_validation.py @@ -14,10 +14,12 @@ from sklearn.base import clone from sklearn.neighbors._base import KNeighborsMixin from sklearn.neighbors import NearestNeighbors -from sklearn.utils import column_or_1d -from sklearn.utils.multiclass import type_of_target +from ..dask._support import is_dask_collection from ..exceptions import raise_isinstance_error +from .wrapper import _is_multiclass_encoded +from .wrapper import column_or_1d +from .wrapper import type_of_target SAMPLING_KIND = ( "over-sampling", @@ -36,6 +38,16 @@ def __init__(self, X, y): self.x_props = self._gets_props(X) self.y_props = self._gets_props(y) + @staticmethod + def to_dask_array(X, y): + if hasattr(X, "to_dask_array"): + X = X.to_dask_array() + X.compute_chunk_sizes() + if hasattr(y, "to_dask_array"): + y = y.to_dask_array() + y.compute_chunk_sizes() + return X, y + def transform(self, X, y): X = self._transfrom_one(X, self.x_props) y = self._transfrom_one(y, self.y_props) @@ -44,6 +56,9 @@ def transform(self, X, y): def _gets_props(self, array): props = {} props["type"] = array.__class__.__name__ + if props["type"].lower() in ("series", "dataframe"): + suffix = "dask-" if is_dask_collection(array) else "pandas-" + props["type"] = suffix + props["type"] props["columns"] = getattr(array, "columns", None) props["name"] = getattr(array, "name", None) props["dtypes"] = getattr(array, "dtypes", None) @@ -53,13 +68,34 @@ def _transfrom_one(self, array, props): type_ = props["type"].lower() if type_ == "list": ret = array.tolist() - elif type_ == "dataframe": + elif type_ == "pandas-dataframe": import pandas as pd + ret = pd.DataFrame(array, columns=props["columns"]) ret = ret.astype(props["dtypes"]) - elif type_ == "series": + elif type_ == "pandas-series": import pandas as pd + ret = pd.Series(array, dtype=props["dtypes"], name=props["name"]) + elif type_ == "dask-dataframe": + from dask import dataframe + + if is_dask_collection(array): + ret = dataframe.from_dask_array( + array, columns=props["columns"] + ) + else: + ret = dataframe.from_array(array, columns=props["columns"]) + ret = ret.astype(props["dtypes"]) + elif type_ == "dask-series": + from dask import dataframe + + if is_dask_collection(array): + ret = dataframe.from_dask_array(array) + else: + ret = dataframe.from_array(array) + ret = ret.astype(props["dtypes"]) + ret = ret.rename(props["name"]) else: ret = array return ret @@ -97,8 +133,25 @@ def check_neighbors_object(nn_name, nn_object, additional_neighbor=0): raise_isinstance_error(nn_name, [int, KNeighborsMixin], nn_object) -def _count_class_sample(y): +def get_classes_counts(y): + """Compute the counts of each class present in `y`. + + Parameters + ---------- + y : ndarray of shape (n_samples,) + The target array. + + Returns + ------- + classes_counts : dict + A dictionary where the keys are the class labels and the values are the + counts for each class. + """ unique, counts = np.unique(y, return_counts=True) + if is_dask_collection(unique): + from dask import compute + + unique, counts = compute(unique, counts) return dict(zip(unique, counts)) @@ -124,10 +177,13 @@ def check_target_type(y, indicate_one_vs_all=False): is_one_vs_all : bool, optional Indicate if the target was originally encoded in a one-vs-all fashion. Only returned if ``indicate_multilabel=True``. + + y_unique : ndarray + The unique values in `y`. """ type_y = type_of_target(y) if type_y == "multilabel-indicator": - if np.any(y.sum(axis=1) > 1): + if not _is_multiclass_encoded(y): raise ValueError( "Imbalanced-learn currently supports binary, multiclass and " "binarized encoded multiclasss targets. Multilabel and " @@ -137,24 +193,27 @@ def check_target_type(y, indicate_one_vs_all=False): else: y = column_or_1d(y) - return (y, type_y == "multilabel-indicator") if indicate_one_vs_all else y + output = [y] + if indicate_one_vs_all: + output += [type_y == "multilabel-indicator"] + + return output[0] if len(output) == 1 else tuple(output) -def _sampling_strategy_all(y, sampling_type): +def _sampling_strategy_all(classes_counts, sampling_type): """Returns sampling target by targeting all classes.""" - target_stats = _count_class_sample(y) if sampling_type == "over-sampling": - n_sample_majority = max(target_stats.values()) + n_sample_majority = max(classes_counts.values()) sampling_strategy = { key: n_sample_majority - value - for (key, value) in target_stats.items() + for (key, value) in classes_counts.items() } elif ( sampling_type == "under-sampling" or sampling_type == "clean-sampling" ): - n_sample_minority = min(target_stats.values()) + n_sample_minority = min(classes_counts.values()) sampling_strategy = { - key: n_sample_minority for key in target_stats.keys() + key: n_sample_minority for key in classes_counts.keys() } else: raise NotImplementedError @@ -162,7 +221,7 @@ def _sampling_strategy_all(y, sampling_type): return sampling_strategy -def _sampling_strategy_majority(y, sampling_type): +def _sampling_strategy_majority(classes_counts, sampling_type): """Returns sampling target by targeting the majority class only.""" if sampling_type == "over-sampling": raise ValueError( @@ -172,12 +231,11 @@ def _sampling_strategy_majority(y, sampling_type): elif ( sampling_type == "under-sampling" or sampling_type == "clean-sampling" ): - target_stats = _count_class_sample(y) - class_majority = max(target_stats, key=target_stats.get) - n_sample_minority = min(target_stats.values()) + class_majority = max(classes_counts, key=classes_counts.get) + n_sample_minority = min(classes_counts.values()) sampling_strategy = { key: n_sample_minority - for key in target_stats.keys() + for key in classes_counts.keys() if key == class_majority } else: @@ -186,26 +244,25 @@ def _sampling_strategy_majority(y, sampling_type): return sampling_strategy -def _sampling_strategy_not_majority(y, sampling_type): +def _sampling_strategy_not_majority(classes_counts, sampling_type): """Returns sampling target by targeting all classes but not the majority.""" - target_stats = _count_class_sample(y) if sampling_type == "over-sampling": - n_sample_majority = max(target_stats.values()) - class_majority = max(target_stats, key=target_stats.get) + n_sample_majority = max(classes_counts.values()) + class_majority = max(classes_counts, key=classes_counts.get) sampling_strategy = { key: n_sample_majority - value - for (key, value) in target_stats.items() + for (key, value) in classes_counts.items() if key != class_majority } elif ( sampling_type == "under-sampling" or sampling_type == "clean-sampling" ): - n_sample_minority = min(target_stats.values()) - class_majority = max(target_stats, key=target_stats.get) + n_sample_minority = min(classes_counts.values()) + class_majority = max(classes_counts, key=classes_counts.get) sampling_strategy = { key: n_sample_minority - for key in target_stats.keys() + for key in classes_counts.keys() if key != class_majority } else: @@ -214,26 +271,25 @@ def _sampling_strategy_not_majority(y, sampling_type): return sampling_strategy -def _sampling_strategy_not_minority(y, sampling_type): +def _sampling_strategy_not_minority(classes_counts, sampling_type): """Returns sampling target by targeting all classes but not the minority.""" - target_stats = _count_class_sample(y) if sampling_type == "over-sampling": - n_sample_majority = max(target_stats.values()) - class_minority = min(target_stats, key=target_stats.get) + n_sample_majority = max(classes_counts.values()) + class_minority = min(classes_counts, key=classes_counts.get) sampling_strategy = { key: n_sample_majority - value - for (key, value) in target_stats.items() + for (key, value) in classes_counts.items() if key != class_minority } elif ( sampling_type == "under-sampling" or sampling_type == "clean-sampling" ): - n_sample_minority = min(target_stats.values()) - class_minority = min(target_stats, key=target_stats.get) + n_sample_minority = min(classes_counts.values()) + class_minority = min(classes_counts, key=classes_counts.get) sampling_strategy = { key: n_sample_minority - for key in target_stats.keys() + for key in classes_counts.keys() if key != class_minority } else: @@ -242,15 +298,14 @@ def _sampling_strategy_not_minority(y, sampling_type): return sampling_strategy -def _sampling_strategy_minority(y, sampling_type): +def _sampling_strategy_minority(classes_counts, sampling_type): """Returns sampling target by targeting the minority class only.""" - target_stats = _count_class_sample(y) if sampling_type == "over-sampling": - n_sample_majority = max(target_stats.values()) - class_minority = min(target_stats, key=target_stats.get) + n_sample_majority = max(classes_counts.values()) + class_minority = min(classes_counts, key=classes_counts.get) sampling_strategy = { key: n_sample_majority - value - for (key, value) in target_stats.items() + for (key, value) in classes_counts.items() if key == class_minority } elif ( @@ -266,24 +321,23 @@ def _sampling_strategy_minority(y, sampling_type): return sampling_strategy -def _sampling_strategy_auto(y, sampling_type): +def _sampling_strategy_auto(classes_counts, sampling_type): """Returns sampling target auto for over-sampling and not-minority for under-sampling.""" if sampling_type == "over-sampling": - return _sampling_strategy_not_majority(y, sampling_type) + return _sampling_strategy_not_majority(classes_counts, sampling_type) elif ( sampling_type == "under-sampling" or sampling_type == "clean-sampling" ): - return _sampling_strategy_not_minority(y, sampling_type) + return _sampling_strategy_not_minority(classes_counts, sampling_type) -def _sampling_strategy_dict(sampling_strategy, y, sampling_type): +def _sampling_strategy_dict(sampling_strategy, classes_counts, sampling_type): """Returns sampling target by converting the dictionary depending of the sampling.""" - target_stats = _count_class_sample(y) # check that all keys in sampling_strategy are also in y set_diff_sampling_strategy_target = set(sampling_strategy.keys()) - set( - target_stats.keys() + classes_counts.keys() ) if len(set_diff_sampling_strategy_target) > 0: raise ValueError( @@ -300,17 +354,17 @@ def _sampling_strategy_dict(sampling_strategy, y, sampling_type): ) sampling_strategy_ = {} if sampling_type == "over-sampling": - n_samples_majority = max(target_stats.values()) - class_majority = max(target_stats, key=target_stats.get) + n_samples_majority = max(classes_counts.values()) + class_majority = max(classes_counts, key=classes_counts.get) for class_sample, n_samples in sampling_strategy.items(): - if n_samples < target_stats[class_sample]: + if n_samples < classes_counts[class_sample]: raise ValueError( "With over-sampling methods, the number" " of samples in a class should be greater" " or equal to the original number of samples." " Originally, there is {} samples and {}" " samples are asked.".format( - target_stats[class_sample], n_samples + classes_counts[class_sample], n_samples ) ) if n_samples > n_samples_majority: @@ -326,18 +380,18 @@ def _sampling_strategy_dict(sampling_strategy, y, sampling_type): ) ) sampling_strategy_[class_sample] = ( - n_samples - target_stats[class_sample] + n_samples - classes_counts[class_sample] ) elif sampling_type == "under-sampling": for class_sample, n_samples in sampling_strategy.items(): - if n_samples > target_stats[class_sample]: + if n_samples > classes_counts[class_sample]: raise ValueError( "With under-sampling methods, the number of" " samples in a class should be less or equal" " to the original number of samples." " Originally, there is {} samples and {}" " samples are asked.".format( - target_stats[class_sample], n_samples + classes_counts[class_sample], n_samples ) ) sampling_strategy_[class_sample] = n_samples @@ -353,19 +407,18 @@ def _sampling_strategy_dict(sampling_strategy, y, sampling_type): return sampling_strategy_ -def _sampling_strategy_list(sampling_strategy, y, sampling_type): +def _sampling_strategy_list(sampling_strategy, classes_counts, sampling_type): """With cleaning methods, sampling_strategy can be a list to target the - class of interest.""" + class of interest.""" if sampling_type != "clean-sampling": raise ValueError( "'sampling_strategy' cannot be a list for samplers " "which are not cleaning methods." ) - target_stats = _count_class_sample(y) # check that all keys in sampling_strategy are also in y set_diff_sampling_strategy_target = set(sampling_strategy) - set( - target_stats.keys() + classes_counts.keys() ) if len(set_diff_sampling_strategy_target) > 0: raise ValueError( @@ -374,27 +427,26 @@ class of interest.""" ) return { - class_sample: min(target_stats.values()) + class_sample: min(classes_counts.values()) for class_sample in sampling_strategy } -def _sampling_strategy_float(sampling_strategy, y, sampling_type): +def _sampling_strategy_float(sampling_strategy, classes_counts, sampling_type): """Take a proportion of the majority (over-sampling) or minority (under-sampling) class in binary classification.""" - type_y = type_of_target(y) - if type_y != "binary": + + if len(classes_counts) != 2: raise ValueError( '"sampling_strategy" can be a float only when the type ' "of target is binary. For multi-class, use a dict." ) - target_stats = _count_class_sample(y) if sampling_type == "over-sampling": - n_sample_majority = max(target_stats.values()) - class_majority = max(target_stats, key=target_stats.get) + n_sample_majority = max(classes_counts.values()) + class_majority = max(classes_counts, key=classes_counts.get) sampling_strategy_ = { key: int(n_sample_majority * sampling_strategy - value) - for (key, value) in target_stats.items() + for (key, value) in classes_counts.items() if key != class_majority } if any([n_samples <= 0 for n_samples in sampling_strategy_.values()]): @@ -405,16 +457,16 @@ def _sampling_strategy_float(sampling_strategy, y, sampling_type): "ratio." ) elif sampling_type == "under-sampling": - n_sample_minority = min(target_stats.values()) - class_minority = min(target_stats, key=target_stats.get) + n_sample_minority = min(classes_counts.values()) + class_minority = min(classes_counts, key=classes_counts.get) sampling_strategy_ = { key: int(n_sample_minority / sampling_strategy) - for (key, value) in target_stats.items() + for (key, value) in classes_counts.items() if key != class_minority } if any( [ - n_samples > target_stats[target] + n_samples > classes_counts[target] for target, n_samples in sampling_strategy_.items() ] ): @@ -431,7 +483,9 @@ def _sampling_strategy_float(sampling_strategy, y, sampling_type): return sampling_strategy_ -def check_sampling_strategy(sampling_strategy, y, sampling_type, **kwargs): +def check_sampling_strategy( + sampling_strategy, classes_counts, sampling_type, **kwargs +): """Sampling target validation for samplers. Checks that ``sampling_strategy`` is of consistent type and return a @@ -501,8 +555,14 @@ def check_sampling_strategy(sampling_strategy, y, sampling_type, **kwargs): correspond to the targeted classes. The values correspond to the desired number of samples for each class. - y : ndarray of shape (n_samples,) - The target array. + classes_counts : dict or ndarray of shape (n_samples,) + A dictionary where the keys are the class present in `y` and the values + are the counts. The function :func:`~imblearn.utils.get_classes_count` + provides such a dictionary, giving `y` as an input. + + .. deprecated:: 0.7 + Passing the array `y` is deprecated from 0.7 and will be removed + in 0.9. sampling_type : {{'over-sampling', 'under-sampling', 'clean-sampling'}} The type of sampling. Can be either ``'over-sampling'``, @@ -526,10 +586,19 @@ def check_sampling_strategy(sampling_strategy, y, sampling_type, **kwargs): " instead.".format(SAMPLING_KIND, sampling_type) ) - if np.unique(y).size <= 1: + if hasattr(classes_counts, "__array__"): + warnings.warn( + "Passing an array of target `y` is deprecated in 0.7 and will " + "raise an error from 0.9. Instead, pass `y` to " + "imblearn.utils.get_classes_counts function to get the " + "dictionary.", FutureWarning + ) + classes_counts = get_classes_counts(classes_counts) + + if len(classes_counts) <= 1: raise ValueError( "The target 'y' needs to have more than 1 class." - " Got {} class instead".format(np.unique(y).size) + " Got {} class instead".format(len(classes_counts)) ) if sampling_type in ("ensemble", "bypass"): @@ -546,7 +615,7 @@ def check_sampling_strategy(sampling_strategy, y, sampling_type, **kwargs): return OrderedDict( sorted( SAMPLING_TARGET_KIND[sampling_strategy]( - y, sampling_type + classes_counts, sampling_type ).items() ) ) @@ -554,7 +623,7 @@ def check_sampling_strategy(sampling_strategy, y, sampling_type, **kwargs): return OrderedDict( sorted( _sampling_strategy_dict( - sampling_strategy, y, sampling_type + sampling_strategy, classes_counts, sampling_type ).items() ) ) @@ -562,7 +631,7 @@ def check_sampling_strategy(sampling_strategy, y, sampling_type, **kwargs): return OrderedDict( sorted( _sampling_strategy_list( - sampling_strategy, y, sampling_type + sampling_strategy, classes_counts, sampling_type ).items() ) ) @@ -577,16 +646,16 @@ def check_sampling_strategy(sampling_strategy, y, sampling_type, **kwargs): return OrderedDict( sorted( _sampling_strategy_float( - sampling_strategy, y, sampling_type + sampling_strategy, classes_counts, sampling_type ).items() ) ) elif callable(sampling_strategy): - sampling_strategy_ = sampling_strategy(y, **kwargs) + sampling_strategy_ = sampling_strategy(classes_counts, **kwargs) return OrderedDict( sorted( _sampling_strategy_dict( - sampling_strategy_, y, sampling_type + sampling_strategy_, classes_counts, sampling_type ).items() ) ) diff --git a/imblearn/utils/estimator_checks.py b/imblearn/utils/estimator_checks.py index 729ceebea..ceb828272 100644 --- a/imblearn/utils/estimator_checks.py +++ b/imblearn/utils/estimator_checks.py @@ -51,6 +51,7 @@ def _set_checking_parameters(estimator): def _yield_sampler_checks(sampler): + tags = sampler._get_tags() yield check_target_type yield check_samplers_one_label yield check_samplers_fit @@ -58,8 +59,16 @@ def _yield_sampler_checks(sampler): yield check_samplers_sampling_strategy_fit_resample yield check_samplers_sparse yield check_samplers_pandas + if "dask-array" in tags["X_types"]: + yield check_samplers_dask_array + if "dask-dataframe" in tags["X_types"]: + yield check_samplers_dask_dataframe yield check_samplers_list yield check_samplers_multiclass_ova + if "dask-array" in tags["X_types"]: + yield check_samplers_multiclass_ova_dask_array + if "dask-dataframe" in tags["X_types"]: + yield check_samplers_multiclass_ova_dask_dataframe yield check_samplers_preserve_dtype yield check_samplers_sample_indices yield check_samplers_2d_target @@ -290,6 +299,72 @@ def check_samplers_pandas(name, sampler): assert_allclose(y_res_s.to_numpy(), y_res) +def check_samplers_dask_array(name, sampler_orig): + pytest.importorskip("dask") + from dask import array + sampler = clone(sampler_orig) + # Check that the samplers handle dask array + X, y = make_classification( + n_samples=1000, + n_classes=3, + n_informative=4, + weights=[0.2, 0.3, 0.5], + random_state=0, + ) + X_dask = array.from_array(X, chunks=100) + y_dask = array.from_array(y, chunks=100) + + for validate_if_dask_collection in (True, False): + sampler.set_params( + validate_if_dask_collection=validate_if_dask_collection + ) + X_res_dask, y_res_dask = sampler.fit_resample(X_dask, y_dask) + X_res, y_res = sampler.fit_resample(X, y) + + # check that we return the same type for dataframes or series types + assert isinstance(X_res_dask, array.Array) + assert isinstance(y_res_dask, array.Array) + + assert_allclose(X_res_dask, X_res) + assert_allclose(y_res_dask, y_res) + + +def check_samplers_dask_dataframe(name, sampler_orig): + pytest.importorskip("dask") + from dask import dataframe + sampler = clone(sampler_orig) + # Check that the samplers handle dask dataframe and dask series + X, y = make_classification( + n_samples=1000, + n_classes=3, + n_informative=4, + weights=[0.2, 0.3, 0.5], + random_state=0, + ) + X_df = dataframe.from_array( + X, columns=[str(i) for i in range(X.shape[1])] + ) + y_s = dataframe.from_array(y) + y_s = y_s.rename("target") + + for validate_if_dask_collection in (True, False): + sampler.set_params( + validate_if_dask_collection=validate_if_dask_collection + ) + X_res_df, y_res_s = sampler.fit_resample(X_df, y_s) + X_res, y_res = sampler.fit_resample(X, y) + + # check that we return the same type for dataframes or series types + assert isinstance(X_res_df, dataframe.DataFrame) + assert isinstance(y_res_s, dataframe.Series) + + assert X_df.columns.to_list() == X_res_df.columns.to_list() + assert y_s.name == y_res_s.name + + assert_allclose(np.array(X_res_df), X_res) + assert_allclose(np.array(y_res_s), y_res) + + def check_samplers_list(name, sampler): # Check that the can samplers handle simple lists X, y = make_classification( @@ -329,6 +404,66 @@ def check_samplers_multiclass_ova(name, sampler): assert_allclose(y_res, y_res_ova.argmax(axis=1)) +def check_samplers_multiclass_ova_dask_array(name, sampler_orig): + pytest.importorskip("dask") + from dask import array + sampler = clone(sampler_orig) + X, y = make_classification( + n_samples=1000, + n_classes=3, + n_informative=4, + weights=[0.2, 0.3, 0.5], + random_state=0, + ) + y_ova = label_binarize(y, np.unique(y)) + + X = array.from_array(X) + y = array.from_array(y) + y_ova = array.from_array(y_ova) + + sampler.set_params(validate_if_dask_collection=True) + X_res, y_res = sampler.fit_resample(X, y) + X_res_ova, y_res_ova = sampler.fit_resample(X, y_ova) + + assert_allclose(X_res, X_res_ova) + assert type_of_target(y_res_ova) == type_of_target(y_ova) + assert_allclose(y_res, y_res_ova.argmax(axis=1)) + + assert isinstance(X_res_ova, array.Array) + assert isinstance(y_res, array.Array) + assert isinstance(y_res_ova, array.Array) + + +def check_samplers_multiclass_ova_dask_dataframe(name, sampler_orig): + pytest.importorskip("dask") + from dask import dataframe + sampler = clone(sampler_orig) + X, y = make_classification( + n_samples=1000, + n_classes=3, + n_informative=4, + weights=[0.2, 0.3, 0.5], + random_state=0, + ) + y_ova = label_binarize(y, np.unique(y)) + + X = dataframe.from_array(X) + y = dataframe.from_array(y) + y_ova = dataframe.from_array(y_ova) + + sampler.set_params(validate_if_dask_collection=True) + X_res, y_res = sampler.fit_resample(X, y) + X_res_ova, y_res_ova = sampler.fit_resample(X, y_ova) + + assert_allclose(X_res, X_res_ova) + assert type_of_target(y_res_ova) == type_of_target(y_ova) + assert_allclose(y_res, y_res_ova.to_dask_array().argmax(axis=1)) + + assert isinstance(X_res_ova, dataframe.DataFrame) + assert isinstance(y_res, dataframe.Series) + assert isinstance(y_res_ova, dataframe.DataFrame) + + def check_samplers_2d_target(name, sampler): X, y = make_classification( n_samples=100, diff --git a/imblearn/utils/testing.py b/imblearn/utils/testing.py index b5dc79828..b779b6cc1 100644 --- a/imblearn/utils/testing.py +++ b/imblearn/utils/testing.py @@ -53,7 +53,7 @@ def is_abstract(c): return True all_classes = [] - modules_to_ignore = {"tests"} + modules_to_ignore = {"tests", "dask"} root = str(Path(__file__).parent.parent) # Ignore deprecation warnings triggered at import time and from walking # packages diff --git a/imblearn/utils/tests/test_validation.py b/imblearn/utils/tests/test_validation.py index e4f9c01c8..b5f06e5b6 100644 --- a/imblearn/utils/tests/test_validation.py +++ b/imblearn/utils/tests/test_validation.py @@ -17,11 +17,14 @@ from imblearn.utils import check_neighbors_object from imblearn.utils import check_sampling_strategy from imblearn.utils import check_target_type +from imblearn.utils import get_classes_counts from imblearn.utils._validation import ArraysTransformer from imblearn.utils._validation import _deprecate_positional_args multiclass_target = np.array([1] * 50 + [2] * 100 + [3] * 25) +multiclass_classes_counts = get_classes_counts(multiclass_target) binary_target = np.array([1] * 25 + [0] * 100) +binary_classes_counts = get_classes_counts(binary_target) def test_check_neighbors_object(): @@ -70,11 +73,11 @@ def test_check_target_type_ova(target, output_target, is_ova): assert binarize_target == is_ova -def test_check_sampling_strategy_warning(): +def test_check_sampling_strategy_error_dict_cleaning_methods(): msg = "dict for cleaning methods is not supported" with pytest.raises(ValueError, match=msg): check_sampling_strategy( - {1: 0, 2: 0, 3: 0}, multiclass_target, "clean-sampling" + {1: 0, 2: 0, 3: 0}, multiclass_classes_counts, "clean-sampling" ) @@ -83,19 +86,19 @@ def test_check_sampling_strategy_warning(): [ ( 0.5, - binary_target, + binary_classes_counts, "clean-sampling", "'clean-sampling' methods do let the user specify the sampling ratio", # noqa ), ( 0.1, - np.array([0] * 10 + [1] * 20), + get_classes_counts(np.array([0] * 10 + [1] * 20)), "over-sampling", "remove samples from the minority class while trying to generate new", # noqa ), ( 0.1, - np.array([0] * 10 + [1] * 20), + get_classes_counts(np.array([0] * 10 + [1] * 20)), "under-sampling", "generate new sample in the majority class while trying to remove", ), @@ -108,15 +111,21 @@ def test_check_sampling_strategy_float_error(ratio, y, type, err_msg): def test_check_sampling_strategy_error(): with pytest.raises(ValueError, match="'sampling_type' should be one of"): - check_sampling_strategy("auto", np.array([1, 2, 3]), "rnd") + check_sampling_strategy( + "auto", get_classes_counts(np.array([1, 2, 3])), "rnd" + ) error_regex = "The target 'y' needs to have more than 1 class." with pytest.raises(ValueError, match=error_regex): - check_sampling_strategy("auto", np.ones((10,)), "over-sampling") + check_sampling_strategy( + "auto", get_classes_counts(np.ones((10,))), "over-sampling" + ) error_regex = "When 'sampling_strategy' is a string, it needs to be one of" with pytest.raises(ValueError, match=error_regex): - check_sampling_strategy("rnd", np.array([1, 2, 3]), "over-sampling") + check_sampling_strategy( + "rnd", get_classes_counts(np.array([1, 2, 3])), "over-sampling" + ) @pytest.mark.parametrize( @@ -136,7 +145,9 @@ def test_check_sampling_strategy_error_wrong_string( ), ): check_sampling_strategy( - sampling_strategy, np.array([1, 2, 3]), sampling_type + sampling_strategy, + get_classes_counts(np.array([1, 2, 3])), + sampling_type, ) @@ -153,14 +164,18 @@ def test_sampling_strategy_class_target_unknown( ): y = np.array([1] * 50 + [2] * 100 + [3] * 25) with pytest.raises(ValueError, match="are not present in the data."): - check_sampling_strategy(sampling_strategy, y, sampling_method) + check_sampling_strategy( + sampling_strategy, get_classes_counts(y), sampling_method + ) def test_sampling_strategy_dict_error(): y = np.array([1] * 50 + [2] * 100 + [3] * 25) sampling_strategy = {1: -100, 2: 50, 3: 25} with pytest.raises(ValueError, match="in a class cannot be negative."): - check_sampling_strategy(sampling_strategy, y, "under-sampling") + check_sampling_strategy( + sampling_strategy, get_classes_counts(y), "under-sampling" + ) sampling_strategy = {1: 45, 2: 100, 3: 70} error_regex = ( "With over-sampling methods, the number of samples in a" @@ -169,7 +184,9 @@ def test_sampling_strategy_dict_error(): " samples are asked." ) with pytest.raises(ValueError, match=error_regex): - check_sampling_strategy(sampling_strategy, y, "over-sampling") + check_sampling_strategy( + sampling_strategy, get_classes_counts(y), "over-sampling" + ) error_regex = ( "With under-sampling methods, the number of samples in a" @@ -178,21 +195,27 @@ def test_sampling_strategy_dict_error(): " are asked." ) with pytest.raises(ValueError, match=error_regex): - check_sampling_strategy(sampling_strategy, y, "under-sampling") + check_sampling_strategy( + sampling_strategy, get_classes_counts(y), "under-sampling" + ) @pytest.mark.parametrize("sampling_strategy", [-10, 10]) def test_sampling_strategy_float_error_not_in_range(sampling_strategy): y = np.array([1] * 50 + [2] * 100) with pytest.raises(ValueError, match="it should be in the range"): - check_sampling_strategy(sampling_strategy, y, "under-sampling") + check_sampling_strategy( + sampling_strategy, get_classes_counts(y), "under-sampling" + ) def test_sampling_strategy_float_error_not_binary(): y = np.array([1] * 50 + [2] * 100 + [3] * 25) with pytest.raises(ValueError, match="the type of target is binary"): sampling_strategy = 0.5 - check_sampling_strategy(sampling_strategy, y, "under-sampling") + check_sampling_strategy( + sampling_strategy, get_classes_counts(y), "under-sampling" + ) @pytest.mark.parametrize( @@ -202,7 +225,9 @@ def test_sampling_strategy_list_error_not_clean_sampling(sampling_method): y = np.array([1] * 50 + [2] * 100 + [3] * 25) with pytest.raises(ValueError, match="cannot be a list for samplers"): sampling_strategy = [1, 2, 3] - check_sampling_strategy(sampling_strategy, y, sampling_method) + check_sampling_strategy( + sampling_strategy, get_classes_counts(y), sampling_method + ) def _sampling_strategy_func(y): @@ -215,42 +240,87 @@ def _sampling_strategy_func(y): @pytest.mark.parametrize( "sampling_strategy, sampling_type, expected_sampling_strategy, target", [ - ("auto", "under-sampling", {1: 25, 2: 25}, multiclass_target), - ("auto", "clean-sampling", {1: 25, 2: 25}, multiclass_target), - ("auto", "over-sampling", {1: 50, 3: 75}, multiclass_target), - ("all", "over-sampling", {1: 50, 2: 0, 3: 75}, multiclass_target), - ("all", "under-sampling", {1: 25, 2: 25, 3: 25}, multiclass_target), - ("all", "clean-sampling", {1: 25, 2: 25, 3: 25}, multiclass_target), - ("majority", "under-sampling", {2: 25}, multiclass_target), - ("majority", "clean-sampling", {2: 25}, multiclass_target), - ("minority", "over-sampling", {3: 75}, multiclass_target), - ("not minority", "over-sampling", {1: 50, 2: 0}, multiclass_target), - ("not minority", "under-sampling", {1: 25, 2: 25}, multiclass_target), - ("not minority", "clean-sampling", {1: 25, 2: 25}, multiclass_target), - ("not majority", "over-sampling", {1: 50, 3: 75}, multiclass_target), - ("not majority", "under-sampling", {1: 25, 3: 25}, multiclass_target), - ("not majority", "clean-sampling", {1: 25, 3: 25}, multiclass_target), + ("auto", "under-sampling", {1: 25, 2: 25}, multiclass_classes_counts), + ("auto", "clean-sampling", {1: 25, 2: 25}, multiclass_classes_counts), + ("auto", "over-sampling", {1: 50, 3: 75}, multiclass_classes_counts), + ( + "all", + "over-sampling", + {1: 50, 2: 0, 3: 75}, + multiclass_classes_counts, + ), + ( + "all", + "under-sampling", + {1: 25, 2: 25, 3: 25}, + multiclass_classes_counts, + ), + ( + "all", + "clean-sampling", + {1: 25, 2: 25, 3: 25}, + multiclass_classes_counts, + ), + ("majority", "under-sampling", {2: 25}, multiclass_classes_counts), + ("majority", "clean-sampling", {2: 25}, multiclass_classes_counts), + ("minority", "over-sampling", {3: 75}, multiclass_classes_counts), + ( + "not minority", + "over-sampling", + {1: 50, 2: 0}, + multiclass_classes_counts, + ), + ( + "not minority", + "under-sampling", + {1: 25, 2: 25}, + multiclass_classes_counts, + ), + ( + "not minority", + "clean-sampling", + {1: 25, 2: 25}, + multiclass_classes_counts, + ), + ( + "not majority", + "over-sampling", + {1: 50, 3: 75}, + multiclass_classes_counts, + ), + ( + "not majority", + "under-sampling", + {1: 25, 3: 25}, + multiclass_classes_counts, + ), + ( + "not majority", + "clean-sampling", + {1: 25, 3: 25}, + multiclass_classes_counts, + ), ( {1: 70, 2: 100, 3: 70}, "over-sampling", {1: 20, 2: 0, 3: 45}, - multiclass_target, + multiclass_classes_counts, ), ( {1: 30, 2: 45, 3: 25}, "under-sampling", {1: 30, 2: 45, 3: 25}, - multiclass_target, + multiclass_classes_counts, ), - ([1], "clean-sampling", {1: 25}, multiclass_target), + ([1], "clean-sampling", {1: 25}, multiclass_classes_counts), ( _sampling_strategy_func, "over-sampling", {1: 50, 2: 0, 3: 75}, - multiclass_target, + multiclass_classes_counts, ), - (0.5, "over-sampling", {1: 25}, binary_target), - (0.5, "under-sampling", {0: 50}, binary_target), + (0.5, "over-sampling", {1: 25}, binary_classes_counts), + (0.5, "under-sampling", {0: 50}, binary_classes_counts), ], ) def test_check_sampling_strategy( @@ -271,23 +341,27 @@ def test_sampling_strategy_dict_over_sampling(): r" the majority class \(class #2 -> 100\)" ) with warns(UserWarning, expected_msg): - check_sampling_strategy(sampling_strategy, y, "over-sampling") + check_sampling_strategy( + sampling_strategy, get_classes_counts(y), "over-sampling" + ) def test_sampling_strategy_callable_args(): y = np.array([1] * 50 + [2] * 100 + [3] * 25) multiplier = {1: 1.5, 2: 1, 3: 3} - def sampling_strategy_func(y, multiplier): + def sampling_strategy_func(classes_counts, multiplier): """samples such that each class will be affected by the multiplier.""" - target_stats = Counter(y) return { key: int(values * multiplier[key]) - for key, values in target_stats.items() + for key, values in classes_counts.items() } sampling_strategy_ = check_sampling_strategy( - sampling_strategy_func, y, "over-sampling", multiplier=multiplier + sampling_strategy_func, + get_classes_counts(y), + "over-sampling", + multiplier=multiplier, ) assert sampling_strategy_ == {1: 25, 2: 0, 3: 50} @@ -314,11 +388,21 @@ def test_sampling_strategy_check_order( # dictionary is sorted. Refer to issue #428. y = np.array([1] * 50 + [2] * 100 + [3] * 25) sampling_strategy_ = check_sampling_strategy( - sampling_strategy, y, sampling_type + sampling_strategy, get_classes_counts(y), sampling_type ) assert sampling_strategy_ == expected_result +# FIXME: remove in 0.9 +def test_sampling_strategy_deprecation_array_target(): + # Check that we raise a FutureWarning when an array of target is passed + with pytest.warns(FutureWarning): + sampling_strategy = "auto" + check_sampling_strategy( + sampling_strategy, binary_target, "under-sampling", + ) + + def test_arrays_transformer_plain_list(): X = np.array([[0, 0], [1, 1]]) y = np.array([[0, 0], [1, 1]]) diff --git a/imblearn/utils/wrapper.py b/imblearn/utils/wrapper.py new file mode 100644 index 000000000..cbc9e1b1d --- /dev/null +++ b/imblearn/utils/wrapper.py @@ -0,0 +1,60 @@ +import numpy as np + +from sklearn.preprocessing import label_binarize as sklearn_label_binarize +from sklearn.utils.multiclass import check_classification_targets as \ + sklearn_check_classification_targets +from sklearn.utils.multiclass import type_of_target as sklearn_type_of_target +from sklearn.utils.validation import column_or_1d as sklearn_column_or_1d + +from ..dask._support import is_dask_collection + + +def type_of_target(y): + if is_dask_collection(y): + from ..dask.utils import type_of_target as dask_type_of_target + + return dask_type_of_target(y) + return sklearn_type_of_target(y) + + +def _is_multiclass_encoded(y): + if is_dask_collection(y): + from dask import array + + return array.all(y.sum(axis=1) == 1).compute() + return np.all(y.sum(axis=1) == 1) + + +def column_or_1d(y, *, warn=False): + if is_dask_collection(y): + from ..dask.utils import column_or_1d as dask_column_or_1d + + return dask_column_or_1d(y, warn=warn) + return sklearn_column_or_1d(y, warn=warn) + + +def unique(arr, **kwargs): + if is_dask_collection(arr): + if hasattr(arr, "unique"): + output = np.asarray(arr.unique(**kwargs)) + else: + output = np.unique(arr).compute() + return output + return np.unique(arr, **kwargs) + + +def check_classification_targets(y): + if is_dask_collection(y): + from ..dask.utils import check_classification_targets as \ + dask_check_classification_targets + + return dask_check_classification_targets(y) + return sklearn_check_classification_targets(y) + + +def label_binarize(y, *, classes): + if is_dask_collection(y): + from ..dask.preprocessing import label_binarize as dask_label_binarize + + return dask_label_binarize(y, classes=classes) + return sklearn_label_binarize(y, classes=classes) diff --git a/requirements.optional.txt b/requirements.optional.txt index 826277d5e..f785df2ff 100644 --- a/requirements.optional.txt +++ b/requirements.optional.txt @@ -1,2 +1,3 @@ +dask[complete] keras tensorflow diff --git a/setup.cfg b/setup.cfg index 1062c584c..0b7b5b1d2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -21,7 +21,7 @@ test = pytest [tool:pytest] doctest_optionflags = NORMALIZE_WHITESPACE ELLIPSIS -addopts = +addopts = --ignore build_tools --ignore benchmarks --ignore doc @@ -29,6 +29,9 @@ addopts = --ignore maint_tools --doctest-modules -rs -filterwarnings = +filterwarnings = ignore:the matrix subclass:PendingDeprecationWarning +[flake8] +# Default flake8 3.5 ignored flags +ignore=E121,E123,E126,E226,E24,E704,W503,W504,E402