Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Maximum likelihood delay power spectrum estimator #262

Merged
merged 7 commits into from
Jun 17, 2024
207 changes: 153 additions & 54 deletions draco/analysis/delay.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from ..core import containers, io, task
from ..util import random, tools
from .delayopt import delay_power_spectrum_maxpost


class DelayFilter(task.SingleTask):
Expand Down Expand Up @@ -362,6 +363,22 @@ class DelayTransformBase(task.SingleTask):
assume the noise power in the data is `weight_boost` times lower, which is
useful if you want the "true" noise to not be downweighted by the Wiener filter,
or have it included in the Gibbs sampler. Default: 1.0.
freq_frac
The threshold for the fraction of time samples present in a frequency for it
to be retained. Must be strictly greater than this value, so the default
value 0, retains any channel with at least one sample. A value of 0.01 would
retain any frequency that has > 1% of time samples unmasked.
time_frac
The threshold for the fraction of frequency samples required to retain a
time sample. Must be strictly greater than this value. The default value (-1)
means that all time samples are kept. A value of 0.01 would keep any time
sample with >1% of frequencies unmasked.
remove_mean
Subtract the mean in time of each frequency channel. This is done after time
samples are pruned by the `time_frac` threshold.
scale_freq
Scale each frequency by its standard deviation to flatten the fluctuations
across the band. Applied before any apodisation is done.
"""

freq_zero = config.Property(proptype=float, default=None)
Expand All @@ -385,6 +402,12 @@ class DelayTransformBase(task.SingleTask):
complex_timedomain = config.Property(proptype=bool, default=False)
weight_boost = config.Property(proptype=float, default=1.0)

freq_frac = config.Property(proptype=float, default=0.0)
time_frac = config.Property(proptype=float, default=-1.0)

remove_mean = config.Property(proptype=bool, default=False)
scale_freq = config.Property(proptype=bool, default=False)

def process(self, ss):
"""Estimate the delay spectrum or power spectrum.

Expand Down Expand Up @@ -542,8 +565,10 @@ def _calculate_delays(
# NOTE: this not obviously the right level for this, but it's the only baseclass in
# common to where it's used
def _cut_data(
self, data: np.ndarray, weight: np.ndarray, channel_ind: np.ndarray
) -> Optional[tuple[np.ndarray, np.ndarray, np.ndarray]]:
self,
data: np.ndarray,
weight: np.ndarray,
) -> Optional[tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]]:
"""Apply cuts on the data and weights and returned modified versions.

Parameters
Expand All @@ -553,8 +578,6 @@ def _cut_data(
the second last.
weight
A n-d array of the weights. Axes the same as the data.
channel_ind
The indices of the frequency channels.

Returns
-------
Expand All @@ -563,35 +586,51 @@ def _cut_data(
new_weight
The new weights with cuts applied and averaged over the `average_axis` (i.e
second last).
new_channel_ind
The indices of the remaining channels after cuts.
non_zero_freq
The selection of frequencies retained. A boolean array of shape N_freq that
is true at indices of frequencies retained after applying the freq_frac cut.
non_zero_time
The selection of times retained. A boolean array of shape N_rime that is
true at indices of time samples retained after applying the time_frac cut.
"""
# Mask out data with completely zero'd weights and generate time
# averaged weights
weight_cut = (
1e-4 * weight.mean()
) # Use approx threshold to ignore small weights
data = data * (weight > weight_cut)
weight = np.mean(weight, axis=-2)
ntime, nfreq = data.shape[-2:]
non_zero_time = (weight > 0).mean(axis=-1).reshape(-1, ntime).mean(
axis=0
) > self.time_frac
non_zero_freq = (weight > 0).mean(axis=-2).reshape(-1, nfreq).mean(
axis=0
) > self.freq_frac

if (data == 0.0).all():
# If there are no non-zero weighted entries skip
if not non_zero_freq.any():
return None

# If there are no non-zero weighted entries skip
non_zero = (weight > 0).reshape(-1, weight.shape[-1]).all(axis=0)
if not non_zero.any():
data = data[..., non_zero_time, :][..., non_zero_freq]
weight = weight[..., non_zero_time, :][..., non_zero_freq]

# Remove the mean from the data before estimating the spectrum
if self.remove_mean:
# Do not apply this in place to make sure we don't modify
# the input data
data = data - data.mean(axis=0, keepdims=True)

# If there are no non-zero data entries skip
if (data == 0.0).all():
return None

