diff --git a/draco/analysis/delay.py b/draco/analysis/delay.py index f6acb021c..7946aa925 100644 --- a/draco/analysis/delay.py +++ b/draco/analysis/delay.py @@ -965,6 +965,141 @@ def _evaluate(self, data_view, weight_view, out_cont, delays, channel_ind): DelaySpectrumWienerBase = DelaySpectrumWienerEstimator +class DelayCrossPowerSpectrumEstimator( + DelayGeneralContainerBase, DelayGibbsSamplerBase +): + """A delay cross power spectrum estimator. + + This takes multiple compatible `FreqContainer`s as inputs and will return a + `DelayCrossSpectrum` container with the full pair-wise cross power spectrum. + """ + + def _process_data( + self, sslist: list[FreqContainer] + ) -> tuple[list[mpiarray.MPIArray], list[mpiarray.MPIArray], list[str]]: + if len(sslist) == 0: + raise ValueError("No datasets passed.") + + freq_ref = sslist[0].freq + + data_views = [] + weight_views = [] + coord_axes = None + + for ss in sslist: + ss.redistribute("freq") + + if (ss.freq != freq_ref).all(): + raise ValueError("Input containers must have the same frequencies.") + dv, wv, ca = DelayGeneralContainerBase._process_data(self, ss) + + if coord_axes is not None and not coord_axes == ca: + raise ValueError("Different axes found for the input containers.") + + data_views.append(dv) + weight_views.append(wv) + coord_axes = ca + + return data_views, weight_views, coord_axes + + def _create_output( + self, ss: list[FreqContainer], delays: np.ndarray, coord_axes: list[str] + ) -> ContainerBase: + """Create the output container for the delay power spectrum. + + If `coord_axes` is a list of strings then it is assumed to be a list of the + names of the folded axes. If it's an array then assume it is the actual axis + definition. + """ + ssref = ss[0] + ndata = len(ss) + + # If only one axis is being collapsed, use that as the baseline axis definition, + # otherwise just use integer indices + if len(coord_axes) == 1: + bl = ssref.index_map[coord_axes[0]] + else: + bl = np.prod([len(ssref.index_map[ax]) for ax in coord_axes]) + + # Initialise the spectrum container + delay_spec = containers.DelayCrossSpectrum( + baseline=bl, + dataset=ndata, + delay=delays, + sample=self.nsamp, + attrs_from=ssref, + ) + + delay_spec.redistribute("baseline") + delay_spec.spectrum[:] = 0.0 + + # Copy the index maps for all the flattened axes into the output container, and + # write out their order into an attribute so we can reconstruct this easily + # when loading in the spectrum + if isinstance(coord_axes, list): + for ax in coord_axes: + delay_spec.create_index_map(ax, ssref.index_map[ax]) + delay_spec.attrs["baseline_axes"] = coord_axes + + if self.save_samples: + delay_spec.add_dataset("spectrum_samples") + + # Save the frequency axis of the input data as an attribute in the output + # container + delay_spec.attrs["freq"] = ssref.freq + + return delay_spec + + def _evaluate(self, data_view, weight_view, out_cont, delays, channel_ind): + ndata = len(data_view) + ndelay = len(delays) + + initial_S = ( + np.identity(ndata)[:, :, np.newaxis] + * np.ones_like(delays) + * self.initial_amplitude + ) + + # Initialize the random number generator we'll use + rng = self.rng + nbase = out_cont.spectrum.shape[-2] + + # Iterate over all baselines and use the Gibbs sampler to estimate the spectrum + for lbi, bi in out_cont.spectrum[:].enumerate(axis=-2): + self.log.debug(f"Delay transforming baseline {bi}/{nbase}") + + # Get the local selections for all datasets and combine into a single array + data = np.array([d.local_array[lbi] for d in data_view]) + weight = np.array([w.local_array[lbi] for w in weight_view]) + + # Apply the cuts to the data + t = self._cut_data(data, weight, channel_ind) + if t is None: + continue + data, weight, non_zero_channel = t + + spec = delay_spectrum_gibbs_cross( + data, + ndelay, + weight, + initial_S, + window=self.window if self.apply_window else None, + fsel=non_zero_channel, + niter=self.nsamp, + rng=rng, + ) + + # Take an average over the last half of the delay spectrum samples + # (presuming that removes the burn-in) + spec_av = np.median(spec[-(self.nsamp // 2) :], axis=0) + out_cont.spectrum[..., bi, :] = np.fft.fftshift(spec_av) + + if self.save_samples: + out_cont.datasets["spectrum_samples"][..., bi, :] = spec + + return out_cont + + def stokes_I(sstream, tel): """Extract instrumental Stokes I from a time/sidereal stream. diff --git a/draco/core/containers.py b/draco/core/containers.py index 78aaa131f..3e6755e63 100644 --- a/draco/core/containers.py +++ b/draco/core/containers.py @@ -2428,7 +2428,14 @@ class DelayCrossSpectrum(DelaySpectrum): "initialise": True, "distributed": True, "distributed_axis": "baseline", - } + }, + "spectrum_samples": { + "axes": ["sample", "dataset", "dataset", "baseline", "delay"], + "dtype": np.float64, + "initialise": False, + "distributed": True, + "distributed_axis": "baseline", + }, } @property