Skip to content

Commit

Permalink
feat(DelayCrossPowerSpectrumEstimator): calculate cross power spectra
Browse files Browse the repository at this point in the history
  • Loading branch information
jrs65 committed Jan 11, 2024
1 parent 140ada9 commit 58e8723
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 1 deletion.
135 changes: 135 additions & 0 deletions draco/analysis/delay.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,6 +965,141 @@ def _evaluate(self, data_view, weight_view, out_cont, delays, channel_ind):
DelaySpectrumWienerBase = DelaySpectrumWienerEstimator


class DelayCrossPowerSpectrumEstimator(
DelayGeneralContainerBase, DelayGibbsSamplerBase
):
"""A delay cross power spectrum estimator.
This takes multiple compatible `FreqContainer`s as inputs and will return a
`DelayCrossSpectrum` container with the full pair-wise cross power spectrum.
"""

def _process_data(
self, sslist: list[FreqContainer]
) -> tuple[list[mpiarray.MPIArray], list[mpiarray.MPIArray], list[str]]:
if len(sslist) == 0:
raise ValueError("No datasets passed.")

freq_ref = sslist[0].freq

data_views = []
weight_views = []
coord_axes = None

for ss in sslist:
ss.redistribute("freq")

if (ss.freq != freq_ref).all():
raise ValueError("Input containers must have the same frequencies.")
dv, wv, ca = DelayGeneralContainerBase._process_data(self, ss)

if coord_axes is not None and not coord_axes == ca:
raise ValueError("Different axes found for the input containers.")

data_views.append(dv)
weight_views.append(wv)
coord_axes = ca

return data_views, weight_views, coord_axes

def _create_output(
self, ss: list[FreqContainer], delays: np.ndarray, coord_axes: list[str]
) -> ContainerBase:
"""Create the output container for the delay power spectrum.
If `coord_axes` is a list of strings then it is assumed to be a list of the
names of the folded axes. If it's an array then assume it is the actual axis
definition.
"""
ssref = ss[0]
ndata = len(ss)

# If only one axis is being collapsed, use that as the baseline axis definition,
# otherwise just use integer indices
if len(coord_axes) == 1:
bl = ssref.index_map[coord_axes[0]]
else:
bl = np.prod([len(ssref.index_map[ax]) for ax in coord_axes])

# Initialise the spectrum container
delay_spec = containers.DelayCrossSpectrum(
baseline=bl,
dataset=ndata,
delay=delays,
sample=self.nsamp,
attrs_from=ssref,
)

delay_spec.redistribute("baseline")
delay_spec.spectrum[:] = 0.0

# Copy the index maps for all the flattened axes into the output container, and
# write out their order into an attribute so we can reconstruct this easily
# when loading in the spectrum
if isinstance(coord_axes, list):
for ax in coord_axes:
delay_spec.create_index_map(ax, ssref.index_map[ax])
delay_spec.attrs["baseline_axes"] = coord_axes

if self.save_samples:
delay_spec.add_dataset("spectrum_samples")

# Save the frequency axis of the input data as an attribute in the output
# container
delay_spec.attrs["freq"] = ssref.freq

return delay_spec

def _evaluate(self, data_view, weight_view, out_cont, delays, channel_ind):
ndata = len(data_view)
ndelay = len(delays)

initial_S = (
np.identity(ndata)[:, :, np.newaxis]
* np.ones_like(delays)
* self.initial_amplitude
)

# Initialize the random number generator we'll use
rng = self.rng
nbase = out_cont.spectrum.shape[-2]

# Iterate over all baselines and use the Gibbs sampler to estimate the spectrum
for lbi, bi in out_cont.spectrum[:].enumerate(axis=-2):
self.log.debug(f"Delay transforming baseline {bi}/{nbase}")

# Get the local selections for all datasets and combine into a single array
data = np.array([d.local_array[lbi] for d in data_view])
weight = np.array([w.local_array[lbi] for w in weight_view])

# Apply the cuts to the data
t = self._cut_data(data, weight, channel_ind)
if t is None:
continue
data, weight, non_zero_channel = t

spec = delay_spectrum_gibbs_cross(
data,
ndelay,
weight,
initial_S,
window=self.window if self.apply_window else None,
fsel=non_zero_channel,
niter=self.nsamp,
rng=rng,
)

# Take an average over the last half of the delay spectrum samples
# (presuming that removes the burn-in)
spec_av = np.median(spec[-(self.nsamp // 2) :], axis=0)
out_cont.spectrum[..., bi, :] = np.fft.fftshift(spec_av)

if self.save_samples:
out_cont.datasets["spectrum_samples"][..., bi, :] = spec

return out_cont


def stokes_I(sstream, tel):
"""Extract instrumental Stokes I from a time/sidereal stream.
Expand Down
9 changes: 8 additions & 1 deletion draco/core/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2428,7 +2428,14 @@ class DelayCrossSpectrum(DelaySpectrum):
"initialise": True,
"distributed": True,
"distributed_axis": "baseline",
}
},
"spectrum_samples": {
"axes": ["sample", "dataset", "dataset", "baseline", "delay"],
"dtype": np.float64,
"initialise": False,
"distributed": True,
"distributed_axis": "baseline",
},
}

@property
Expand Down

0 comments on commit 58e8723

Please sign in to comment.