diff --git a/onedal/_device_offload.py b/onedal/_device_offload.py index 1eea282143..43c3da0b9a 100644 --- a/onedal/_device_offload.py +++ b/onedal/_device_offload.py @@ -140,7 +140,7 @@ def _transfer_to_host(queue, *data): raise RuntimeError("Input data shall be located on single target device") host_data.append(item) - return queue, host_data + return has_usm_data, queue, host_data def _get_global_queue(): @@ -157,8 +157,8 @@ def _get_global_queue(): def _get_host_inputs(*args, **kwargs): q = _get_global_queue() - q, hostargs = _transfer_to_host(q, *args) - q, hostvalues = _transfer_to_host(q, *kwargs.values()) + _, q, hostargs = _transfer_to_host(q, *args) + _, q, hostvalues = _transfer_to_host(q, *kwargs.values()) hostkwargs = dict(zip(kwargs.keys(), hostvalues)) return q, hostargs, hostkwargs diff --git a/sklearnex/_device_offload.py b/sklearnex/_device_offload.py index 06f97aa679..8b52b3c395 100644 --- a/sklearnex/_device_offload.py +++ b/sklearnex/_device_offload.py @@ -63,14 +63,16 @@ def _get_backend(obj, queue, method_name, *data): def dispatch(obj, method_name, branches, *args, **kwargs): q = _get_global_queue() - q, hostargs = _transfer_to_host(q, *args) - q, hostvalues = _transfer_to_host(q, *kwargs.values()) + has_usm_data_for_args, q, hostargs = _transfer_to_host(q, *args) + has_usm_data_for_kwargs, q, hostvalues = _transfer_to_host(q, *kwargs.values()) hostkwargs = dict(zip(kwargs.keys(), hostvalues)) backend, q, patching_status = _get_backend(obj, q, method_name, *hostargs) - + has_usm_data = has_usm_data_for_args or has_usm_data_for_kwargs if backend == "onedal": - patching_status.write_log(queue=q) + # Host args only used before onedal backend call. + # Device will be offloaded when onedal backend will be called. + patching_status.write_log(queue=q, transferred_to_host=False) return branches[backend](obj, *hostargs, **hostkwargs, queue=q) if backend == "sklearn": if ( @@ -78,10 +80,17 @@ def dispatch(obj, method_name, branches, *args, **kwargs): and get_config()["array_api_dispatch"] and "array_api_support" in obj._get_tags() and obj._get_tags()["array_api_support"] + and not has_usm_data ): + # USM ndarrays are also excluded for the fallback Array API. Currently, DPNP.ndarray is + # not compliant with the Array API standard, and DPCTL usm_ndarray Array API is compliant, + # except for the linalg module. There is no guarantee that stock scikit-learn will + # work with such input data. The condition will be updated after DPNP.ndarray and + # DPCTL usm_ndarray enabling for conformance testing and these arrays supportance + # of the fallback cases. # If `array_api_dispatch` enabled and array api is supported for the stock scikit-learn, # then raw inputs are used for the fallback. - patching_status.write_log() + patching_status.write_log(transferred_to_host=False) return branches[backend](obj, *args, **kwargs) else: patching_status.write_log() diff --git a/sklearnex/_utils.py b/sklearnex/_utils.py index d4c4ce2ebe..c318a858db 100755 --- a/sklearnex/_utils.py +++ b/sklearnex/_utils.py @@ -29,10 +29,10 @@ class PatchingConditionsChain(daal4py_PatchingConditionsChain): def get_status(self): return self.patching_is_enabled - def write_log(self, queue=None): + def write_log(self, queue=None, transferred_to_host=True): if self.patching_is_enabled: self.logger.info( - f"{self.scope_name}: {get_patch_message('onedal', queue=queue)}" + f"{self.scope_name}: {get_patch_message('onedal', queue=queue, transferred_to_host=transferred_to_host)}" ) else: self.logger.debug( @@ -43,7 +43,9 @@ def write_log(self, queue=None): self.logger.debug( f"{self.scope_name}: patching failed with cause - {message}" ) - self.logger.info(f"{self.scope_name}: {get_patch_message('sklearn')}") + self.logger.info( + f"{self.scope_name}: {get_patch_message('sklearn', transferred_to_host=transferred_to_host)}" + ) def set_sklearn_ex_verbose(): @@ -66,7 +68,7 @@ def set_sklearn_ex_verbose(): ) -def get_patch_message(s, queue=None): +def get_patch_message(s, queue=None, transferred_to_host=True): if s == "onedal": message = "running accelerated version on " if queue is not None: @@ -87,6 +89,10 @@ def get_patch_message(s, queue=None): f"Invalid input - expected one of 'onedal','sklearn'," f" 'sklearn_after_onedal', got {s}" ) + if transferred_to_host: + message += ( + ". All input data transferred to host for further backend computations." + ) return message