Skip to content

Commit

Permalink
feat(SelectionsMixin): allow selection from index map values
Browse files Browse the repository at this point in the history
  • Loading branch information
ljgray committed Oct 23, 2024
1 parent e0fc262 commit 7bdd270
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 16 deletions.
27 changes: 20 additions & 7 deletions draco/analysis/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1354,22 +1354,35 @@ def process(self, data: containers.ContainerBase) -> containers.ContainerBase:
Container of same type as the input with specific axis selections.
Any datasets not included in the selections will not be initialized.
"""
# Re-format selections to only use axis name
for ax_sel in list(self._sel):
ax = ax_sel.replace("_sel", "")
self._sel[ax] = self._sel.pop(ax_sel)
sel = {}

# Parse axes with selections and reformat to use only
# the axis name
for k in self.selections:
*axis, type_ = k.split("_")
axis_name = "_".join(axis)

ax_sel = self._sel.get(f"{axis_name}_sel")

if type_ == "map":
# Use index map to get the correct axis indices
imap = list(data.index_map[axis_name])
ax_sel = [imap.index(x) for x in ax_sel]

if ax_sel is not None:
sel[axis_name] = ax_sel

# Figure out the axes for the new container and
# Apply the downselections to each axis index_map
output_axes = {
ax: mpiarray._apply_sel(data.index_map[ax], sel, 0)
for ax, sel in self._sel.items()
ax: mpiarray._apply_sel(data.index_map[ax], ax_sel, 0)
for ax, ax_sel in sel.items()
}
# Create the output container without initializing any datasets.
out = data.__class__(
axes_from=data, attrs_from=data, skip_datasets=True, **output_axes
)
containers.copy_datasets_filter(data, out, selection=self._sel)
containers.copy_datasets_filter(data, out, selection=sel)

return out

Expand Down
36 changes: 27 additions & 9 deletions draco/core/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import os.path
import shutil
import subprocess
from functools import partial
from typing import ClassVar, Optional, Union

import numpy as np
Expand Down Expand Up @@ -351,6 +352,11 @@ class SelectionsMixin:
selections : dict, optional
A dictionary of axis selections. See below for details.
allow_index_map : bool, optional
If true, selections can be made based on an index_map dataset.
This cannot be implemented when reading from disk. See below for
details. Default is False.
Selections
----------
Selections can be given to limit the data read to specified subsets. They can be
Expand All @@ -359,11 +365,13 @@ class SelectionsMixin:
Selections can be given as a slice with an `<axis name>_range` key with either
`[start, stop]` or `[start, stop, step]` as the value. Alternatively a list of
explicit indices to extract can be given with the `<axis name>_index` key, and
the value is a list of the indices. If both `<axis name>_range` and `<axis
name>_index` keys are given the former will take precedence, but you should
clearly avoid doing this.
the value is a list of the indices. Finally, selection based on an `index_map`
can be given with specific `index_map` entries with the `<axis name>_map` key,
which will be converted to axis indices. `<axis name>_range` will take precedence
over `<axis name>_index`, which will in turn take precedence over `<axis_name>_map`,
but you should clearly avoid doing this.
Additionally index based selections currently don't work for distributed reads.
Additionally, index-based selections currently don't work for distributed reads.
Here's an example in the YAML format that the pipeline uses:
Expand All @@ -373,9 +381,11 @@ class SelectionsMixin:
freq_range: [256, 512, 4] # A strided slice
stack_index: [1, 2, 4, 9, 16, 25, 36, 49, 64] # A sparse selection
stack_range: [1, 14] # Will override the selection above
pol_map: ["XX", "YY"] # Select the indices corresponding to these entries
"""

selections = config.Property(proptype=dict, default=None)
allow_index_map = config.Property(proptype=bool, default=False)

def setup(self):
"""Resolve the selections."""
Expand All @@ -386,7 +396,14 @@ def _resolve_sel(self):

sel = {}

sel_parsers = {"range": self._parse_range, "index": self._parse_index}
sel_parsers = {
"range": self._parse_range,
"index": partial(self._parse_index, type_=int),
"map": self._parse_index,
}

if not self.allow_index_map:
del sel_parsers["map"]

# To enforce the precedence of range vs index selections, we rely on the fact
# that a sort will place the axis_range keys after axis_index keys
Expand All @@ -398,7 +415,8 @@ def _resolve_sel(self):

if type_ not in sel_parsers:
raise ValueError(
f'Unsupported selection type "{type_}", or invalid key "{k}"'
f'Unsupported selection type "{type_}", or invalid key "{k}". '
"Note that map-type selections require `allow_index_map=True`."
)

sel[f"{axis_name}_sel"] = sel_parsers[type_](self.selections[k])
Expand All @@ -419,15 +437,15 @@ def _parse_range(self, x):

return slice(*x)

def _parse_index(self, x):
def _parse_index(self, x, type_=object):
# Parse and validate an index type selection

if not isinstance(x, (list, tuple)) or len(x) == 0:
raise ValueError(f"Index spec must be a non-empty list or tuple. Got {x}.")

for v in x:
if not isinstance(v, int):
raise ValueError(f"All elements of index spec must be ints. Got {x}")
if not isinstance(v, type_):
raise ValueError(f"All elements of index spec must be {type_}. Got {x}")

return list(x)

Expand Down

0 comments on commit 7bdd270

Please sign in to comment.