diff --git a/draco/analysis/delay.py b/draco/analysis/delay.py index b6dcfb8f6..f6acb021c 100644 --- a/draco/analysis/delay.py +++ b/draco/analysis/delay.py @@ -398,9 +398,7 @@ def process(self, ss): out_cont : `containers.DelayTransform` or `containers.DelaySpectrum` Output delay spectrum or delay power spectrum. """ - ss.redistribute("freq") - - delays, channel_ind = self._calculate_delays(ss.freq[:]) + delays, channel_ind = self._calculate_delays(ss) # Get views of data and weights appropriate for the type of processing we're # doing. @@ -409,10 +407,6 @@ def process(self, ss): # Create the right output container out_cont = self._create_output(ss, delays, coord_axes) - # Save the frequency axis of the input data as an attribute in the output - # container - out_cont.attrs["freq"] = ss.freq - # Evaluate frequency->delay transform. (self._evaluate take the empty output # container, fills it, and returns it) return self._evaluate(data_view, weight_view, out_cont, delays, channel_ind) @@ -492,13 +486,15 @@ def _create_output( """ raise NotImplementedError() - def _calculate_delays(self, freq: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + def _calculate_delays( + self, ss: Union[FreqContainer, list[FreqContainer]] + ) -> tuple[np.ndarray, np.ndarray]: """Calculate the grid of delays. Parameters ---------- - freq - The frequencies in the data. + ss + A FreqContainer to determine the delays from. Returns ------- @@ -507,6 +503,13 @@ def _calculate_delays(self, freq: np.ndarray) -> tuple[np.ndarray, np.ndarray]: channel_ind The effective channel indices of the data. """ + if isinstance(ss, FreqContainer): + freq = ss.freq + elif len(ss) > 0: + freq = ss[0].freq + else: + raise TypeError("Could not find a frequency axis in the input.") + freq_zero = freq[0] if self.freq_zero is None else self.freq_zero freq_spacing = self.freq_spacing @@ -653,6 +656,10 @@ def _create_output( if self.save_samples: delay_spec.add_dataset("spectrum_samples") + # Save the frequency axis of the input data as an attribute in the output + # container + delay_spec.attrs["freq"] = ss.freq + return delay_spec def _evaluate(self, data_view, weight_view, out_cont, delays, channel_ind): @@ -761,6 +768,8 @@ def _process_data(self, ss): out_cont : `containers.DelayTransform` or `containers.DelaySpectrum` Container for output delay spectrum or power spectrum. """ + ss.redistribute("freq") + if self.dataset is not None: if self.dataset not in ss.datasets: raise ValueError( @@ -819,6 +828,8 @@ def _process_data(self, ss): out_cont : `containers.DelayTransform` or `containers.DelaySpectrum` Container for output delay spectrum or power spectrum. """ + ss.redistribute("freq") + tel = self.telescope # Construct the Stokes I vis, and transpose from [baseline, freq, ra] to @@ -878,6 +889,10 @@ def _create_output( delay_spec.create_index_map(ax, ss.index_map[ax]) delay_spec.attrs["baseline_axes"] = coord_axes + # Save the frequency axis of the input data as an attribute in the output + # container + delay_spec.attrs["freq"] = ss.freq + return delay_spec def _evaluate(self, data_view, weight_view, out_cont, delays, channel_ind):