diff --git a/setup.cfg b/setup.cfg index e374ca8b..ee08407e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -39,7 +39,7 @@ test = scipy awkward >=2.2.2;python_version>"3.7" awkward <2;python_version<="3.7" - dask-awkward;python_version>"3.7" + dask-awkward >=2024.1.1;python_version>"3.7" dev = pytest >=4.6 pre-commit diff --git a/src/correctionlib/highlevel.py b/src/correctionlib/highlevel.py index 52163551..7cc1e13c 100644 --- a/src/correctionlib/highlevel.py +++ b/src/correctionlib/highlevel.py @@ -11,7 +11,8 @@ import correctionlib._core import correctionlib.version -_version_two = version.parse("2") +_min_version_ak = version.parse("2.0.0") +_min_version_dak = version.parse("2024.1.1") def open_auto(filename: str) -> str: @@ -58,9 +59,9 @@ def _call_as_numpy( ) -> Any: import awkward - if version.parse(awkward.__version__) < _version_two: + if version.parse(awkward.__version__) < _min_version_ak: raise RuntimeError( - f"""imported awkward is version {awkward.__version__} < 2.0.0 + f"""imported awkward is version {awkward.__version__} < {str(_min_version_ak)} If you cannot upgrade, try doing: ak.flatten(arrays) -> result = correction(arrays) -> ak.unflatten(result, counts) """ ) @@ -130,6 +131,49 @@ def _wrap_awkward( return awkward.transform(tocall, *array_args) +def _call_dask_correction( + correction: Any, + *args: Union["numpy.ndarray[Any, Any]", str, int, float], +): + return _wrap_awkward(correction._base.evalv, *args) + + +def _wrap_dask_awkward( + correction: Any, + *args: Union["numpy.ndarray[Any, Any]", str, int, float], +) -> Any: + import dask.delayed + import dask_awkward + + if version.parse(dask_awkward.__version__) < _min_version_dak: + raise RuntimeError( + f"""imported dask_awkward is version {dask_awkward.__version__} < {str(_min_version_dak)} + This version of dask_awkward includes several useful bugfixes and functionality extensions. + Please upgrade dask_awkward. + """ + ) + + if not hasattr(correction, "_delayed_correction"): + setattr( # noqa: B010 + correction, + "_delayed_correction", + dask.delayed(correction), + ) + + correction_meta = _wrap_awkward( + correction._base.evalv, + *(arg._meta if isinstance(arg, dask_awkward.Array) else arg for arg in args), + ) + + return dask_awkward.map_partitions( + _call_dask_correction, + correction._delayed_correction, + *args, + meta=correction_meta, + label=correction._name, + ) + + class Correction: """High-level correction evaluator object @@ -174,12 +218,22 @@ def evaluate( self, *args: Union["numpy.ndarray[Any, Any]", str, int, float] ) -> Union[float, "numpy.ndarray[Any, numpy.dtype[numpy.float64]]"]: # TODO: create a ufunc with numpy.vectorize in constructor? + if any(str(type(arg)).startswith(" Union[float, "numpy.ndarray[Any, numpy.dtype[numpy.float64]]"]: # TODO: create a ufunc with numpy.vectorize in constructor? - vargs = [ - numpy.asarray(arg) for arg in args if not isinstance(arg, (str, int, float)) - ] + if any(str(type(arg)).startswith("