Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX: update functional support fallback logic for a DPNP/DPCTL ndarray inputs #2113

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions onedal/_device_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def _transfer_to_host(queue, *data):
raise RuntimeError("Input data shall be located on single target device")

host_data.append(item)
return queue, host_data
return has_usm_data, queue, host_data


def _get_global_queue():
Expand All @@ -157,8 +157,8 @@ def _get_global_queue():

def _get_host_inputs(*args, **kwargs):
q = _get_global_queue()
q, hostargs = _transfer_to_host(q, *args)
q, hostvalues = _transfer_to_host(q, *kwargs.values())
_, q, hostargs = _transfer_to_host(q, *args)
_, q, hostvalues = _transfer_to_host(q, *kwargs.values())
hostkwargs = dict(zip(kwargs.keys(), hostvalues))
return q, hostargs, hostkwargs

Expand Down
19 changes: 14 additions & 5 deletions sklearnex/_device_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,25 +63,34 @@ def _get_backend(obj, queue, method_name, *data):

def dispatch(obj, method_name, branches, *args, **kwargs):
q = _get_global_queue()
q, hostargs = _transfer_to_host(q, *args)
q, hostvalues = _transfer_to_host(q, *kwargs.values())
has_usm_data_for_args, q, hostargs = _transfer_to_host(q, *args)
has_usm_data_for_kwargs, q, hostvalues = _transfer_to_host(q, *kwargs.values())
hostkwargs = dict(zip(kwargs.keys(), hostvalues))

backend, q, patching_status = _get_backend(obj, q, method_name, *hostargs)

has_usm_data = has_usm_data_for_args or has_usm_data_for_kwargs
if backend == "onedal":
patching_status.write_log(queue=q)
# Host args only used before onedal backend call.
# Device will be offloaded when onedal backend will be called.
patching_status.write_log(queue=q, transferred_to_host=False)
return branches[backend](obj, *hostargs, **hostkwargs, queue=q)
if backend == "sklearn":
if (
"array_api_dispatch" in get_config()
and get_config()["array_api_dispatch"]
and "array_api_support" in obj._get_tags()
and obj._get_tags()["array_api_support"]
and not has_usm_data
):
# USM ndarrays are also excluded for the fallback Array API. Currently, DPNP.ndarray is
# not compliant with the Array API standard, and DPCTL usm_ndarray Array API is compliant,
# except for the linalg module. There is no guarantee that stock scikit-learn will
# work with such input data. The condition will be updated after DPNP.ndarray and
# DPCTL usm_ndarray enabling for conformance testing and these arrays supportance
# of the fallback cases.
# If `array_api_dispatch` enabled and array api is supported for the stock scikit-learn,
# then raw inputs are used for the fallback.
patching_status.write_log()
patching_status.write_log(transferred_to_host=False)
return branches[backend](obj, *args, **kwargs)
else:
patching_status.write_log()
Expand Down
14 changes: 10 additions & 4 deletions sklearnex/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ class PatchingConditionsChain(daal4py_PatchingConditionsChain):
def get_status(self):
return self.patching_is_enabled

def write_log(self, queue=None):
def write_log(self, queue=None, transferred_to_host=True):
if self.patching_is_enabled:
self.logger.info(
f"{self.scope_name}: {get_patch_message('onedal', queue=queue)}"
f"{self.scope_name}: {get_patch_message('onedal', queue=queue, transferred_to_host=transferred_to_host)}"
)
else:
self.logger.debug(
Expand All @@ -43,7 +43,9 @@ def write_log(self, queue=None):
self.logger.debug(
f"{self.scope_name}: patching failed with cause - {message}"
)
self.logger.info(f"{self.scope_name}: {get_patch_message('sklearn')}")
self.logger.info(
f"{self.scope_name}: {get_patch_message('sklearn', transferred_to_host=transferred_to_host)}"
)


def set_sklearn_ex_verbose():
Expand All @@ -66,7 +68,7 @@ def set_sklearn_ex_verbose():
)


def get_patch_message(s, queue=None):
def get_patch_message(s, queue=None, transferred_to_host=True):
if s == "onedal":
message = "running accelerated version on "
if queue is not None:
Expand All @@ -87,6 +89,10 @@ def get_patch_message(s, queue=None):
f"Invalid input - expected one of 'onedal','sklearn',"
f" 'sklearn_after_onedal', got {s}"
)
if transferred_to_host:
message += (
". All input data transferred to host for further backend computations."
)
return message


Expand Down