diff --git a/draco/analysis/delay.py b/draco/analysis/delay.py index 4b3ed3bf..7a9b8b7a 100644 --- a/draco/analysis/delay.py +++ b/draco/analysis/delay.py @@ -12,6 +12,7 @@ from ..core import containers, io, task from ..util import random, tools +from .delayopt import delay_power_spectrum_maxpost class DelayFilter(task.SingleTask): @@ -642,22 +643,34 @@ class DelayGibbsSamplerBase(DelayTransformBase, random.RandomTask): Attributes ---------- nsamp : int, optional - The number of Gibbs samples to draw. + If maxpost=False, the number of Gibbs samples to draw. If maxpost=True, + the number of iterations allowed in the call to scipy.optimize.minimize + in the maximum-likelihood estimator. initial_amplitude : float, optional The Gibbs sampler will be initialized with a flat power spectrum with - this amplitude. Default: 10. + this amplitude. Unused if maxpost=True (flat spectrum is a bad initial + guess for the max-likelihood estimator). Default: 10. save_samples : bool, optional. The entire chain of samples will be saved rather than just the final result. Default: False initial_sample_path : str, optional File path to load an initial power spectrum sample. If no file is given, - start with a flat power spectrum. Default: None + start with a flat power spectrum (Gibbs) or inverse FFT (max-likelihood). + Default: None + maxpost : bool, optional + The NRML maximum-likelihood delay spectrum estimator will be used instead + of the Gibbs sampler. + maxpost_tol : float, optional + Only used if maxpost=True. The convergence tolerance used by + scipy.optimize.minimize in the maximum likelihood estimator. """ nsamp = config.Property(proptype=int, default=20) initial_amplitude = config.Property(proptype=float, default=10.0) save_samples = config.Property(proptype=bool, default=False) initial_sample_path = config.Property(proptype=str, default=None) + maxpost = config.Property(proptype=bool, default=False) + maxpost_tol = config.Property(proptype=float, default=1e-3) def _create_output( self, @@ -702,6 +715,12 @@ def _create_output( if self.save_samples: delay_spec.add_dataset("spectrum_samples") + # If estimating delay spectrum w/ max-likelihood, initialize a mask dataset + # to record the baselines for which the estimator did/didn't converge. + if self.maxpost: + delay_spec.add_dataset("spectrum_mask") + delay_spec.datasets["spectrum_mask"][:] = 0 + # Save the frequency axis of the input data as an attribute in the output # container delay_spec.attrs["freq"] = ss.freq @@ -729,8 +748,16 @@ def _get_initial_S(self, nbase, ndelay, dtype): initial_S = cont.spectrum[:].local_array bl_ax = cont.spectrum.attrs["axis"].tolist().index("baseline") initial_S = np.moveaxis(initial_S, bl_ax, 0) - else: + # Gibbs case. + elif not self.maxpost: initial_S = np.ones((nbase, ndelay), dtype=dtype) * self.initial_amplitude + # Max-likelihood case. + else: + # Flat spectrum is a bad initial guess for max-likelihood. + # Passing None as the initial guess to the max-likelihood + # estimator will cause it to use an inverse FFT as the + # initial guess, which works well in practice. + initial_S = np.full(nbase, None) return initial_S @@ -773,30 +800,65 @@ def _evaluate(self, data_view, weight_view, out_cont, delays, channel_ind): weight = weight_view.local_array[lbi] # Apply the cuts to the data - t = self._cut_data(data, weight, channel_ind) + t = self._cut_data(data, weight) if t is None: continue - data, weight, non_zero_channel = t + data, weight, nzf, _ = t - spec = delay_power_spectrum_gibbs( - data, - ndelay, - weight, - initial_S[lbi], - window=self.window if self.apply_window else None, - fsel=non_zero_channel, - niter=self.nsamp, - rng=rng, - complex_timedomain=self.complex_timedomain, - ) + if self.maxpost: + spec, success = delay_power_spectrum_maxpost( + data, + ndelay, + weight, + initial_S[lbi], + window=self.window if self.apply_window else None, + fsel=channel_ind[nzf], + maxiter=self.nsamp, + tol=self.maxpost_tol, + ) - # Take an average over the last half of the delay spectrum samples - # (presuming that removes the burn-in) - spec_av = np.median(spec[-(self.nsamp // 2) :], axis=0) - out_cont.spectrum[bi] = np.fft.fftshift(spec_av) + # If max-likelihood didn't converge in allowed number of iters, reflect this in the mask. + if not success: + out_cont.datasets["spectrum_mask"][ + bi + ] = 1 # Indexing into a MemDatasetDistributed object with the + # global index bi actually ends up (under the hood) + # indexing the underlying MPIArray with the local index. - if self.save_samples: - out_cont.datasets["spectrum_samples"][:, bi] = spec + out_cont.spectrum[bi] = np.fft.fftshift(spec[-1]) + + if self.save_samples: + nsamp = len(spec) + out_cont.datasets["spectrum_samples"][:, bi] = 0.0 + out_cont.datasets["spectrum_samples"][-nsamp:, bi] = np.array(spec) + + else: + spec = delay_power_spectrum_gibbs( + data, + ndelay, + weight, + initial_S[lbi], + window=self.window if self.apply_window else None, + fsel=channel_ind[nzf], + niter=self.nsamp, + rng=rng, + complex_timedomain=self.complex_timedomain, + ) + + # Take an average over the last half of the delay spectrum samples + # (presuming that removes the burn-in) + spec_av = np.median(spec[-(self.nsamp // 2) :], axis=0) + out_cont.spectrum[bi] = np.fft.fftshift(spec_av) + + if self.save_samples: + out_cont.datasets["spectrum_samples"][:, bi] = spec + + if self.maxpost: + # Record number of converged baselines for debugging info. + n_conv = nbase - out_cont.datasets["spectrum_mask"][:].allgather().sum() + self.log.debug( + f"{n_conv}/{nbase} baselines converged in maximum-likelihood estimate of delay power spectrum." + ) return out_cont