Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sklearnex.BasicStatistics API for CSR inputs on GPU and a test for it #2253

Merged
merged 18 commits into from
Jan 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion onedal/basic_statistics/basic_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,13 @@ def fit(self, data, sample_weight=None, queue=None):
sample_weight = _check_array(sample_weight, ensure_2d=False)

is_single_dim = data.ndim == 1
data_table, weights_table = to_table(data, sample_weight, queue=queue)

data_table = to_table(data, queue=queue)
weights_table = (
icfaust marked this conversation as resolved.
Show resolved Hide resolved
to_table(sample_weight, queue=queue)
if sample_weight is not None
else to_table(None)
)

dtype = data_table.dtype
raw_result = self._compute_raw(data_table, weights_table, policy, dtype, is_csr)
Expand Down
25 changes: 21 additions & 4 deletions onedal/basic_statistics/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,35 @@

import numpy as np


# Compute unbiased variation for the columns of array-like X
def variation(X):
X_mean = np.mean(X, axis=0)
if np.all(X_mean):
# Avoid division by zero
return np.std(X, axis=0, ddof=1) / X_mean
else:
return np.array(
[
x / y if y != 0 else np.nan
for x, y in zip(np.std(X, axis=0, ddof=1), X_mean)
]
)


options_and_tests = {
"sum": (lambda X: np.sum(X, axis=0), (5e-4, 1e-7)),
"min": (lambda X: np.min(X, axis=0), (1e-7, 1e-7)),
"max": (lambda X: np.max(X, axis=0), (1e-7, 1e-7)),
"mean": (lambda X: np.mean(X, axis=0), (5e-7, 1e-7)),
"variance": (lambda X: np.var(X, axis=0), (2e-3, 2e-3)),
"variation": (lambda X: np.std(X, axis=0) / np.mean(X, axis=0), (5e-2, 5e-2)),
# sklearnex computes unbiased variance and standard deviation that is why ddof=1
"variance": (lambda X: np.var(X, axis=0, ddof=1), (2e-4, 1e-7)),
"variation": (lambda X: variation(X), (1e-3, 1e-6)),
"sum_squares": (lambda X: np.sum(np.square(X), axis=0), (2e-4, 1e-7)),
"sum_squares_centered": (
lambda X: np.sum(np.square(X - np.mean(X, axis=0)), axis=0),
(2e-4, 1e-7),
(1e-3, 1e-7),
),
"standard_deviation": (lambda X: np.std(X, axis=0), (2e-3, 2e-3)),
"standard_deviation": (lambda X: np.std(X, axis=0, ddof=1), (2e-3, 1e-7)),
icfaust marked this conversation as resolved.
Show resolved Hide resolved
"second_order_raw_moment": (lambda X: np.mean(np.square(X), axis=0), (1e-6, 1e-7)),
}
47 changes: 39 additions & 8 deletions sklearnex/basic_statistics/basic_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@
import warnings

import numpy as np
from scipy.sparse import issparse
from sklearn.base import BaseEstimator
from sklearn.utils import check_array
from sklearn.utils.validation import _check_sample_weight

from daal4py.sklearn._n_jobs_support import control_n_jobs
from daal4py.sklearn._utils import sklearn_check_version
from daal4py.sklearn._utils import daal_check_version, sklearn_check_version
from onedal.basic_statistics import BasicStatistics as onedal_BasicStatistics
from onedal.utils import _is_csr

from .._device_offload import dispatch
from .._utils import IntelEstimator, PatchingConditionsChain
Expand Down Expand Up @@ -62,13 +64,13 @@ class BasicStatistics(IntelEstimator, BaseEstimator):
mean_ : ndarray of shape (n_features,)
Mean of each feature over all samples.
variance_ : ndarray of shape (n_features,)
Variance of each feature over all samples.
Variance of each feature over all samples. Bessel's correction is used.
variation_ : ndarray of shape (n_features,)
Variation of each feature over all samples.
Variation of each feature over all samples. Bessel's correction is used.
sum_squares_ : ndarray of shape (n_features,)
Sum of squares for each feature over all samples.
standard_deviation_ : ndarray of shape (n_features,)
Standard deviation of each feature over all samples.
Unbiased standard deviation of each feature over all samples. Bessel's correction is used.
sum_squares_centered_ : ndarray of shape (n_features,)
Centered sum of squares for each feature over all samples.
second_order_raw_moment_ : ndarray of shape (n_features,)
Expand Down Expand Up @@ -166,21 +168,50 @@ def __getattr__(self, attr):
f"'{self.__class__.__name__}' object has no attribute '{attr}'"
)

def _onedal_supported(self, method_name, *data):
def _onedal_cpu_supported(self, method_name, *data):
patching_status = PatchingConditionsChain(
f"sklearnex.basic_statistics.{self.__class__.__name__}.{method_name}"
)
return patching_status

_onedal_cpu_supported = _onedal_supported
_onedal_gpu_supported = _onedal_supported
def _onedal_gpu_supported(self, method_name, *data):
patching_status = PatchingConditionsChain(
f"sklearnex.basic_statistics.{self.__class__.__name__}.{method_name}"
)
X, sample_weight = data

is_data_supported = not issparse(X) or (
_is_csr(X) and daal_check_version((2025, "P", 200))
icfaust marked this conversation as resolved.
Show resolved Hide resolved
)

is_sample_weight_supported = sample_weight is None or not issparse(X)
Vika-F marked this conversation as resolved.
Show resolved Hide resolved

patching_status.and_conditions(
[
(
is_sample_weight_supported,
"Sample weights are not supported for CSR data format",
),
(
is_data_supported,
"Supported data formats: Dense, CSR (oneDAL version >= 2025.2.0).",
),
]
)
return patching_status

def _onedal_fit(self, X, sample_weight=None, queue=None):
if sklearn_check_version("1.2"):
self._validate_params()

if sklearn_check_version("1.0"):
X = validate_data(self, X, dtype=[np.float64, np.float32], ensure_2d=False)
X = validate_data(
self,
X,
dtype=[np.float64, np.float32],
ensure_2d=False,
accept_sparse="csr",
)
else:
X = check_array(X, dtype=[np.float64, np.float32])

Expand Down
Loading
Loading