Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Delay cross 1 #254

Merged
merged 7 commits into from
Dec 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
231 changes: 227 additions & 4 deletions draco/analysis/delay.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@ def _evaluate(self, data_view, weight_view, out_cont):

Returns
-------
out_cont : `contaiers.DelayTransform` or `containers.DelaySpectrum`
out_cont : `containers.DelayTransform` or `containers.DelaySpectrum`
Output delay spectrum or delay power spectrum.
"""
nbase = out_cont.spectrum.global_shape[0]
Expand Down Expand Up @@ -1024,6 +1024,34 @@ def fourier_matrix_c2c(N, fsel=None):
return F


def fourier_matrix(N: int, fsel: Optional[np.ndarray] = None) -> np.ndarray:
"""Generate a Fourier matrix to represent a real to complex FFT.

Parameters
----------
N : integer
Length of timestream that we are transforming to. Must be even.
fsel : array_like, optional
Indexes of the frequency channels to include in the transformation
matrix. By default, assume all channels.

Returns
-------
Fr : np.ndarray
An array performing the Fourier transform from a real time series to
frequencies packed as alternating real and imaginary elements,
"""
if fsel is None:
fa = np.arange(N)
else:
fa = np.array(fsel)

fa = fa[:, np.newaxis]
ta = np.arange(N)[np.newaxis, :]

return np.exp(-2.0j * np.pi * ta * fa / N)


def _complex_to_alternating_real(array):
"""View complex numbers as an array with alternating real and imaginary components.

Expand Down Expand Up @@ -1211,7 +1239,7 @@ def _draw_signal_sample_f(S):
# then doing a matrix solve
y = np.dot(FTNih, data + w2) + Si[:, np.newaxis] ** 0.5 * w1

return la.solve(Ci, y, sym_pos=True)
return la.solve(Ci, y, assume_a="pos")

def _draw_signal_sample_t(S):
# This method is fastest if the number of delays is larger than the number of
Expand Down Expand Up @@ -1240,7 +1268,7 @@ def _draw_signal_sample_t(S):
# Perform the solve step (rather than explicitly using the inverse)
y = data + w2 - np.dot(R, w1)
Ci = np.identity(2 * Ni.shape[0]) + np.dot(R, Rt)
x = la.solve(Ci, y, sym_pos=True)
x = la.solve(Ci, y, assume_a="pos")

return Sh[:, np.newaxis] * (np.dot(Rt, x) + w1)

Expand Down Expand Up @@ -1277,6 +1305,201 @@ def _draw_ps_sample(d):
return spec


def delay_spectrum_gibbs_cross(
data: np.ndarray,
N: int,
Ni: np.ndarray,
initial_S: np.ndarray,
window: str = "nuttall",
fsel: Optional[np.ndarray] = None,
niter: int = 20,
rng: Optional[np.random.Generator] = None,
) -> list[np.ndarray]:
"""Estimate the delay power spectrum by Gibbs sampling.

This routine estimates the spectrum at the `N` delay samples conjugate to
an input frequency spectrum with ``N/2 + 1`` channels (if the delay spectrum is
assumed real) or `N` channels (if the delay spectrum is assumed complex).
A subset of these channels can be specified using the `fsel` argument.

Parameters
----------
data
A 3D array of [dataset, sample, freq]. The delay cross-power spectrum of these
will be calculated.
N
The length of the output delay spectrum. There are assumed to be `N/2 + 1`
total frequency channels if assuming a real delay spectrum, or `N` channels
for a complex delay spectrum.
Ni
Inverse noise variance as a 3D [dataset, sample, freq] array.
initial_S
The initial delay cross-power spectrum guess. A 3D array of [data1, data2,
delay].
window : one of {'nuttall', 'blackman_nuttall', 'blackman_harris', None}, optional
Apply an apodisation function. Default: 'nuttall'.
fsel
Indices of channels that we have data at. By default assume all channels.
niter
Number of Gibbs samples to generate.
rng
A generator to use to produce the random samples.

Returns
-------
spec : list
List of cross-power spectrum samples.
"""
# Get reference to RNG

if rng is None:
rng = random.default_rng()

spec = []

nd, nsamp, Nf = data.shape

if fsel is None:
fsel = np.arange(Nf)
elif len(fsel) != Nf:
raise ValueError(
"Length of frequency selection must match frequencies passed. "
f"{len(fsel)} != {data.shape[-1]}"
)

# Construct the Fourier matrix
F = fourier_matrix(N, fsel)

if nd == 0:
raise ValueError("Need at least one set of data")

# We want the sample axis to be last
data = data.transpose(0, 2, 1)

# Window the frequency data
if window is not None:
# Construct the window function
x = fsel * 1.0 / N
w = tools.window_generalised(x, window=window)

# Apply to the projection matrix and the data
F *= w[:, np.newaxis]
data *= w[:, np.newaxis]

# Create the transpose of the Fourier matrix weighted by the noise
# (this is used multiple times)
# This is packed as a single freq -> delay projection per dataset
FTNih = F.T[np.newaxis, :, :] * Ni[:, np.newaxis, :] ** 0.5

# This should be an array for each dataset i of F_i^H N_i^{-1} F_i
FTNiF = np.zeros((nd, N, nd, N), dtype=np.complex128)
for ii in range(nd):
FTNiF[ii, :, ii] = FTNih[ii] @ FTNih[ii].T.conj()

