From a97625c787318cc50623e1519a409300737e0732 Mon Sep 17 00:00:00 2001 From: Tristan Pinsonneault-Marotte Date: Tue, 30 Aug 2022 16:48:38 -0700 Subject: [PATCH] fix(andata): Fix axis selections to work with CorrData. --- ch_util/andata.py | 49 ++++++++++++++++++++++++++--------------------- 1 file changed, 27 insertions(+), 22 deletions(-) diff --git a/ch_util/andata.py b/ch_util/andata.py index 2cfffce1..236f7b3a 100644 --- a/ch_util/andata.py +++ b/ch_util/andata.py @@ -267,17 +267,28 @@ def time(self): return time @classmethod - def _interpret_and_read(cls, acq_files, start, stop, datasets, out_group, sel=None): - """Read and concatenate the list of files. Optionally specify one axis on which to make - selections with a tuple like `("axis", selection)`""" + def _interpret_and_read(cls, acq_files, start, stop, datasets, out_group, **kwargs): + """Read and concatenate the list of files. Keyword args may contain up to one axis selection.""" # Save a reference to the first file to get index map information for # later. f_first = acq_files[0] + # Handle axis selections + sel = [] + for key in kwargs: + if key[-4:] == "_sel": + sel.append((key[:-4], kwargs[key])) + if len(sel) > 1: + raise ValueError("Cannot handle more than one axis selection.") + elif len(sel) == 0: + sel = None + else: + ax, sel = sel[0] + if sel is None: andata_objs = [cls(d) for d in acq_files] else: - andata_objs = [_read_axis_sel(cls, d, sel[0], sel[1]) for d in acq_files] + andata_objs = [_read_axis_sel(cls, d, ax, sel) for d in acq_files] data = concatenate( andata_objs, @@ -354,11 +365,12 @@ def _from_acq_h5_single( stop=None, datasets=None, out_group=None, - sel=None, **kwargs, ): - """Load and concatenate the list of acquisition files into a local array. Optionally - specify one axis on which to make selections with a tuple like `("axis", selection)`""" + """Load and concatenate the list of acquisition files into a local array. + Axis selections may be supplied as keyword args, but the `BaseData` implementation + only supports up to one axis selection. + """ # Make sure the input is a sequence and that we have at least one file. acq_files = tod.ensure_file_list(acq_files) @@ -380,7 +392,6 @@ def _from_acq_h5_single( stop=stop, datasets=datasets, out_group=out_group, - sel=sel, **kwargs, ) @@ -400,12 +411,12 @@ def _from_acq_h5_distributed( stop, datasets, comm, - sel=None, **kwargs, ): - """Load and concatenate the list of acquisition files into a distributed array. Optionally - specify a selection on the distributed axis with a tuple like `("axis", selection)`. Note - that selections are only allowed along the distributed axis.""" + """Load and concatenate the list of acquisition files into a distributed array. + Axis selections may be supplied as keyword args, but the `BaseData` implementation + only supports up to one axis selection, and it must match the distributed axis. + """ if cls.distributed_axis is None: raise RuntimeError( @@ -431,16 +442,10 @@ def _from_acq_h5_distributed( ndist = len(f["index_map/" + ax][:]) ndist = comm.bcast(ndist, root=0) + # Handle selections along the distributed axis + dist_sel = kwargs.get(ax + "_sel", None) + # Calculate the global distributed selection - if sel is not None: - if sel[0] != ax: - raise ValueError( - "For distributed reads, selections are only allowed on the distributed axis. " - f"The distributed axis is {ax} and a selection was passed for {sel[0]}." - ) - dist_sel = sel[1] - else: - dist_sel = None dist_sel = _ensure_1D_selection(dist_sel) if isinstance(dist_sel, slice): dist_sel = list(range(*dist_sel.indices(ndist))) @@ -451,6 +456,7 @@ def _from_acq_h5_distributed( local_dist_sel = _ensure_1D_selection( _convert_to_slice(dist_sel[d_start:d_end]) ) + kwargs.update({ax + "_sel": local_dist_sel}) # Load just the local part of the data. local_data = cls._from_acq_h5_single( @@ -459,7 +465,6 @@ def _from_acq_h5_distributed( stop=stop, datasets=datasets, out_group=None, - sel=(ax, local_dist_sel), **kwargs, )