From 36cafcd5ad81168d19b528ccafb7c3783008b488 Mon Sep 17 00:00:00 2001 From: Samir Nasibli Date: Tue, 15 Oct 2024 05:12:32 -0700 Subject: [PATCH 1/5] FIX: update functional support fallback logic a little bit host numpy copies of the inputs data will be used for the fallback cases, since stock scikit-learn doesn't support DPCTL usm_ndarray and DPNP ndarray --- onedal/_device_offload.py | 6 +++--- sklearnex/_device_offload.py | 7 ++++--- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/onedal/_device_offload.py b/onedal/_device_offload.py index 1eea282143..43c3da0b9a 100644 --- a/onedal/_device_offload.py +++ b/onedal/_device_offload.py @@ -140,7 +140,7 @@ def _transfer_to_host(queue, *data): raise RuntimeError("Input data shall be located on single target device") host_data.append(item) - return queue, host_data + return has_usm_data, queue, host_data def _get_global_queue(): @@ -157,8 +157,8 @@ def _get_global_queue(): def _get_host_inputs(*args, **kwargs): q = _get_global_queue() - q, hostargs = _transfer_to_host(q, *args) - q, hostvalues = _transfer_to_host(q, *kwargs.values()) + _, q, hostargs = _transfer_to_host(q, *args) + _, q, hostvalues = _transfer_to_host(q, *kwargs.values()) hostkwargs = dict(zip(kwargs.keys(), hostvalues)) return q, hostargs, hostkwargs diff --git a/sklearnex/_device_offload.py b/sklearnex/_device_offload.py index 06f97aa679..fd65be9c27 100644 --- a/sklearnex/_device_offload.py +++ b/sklearnex/_device_offload.py @@ -63,12 +63,12 @@ def _get_backend(obj, queue, method_name, *data): def dispatch(obj, method_name, branches, *args, **kwargs): q = _get_global_queue() - q, hostargs = _transfer_to_host(q, *args) - q, hostvalues = _transfer_to_host(q, *kwargs.values()) + has_usm_data_for_args, q, hostargs = _transfer_to_host(q, *args) + has_usm_data_for_kwargs, q, hostvalues = _transfer_to_host(q, *kwargs.values()) hostkwargs = dict(zip(kwargs.keys(), hostvalues)) backend, q, patching_status = _get_backend(obj, q, method_name, *hostargs) - + has_usm_data = has_usm_data_for_args or has_usm_data_for_kwargs if backend == "onedal": patching_status.write_log(queue=q) return branches[backend](obj, *hostargs, **hostkwargs, queue=q) @@ -78,6 +78,7 @@ def dispatch(obj, method_name, branches, *args, **kwargs): and get_config()["array_api_dispatch"] and "array_api_support" in obj._get_tags() and obj._get_tags()["array_api_support"] + and not has_usm_data ): # If `array_api_dispatch` enabled and array api is supported for the stock scikit-learn, # then raw inputs are used for the fallback. From b6012c11b4d63532388b964f7a21c901842e555d Mon Sep 17 00:00:00 2001 From: Samir Nasibli Date: Wed, 16 Oct 2024 15:09:23 -0700 Subject: [PATCH 2/5] Added a clarifying comment --- sklearnex/_device_offload.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/sklearnex/_device_offload.py b/sklearnex/_device_offload.py index fd65be9c27..a3ece9b498 100644 --- a/sklearnex/_device_offload.py +++ b/sklearnex/_device_offload.py @@ -80,6 +80,12 @@ def dispatch(obj, method_name, branches, *args, **kwargs): and obj._get_tags()["array_api_support"] and not has_usm_data ): + # USM ndarrays are also excluded for the fallback Array API. Currently, DPNP.ndarray is + # not compliant with the Array API standard, and DPCTL usm_ndarray Array API is compliant, + # except for the linalg module. There is no guarantee that stock scikit-learn will + # work with such input data. The condition will be updated after DPNP.ndarray and + # DPCTL usm_ndarray enabling for conformance testing and these arrays supportance + # of the fallback cases. # If `array_api_dispatch` enabled and array api is supported for the stock scikit-learn, # then raw inputs are used for the fallback. patching_status.write_log() From 4c742f926925e258f2f0c1b495f2d97cc1738cf0 Mon Sep 17 00:00:00 2001 From: Samir Nasibli Date: Fri, 18 Oct 2024 06:52:31 -0700 Subject: [PATCH 3/5] Enhanced patch message for data transfer --- sklearnex/_device_offload.py | 6 ++++-- sklearnex/_utils.py | 21 +++++++++++++++------ 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/sklearnex/_device_offload.py b/sklearnex/_device_offload.py index a3ece9b498..8b52b3c395 100644 --- a/sklearnex/_device_offload.py +++ b/sklearnex/_device_offload.py @@ -70,7 +70,9 @@ def dispatch(obj, method_name, branches, *args, **kwargs): backend, q, patching_status = _get_backend(obj, q, method_name, *hostargs) has_usm_data = has_usm_data_for_args or has_usm_data_for_kwargs if backend == "onedal": - patching_status.write_log(queue=q) + # Host args only used before onedal backend call. + # Device will be offloaded when onedal backend will be called. + patching_status.write_log(queue=q, transferred_to_host=False) return branches[backend](obj, *hostargs, **hostkwargs, queue=q) if backend == "sklearn": if ( @@ -88,7 +90,7 @@ def dispatch(obj, method_name, branches, *args, **kwargs): # of the fallback cases. # If `array_api_dispatch` enabled and array api is supported for the stock scikit-learn, # then raw inputs are used for the fallback. - patching_status.write_log() + patching_status.write_log(transferred_to_host=False) return branches[backend](obj, *args, **kwargs) else: patching_status.write_log() diff --git a/sklearnex/_utils.py b/sklearnex/_utils.py index d4c4ce2ebe..c302b958e5 100755 --- a/sklearnex/_utils.py +++ b/sklearnex/_utils.py @@ -29,10 +29,10 @@ class PatchingConditionsChain(daal4py_PatchingConditionsChain): def get_status(self): return self.patching_is_enabled - def write_log(self, queue=None): + def write_log(self, queue=None, transferred_to_host=True): if self.patching_is_enabled: self.logger.info( - f"{self.scope_name}: {get_patch_message('onedal', queue=queue)}" + f"{self.scope_name}: {get_patch_message('onedal', queue=queue, transferred_to_host=transferred_to_host)}" ) else: self.logger.debug( @@ -43,7 +43,9 @@ def write_log(self, queue=None): self.logger.debug( f"{self.scope_name}: patching failed with cause - {message}" ) - self.logger.info(f"{self.scope_name}: {get_patch_message('sklearn')}") + self.logger.info( + f"{self.scope_name}: {get_patch_message('sklearn', transferred_to_host=transferred_to_host)}" + ) def set_sklearn_ex_verbose(): @@ -66,7 +68,7 @@ def set_sklearn_ex_verbose(): ) -def get_patch_message(s, queue=None): +def get_patch_message(s, queue=None, transferred_to_host=True): if s == "onedal": message = "running accelerated version on " if queue is not None: @@ -87,11 +89,18 @@ def get_patch_message(s, queue=None): f"Invalid input - expected one of 'onedal','sklearn'," f" 'sklearn_after_onedal', got {s}" ) + if transferred_to_host: + message += ( + ". All input data transferred to host for further backend computations." + ) return message -def get_sklearnex_version(rule): - return daal_check_version(rule) +def get_usm_data_message(is_copied=False): + message = "" + if is_copied: + message = ". All USM inputs are being copied to HOST for further computations." + return message def register_hyperparameters(hyperparameters_map): From e48a2db9947d229bf049fe1e14f473683ef50ac8 Mon Sep 17 00:00:00 2001 From: Samir Nasibli Date: Fri, 18 Oct 2024 07:27:02 -0700 Subject: [PATCH 4/5] removed debug code --- sklearnex/_utils.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/sklearnex/_utils.py b/sklearnex/_utils.py index c302b958e5..f0bc485250 100755 --- a/sklearnex/_utils.py +++ b/sklearnex/_utils.py @@ -96,13 +96,6 @@ def get_patch_message(s, queue=None, transferred_to_host=True): return message -def get_usm_data_message(is_copied=False): - message = "" - if is_copied: - message = ". All USM inputs are being copied to HOST for further computations." - return message - - def register_hyperparameters(hyperparameters_map): """Decorator for hyperparameters support in estimator class. Adds `get_hyperparameters` method to class. From cff9056f12ca437798f72d3ae8a1c5d4d9311a16 Mon Sep 17 00:00:00 2001 From: Samir Nasibli Date: Fri, 18 Oct 2024 08:19:13 -0700 Subject: [PATCH 5/5] return back accidently removed primitive --- sklearnex/_utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sklearnex/_utils.py b/sklearnex/_utils.py index f0bc485250..c318a858db 100755 --- a/sklearnex/_utils.py +++ b/sklearnex/_utils.py @@ -96,6 +96,10 @@ def get_patch_message(s, queue=None, transferred_to_host=True): return message +def get_sklearnex_version(rule): + return daal_check_version(rule) + + def register_hyperparameters(hyperparameters_map): """Decorator for hyperparameters support in estimator class. Adds `get_hyperparameters` method to class.