Skip to content

Commit

Permalink
feat(delay): provide an initial guess for the gibbs sampler
Browse files Browse the repository at this point in the history
Provide a file path to delay power spectrum estimator when using
a Gibbs sampling method.
  • Loading branch information
ljgray committed Jan 9, 2024
1 parent 9b18e41 commit a061e07
Showing 1 changed file with 48 additions and 7 deletions.
55 changes: 48 additions & 7 deletions draco/analysis/delay.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,11 +607,18 @@ class DelayGibbsSamplerBase(DelayTransformBase, random.RandomTask):
initial_amplitude : float, optional
The Gibbs sampler will be initialized with a flat power spectrum with
this amplitude. Default: 10.
save_samples : bool, optional.
The entire chain of samples will be saved rather than just the final
result. Default: False
initial_sample_path : str, optional
File path to load an initial power spectrum sample. If no file is given,
start with a flat power spectrum. Default: None
"""

nsamp = config.Property(proptype=int, default=20)
initial_amplitude = config.Property(proptype=float, default=10.0)
save_samples = config.Property(proptype=bool, default=False)
initial_sample_path = config.Property(proptype=str, default=None)

def _create_output(
self,
Expand Down Expand Up @@ -662,6 +669,40 @@ def _create_output(

return delay_spec

def _get_initial_S(self, nbase, ndelay, dtype):
"""Load the initial spectrum estimate if a file path exists.
Parameters
----------
nbase : int
Number of baselines
ndelay : int
Number of delay samples
dtype : int | type | np.dtype | str
Expected datatype for the sample. Must be provided in a way
than can be understood when calling `numpy.dtype`
"""
dtype = np.dtype(dtype)

if self.initial_sample_path is not None:
cont = ContainerBase.from_file(self.initial_sample_path, distributed=True)
# Make sure this is distributed over the correct axis
cont.redistribute("baseline")

initial_S = cont.spectrum[:].local_array
# Ensure the baseline axis is at the front
bl_ax = cont.spectrum.attrs["axis"].tolist().index("baseline")
initial_S = np.moveaxis(initial_S, bl_ax, 0)

if initial_S.dtype != dtype:
raise TypeError(
f"Expected spectrum with type {dtype}. Got {initial_S.dtype}."
)
else:
initial_S = np.ones((nbase, ndelay), dtype=dtype) * self.initial_amplitude

return initial_S

def _evaluate(self, data_view, weight_view, out_cont, delays, channel_ind):
"""Estimate the delay spectrum or power spectrum.
Expand All @@ -687,7 +728,7 @@ def _evaluate(self, data_view, weight_view, out_cont, delays, channel_ind):
ndelay = len(delays)

# Set initial conditions for delay power spectrum
initial_S = np.ones_like(delays) * self.initial_amplitude
initial_S = self._get_initial_S(nbase, ndelay, delays.dtype)

# Initialize the random number generator we'll use
rng = self.rng
Expand All @@ -710,7 +751,7 @@ def _evaluate(self, data_view, weight_view, out_cont, delays, channel_ind):
data,
ndelay,
weight,
initial_S,
initial_S[lbi],
window=self.window if self.apply_window else None,
fsel=non_zero_channel,
niter=self.nsamp,
Expand Down Expand Up @@ -1053,16 +1094,16 @@ def _create_output(
def _evaluate(self, data_view, weight_view, out_cont, delays, channel_ind):
ndata = len(data_view)
ndelay = len(delays)
nbase = out_cont.spectrum.shape[-2]

initial_S = self._get_initial_S(nbase, ndelay, delays.dtype)
initial_S = (
np.identity(ndata)[:, :, np.newaxis]
* np.ones_like(delays)
* self.initial_amplitude
np.identity(ndata)[np.newaxis, ..., np.newaxis]
* initial_S[:, np.newaxis, np.newaxis]
)

# Initialize the random number generator we'll use
rng = self.rng
nbase = out_cont.spectrum.shape[-2]

# Iterate over all baselines and use the Gibbs sampler to estimate the spectrum
for lbi, bi in out_cont.spectrum[:].enumerate(axis=-2):
Expand All @@ -1082,7 +1123,7 @@ def _evaluate(self, data_view, weight_view, out_cont, delays, channel_ind):
data,
ndelay,
weight,
initial_S,
initial_S[lbi],
window=self.window if self.apply_window else None,
fsel=non_zero_channel,
niter=self.nsamp,
Expand Down

0 comments on commit a061e07

Please sign in to comment.