diff --git a/draco/analysis/delay.py b/draco/analysis/delay.py index 7946aa925..669257454 100644 --- a/draco/analysis/delay.py +++ b/draco/analysis/delay.py @@ -607,11 +607,18 @@ class DelayGibbsSamplerBase(DelayTransformBase, random.RandomTask): initial_amplitude : float, optional The Gibbs sampler will be initialized with a flat power spectrum with this amplitude. Default: 10. + save_samples : bool, optional. + The entire chain of samples will be saved rather than just the final + result. Default: False + initial_sample_path : str, optional + File path to load an initial power spectrum sample. If no file is given, + start with a flat power spectrum. Default: None """ nsamp = config.Property(proptype=int, default=20) initial_amplitude = config.Property(proptype=float, default=10.0) save_samples = config.Property(proptype=bool, default=False) + initial_sample_path = config.Property(proptype=str, default=None) def _create_output( self, @@ -662,6 +669,32 @@ def _create_output( return delay_spec + def _get_initial_S(self, nbase, ndelay, dtype): + """Load the initial spectrum estimate if a file path exists. + + Parameters + ---------- + nbase : int + Number of baselines + ndelay : int + Number of delay samples + dtype : type | np.dtype | str + Datatype for the sample if no path is given + """ + if self.initial_sample_path is not None: + cont = ContainerBase.from_file(self.initial_sample_path, distributed=True) + + cont.redistribute("baseline") + + # Extract the spectrum and ove the baseline axis to the front + initial_S = cont.spectrum[:].local_array + bl_ax = cont.spectrum.attrs["axis"].tolist().index("baseline") + initial_S = np.moveaxis(initial_S, bl_ax, 0) + else: + initial_S = np.ones((nbase, ndelay), dtype=dtype) * self.initial_amplitude + + return initial_S + def _evaluate(self, data_view, weight_view, out_cont, delays, channel_ind): """Estimate the delay spectrum or power spectrum. @@ -687,7 +720,7 @@ def _evaluate(self, data_view, weight_view, out_cont, delays, channel_ind): ndelay = len(delays) # Set initial conditions for delay power spectrum - initial_S = np.ones_like(delays) * self.initial_amplitude + initial_S = self._get_initial_S(nbase, ndelay, delays.dtype) # Initialize the random number generator we'll use rng = self.rng @@ -710,7 +743,7 @@ def _evaluate(self, data_view, weight_view, out_cont, delays, channel_ind): data, ndelay, weight, - initial_S, + initial_S[lbi], window=self.window if self.apply_window else None, fsel=non_zero_channel, niter=self.nsamp, @@ -1053,16 +1086,24 @@ def _create_output( def _evaluate(self, data_view, weight_view, out_cont, delays, channel_ind): ndata = len(data_view) ndelay = len(delays) + nbase = out_cont.spectrum.shape[-2] - initial_S = ( - np.identity(ndata)[:, :, np.newaxis] - * np.ones_like(delays) - * self.initial_amplitude - ) + initial_S = self._get_initial_S(nbase, ndelay, delays.dtype) + + if initial_S.ndim == 2: + # Expand the sample shape to match the number of datasets + initial_S = ( + np.identity(ndata)[np.newaxis, ..., np.newaxis] + * initial_S[:, np.newaxis, np.newaxis] + ) + elif (initial_S.ndim != 4) or (initial_S.shape[1] != ndata): + raise ValueError( + f"Expected an initial sample with dimension 4 and {ndata} datasets. " + f"Got sample with dimension {initial_S.ndim} and shape {initial_S.shape}." + ) # Initialize the random number generator we'll use rng = self.rng - nbase = out_cont.spectrum.shape[-2] # Iterate over all baselines and use the Gibbs sampler to estimate the spectrum for lbi, bi in out_cont.spectrum[:].enumerate(axis=-2): @@ -1082,7 +1123,7 @@ def _evaluate(self, data_view, weight_view, out_cont, delays, channel_ind): data, ndelay, weight, - initial_S, + initial_S[lbi], window=self.window if self.apply_window else None, fsel=non_zero_channel, niter=self.nsamp,