-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH make Random*Sampler accept dask array and dataframe #777
Open
glemaitre
wants to merge
32
commits into
scikit-learn-contrib:master
Choose a base branch
from
glemaitre:dask_base_tools
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
32 commits
Select commit
Hold shift + click to select a range
95247e6
ENH make RandomUnderSampler accept dask array
glemaitre ea30287
add dask to the install
glemaitre 0766964
PEP8
glemaitre d9edb9a
iter
glemaitre 4960724
PEP8
glemaitre 2152429
iter
glemaitre e5ce7a6
PEP8
glemaitre b537a20
iter
glemaitre f781be0
iter
glemaitre fb3d6a4
avoid import dask explicitely
glemaitre b7d9f3b
TST remove redundant test
glemaitre d26da3c
iter
glemaitre c065808
xxx
glemaitre f2d0ec0
install complete dask
glemaitre 20ba934
iter
glemaitre 0941a5e
iter
glemaitre 7aae9d9
iter
glemaitre 00c0a26
iter
glemaitre 8bfa040
requirements
glemaitre d4aabf8
iter
glemaitre 58acdf2
iter
glemaitre e54c772
PEP8
glemaitre f2a572f
iter
glemaitre 36a0aa3
iter
glemaitre c7bdc74
check raise FutureWarning
glemaitre f095221
iter
glemaitre 20b44c6
iter
glemaitre 4cd9116
iter
glemaitre a6e975b
PEP8
glemaitre 32eda46
iter
glemaitre 6c592ff
iter
glemaitre 456c3eb
PEP8
glemaitre File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
def is_dask_collection(container): | ||
try: | ||
# to keep dask as an optional depency, keep the statement in a | ||
# try/except statement | ||
from dask import is_dask_collection | ||
|
||
return is_dask_collection(container) | ||
except ImportError: | ||
return False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
def label_binarize(y, *, classes): | ||
import pandas as pd | ||
from dask import dataframe | ||
|
||
cat_dtype = pd.CategoricalDtype(categories=classes) | ||
y = dataframe.from_array(y).astype(cat_dtype) | ||
return dataframe.get_dummies(y).to_dask_array() |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import numpy as np | ||
import pytest | ||
|
||
dask = pytest.importorskip("dask") | ||
from dask import array | ||
|
||
from imblearn.dask.utils import is_multilabel | ||
from imblearn.dask.utils import type_of_target | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"y, expected_result", | ||
[ | ||
(array.from_array(np.array([0, 1, 0, 1])), False), | ||
(array.from_array(np.array([[1, 0], [0, 0]])), True), | ||
(array.from_array(np.array([[1], [0], [0]])), False), | ||
(array.from_array(np.array([[1, 0, 0]])), True), | ||
] | ||
) | ||
def test_is_multilabel(y, expected_result): | ||
assert is_multilabel(y) is expected_result | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"y, expected_type_of_target", | ||
[ | ||
(array.from_array(np.array([[1, 0], [0, 0]])), "multilabel-indicator"), | ||
(array.from_array(np.array([[1, 0, 0]])), "multilabel-indicator"), | ||
(array.from_array(np.array([[[1, 2]]])), "unknown"), | ||
(array.from_array(np.array([[]])), "unknown"), | ||
(array.from_array(np.array([.1, .2, 3])), "continuous"), | ||
(array.from_array(np.array([[.1, .2, 3]])), "continuous-multioutput"), | ||
(array.from_array(np.array([[1., .2]])), "continuous-multioutput"), | ||
(array.from_array(np.array([1, 2])), "binary"), | ||
(array.from_array(np.array(["a", "b"])), "binary"), | ||
] | ||
) | ||
def test_type_of_target(y, expected_type_of_target): | ||
target_type = type_of_target(y) | ||
assert target_type == expected_type_of_target |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import warnings | ||
|
||
import numpy as np | ||
from sklearn.exceptions import DataConversionWarning | ||
from sklearn.utils.multiclass import _is_integral_float | ||
|
||
|
||
def is_multilabel(y): | ||
if not (y.ndim == 2 and y.shape[1] > 1): | ||
return False | ||
|
||
if hasattr(y, "unique"): | ||
labels = np.asarray(y.unique()) | ||
else: | ||
labels = np.unique(y).compute() | ||
|
||
return len(labels) < 3 and ( | ||
y.dtype.kind in 'biu' or _is_integral_float(labels) | ||
) | ||
|
||
|
||
def type_of_target(y): | ||
if is_multilabel(y): | ||
return 'multilabel-indicator' | ||
|
||
if y.ndim > 2: | ||
return 'unknown' | ||
|
||
if y.ndim == 2 and y.shape[1] == 0: | ||
return 'unknown' # [[]] | ||
|
||
if y.ndim == 2 and y.shape[1] > 1: | ||
# [[1, 2], [1, 2]] | ||
suffix = "-multioutput" | ||
else: | ||
# [1, 2, 3] or [[1], [2], [3]] | ||
suffix = "" | ||
|
||
# check float and contains non-integer float values | ||
if y.dtype.kind == 'f' and np.any(y != y.astype(int)): | ||
# [.1, .2, 3] or [[.1, .2, 3]] or [[1., .2]] and not [1., 2., 3.] | ||
# NOTE: we don't check for infinite values | ||
return 'continuous' + suffix | ||
|
||
if hasattr(y, "unique"): | ||
labels = np.asarray(y.unique()) | ||
else: | ||
labels = np.unique(y).compute() | ||
if (len((labels)) > 2) or (y.ndim >= 2 and len(y[0]) > 1): | ||
# [1, 2, 3] or [[1., 2., 3]] or [[1, 2]] | ||
return 'multiclass' + suffix | ||
# [1, 2] or [["a"], ["b"]] | ||
return 'binary' | ||
|
||
|
||
def column_or_1d(y, *, warn=False): | ||
shape = y.shape | ||
if len(shape) == 1: | ||
return y.ravel() | ||
if len(shape) == 2 and shape[1] == 1: | ||
if warn: | ||
warnings.warn( | ||
"A column-vector y was passed when a 1d array was expected. " | ||
"Please change the shape of y to (n_samples, ), for example " | ||
"using ravel().", DataConversionWarning, stacklevel=2 | ||
) | ||
return y.ravel() | ||
|
||
raise ValueError( | ||
f"y should be a 1d array. Got an array of shape {shape} instead." | ||
) | ||
|
||
|
||
def check_classification_targets(y): | ||
y_type = type_of_target(y) | ||
if y_type not in ['binary', 'multiclass', 'multiclass-multioutput', | ||
'multilabel-indicator', 'multilabel-sequences']: | ||
raise ValueError("Unknown label type: %r" % y_type) | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've struggled with this check in dask-ml. Depending on where it's called, it's potentially very expensive (you might be loading a ton of data just to check if it's multi-label, and then loading it again to to the training).
Whenever possible, it's helpful to provide an option to skip this check by having the user specify it when creating the estimator, or in a keyword to
fit
(dunno if that applies here).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought about it. Do you think that having a context manager outside would make sense:
Thought, we might get into trouble with issues related to scikit-learn/scikit-learn#18736
It might just be easier to have an optional class parameter that applies only for dask arrays.