Skip to content

Commit

Permalink
refactor(delay): change internal API of calculate_delay
Browse files Browse the repository at this point in the history
This is in preparation for the cross power spectrum code.
  • Loading branch information
jrs65 committed Jan 11, 2024
1 parent 12518e8 commit 140ada9
Showing 1 changed file with 25 additions and 10 deletions.
35 changes: 25 additions & 10 deletions draco/analysis/delay.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,9 +398,7 @@ def process(self, ss):
out_cont : `containers.DelayTransform` or `containers.DelaySpectrum`
Output delay spectrum or delay power spectrum.
"""
ss.redistribute("freq")

delays, channel_ind = self._calculate_delays(ss.freq[:])
delays, channel_ind = self._calculate_delays(ss)

# Get views of data and weights appropriate for the type of processing we're
# doing.
Expand All @@ -409,10 +407,6 @@ def process(self, ss):
# Create the right output container
out_cont = self._create_output(ss, delays, coord_axes)

# Save the frequency axis of the input data as an attribute in the output
# container
out_cont.attrs["freq"] = ss.freq

# Evaluate frequency->delay transform. (self._evaluate take the empty output
# container, fills it, and returns it)
return self._evaluate(data_view, weight_view, out_cont, delays, channel_ind)
Expand Down Expand Up @@ -492,13 +486,15 @@ def _create_output(
"""
raise NotImplementedError()

def _calculate_delays(self, freq: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
def _calculate_delays(
self, ss: Union[FreqContainer, list[FreqContainer]]
) -> tuple[np.ndarray, np.ndarray]:
"""Calculate the grid of delays.
Parameters
----------
freq
The frequencies in the data.
ss
A FreqContainer to determine the delays from.
Returns
-------
Expand All @@ -507,6 +503,13 @@ def _calculate_delays(self, freq: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
channel_ind
The effective channel indices of the data.
"""
if isinstance(ss, FreqContainer):
freq = ss.freq
elif len(ss) > 0:
freq = ss[0].freq
else:
raise TypeError("Could not find a frequency axis in the input.")

freq_zero = freq[0] if self.freq_zero is None else self.freq_zero

freq_spacing = self.freq_spacing
Expand Down Expand Up @@ -653,6 +656,10 @@ def _create_output(
if self.save_samples:
delay_spec.add_dataset("spectrum_samples")

# Save the frequency axis of the input data as an attribute in the output
# container
delay_spec.attrs["freq"] = ss.freq

return delay_spec

def _evaluate(self, data_view, weight_view, out_cont, delays, channel_ind):
Expand Down Expand Up @@ -761,6 +768,8 @@ def _process_data(self, ss):
out_cont : `containers.DelayTransform` or `containers.DelaySpectrum`
Container for output delay spectrum or power spectrum.
"""
ss.redistribute("freq")

if self.dataset is not None:
if self.dataset not in ss.datasets:
raise ValueError(
Expand Down Expand Up @@ -819,6 +828,8 @@ def _process_data(self, ss):
out_cont : `containers.DelayTransform` or `containers.DelaySpectrum`
Container for output delay spectrum or power spectrum.
"""
ss.redistribute("freq")

tel = self.telescope

# Construct the Stokes I vis, and transpose from [baseline, freq, ra] to
Expand Down Expand Up @@ -878,6 +889,10 @@ def _create_output(
delay_spec.create_index_map(ax, ss.index_map[ax])
delay_spec.attrs["baseline_axes"] = coord_axes

# Save the frequency axis of the input data as an attribute in the output
# container
delay_spec.attrs["freq"] = ss.freq

return delay_spec

def _evaluate(self, data_view, weight_view, out_cont, delays, channel_ind):
Expand Down

0 comments on commit 140ada9

Please sign in to comment.