Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optionally load an initial delay spectrum sample #256

Merged
merged 1 commit into from
Jan 11, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 50 additions & 9 deletions draco/analysis/delay.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,11 +607,18 @@ class DelayGibbsSamplerBase(DelayTransformBase, random.RandomTask):
initial_amplitude : float, optional
The Gibbs sampler will be initialized with a flat power spectrum with
this amplitude. Default: 10.
save_samples : bool, optional.
The entire chain of samples will be saved rather than just the final
result. Default: False
initial_sample_path : str, optional
File path to load an initial power spectrum sample. If no file is given,
start with a flat power spectrum. Default: None
"""

nsamp = config.Property(proptype=int, default=20)
initial_amplitude = config.Property(proptype=float, default=10.0)
save_samples = config.Property(proptype=bool, default=False)
initial_sample_path = config.Property(proptype=str, default=None)

def _create_output(
self,
Expand Down Expand Up @@ -662,6 +669,32 @@ def _create_output(

return delay_spec

def _get_initial_S(self, nbase, ndelay, dtype):
"""Load the initial spectrum estimate if a file path exists.

Parameters
----------
nbase : int
Number of baselines
ndelay : int
Number of delay samples
dtype : type | np.dtype | str
Datatype for the sample if no path is given
"""
if self.initial_sample_path is not None:
cont = ContainerBase.from_file(self.initial_sample_path, distributed=True)

cont.redistribute("baseline")

# Extract the spectrum and ove the baseline axis to the front
initial_S = cont.spectrum[:].local_array
bl_ax = cont.spectrum.attrs["axis"].tolist().index("baseline")
initial_S = np.moveaxis(initial_S, bl_ax, 0)
else:
initial_S = np.ones((nbase, ndelay), dtype=dtype) * self.initial_amplitude

return initial_S

def _evaluate(self, data_view, weight_view, out_cont, delays, channel_ind):
"""Estimate the delay spectrum or power spectrum.

Expand All @@ -687,7 +720,7 @@ def _evaluate(self, data_view, weight_view, out_cont, delays, channel_ind):
ndelay = len(delays)

# Set initial conditions for delay power spectrum
initial_S = np.ones_like(delays) * self.initial_amplitude
initial_S = self._get_initial_S(nbase, ndelay, delays.dtype)

# Initialize the random number generator we'll use
rng = self.rng
Expand All @@ -710,7 +743,7 @@ def _evaluate(self, data_view, weight_view, out_cont, delays, channel_ind):
data,
ndelay,
weight,
initial_S,
initial_S[lbi],
window=self.window if self.apply_window else None,
fsel=non_zero_channel,
niter=self.nsamp,
Expand Down Expand Up @@ -1053,16 +1086,24 @@ def _create_output(
def _evaluate(self, data_view, weight_view, out_cont, delays, channel_ind):
ndata = len(data_view)
ndelay = len(delays)
nbase = out_cont.spectrum.shape[-2]

initial_S = (
np.identity(ndata)[:, :, np.newaxis]
* np.ones_like(delays)
* self.initial_amplitude
)
initial_S = self._get_initial_S(nbase, ndelay, delays.dtype)

if initial_S.ndim == 2:
# Expand the sample shape to match the number of datasets
initial_S = (
np.identity(ndata)[np.newaxis, ..., np.newaxis]
* initial_S[:, np.newaxis, np.newaxis]
)
elif (initial_S.ndim != 4) or (initial_S.shape[1] != ndata):
raise ValueError(
f"Expected an initial sample with dimension 4 and {ndata} datasets. "
f"Got sample with dimension {initial_S.ndim} and shape {initial_S.shape}."
)

# Initialize the random number generator we'll use
rng = self.rng
nbase = out_cont.spectrum.shape[-2]

# Iterate over all baselines and use the Gibbs sampler to estimate the spectrum
for lbi, bi in out_cont.spectrum[:].enumerate(axis=-2):
Expand All @@ -1082,7 +1123,7 @@ def _evaluate(self, data_view, weight_view, out_cont, delays, channel_ind):
data,
ndelay,
weight,
initial_S,
initial_S[lbi],
window=self.window if self.apply_window else None,
fsel=non_zero_channel,
niter=self.nsamp,
Expand Down
Loading