diff --git a/draco/analysis/transform.py b/draco/analysis/transform.py index 71c3ea6e..41a861c1 100644 --- a/draco/analysis/transform.py +++ b/draco/analysis/transform.py @@ -1354,22 +1354,35 @@ def process(self, data: containers.ContainerBase) -> containers.ContainerBase: Container of same type as the input with specific axis selections. Any datasets not included in the selections will not be initialized. """ - # Re-format selections to only use axis name - for ax_sel in list(self._sel): - ax = ax_sel.replace("_sel", "") - self._sel[ax] = self._sel.pop(ax_sel) + sel = {} + + # Parse axes with selections and reformat to use only + # the axis name + for k in self.selections: + *axis, type_ = k.split("_") + axis_name = "_".join(axis) + + ax_sel = self._sel.get(f"{axis_name}_sel") + + if type_ == "map": + # Use index map to get the correct axis indices + imap = list(data.index_map[axis_name]) + ax_sel = [imap.index(x) for x in ax_sel] + + if ax_sel is not None: + sel[axis_name] = ax_sel # Figure out the axes for the new container and # Apply the downselections to each axis index_map output_axes = { - ax: mpiarray._apply_sel(data.index_map[ax], sel, 0) - for ax, sel in self._sel.items() + ax: mpiarray._apply_sel(data.index_map[ax], ax_sel, 0) + for ax, ax_sel in sel.items() } # Create the output container without initializing any datasets. out = data.__class__( axes_from=data, attrs_from=data, skip_datasets=True, **output_axes ) - containers.copy_datasets_filter(data, out, selection=self._sel) + containers.copy_datasets_filter(data, out, selection=sel) return out diff --git a/draco/core/io.py b/draco/core/io.py index aa8792f4..e3ca0b89 100644 --- a/draco/core/io.py +++ b/draco/core/io.py @@ -26,6 +26,7 @@ import os.path import shutil import subprocess +from functools import partial from typing import ClassVar, Optional, Union import numpy as np @@ -351,6 +352,11 @@ class SelectionsMixin: selections : dict, optional A dictionary of axis selections. See below for details. + allow_index_map : bool, optional + If true, selections can be made based on an index_map dataset. + This cannot be implemented when reading from disk. See below for + details. Default is False. + Selections ---------- Selections can be given to limit the data read to specified subsets. They can be @@ -359,11 +365,13 @@ class SelectionsMixin: Selections can be given as a slice with an `_range` key with either `[start, stop]` or `[start, stop, step]` as the value. Alternatively a list of explicit indices to extract can be given with the `_index` key, and - the value is a list of the indices. If both `_range` and `_index` keys are given the former will take precedence, but you should - clearly avoid doing this. + the value is a list of the indices. Finally, selection based on an `index_map` + can be given with specific `index_map` entries with the `_map` key, + which will be converted to axis indices. `_range` will take precedence + over `_index`, which will in turn take precedence over `_map`, + but you should clearly avoid doing this. - Additionally index based selections currently don't work for distributed reads. + Additionally, index-based selections currently don't work for distributed reads. Here's an example in the YAML format that the pipeline uses: @@ -373,9 +381,11 @@ class SelectionsMixin: freq_range: [256, 512, 4] # A strided slice stack_index: [1, 2, 4, 9, 16, 25, 36, 49, 64] # A sparse selection stack_range: [1, 14] # Will override the selection above + pol_map: ["XX", "YY"] # Select the indices corresponding to these entries """ selections = config.Property(proptype=dict, default=None) + allow_index_map = config.Property(proptype=bool, default=False) def setup(self): """Resolve the selections.""" @@ -386,7 +396,14 @@ def _resolve_sel(self): sel = {} - sel_parsers = {"range": self._parse_range, "index": self._parse_index} + sel_parsers = { + "range": self._parse_range, + "index": partial(self._parse_index, type_=int), + "map": self._parse_index, + } + + if not self.allow_index_map: + del sel_parsers["map"] # To enforce the precedence of range vs index selections, we rely on the fact # that a sort will place the axis_range keys after axis_index keys @@ -398,7 +415,8 @@ def _resolve_sel(self): if type_ not in sel_parsers: raise ValueError( - f'Unsupported selection type "{type_}", or invalid key "{k}"' + f'Unsupported selection type "{type_}", or invalid key "{k}". ' + "Note that map-type selections require `allow_index_map=True`." ) sel[f"{axis_name}_sel"] = sel_parsers[type_](self.selections[k]) @@ -419,15 +437,15 @@ def _parse_range(self, x): return slice(*x) - def _parse_index(self, x): + def _parse_index(self, x, type_=object): # Parse and validate an index type selection if not isinstance(x, (list, tuple)) or len(x) == 0: raise ValueError(f"Index spec must be a non-empty list or tuple. Got {x}.") for v in x: - if not isinstance(v, int): - raise ValueError(f"All elements of index spec must be ints. Got {x}") + if not isinstance(v, type_): + raise ValueError(f"All elements of index spec must be {type_}. Got {x}") return list(x)