# Remove any frequency channel which is entirely zero, this is just to
# reduce the computational cost, it should make no difference to the result
data = data[..., non_zero]
weight = weight[..., non_zero]
non_zero_channel = channel_ind[non_zero]
# Scale the frequencies by the typical fluctuation size, with a scaling to
# obtain constant total power
if self.scale_freq:
dscl = (
data.std(axis=-2)[..., np.newaxis, :]
/ data.std(axis=(-1, -2))[..., np.newaxis, np.newaxis]
)
data = data * tools.invert_no_zero(dscl)

# Increase the weights by a specified amount
weight = np.mean(weight, axis=-2)
weight *= self.weight_boost

return data, weight, non_zero_channel
return data, weight, non_zero_freq, non_zero_time


class DelayGibbsSamplerBase(DelayTransformBase, random.RandomTask):
Expand All @@ -603,22 +642,34 @@ class DelayGibbsSamplerBase(DelayTransformBase, random.RandomTask):
Attributes
----------
nsamp : int, optional
The number of Gibbs samples to draw.
If maxpost=False, the number of Gibbs samples to draw. If maxpost=True,
the number of iterations allowed in the call to scipy.optimize.minimize
in the maximum-likelihood estimator.
initial_amplitude : float, optional
The Gibbs sampler will be initialized with a flat power spectrum with
this amplitude. Default: 10.
this amplitude. Unused if maxpost=True (flat spectrum is a bad initial
guess for the max-likelihood estimator). Default: 10.
save_samples : bool, optional.
The entire chain of samples will be saved rather than just the final
result. Default: False
initial_sample_path : str, optional
File path to load an initial power spectrum sample. If no file is given,
start with a flat power spectrum. Default: None
start with a flat power spectrum (Gibbs) or inverse FFT (max-likelihood).
Default: None
maxpost : bool, optional
The NRML maximum-likelihood delay spectrum estimator will be used instead
of the Gibbs sampler.
maxpost_tol : float, optional
Only used if maxpost=True. The convergence tolerance used by
scipy.optimize.minimize in the maximum likelihood estimator.
"""

nsamp = config.Property(proptype=int, default=20)
initial_amplitude = config.Property(proptype=float, default=10.0)
save_samples = config.Property(proptype=bool, default=False)
initial_sample_path = config.Property(proptype=str, default=None)
maxpost = config.Property(proptype=bool, default=False)
maxpost_tol = config.Property(proptype=float, default=1e-3)

def _create_output(
self,
Expand Down Expand Up @@ -663,6 +714,12 @@ def _create_output(
if self.save_samples:
delay_spec.add_dataset("spectrum_samples")

# If estimating delay spectrum w/ max-likelihood, initialize a mask dataset
# to record the baselines for which the estimator did/didn't converge.
if self.maxpost:
delay_spec.add_dataset("spectrum_mask")
delay_spec.datasets["spectrum_mask"][:] = 0

# Save the frequency axis of the input data as an attribute in the output
# container
delay_spec.attrs["freq"] = ss.freq
Expand Down Expand Up @@ -690,8 +747,16 @@ def _get_initial_S(self, nbase, ndelay, dtype):
initial_S = cont.spectrum[:].local_array
bl_ax = cont.spectrum.attrs["axis"].tolist().index("baseline")
initial_S = np.moveaxis(initial_S, bl_ax, 0)
else:
# Gibbs case.
elif not self.maxpost:
initial_S = np.ones((nbase, ndelay), dtype=dtype) * self.initial_amplitude
# Max-likelihood case.
else:
# Flat spectrum is a bad initial guess for max-likelihood.
# Passing None as the initial guess to the max-likelihood
# estimator will cause it to use an inverse FFT as the
# initial guess, which works well in practice.
initial_S = np.full(nbase, None)

return initial_S

Expand Down Expand Up @@ -734,30 +799,64 @@ def _evaluate(self, data_view, weight_view, out_cont, delays, channel_ind):
weight = weight_view.local_array[lbi]

# Apply the cuts to the data
t = self._cut_data(data, weight, channel_ind)
t = self._cut_data(data, weight)
if t is None:
continue
data, weight, non_zero_channel = t
data, weight, nzf, _ = t

if self.maxpost:
spec, success = delay_power_spectrum_maxpost(
data,
ndelay,
weight,
initial_S[lbi],
window=self.window if self.apply_window else None,
fsel=channel_ind[nzf],
maxiter=self.nsamp,
tol=self.maxpost_tol,
)

spec = delay_power_spectrum_gibbs(
data,
ndelay,
weight,
initial_S[lbi],
window=self.window if self.apply_window else None,
fsel=non_zero_channel,
niter=self.nsamp,
rng=rng,
complex_timedomain=self.complex_timedomain,
)
# If max-likelihood didn't converge in allowed number of iters, reflect this in the mask.
if not success:
# Indexing into a MemDatasetDistributed object with the
# global index bi actually ends up (under the hood)
# indexing the underlying MPIArray with the local index.
out_cont.datasets["spectrum_mask"][bi] = 1
jmaceachern marked this conversation as resolved.
Show resolved Hide resolved

# Take an average over the last half of the delay spectrum samples
# (presuming that removes the burn-in)
spec_av = np.median(spec[-(self.nsamp // 2) :], axis=0)
out_cont.spectrum[bi] = np.fft.fftshift(spec_av)
out_cont.spectrum[bi] = np.fft.fftshift(spec[-1])

if self.save_samples:
out_cont.datasets["spectrum_samples"][:, bi] = spec
if self.save_samples:
nsamp = len(spec)
out_cont.datasets["spectrum_samples"][:, bi] = 0.0
out_cont.datasets["spectrum_samples"][-nsamp:, bi] = np.array(spec)

else:
spec = delay_power_spectrum_gibbs(
data,
ndelay,
weight,
initial_S[lbi],
window=self.window if self.apply_window else None,
fsel=channel_ind[nzf],
niter=self.nsamp,
rng=rng,
complex_timedomain=self.complex_timedomain,
)

# Take an average over the last half of the delay spectrum samples
# (presuming that removes the burn-in)
spec_av = np.median(spec[-(self.nsamp // 2) :], axis=0)
out_cont.spectrum[bi] = np.fft.fftshift(spec_av)

if self.save_samples:
out_cont.datasets["spectrum_samples"][:, bi] = spec

if self.maxpost:
# Record number of converged baselines for debugging info.
n_conv = nbase - out_cont.datasets["spectrum_mask"][:].allgather().sum()
self.log.debug(
f"{n_conv}/{nbase} baselines converged in maximum-likelihood estimate of delay power spectrum."
)

return out_cont

Expand Down Expand Up @@ -968,10 +1067,10 @@ def _evaluate(self, data_view, weight_view, out_cont, delays, channel_ind):
weight = weight_view.local_array[lbi]

# Apply the cuts to the data
t = self._cut_data(data, weight, channel_ind)
t = self._cut_data(data, weight)
if t is None:
continue
data, weight, non_zero_channel = t
data, weight, nzf, _ = t

# Pass the delay power spectrum and frequency spectrum for each "baseline"
# to the Wiener filtering routine.The delay power spectrum has been
Expand All @@ -983,7 +1082,7 @@ def _evaluate(self, data_view, weight_view, out_cont, delays, channel_ind):
ndelay,
weight,
window=self.window if self.apply_window else None,
fsel=non_zero_channel,
fsel=channel_ind[nzf],
complex_timedomain=self.complex_timedomain,
)
# FFT-shift along the last axis
Expand Down Expand Up @@ -1114,18 +1213,18 @@ def _evaluate(self, data_view, weight_view, out_cont, delays, channel_ind):
weight = np.array([w.local_array[lbi] for w in weight_view])

# Apply the cuts to the data
t = self._cut_data(data, weight, channel_ind)
t = self._cut_data(data, weight)
if t is None:
continue
data, weight, non_zero_channel = t
data, weight, nzf, _ = t

spec = delay_spectrum_gibbs_cross(
data,
ndelay,
weight,
initial_S[lbi],
window=self.window if self.apply_window else None,
fsel=non_zero_channel,
fsel=channel_ind[nzf],
niter=self.nsamp,
rng=rng,
)
Expand Down Expand Up @@ -1512,7 +1611,7 @@ def _draw_signal_sample_f(S):
# power spectrum equal to 0.5 times the power spectrum of the complex
# delay spectrum, if the statistics are circularly symmetric
S = 0.5 * np.repeat(S, 2)
Si = 1.0 / S
Si = 1.0 * tools.invert_no_zero(S)
Ci = np.diag(Si) + FTNiF

# Draw random vectors that form the perturbations
Expand Down
Loading
Loading