From 9690ecd1eddaa1afb183189c5ea5f138ff23c573 Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Tue, 16 Apr 2024 19:12:46 +0200 Subject: [PATCH] Optimise broadcast_arrays in katdal import (#326) --- HISTORY.rst | 1 + daskms/experimental/katdal/msv2_facade.py | 37 +++++++++++++++++------ 2 files changed, 29 insertions(+), 9 deletions(-) diff --git a/HISTORY.rst b/HISTORY.rst index a72f9ef4..303c7758 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -4,6 +4,7 @@ History X.Y.Z (YYYY-MM-DD) ------------------ +* Optimise `broadcast_arrays` in katdal import (:pr:`326`) * Change `dask-ms katdal import` to `dask-ms import katdal` (:pr:`325`) * Configure dependabot (:pr:`319`) * Add chunk specification to ``dask-ms katdal import`` (:pr:`318`) diff --git a/daskms/experimental/katdal/msv2_facade.py b/daskms/experimental/katdal/msv2_facade.py index b528bce6..dcf880b9 100644 --- a/daskms/experimental/katdal/msv2_facade.py +++ b/daskms/experimental/katdal/msv2_facade.py @@ -19,6 +19,7 @@ ################################################################################ from functools import partial +from operator import getitem import dask.array as da import numpy as np @@ -126,14 +127,7 @@ def _main_xarray_factory(self, field_id, state_id, scan_index, scan_state, targe flags = DaskLazyIndexer(dataset.flags, (), (rechunk, flag_transpose)) weights = DaskLazyIndexer(dataset.weights, (), (rechunk, weight_transpose)) - vis = DaskLazyIndexer( - dataset.vis, - (), - transforms=( - rechunk, - vis_transpose, - ), - ) + vis = DaskLazyIndexer(dataset.vis, (), (rechunk, vis_transpose)) time = da.from_array(time_mjds[:, None], chunks=(t_chunks, 1)) ant1 = da.from_array(cp_info.ant1_index[None, :], chunks=(1, cpi.shape[0])) @@ -147,7 +141,32 @@ def _main_xarray_factory(self, field_id, state_id, scan_index, scan_state, targe row=self._row_view, ) - time, ant1, ant2 = da.broadcast_arrays(time, ant1, ant2) + # Better graph than da.broadcast_arrays + bcast = da.blockwise( + np.broadcast_arrays, + ("time", "bl"), + time, + ("time", "bl"), + ant1, + ("time", "bl"), + ant2, + ("time", "bl"), + align_arrays=False, + adjust_chunks={"time": time.chunks[0], "bl": ant1.chunks[1]}, + meta=np.empty((0,) * 2, dtype=np.int32), + ) + + time = da.blockwise( + getitem, ("time", "bl"), bcast, ("time", "bl"), 0, None, dtype=time.dtype + ) + + ant1 = da.blockwise( + getitem, ("time", "bl"), bcast, ("time", "bl"), 1, None, dtype=ant1.dtype + ) + + ant2 = da.blockwise( + getitem, ("time", "bl"), bcast, ("time", "bl"), 2, None, dtype=ant2.dtype + ) if self._row_view: primary_dims = ("row",)