Skip to content

Commit

Permalink
refactor(delay): use flatten_axes routine
Browse files Browse the repository at this point in the history
  • Loading branch information
jrs65 committed Dec 19, 2023
1 parent 365e22d commit c8bffd8
Showing 1 changed file with 7 additions and 30 deletions.
37 changes: 7 additions & 30 deletions draco/analysis/delay.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,41 +642,19 @@ def _process_data(self, ss):
)

# Find the relevant axis positions
data_axes = ss.datasets[self.dataset].attrs["axis"]
freq_axis_pos = list(data_axes).index("freq")
average_axis_pos = list(data_axes).index(self.average_axis)

# Create a view of the dataset with the relevant axes at the back,
# and all other axes compressed. End result is packed as
# [baseline_axis, average_axis, freq_axis].
data_view = np.moveaxis(
ss.datasets[self.dataset][:].local_array,
[average_axis_pos, freq_axis_pos],
[-2, -1],
data_view, bl_axes = flatten_axes(
ss.datasets[self.dataset], [self.average_axis, "freq"]
)
data_view = data_view.reshape(-1, data_view.shape[-2], data_view.shape[-1])
data_view = mpiarray.MPIArray.wrap(data_view, axis=2, comm=ss.comm)
nbase = int(np.prod(data_view.shape[:-2]))
data_view = data_view.redistribute(axis=0)

# ... do the same for the weights, but we also need to make the weights full
# size
weight_full = np.zeros(
ss.datasets[self.dataset][:].shape, dtype=ss.weight.dtype
weight_view, _ = flatten_axes(
ss.weight,
[self.average_axis, "freq"],
match_dset=ss.datasets[self.dataset],
)
weight_full[:] = match_axes(ss.datasets[self.dataset], ss.weight)
weight_view = np.moveaxis(
weight_full, [average_axis_pos, freq_axis_pos], [-2, -1]
)
weight_view = weight_view.reshape(
-1, weight_view.shape[-2], weight_view.shape[-1]
)
weight_view = mpiarray.MPIArray.wrap(weight_view, axis=2, comm=ss.comm)
weight_view = weight_view.redistribute(axis=0)

# Use the "baselines" axis to generically represent all the other axes

# Initialise the spectrum container
nbase = data_view.global_shape[0]
if self.output_power_spectrum:
delay_spec = containers.DelaySpectrum(
baseline=nbase,
Expand All @@ -693,7 +671,6 @@ def _process_data(self, ss):
)
delay_spec.redistribute("baseline")
delay_spec.spectrum[:] = 0.0
bl_axes = [da for da in data_axes if da not in [self.average_axis, "freq"]]

# Copy the index maps for all the flattened axes into the output container, and
# write out their order into an attribute so we can reconstruct this easily
Expand Down

0 comments on commit c8bffd8

Please sign in to comment.