diff --git a/onedal/common/_backend.py b/onedal/common/_backend.py index 00df438b93..70c6042eed 100644 --- a/onedal/common/_backend.py +++ b/onedal/common/_backend.py @@ -15,6 +15,8 @@ # ============================================================================== import logging +from contextlib import contextmanager +from types import MethodType from typing import Any, Callable, Literal, Optional from onedal import Backend, _default_backend, _spmd_backend @@ -59,6 +61,23 @@ def _get_policy(self, queue: Any, *data: Any) -> Any: return _get_policy +@contextmanager +def DefaultPolicyOverride(instance: Any): + original_method = getattr(instance, "_get_policy", None) + try: + # Inject the new _get_policy method from _default_backend + new_policy_method = inject_policy_manager(_default_backend) + bound_method = MethodType(new_policy_method, instance) + setattr(instance, "_get_policy", bound_method) + yield + finally: + # Restore the original _get_policy method + if original_method is not None: + setattr(instance, "_get_policy", original_method) + else: + delattr(instance, "_get_policy") + + def bind_default_backend(module_name: str, lookup_name: Optional[str] = None): def decorator(method: Callable[..., Any]): # grab the lookup_name from outer scope diff --git a/onedal/spmd/covariance/incremental_covariance.py b/onedal/spmd/covariance/incremental_covariance.py index dc0edcfd8e..7558edfc9c 100644 --- a/onedal/spmd/covariance/incremental_covariance.py +++ b/onedal/spmd/covariance/incremental_covariance.py @@ -15,7 +15,7 @@ # ============================================================================== -from ...common._backend import bind_spmd_backend +from ...common._backend import DefaultPolicyOverride, bind_spmd_backend from ...covariance import ( IncrementalEmpiricalCovariance as base_IncrementalEmpiricalCovariance, ) @@ -27,3 +27,8 @@ def _get_policy(self, queue, *data): ... @bind_spmd_backend("covariance") def finalize_compute(self, policy, params, partial_result): ... + + def partial_fit(self, X, y=None, queue=None): + # partial fit performed by parent backend, therefore default policy required + with DefaultPolicyOverride(self): + return super().partial_fit(X, y, queue) diff --git a/onedal/spmd/decomposition/incremental_pca.py b/onedal/spmd/decomposition/incremental_pca.py index b57be9c9bf..ce269631ed 100644 --- a/onedal/spmd/decomposition/incremental_pca.py +++ b/onedal/spmd/decomposition/incremental_pca.py @@ -14,7 +14,11 @@ # limitations under the License. # ============================================================================== -from ...common._backend import bind_spmd_backend +from ...common._backend import ( + DefaultPolicyOverride, + bind_default_backend, + bind_spmd_backend, +) from ...decomposition import IncrementalPCA as base_IncrementalPCA @@ -26,8 +30,13 @@ class IncrementalPCA(base_IncrementalPCA): API is the same as for `onedal.decomposition.IncrementalPCA` """ - @bind_spmd_backend("decomposition") + @bind_spmd_backend("decomposition", lookup_name="_get_policy") def _get_policy(self, queue, *data): ... @bind_spmd_backend("decomposition.dim_reduction") def finalize_train(self, policy, params, partial_result): ... + + def partial_fit(self, X, queue): + # partial fit performed by parent backend, therefore default policy required + with DefaultPolicyOverride(self): + return super().partial_fit(X, queue) diff --git a/onedal/spmd/linear_model/incremental_linear_model.py b/onedal/spmd/linear_model/incremental_linear_model.py index 640bd4b619..6470173a9c 100644 --- a/onedal/spmd/linear_model/incremental_linear_model.py +++ b/onedal/spmd/linear_model/incremental_linear_model.py @@ -15,7 +15,7 @@ # ============================================================================== -from ...common._backend import bind_spmd_backend +from ...common._backend import DefaultPolicyOverride, bind_spmd_backend from ...linear_model import ( IncrementalLinearRegression as base_IncrementalLinearRegression, ) @@ -33,3 +33,8 @@ def _get_policy(self): ... @bind_spmd_backend("linear_model.regression") def finalize_train(self, *args, **kwargs): ... + + def partial_fit(self, X, y, queue): + # partial fit performed by parent backend, therefore default policy required + with DefaultPolicyOverride(self): + return super().partial_fit(X, y, queue)