diff --git a/draco/analysis/delay.py b/draco/analysis/delay.py index c4fb9c520..33e6e367b 100644 --- a/draco/analysis/delay.py +++ b/draco/analysis/delay.py @@ -520,7 +520,7 @@ def _evaluate(self, data_view, weight_view, out_cont): Returns ------- - out_cont : `contaiers.DelayTransform` or `containers.DelaySpectrum` + out_cont : `containers.DelayTransform` or `containers.DelaySpectrum` Output delay spectrum or delay power spectrum. """ nbase = out_cont.spectrum.global_shape[0] @@ -1024,6 +1024,34 @@ def fourier_matrix_c2c(N, fsel=None): return F +def fourier_matrix(N: int, fsel: Optional[np.ndarray] = None) -> np.ndarray: + """Generate a Fourier matrix to represent a real to complex FFT. + + Parameters + ---------- + N : integer + Length of timestream that we are transforming to. Must be even. + fsel : array_like, optional + Indexes of the frequency channels to include in the transformation + matrix. By default, assume all channels. + + Returns + ------- + Fr : np.ndarray + An array performing the Fourier transform from a real time series to + frequencies packed as alternating real and imaginary elements, + """ + if fsel is None: + fa = np.arange(N) + else: + fa = np.array(fsel) + + fa = fa[:, np.newaxis] + ta = np.arange(N)[np.newaxis, :] + + return np.exp(-2.0j * np.pi * ta * fa / N) + + def _complex_to_alternating_real(array): """View complex numbers as an array with alternating real and imaginary components. @@ -1211,7 +1239,7 @@ def _draw_signal_sample_f(S): # then doing a matrix solve y = np.dot(FTNih, data + w2) + Si[:, np.newaxis] ** 0.5 * w1 - return la.solve(Ci, y, sym_pos=True) + return la.solve(Ci, y, assume_a="pos") def _draw_signal_sample_t(S): # This method is fastest if the number of delays is larger than the number of @@ -1240,7 +1268,7 @@ def _draw_signal_sample_t(S): # Perform the solve step (rather than explicitly using the inverse) y = data + w2 - np.dot(R, w1) Ci = np.identity(2 * Ni.shape[0]) + np.dot(R, Rt) - x = la.solve(Ci, y, sym_pos=True) + x = la.solve(Ci, y, assume_a="pos") return Sh[:, np.newaxis] * (np.dot(Rt, x) + w1) @@ -1277,6 +1305,201 @@ def _draw_ps_sample(d): return spec +def delay_spectrum_gibbs_cross( + data: np.ndarray, + N: int, + Ni: np.ndarray, + initial_S: np.ndarray, + window: str = "nuttall", + fsel: Optional[np.ndarray] = None, + niter: int = 20, + rng: Optional[np.random.Generator] = None, +) -> list[np.ndarray]: + """Estimate the delay power spectrum by Gibbs sampling. + + This routine estimates the spectrum at the `N` delay samples conjugate to + an input frequency spectrum with ``N/2 + 1`` channels (if the delay spectrum is + assumed real) or `N` channels (if the delay spectrum is assumed complex). + A subset of these channels can be specified using the `fsel` argument. + + Parameters + ---------- + data + A 3D array of [dataset, sample, freq]. The delay cross-power spectrum of these + will be calculated. + N + The length of the output delay spectrum. There are assumed to be `N/2 + 1` + total frequency channels if assuming a real delay spectrum, or `N` channels + for a complex delay spectrum. + Ni + Inverse noise variance as a 3D [dataset, sample, freq] array. + initial_S + The initial delay cross-power spectrum guess. A 3D array of [data1, data2, + delay]. + window : one of {'nuttall', 'blackman_nuttall', 'blackman_harris', None}, optional + Apply an apodisation function. Default: 'nuttall'. + fsel + Indices of channels that we have data at. By default assume all channels. + niter + Number of Gibbs samples to generate. + rng + A generator to use to produce the random samples. + + Returns + ------- + spec : list + List of cross-power spectrum samples. + """ + # Get reference to RNG + + if rng is None: + rng = random.default_rng() + + spec = [] + + nd, nsamp, Nf = data.shape + + if fsel is None: + fsel = np.arange(Nf) + elif len(fsel) != Nf: + raise ValueError( + "Length of frequency selection must match frequencies passed. " + f"{len(fsel)} != {data.shape[-1]}" + ) + + # Construct the Fourier matrix + F = fourier_matrix(N, fsel) + + if nd == 0: + raise ValueError("Need at least one set of data") + + # We want the sample axis to be last + data = data.transpose(0, 2, 1) + + # Window the frequency data + if window is not None: + # Construct the window function + x = fsel * 1.0 / N + w = tools.window_generalised(x, window=window) + + # Apply to the projection matrix and the data + F *= w[:, np.newaxis] + data *= w[:, np.newaxis] + + # Create the transpose of the Fourier matrix weighted by the noise + # (this is used multiple times) + # This is packed as a single freq -> delay projection per dataset + FTNih = F.T[np.newaxis, :, :] * Ni[:, np.newaxis, :] ** 0.5 + + # This should be an array for each dataset i of F_i^H N_i^{-1} F_i + FTNiF = np.zeros((nd, N, nd, N), dtype=np.complex128) + for ii in range(nd): + FTNiF[ii, :, ii] = FTNih[ii] @ FTNih[ii].T.conj() + + # Pre-whiten the data to save doing it repeatedly + data *= Ni[:, :, np.newaxis] ** 0.5 + + # Set the initial guess for the delay power spectrum. + S_samp = initial_S + + def _draw_signal_sample_f(S): + # Draw a random sample of the signal (delay spectrum) assuming a Gaussian model + # with a given delay power spectrum `S`. Do this using the perturbed Wiener + # filter approach + + # This method is fastest if the number of frequencies is larger than the number + # of delays we are solving for. Typically this isn't true, so we probably want + # `_draw_signal_sample_t` + + Si = np.empty_like(S) + Sh = np.empty((N, nd, nd), dtype=S.dtype) + + for ii in range(N): + inv = la.inv(S[:, :, ii]) + Si[:, :, ii] = inv + Sh[ii, :, :] = la.cholesky(S[:, :, ii], lower=False) + + Ci = FTNiF.copy() + for ii in range(nd): + for jj in range(nd): + Ci[ii, :, jj] += np.diag(Si[ii, jj]) + + w1 = random.standard_complex_normal((N, nd, nsamp), rng=rng) + w2 = random.standard_complex_normal(data.shape, rng=rng) + + # Construct the random signal sample by forming a perturbed vector and + # then doing a matrix solve + y = FTNih @ (data + w2) + + for ii in range(N): + w1s = la.solve_triangular( + Sh[ii], + w1[ii], + overwrite_b=True, + lower=False, + check_finite=False, + ) + y[:, ii] += w1s + # NOTE: Other combinations that you might think would work don't appear to + # be stable. Don't try these: + # y[:, ii] += Si[:, :, ii] @ Sh[:, :, ii] @ w1[:, ii] + # y[:, ii] += Shi[:, :, ii] @ w1[:, ii] + + cf = la.cho_factor( + Ci.reshape(nd * N, nd * N), + overwrite_a=True, + check_finite=False, + ) + + return la.cho_solve( + cf, + y.reshape(nd * N, nsamp), + overwrite_b=True, + check_finite=False, + ).reshape(nd, N, nsamp) + + def _draw_signal_sample_t(S): + # This method is fastest if the number of delays is larger than the number of + # frequencies. This is usually the regime we are in. + raise NotImplementedError("Drawing samples in the time basis not yet written.") + + def _draw_ps_sample(d): + # Draw a random delay power spectrum sample assuming the signal is Gaussian and + # we have a flat prior on the power spectrum. + # This means drawing from a inverse chi^2. + + # Estimate the sample covariance + S = np.empty((nd, nd, N), dtype=np.complex128) + for ii in range(N): + S[:, :, ii] = np.cov(d[:, ii], bias=True) + + # Then in place draw a sample of the true covariance from the posterior which + # is an inverse Wishart + for ii in range(N): + Si = la.inv(S[:, :, ii]) + Si_samp = random.complex_wishart(Si, nsamp, rng=rng) / nsamp + S[:, :, ii] = la.inv(Si_samp) + + return S + + # Select the method to use for the signal sample based on how many frequencies + # versus delays there are. At the moment only the _f method is implemented. + _draw_signal_sample = _draw_signal_sample_f + + # Perform the Gibbs sampling iteration for a given number of loops and + # return the power spectrum output of them. + try: + for ii in range(niter): + d_samp = _draw_signal_sample(S_samp) + S_samp = _draw_ps_sample(d_samp) + + spec.append(S_samp) + except la.LinAlgError as e: + raise RuntimeError("Exiting earlier as singular") from e + + return spec + + # Alias delay_spectrum_gibbs to delay_power_spectrum_gibbs, for backwards compatibility delay_spectrum_gibbs = delay_power_spectrum_gibbs @@ -1341,7 +1564,7 @@ def delay_spectrum_wiener_filter( # Solve the linear equation for the Wiener-filtered spectrum, and transpose to # [average_axis, delay] - y_spec = la.solve(Ci, y, sym_pos=True).T + y_spec = la.solve(Ci, y, assume_a="pos").T if complex_timedomain: y_spec = _alternating_real_to_complex(y_spec) diff --git a/draco/core/containers.py b/draco/core/containers.py index 6373f9ee7..8cb18a341 100644 --- a/draco/core/containers.py +++ b/draco/core/containers.py @@ -2432,6 +2432,27 @@ def weight(self): return self.datasets["weight"] +class DelayCrossSpectrum(DelaySpectrum): + """Container for a delay cross power spectra.""" + + _axes = ("dataset",) + + _dataset_spec = { + "spectrum": { + "axes": ["dataset", "dataset", "baseline", "delay"], + "dtype": np.float64, + "initialise": True, + "distributed": True, + "distributed_axis": "baseline", + } + } + + @property + def spectrum(self): + """Get the spectrum dataset.""" + return self.datasets["spectrum"] + + class Powerspectrum2D(ContainerBase): """Container for a 2D cartesian power spectrum. diff --git a/draco/core/misc.py b/draco/core/misc.py index bdb77be7e..3f868c346 100644 --- a/draco/core/misc.py +++ b/draco/core/misc.py @@ -298,3 +298,15 @@ def setup(self, input_): def next(self, input_): """Immediately forward any input.""" return input_ + + +class PassOn(task.MPILoggedTask): + """Unconditionally forward a tasks input. + + While this seems like a pointless no-op it's useful for connecting tasks in complex + topologies. + """ + + def next(self, input_): + """Immediately forward any input.""" + return input_ diff --git a/draco/util/random.py b/draco/util/random.py index d010746fa..54dc36890 100644 --- a/draco/util/random.py +++ b/draco/util/random.py @@ -1,7 +1,10 @@ """Utilities for drawing random numbers.""" +import concurrent.futures import contextlib +import os import zlib +from typing import Callable, Optional import numpy as np from caput import config @@ -31,7 +34,7 @@ def default_rng(): return _rng -def complex_normal(size=None, loc=0.0, scale=1.0, dtype=None, rng=None, out=None): +def complex_normal(loc=0.0, scale=1.0, size=None, dtype=None, rng=None, out=None): """Get a set of complex normal variables. By default generate standard complex normal variables. @@ -127,7 +130,7 @@ def standard_complex_normal(shape, dtype=None, rng=None): out : np.ndarray[shape] Complex gaussian variates. """ - return complex_normal(shape, dtype=dtype, rng=rng) + return complex_normal(size=shape, dtype=dtype, rng=rng) def standard_complex_wishart(m, n, rng=None): @@ -261,8 +264,202 @@ def mpi_random_seed(seed, extra=0, gen=None): gen.state = old_state +class MultithreadedRNG(np.random.Generator): + """A multithreaded random number generator. + + This wraps specific methods to allow generation across multiple threads. See + `PARALLEL_METHODS` for the specific methods wrapped. + + Parameters + ---------- + seed + The seed to use. + nthreads + The number of threads to use. If not set, this tries to get the number from the + `OMP_NUM_THREADS` environment variable, or just uses 4 if that is also not set. + bitgen + The BitGenerator to use, if not set this uses `_default_bitgen`. + """ + + _parallel_threshold = 1000 + + # The methods to generate parallel versions for. This table is: + # method name, number of initial parameter arguments, default data type, if there is + # a dtype argument, and if there is an out argument. See `_build_method` for + # details. + PARALLEL_METHODS = { + "random": (0, np.float64, True, True), + "integers": (2, np.int64, True, False), + "uniform": (2, np.float64, False, False), + "normal": (2, np.float64, False, False), + "standard_normal": (0, np.float64, True, True), + "poisson": (1, np.float64, False, False), + "power": (1, np.float64, False, False), + } + + def __init__( + self, + seed: Optional[int] = None, + threads: Optional[int] = None, + bitgen: Optional[np.random.BitGenerator] = None, + ): + if bitgen is None: + bitgen = _default_bitgen + + # Initialise this object with the given seed. This allows methods that don't + # have multithreaded support to work + super().__init__(bitgen(seed)) + + if threads is None: + threads = int(os.environ.get("OMP_NUM_THREADS", 4)) + + # Initialise the parallel thread pool + self._threads = threads + self._random_generators = [ + np.random.Generator(bitgen(seed=s)) + for s in np.random.SeedSequence(seed).spawn(threads) + ] + self._executor = concurrent.futures.ThreadPoolExecutor(threads) + + # Create the methods and attach them to this instance. + for method, spec in self.PARALLEL_METHODS.items(): + setattr(self, method, self._build_method(method, *spec)) + + def _build_method( + self, + name: str, + nparam: int, + defdtype: np.dtype, + has_dtype: bool, + has_out: bool, + ) -> Callable: + """Build a method for generating random numbers from a given distribution. + + As the underlying methods are in Cython they can't be adequately introspected + and so we need to provide information about the signature. + + Parameters + ---------- + name + The name of the generation method in `np.random.Generator`. + nparam + The number of distribution parameters that come before the `size` argument. + defdtype + The default datatype used if non is explicitly supplied. + has_dtype + Does the underlying method have a dtype argument? + has_out + Does the underlying method have an `out` parameter for directly filling an + array. + + Returns + ------- + parallel_method + A method for generating in parallel. + """ + method = getattr(np.random.Generator, name) + + def _call(*args, **kwargs): + orig_args = list(args) + orig_kwargs = dict(kwargs) + + # Try and get the size + if len(args) > nparam: + size = args[nparam] + elif "size" in kwargs: + size = kwargs.pop("size") + else: + size = None + + # Try and get an out argument + if has_out and "out" in kwargs: + out = kwargs.pop("out") + size = out.shape + else: + out = None + + # Try to figure out the dtype so we can pre-allocate the array for filling + if has_dtype and len(args) > nparam + 1: + dtype = args[nparam + 1] + elif has_dtype and "dtype" in kwargs: + dtype = kwargs.pop("dtype") + else: + dtype = defdtype + + # Trim any excess positional arguments + args = args[:nparam] + + # Check that all the parameters are scalar + all_scalar = all(np.isscalar(arg) for arg in args) + + # Check that any remaining kwargs (assumed to be parameters are also scalar) + all_scalar &= all(np.isscalar(arg) for arg in kwargs.values()) + + # If neither size nor out is set we can't parallelise this so just call + # directly. + # Additionally if the distribution arguments are not scalars there may be + # some complex broadcasting required, so we also drop out if that is true. + if (size is None and out is None) or not all_scalar: + return method(self, *orig_args, **orig_kwargs) + + flatsize = np.prod(size) + + # If the total size is too small, then just call directly + if flatsize < self._parallel_threshold: + return method(self, *orig_args, **orig_kwargs) + + # Create the output array if required + if out is None: + out = np.empty(size, dtype) + + # Figure out how to split up the array + step = int(np.ceil(flatsize / self._threads)) + + # A worker method for each thread to fill its part of the array with the + # random numbers + def _fill(gen: np.random.Generator, local_array: np.ndarray) -> None: + if has_out: + method(gen, *args, **kwargs, out=local_array) + else: + if has_dtype: + kwargs["dtype"] = dtype + local_array[:] = method( + gen, + *args, + **kwargs, + size=len(local_array), + ) + + # Generate the numbers with each worker thread + futures = [ + self._executor.submit( + _fill, + self._random_generators[i], + out.ravel()[(i * step) : ((i + 1) * step)], + ) + for i in range(self._threads) + ] + concurrent.futures.wait(futures) + + for ii, future in enumerate(futures): + if (e := future.exception()) is not None: + raise RuntimeError( + f"An exception occurred in thread {ii} (and maybe others)." + ) from e + + return out + + # Copy over the docstring for the method + _call.__doc__ = "Multithreaded version.\n" + method.__doc__ + + return _call + + def __del__(self): + self._executor.shutdown(False) + + class RandomTask(task.MPILoggedTask): - """A base class for MPI tasks that needs to generate random numbers. + """A base class for MPI tasks that need to generate random numbers. Attributes ---------- @@ -270,14 +467,19 @@ class RandomTask(task.MPILoggedTask): Set the seed for use in the task. If not set, a random seed is generated and broadcast to all ranks. The seed being used is logged, to repeat a previous run, simply set this as the seed parameter. + threads : int, optional + Set the number of threads to use for the random number generator. If not + explicitly set this will use the value of the `OMP_NUM_THREADS` environment + variable, or fall back to four. """ seed = config.Property(proptype=int, default=None) + threads = config.Property(proptype=int, default=None) _rng = None @property - def rng(self): + def rng(self) -> np.random.Generator: """A random number generator for this task. .. warning:: @@ -291,7 +493,21 @@ def rng(self): MPI jobs. """ if self._rng is None: - # Generate a new base seed for all MPI ranks + self._rng = MultithreadedRNG(self.local_seed, threads=self.threads) + + return self._rng + + _local_seed = None + + @property + def local_seed(self) -> int: + """Get the seed to be used on this rank. + + .. warning:: + Generating the seed is a collective operation if the seed is not set, + and so all ranks must participate in the first access of this property. + """ + if self._local_seed is None: if self.seed is None: # Use seed sequence to generate a random seed seed = np.random.SeedSequence().entropy @@ -308,8 +524,6 @@ def rng(self): # the class seed cls_name = f"{self.__module__}.{self.__class__.__name__}" cls_seed = zlib.adler32(cls_name.encode()) - new_seed = seed + (self.comm.rank + 1) * cls_seed + self._local_seed = seed + (self.comm.rank + 1) * cls_seed - self._rng = np.random.Generator(_default_bitgen(new_seed)) - - return self._rng + return self._local_seed diff --git a/draco/util/testing.py b/draco/util/testing.py index 4bfc77d33..75457e533 100644 --- a/draco/util/testing.py +++ b/draco/util/testing.py @@ -1,8 +1,12 @@ """draco test utils.""" +from typing import List, Optional, Tuple, Union +import numpy as np from caput import config, memh5, pipeline -from draco.core.task import SingleTask +from ..core.containers import SiderealStream +from ..core.task import SingleTask +from . import random class DummyTask(SingleTask): @@ -39,3 +43,167 @@ def process(self): self.total_len -= 1 return cont + + +def mock_freq_data( + freq: np.ndarray, + ntime: int, + delaycut: float, + ndata: Optional[int] = None, + noise: float = 0.0, + bad_freq: Optional[np.ndarray] = None, + rng: Optional[np.random.Generator] = None, +) -> Tuple[np.ndarray, np.ndarray]: + """Make mock delay data with a constant delay spectrum up to a specified cut. + + Parameters + ---------- + freq + Frequencies of each channel (in MHz). + ntime + Number of independent time samples. + delaycut + Cutoff in us. + ndata + Number of correlated data sets. If not set (i.e. `None`) then do no add a + dataset axis. + noise + RMS noise level in the data. + bad_freq + A list of bad frequencies to mask out. + rng + The random number generator to use. + + Return + ------ + data + The 2D/3D data array [dataset, freq, time]. If ndata is `None` then the dataset + axis is dropped. + weights + The 2D weights data [freq, time]. + """ + nfreq = len(freq) + ndelay = nfreq + + df = np.abs(freq[1] - freq[0]) + + delays = np.fft.fftfreq(ndelay, df) + dspec = np.where(np.abs(delays) < delaycut, 1.0, 0.0) + + # Construct a set of delay spectra + delay_spectra = random.complex_normal(size=(ntime, ndelay), rng=rng) + delay_spectra *= dspec**0.5 + + # Generate the noise realisation + outshape = (nfreq, ntime) + if ndata is not None: + outshape = (ndata, *outshape) + data = noise * random.complex_normal(size=outshape, rng=rng) + + # Transform to get frequency spectra + data += np.fft.fft(delay_spectra, axis=-1).T + + weights = np.empty(data.shape, dtype=np.float64) + weights[:] = 1.0 / noise**2 + + if bad_freq: + data[..., bad_freq, :] = 0.0 + weights[..., bad_freq, :] = 0.0 + + return data, weights + + +class RandomFreqData(random.RandomTask): + """Generate a random sidereal stream with structure in delay. + + Attributes + ---------- + num_realisation + How many to generate in subsequent process calls. + num_correlated + The number of correlated realisations output per cycle. + num_ra + The number of RA samples in the output. + num_base + The number of baselines in the output. + freq_start, freq_end + The start and end frequencies. + num_freq + The number of frequency channels. + delay_cut + The maximum delay in the data in us. + noise + The RMS noise level. + """ + + num_realisation = config.Property(proptype=int, default=1) + num_correlated = config.Property(proptype=int, default=None) + + num_ra = config.Property(proptype=int) + num_base = config.Property(proptype=int) + + freq_start = config.Property(proptype=float, default=800.0) + freq_end = config.Property(proptype=float, default=400.0) + num_freq = config.Property(proptype=int, default=1024) + + delay_cut = config.Property(proptype=float, default=0.2) + noise = config.Property(proptype=float, default=1e-5) + + def next(self) -> Union[SiderealStream, List[SiderealStream]]: + """Generate correlated sidereal streams. + + Returns + ------- + streams + Either a single stream (if num_correlated=None), or a list of correlated + streams. + """ + if self.num_realisation == 0: + raise pipeline.PipelineStopIteration() + + # Construct the frequency axis + freq = np.linspace( + self.freq_start, + self.freq_end, + self.num_freq, + endpoint=False, + ) + + streams = [] + + # Construct all the sidereal streams + for ii in range(self.num_correlated or 1): + stream = SiderealStream( + input=5, # Probably should be something smarter + freq=freq, + ra=self.num_ra, + stack=self.num_base, + ) + stream.redistribute("stack") + ssv = stream.vis[:].local_array + ssw = stream.weight[:].local_array + + streams.append((stream, ssv, ssw)) + + # Iterate over baselines and construct correlated realisations for each, and + # then insert them into each of the sidereal streams + for ii in range(ssv.shape[1]): + d, w = mock_freq_data( + freq, + self.num_ra, + self.delay_cut, + ndata=(self.num_correlated or 1), + noise=self.noise, + ) + + for jj, (_, ssv, ssw) in enumerate(streams): + ssv[:, ii] = d[jj] + ssw[:, ii] = w[jj] + + self.num_realisation -= 1 + + # Don't return a list of streams if num_correlated is None + if self.num_correlated is None: + return streams[0][0] + + return [s for s, *_ in streams]