diff --git a/onedal/basic_statistics/basic_statistics.py b/onedal/basic_statistics/basic_statistics.py index 56904adce2..ca245c87f1 100644 --- a/onedal/basic_statistics/basic_statistics.py +++ b/onedal/basic_statistics/basic_statistics.py @@ -82,7 +82,13 @@ def fit(self, data, sample_weight=None, queue=None): sample_weight = _check_array(sample_weight, ensure_2d=False) is_single_dim = data.ndim == 1 - data_table, weights_table = to_table(data, sample_weight, queue=queue) + + data_table = to_table(data, queue=queue) + weights_table = ( + to_table(sample_weight, queue=queue) + if sample_weight is not None + else to_table(None) + ) dtype = data_table.dtype raw_result = self._compute_raw(data_table, weights_table, policy, dtype, is_csr) diff --git a/onedal/basic_statistics/tests/utils.py b/onedal/basic_statistics/tests/utils.py index da0ed559b6..989113d260 100644 --- a/onedal/basic_statistics/tests/utils.py +++ b/onedal/basic_statistics/tests/utils.py @@ -16,18 +16,35 @@ import numpy as np + +# Compute unbiased variation for the columns of array-like X +def variation(X): + X_mean = np.mean(X, axis=0) + if np.all(X_mean): + # Avoid division by zero + return np.std(X, axis=0, ddof=1) / X_mean + else: + return np.array( + [ + x / y if y != 0 else np.nan + for x, y in zip(np.std(X, axis=0, ddof=1), X_mean) + ] + ) + + options_and_tests = { "sum": (lambda X: np.sum(X, axis=0), (5e-4, 1e-7)), "min": (lambda X: np.min(X, axis=0), (1e-7, 1e-7)), "max": (lambda X: np.max(X, axis=0), (1e-7, 1e-7)), "mean": (lambda X: np.mean(X, axis=0), (5e-7, 1e-7)), - "variance": (lambda X: np.var(X, axis=0), (2e-3, 2e-3)), - "variation": (lambda X: np.std(X, axis=0) / np.mean(X, axis=0), (5e-2, 5e-2)), + # sklearnex computes unbiased variance and standard deviation that is why ddof=1 + "variance": (lambda X: np.var(X, axis=0, ddof=1), (2e-4, 1e-7)), + "variation": (lambda X: variation(X), (1e-3, 1e-6)), "sum_squares": (lambda X: np.sum(np.square(X), axis=0), (2e-4, 1e-7)), "sum_squares_centered": ( lambda X: np.sum(np.square(X - np.mean(X, axis=0)), axis=0), - (2e-4, 1e-7), + (1e-3, 1e-7), ), - "standard_deviation": (lambda X: np.std(X, axis=0), (2e-3, 2e-3)), + "standard_deviation": (lambda X: np.std(X, axis=0, ddof=1), (2e-3, 1e-7)), "second_order_raw_moment": (lambda X: np.mean(np.square(X), axis=0), (1e-6, 1e-7)), } diff --git a/sklearnex/basic_statistics/basic_statistics.py b/sklearnex/basic_statistics/basic_statistics.py index da82e3bd82..26f78ac16e 100644 --- a/sklearnex/basic_statistics/basic_statistics.py +++ b/sklearnex/basic_statistics/basic_statistics.py @@ -17,13 +17,15 @@ import warnings import numpy as np +from scipy.sparse import issparse from sklearn.base import BaseEstimator from sklearn.utils import check_array from sklearn.utils.validation import _check_sample_weight from daal4py.sklearn._n_jobs_support import control_n_jobs -from daal4py.sklearn._utils import sklearn_check_version +from daal4py.sklearn._utils import daal_check_version, sklearn_check_version from onedal.basic_statistics import BasicStatistics as onedal_BasicStatistics +from onedal.utils import _is_csr from .._device_offload import dispatch from .._utils import IntelEstimator, PatchingConditionsChain @@ -62,13 +64,13 @@ class BasicStatistics(IntelEstimator, BaseEstimator): mean_ : ndarray of shape (n_features,) Mean of each feature over all samples. variance_ : ndarray of shape (n_features,) - Variance of each feature over all samples. + Variance of each feature over all samples. Bessel's correction is used. variation_ : ndarray of shape (n_features,) - Variation of each feature over all samples. + Variation of each feature over all samples. Bessel's correction is used. sum_squares_ : ndarray of shape (n_features,) Sum of squares for each feature over all samples. standard_deviation_ : ndarray of shape (n_features,) - Standard deviation of each feature over all samples. + Unbiased standard deviation of each feature over all samples. Bessel's correction is used. sum_squares_centered_ : ndarray of shape (n_features,) Centered sum of squares for each feature over all samples. second_order_raw_moment_ : ndarray of shape (n_features,) @@ -166,21 +168,50 @@ def __getattr__(self, attr): f"'{self.__class__.__name__}' object has no attribute '{attr}'" ) - def _onedal_supported(self, method_name, *data): + def _onedal_cpu_supported(self, method_name, *data): patching_status = PatchingConditionsChain( f"sklearnex.basic_statistics.{self.__class__.__name__}.{method_name}" ) return patching_status - _onedal_cpu_supported = _onedal_supported - _onedal_gpu_supported = _onedal_supported + def _onedal_gpu_supported(self, method_name, *data): + patching_status = PatchingConditionsChain( + f"sklearnex.basic_statistics.{self.__class__.__name__}.{method_name}" + ) + X, sample_weight = data + + is_data_supported = not issparse(X) or ( + _is_csr(X) and daal_check_version((2025, "P", 200)) + ) + + is_sample_weight_supported = sample_weight is None or not issparse(X) + + patching_status.and_conditions( + [ + ( + is_sample_weight_supported, + "Sample weights are not supported for CSR data format", + ), + ( + is_data_supported, + "Supported data formats: Dense, CSR (oneDAL version >= 2025.2.0).", + ), + ] + ) + return patching_status def _onedal_fit(self, X, sample_weight=None, queue=None): if sklearn_check_version("1.2"): self._validate_params() if sklearn_check_version("1.0"): - X = validate_data(self, X, dtype=[np.float64, np.float32], ensure_2d=False) + X = validate_data( + self, + X, + dtype=[np.float64, np.float32], + ensure_2d=False, + accept_sparse="csr", + ) else: X = check_array(X, dtype=[np.float64, np.float32]) diff --git a/sklearnex/basic_statistics/tests/test_basic_statistics.py b/sklearnex/basic_statistics/tests/test_basic_statistics.py index a5515f240d..b2132500be 100644 --- a/sklearnex/basic_statistics/tests/test_basic_statistics.py +++ b/sklearnex/basic_statistics/tests/test_basic_statistics.py @@ -17,14 +17,30 @@ import numpy as np import pytest from numpy.testing import assert_allclose +from scipy import sparse as sp from daal4py.sklearn._utils import daal_check_version from onedal.basic_statistics.tests.utils import options_and_tests from onedal.tests.utils._dataframes_support import ( _convert_to_dataframe, get_dataframes_and_queues, + get_queues, ) +from sklearnex import config_context from sklearnex.basic_statistics import BasicStatistics +from sklearnex.tests.utils import gen_sparse_dataset + + +# Compute the basic statistics on sparse data on CPU or GPU depending on the queue +def compute_sparse_result(X_sparse, options, queue): + if queue is not None and queue.sycl_device.is_gpu: + with config_context(target_offload="gpu"): + basicstat = BasicStatistics(result_options=options) + result = basicstat.fit(X_sparse) + else: + basicstat = BasicStatistics(result_options=options) + result = basicstat.fit(X_sparse) + return result @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues()) @@ -41,9 +57,9 @@ def test_sklearnex_import_basic_statistics(dataframe, queue): expected_min = np.array([0, 0]) expected_max = np.array([1, 1]) - assert_allclose(expected_mean, result.mean) - assert_allclose(expected_max, result.max) - assert_allclose(expected_min, result.min) + assert_allclose(expected_mean, result.mean_) + assert_allclose(expected_max, result.max_) + assert_allclose(expected_min, result.min_) result = BasicStatistics().fit(X_df, sample_weight=weights_df) @@ -51,9 +67,9 @@ def test_sklearnex_import_basic_statistics(dataframe, queue): expected_weighted_min = np.array([0, 0]) expected_weighted_max = np.array([0.5, 0.5]) - assert_allclose(expected_weighted_mean, result.mean) - assert_allclose(expected_weighted_min, result.min) - assert_allclose(expected_weighted_max, result.max) + assert_allclose(expected_weighted_mean, result.mean_) + assert_allclose(expected_weighted_min, result.min_) + assert_allclose(expected_weighted_max, result.max_) @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues()) @@ -78,16 +94,16 @@ def test_multiple_options_on_gold_data(dataframe, queue, weighted, dtype): expected_weighted_mean = np.array([0.25, 0.25]) expected_weighted_min = np.array([0, 0]) expected_weighted_max = np.array([0.5, 0.5]) - assert_allclose(expected_weighted_mean, result.mean) - assert_allclose(expected_weighted_max, result.max) - assert_allclose(expected_weighted_min, result.min) + assert_allclose(expected_weighted_mean, result.mean_) + assert_allclose(expected_weighted_max, result.max_) + assert_allclose(expected_weighted_min, result.min_) else: expected_mean = np.array([0.5, 0.5]) expected_min = np.array([0, 0]) expected_max = np.array([1, 1]) - assert_allclose(expected_mean, result.mean) - assert_allclose(expected_max, result.max) - assert_allclose(expected_min, result.min) + assert_allclose(expected_mean, result.mean_) + assert_allclose(expected_max, result.max_) + assert_allclose(expected_min, result.min_) @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues()) @@ -117,7 +133,7 @@ def test_single_option_on_random_data( else: result = basicstat.fit(X_df) - res = getattr(result, result_option) + res = getattr(result, result_option + "_") if weighted: weighted_data = np.diag(weights) @ X gtr = function(weighted_data) @@ -128,6 +144,49 @@ def test_single_option_on_random_data( assert_allclose(gtr, res, atol=tol) +@pytest.mark.parametrize("queue", get_queues()) +@pytest.mark.parametrize("result_option", options_and_tests.keys()) +@pytest.mark.parametrize("row_count", [500, 2000]) +@pytest.mark.parametrize("column_count", [10, 100]) +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_single_option_on_random_sparse_data( + queue, result_option, row_count, column_count, dtype +): + if not daal_check_version((2025, "P", 200)) and result_option in [ + "max", + "sum_squares", + ]: + pytest.skip( + "'max' and 'sum_squares' calculate using a subset of the data in oneDAL version < 2025.2" + ) + + function, tols = options_and_tests[result_option] + fp32tol, fp64tol = tols + seed = 77 + + gen = np.random.default_rng(seed) + + X_sparse = gen_sparse_dataset( + row_count, + column_count, + density=0.01, + format="csr", + dtype=dtype, + random_state=gen, + ) + + X_dense = X_sparse.toarray() + + result = compute_sparse_result(X_sparse, result_option, queue) + + res = getattr(result, result_option + "_") + + gtr = function(X_dense) + + tol = fp32tol if res.dtype == np.float32 else fp64tol + assert_allclose(gtr, res, atol=tol) + + @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues()) @pytest.mark.parametrize("row_count", [100, 1000]) @pytest.mark.parametrize("column_count", [10, 100]) @@ -152,7 +211,7 @@ def test_multiple_options_on_random_data( else: result = basicstat.fit(X_df) - res_mean, res_max, res_sum = result.mean, result.max, result.sum + res_mean, res_max, res_sum = result.mean_, result.max_, result.sum_ if weighted: weighted_data = np.diag(weights) @ X gtr_mean, gtr_max, gtr_sum = ( @@ -173,6 +232,48 @@ def test_multiple_options_on_random_data( assert_allclose(gtr_sum, res_sum, atol=tol) +@pytest.mark.parametrize("queue", get_queues()) +@pytest.mark.parametrize("row_count", [100, 1000]) +@pytest.mark.parametrize("column_count", [10, 100]) +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_multiple_options_on_random_sparse_data(queue, row_count, column_count, dtype): + seed = 77 + + gen = np.random.default_rng(seed) + + X_sparse = gen_sparse_dataset( + row_count, + column_count, + density=0.05, + format="csr", + dtype=dtype, + random_state=gen, + ) + + X_dense = X_sparse.toarray() + + options = [ + "sum", + "min", + "mean", + "standard_deviation", + "variance", + "second_order_raw_moment", + ] + + result = compute_sparse_result(X_sparse, options, queue) + + for result_option in options_and_tests: + function, tols = options_and_tests[result_option] + if not result_option in options: + continue + fp32tol, fp64tol = tols + res = getattr(result, result_option + "_") + gtr = function(X_dense) + tol = fp32tol if res.dtype == np.float32 else fp64tol + assert_allclose(gtr, res, atol=tol) + + @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues()) @pytest.mark.parametrize("row_count", [100, 1000]) @pytest.mark.parametrize("column_count", [10, 100]) @@ -203,7 +304,7 @@ def test_all_option_on_random_data( for result_option in options_and_tests: function, tols = options_and_tests[result_option] fp32tol, fp64tol = tols - res = getattr(result, result_option) + res = getattr(result, result_option + "_") if weighted: gtr = function(weighted_data) else: @@ -212,6 +313,43 @@ def test_all_option_on_random_data( assert_allclose(gtr, res, atol=tol) +@pytest.mark.parametrize("queue", get_queues()) +@pytest.mark.parametrize("row_count", [100, 1000]) +@pytest.mark.parametrize("column_count", [10, 100]) +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_all_option_on_random_sparse_data(queue, row_count, column_count, dtype): + seed = 77 + + gen = np.random.default_rng(seed) + + X_sparse = gen_sparse_dataset( + row_count, + column_count, + density=0.05, + format="csr", + dtype=dtype, + random_state=gen, + ) + X_dense = X_sparse.toarray() + + result = compute_sparse_result(X_sparse, "all", queue) + + for result_option in options_and_tests: + if not daal_check_version((2025, "P", 200)) and result_option in [ + "max", + "sum_squares", + ]: + continue + function, tols = options_and_tests[result_option] + fp32tol, fp64tol = tols + res = getattr(result, result_option + "_") + + gtr = function(X_dense) + + tol = fp32tol if res.dtype == np.float32 else fp64tol + assert_allclose(gtr, res, atol=tol) + + @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues()) @pytest.mark.parametrize("result_option", options_and_tests.keys()) @pytest.mark.parametrize("data_size", [100, 1000]) @@ -238,7 +376,7 @@ def test_1d_input_on_random_data( else: result = basicstat.fit(X_df) - res = getattr(result, result_option) + res = getattr(result, result_option + "_") if weighted: weighted_data = weights * X gtr = function(weighted_data) diff --git a/sklearnex/tests/test_run_to_run_stability.py b/sklearnex/tests/test_run_to_run_stability.py index 6555410968..875063fc1a 100755 --- a/sklearnex/tests/test_run_to_run_stability.py +++ b/sklearnex/tests/test_run_to_run_stability.py @@ -34,6 +34,7 @@ import daal4py as d4p from daal4py.sklearn._utils import daal_check_version from onedal.tests.utils._dataframes_support import _as_numpy, get_dataframes_and_queues +from sklearnex.basic_statistics import BasicStatistics from sklearnex.cluster import DBSCAN, KMeans from sklearnex.decomposition import PCA from sklearnex.metrics import pairwise_distances, roc_auc_score @@ -117,6 +118,12 @@ def _run_test(estimator, method, datasets): _sparse_instances = [SVC()] +if daal_check_version((2025, "P", 200)): # Test for >= 2025.2.0 + _sparse_instances.extend( + [ + BasicStatistics(result_options=["sum", "min"]), + ] + ) if daal_check_version((2024, "P", 700)): # Test for > 2024.7.0 _sparse_instances.extend( [ diff --git a/sklearnex/tests/utils/__init__.py b/sklearnex/tests/utils/__init__.py index db728fe913..feed9f1292 100644 --- a/sklearnex/tests/utils/__init__.py +++ b/sklearnex/tests/utils/__init__.py @@ -26,6 +26,7 @@ call_method, gen_dataset, gen_models_info, + gen_sparse_dataset, sklearn_clone_dict, ) @@ -39,6 +40,7 @@ "call_method", "gen_models_info", "gen_dataset", + "gen_sparse_dataset", "sklearn_clone_dict", "DummyEstimator", ] diff --git a/sklearnex/tests/utils/base.py b/sklearnex/tests/utils/base.py index 33d3804b8f..4ec16c12d7 100755 --- a/sklearnex/tests/utils/base.py +++ b/sklearnex/tests/utils/base.py @@ -339,6 +339,27 @@ def gen_dataset( return output +def gen_sparse_dataset(row_count, column_count, **kwargs): + """Generate sparse dataset for pytest testing. + + Parameters + ---------- + row_count : number of rows in dataset + + column_count: number of columns in dataset + + kwargs: keyword arguments for scipy.sparse.random_array or scipy.sparse.random + + Returns + ------- + scipy.sparse random matrix or array depending on scipy version + """ + if hasattr(sp, "random_array"): + return sp.random_array((row_count, column_count), **kwargs) + else: + return sp.random(row_count, column_count, **kwargs) + + DTYPES = [ np.int8, np.int16,