# Pre-whiten the data to save doing it repeatedly
data *= Ni[:, :, np.newaxis] ** 0.5

# Set the initial guess for the delay power spectrum.
S_samp = initial_S

def _draw_signal_sample_f(S):
# Draw a random sample of the signal (delay spectrum) assuming a Gaussian model
# with a given delay power spectrum `S`. Do this using the perturbed Wiener
# filter approach

# This method is fastest if the number of frequencies is larger than the number
# of delays we are solving for. Typically this isn't true, so we probably want
# `_draw_signal_sample_t`

Si = np.empty_like(S)
Sh = np.empty((N, nd, nd), dtype=S.dtype)

for ii in range(N):
inv = la.inv(S[:, :, ii])
Si[:, :, ii] = inv
Sh[ii, :, :] = la.cholesky(S[:, :, ii], lower=False)

Ci = FTNiF.copy()
for ii in range(nd):
for jj in range(nd):
Ci[ii, :, jj] += np.diag(Si[ii, jj])

w1 = random.standard_complex_normal((N, nd, nsamp), rng=rng)
w2 = random.standard_complex_normal(data.shape, rng=rng)

# Construct the random signal sample by forming a perturbed vector and
# then doing a matrix solve
y = FTNih @ (data + w2)

for ii in range(N):
w1s = la.solve_triangular(
Sh[ii],
w1[ii],
overwrite_b=True,
lower=False,
check_finite=False,
)
y[:, ii] += w1s
# NOTE: Other combinations that you might think would work don't appear to
# be stable. Don't try these:
# y[:, ii] += Si[:, :, ii] @ Sh[:, :, ii] @ w1[:, ii]
# y[:, ii] += Shi[:, :, ii] @ w1[:, ii]

cf = la.cho_factor(
Ci.reshape(nd * N, nd * N),
overwrite_a=True,
check_finite=False,
)

return la.cho_solve(
cf,
y.reshape(nd * N, nsamp),
overwrite_b=True,
check_finite=False,
).reshape(nd, N, nsamp)

def _draw_signal_sample_t(S):
# This method is fastest if the number of delays is larger than the number of
# frequencies. This is usually the regime we are in.
raise NotImplementedError("Drawing samples in the time basis not yet written.")

def _draw_ps_sample(d):
# Draw a random delay power spectrum sample assuming the signal is Gaussian and
# we have a flat prior on the power spectrum.
# This means drawing from a inverse chi^2.

# Estimate the sample covariance
S = np.empty((nd, nd, N), dtype=np.complex128)
for ii in range(N):
S[:, :, ii] = np.cov(d[:, ii], bias=True)

# Then in place draw a sample of the true covariance from the posterior which
# is an inverse Wishart
for ii in range(N):
Si = la.inv(S[:, :, ii])
Si_samp = random.complex_wishart(Si, nsamp, rng=rng) / nsamp
S[:, :, ii] = la.inv(Si_samp)

return S

# Select the method to use for the signal sample based on how many frequencies
# versus delays there are. At the moment only the _f method is implemented.
_draw_signal_sample = _draw_signal_sample_f

# Perform the Gibbs sampling iteration for a given number of loops and
# return the power spectrum output of them.
try:
for ii in range(niter):
d_samp = _draw_signal_sample(S_samp)
S_samp = _draw_ps_sample(d_samp)

spec.append(S_samp)
except la.LinAlgError as e:
raise RuntimeError("Exiting earlier as singular") from e

return spec


# Alias delay_spectrum_gibbs to delay_power_spectrum_gibbs, for backwards compatibility
delay_spectrum_gibbs = delay_power_spectrum_gibbs

Expand Down Expand Up @@ -1341,7 +1564,7 @@ def delay_spectrum_wiener_filter(

# Solve the linear equation for the Wiener-filtered spectrum, and transpose to
# [average_axis, delay]
y_spec = la.solve(Ci, y, sym_pos=True).T
y_spec = la.solve(Ci, y, assume_a="pos").T

if complex_timedomain:
y_spec = _alternating_real_to_complex(y_spec)
Expand Down
21 changes: 21 additions & 0 deletions draco/core/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2432,6 +2432,27 @@ def weight(self):
return self.datasets["weight"]


class DelayCrossSpectrum(DelaySpectrum):
"""Container for a delay cross power spectra."""

_axes = ("dataset",)

_dataset_spec = {
"spectrum": {
"axes": ["dataset", "dataset", "baseline", "delay"],
"dtype": np.float64,
"initialise": True,
"distributed": True,
"distributed_axis": "baseline",
}
}

@property
def spectrum(self):
"""Get the spectrum dataset."""
return self.datasets["spectrum"]


class Powerspectrum2D(ContainerBase):
"""Container for a 2D cartesian power spectrum.

Expand Down
12 changes: 12 additions & 0 deletions draco/core/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,3 +298,15 @@ def setup(self, input_):
def next(self, input_):
"""Immediately forward any input."""
return input_


class PassOn(task.MPILoggedTask):
"""Unconditionally forward a tasks input.

While this seems like a pointless no-op it's useful for connecting tasks in complex
topologies.
"""

def next(self, input_):
"""Immediately forward any input."""
return input_
Loading
Loading