Skip to content

Commit

Permalink
feat(delay): provide an initial guess for the gibbs sampler
Browse files Browse the repository at this point in the history
Provide a file path to delay power spectrum estimator when using
a Gibbs sampling method.
  • Loading branch information
ljgray committed Jan 11, 2024
1 parent a591a6d commit 9df3973
Showing 1 changed file with 50 additions and 9 deletions.
59 changes: 50 additions & 9 deletions draco/analysis/delay.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,11 +607,18 @@ class DelayGibbsSamplerBase(DelayTransformBase, random.RandomTask):
initial_amplitude : float, optional
The Gibbs sampler will be initialized with a flat power spectrum with
this amplitude. Default: 10.
save_samples : bool, optional.
The entire chain of samples will be saved rather than just the final
result. Default: False
initial_sample_path : str, optional
File path to load an initial power spectrum sample. If no file is given,
start with a flat power spectrum. Default: None
"""

nsamp = config.Property(proptype=int, default=20)
initial_amplitude = config.Property(proptype=float, default=10.0)
save_samples = config.Property(proptype=bool, default=False)
initial_sample_path = config.Property(proptype=str, default=None)

def _create_output(
self,
Expand Down Expand Up @@ -662,6 +669,32 @@ def _create_output(

return delay_spec

def _get_initial_S(self, nbase, ndelay, dtype):
"""Load the initial spectrum estimate if a file path exists.
Parameters
----------
nbase : int
Number of baselines
ndelay : int
Number of delay samples
dtype : type | np.dtype | str
Datatype for the sample if no path is given
"""
if self.initial_sample_path is not None:
cont = ContainerBase.from_file(self.initial_sample_path, distributed=True)

cont.redistribute("baseline")

# Extract the spectrum and ove the baseline axis to the front
initial_S = cont.spectrum[:].local_array
bl_ax = cont.spectrum.attrs["axis"].tolist().index("baseline")
initial_S = np.moveaxis(initial_S, bl_ax, 0)
else:
initial_S = np.ones((nbase, ndelay), dtype=dtype) * self.initial_amplitude

return initial_S

def _evaluate(self, data_view, weight_view, out_cont, delays, channel_ind):
"""Estimate the delay spectrum or power spectrum.
Expand All @@ -687,7 +720,7 @@ def _evaluate(self, data_view, weight_view, out_cont, delays, channel_ind):
ndelay = len(delays)

# Set initial conditions for delay power spectrum
initial_S = np.ones_like(delays) * self.initial_amplitude
initial_S = self._get_initial_S(nbase, ndelay, delays.dtype)

# Initialize the random number generator we'll use
rng = self.rng
Expand All @@ -710,7 +743,7 @@ def _evaluate(self, data_view, weight_view, out_cont, delays, channel_ind):
data,
ndelay,
weight,
initial_S,
initial_S[lbi],
window=self.window if self.apply_window else None,
fsel=non_zero_channel,
niter=self.nsamp,
Expand Down Expand Up @@ -1053,16 +1086,24 @@ def _create_output(
def _evaluate(self, data_view, weight_view, out_cont, delays, channel_ind):
ndata = len(data_view)
ndelay = len(delays)
nbase = out_cont.spectrum.shape[-2]

initial_S = (
np.identity(ndata)[:, :, np.newaxis]
* np.ones_like(delays)
* self.initial_amplitude
)
initial_S = self._get_initial_S(nbase, ndelay, delays.dtype)

if initial_S.ndim == 2:
# Expand the sample shape to match the number of datasets
initial_S = (
np.identity(ndata)[np.newaxis, ..., np.newaxis]
* initial_S[:, np.newaxis, np.newaxis]
)
elif (initial_S.ndim != 4) or (initial_S.shape[1] != ndata):
raise ValueError(
f"Expected an initial sample with dimension 4 and {ndata} datasets. "
f"Got sample with dimension {initial_S.ndim} and shape {initial_S.shape}."
)

# Initialize the random number generator we'll use
rng = self.rng
nbase = out_cont.spectrum.shape[-2]

# Iterate over all baselines and use the Gibbs sampler to estimate the spectrum
for lbi, bi in out_cont.spectrum[:].enumerate(axis=-2):
Expand All @@ -1082,7 +1123,7 @@ def _evaluate(self, data_view, weight_view, out_cont, delays, channel_ind):
data,
ndelay,
weight,
initial_S,
initial_S[lbi],
window=self.window if self.apply_window else None,
fsel=non_zero_channel,
niter=self.nsamp,
Expand Down

0 comments on commit 9df3973

Please sign in to comment.