Skip to content

Commit

Permalink
feat(sidereal): tasks to split regridder projection and deconvolution
Browse files Browse the repository at this point in the history
  • Loading branch information
ljgray committed Nov 13, 2024
1 parent 3e0ee1b commit 0139ce9
Showing 1 changed file with 350 additions and 2 deletions.
352 changes: 350 additions & 2 deletions draco/analysis/sidereal.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import numpy as np
import scipy.linalg as la
from caput import config, mpiarray, tod
from caput import config, mpiarray, tod, weighted_median
from cora.util import units

from ..core import containers, io, task
Expand Down Expand Up @@ -257,6 +257,205 @@ def _get_phase(self, freq, prod, lsd):
)


class SiderealDirtyRegridder(SiderealRegridder):
"""Take a factorized sidereal day and put it on a regular grid.
Output container includes a dirty estimate of visibilities on a regular
grid and an estimate of the inverse convolution matrix Ci based on a limited
number of modes over baselines.
"""

def process(
self, data: containers.FactorizedTimeStream
) -> containers.SiderealDirtyStream:
"""Make a dirty projection of the sidereal day.
Parameters
----------
data
Timestream data with factorized weights. Weights can be factorized
using `draco.analysis.transform.FactorizeWeights`.
Returns
-------
sdata
Sidereal stream with dirty `vis` projection and factorized
inverse signal covariance matrix
"""
self.log.info(f"Making dirty grid LSD:{data.attrs['lsd']}")

# Redistribute if needed
data.redistribute("freq")

# Convert data timestamps into LSD deltas (relative to the LSD of this day)
timestamp_lsd = self.observer.unix_to_lsd(data.time) - data.attrs["lsd"]

# Get view of data
Ni = data.weight[:].local_array
modes = data.modes[:].local_array
vis_data = data.vis[:].local_array

pad = 5 * self.lanczos_width

xh, Ci, mp = self._regrid(vis_data, Ni, modes, timestamp_lsd, pad)

# This will be padded so we have to extend the RA axis accordingly
new_samples = self.samples + 2 * pad
ra_delta = ((new_samples / self.samples) * 360 - 360) / 2
ra = np.linspace(-ra_delta, 360 + ra_delta, new_samples, endpoint=False)

# Make the new container
sdata = containers.SiderealDirtyStream(
axes_from=data,
ra=ra,
bandwidth=2 * self.lanczos_width,
)
sdata.add_dataset("mask")
sdata.redistribute("freq")

sdata.vis[:].local_array[:] = xh
sdata.noise_cov[:].local_array[:] = Ci
sdata.modes[:].local_array[:] = modes
sdata.mask[:].local_array[:] = mp

sdata.attrs["lsd"] = data.attrs["lsd"]
sdata.attrs["tag"] = f"lsd_{data.attrs['lsd']}"
# Store this so it can be removed later on
sdata.attrs["pad"] = pad

return sdata

def _regrid(self, vis_data, weight, modes, times, pad):
"""Project the visibility data onto a regular grid in RA.
Returns
-------
xh
Dirty weighted visibility projection
Ci
Covariance matrix
nw
Weight projection
mp
Mask projection
"""
# Create a regular grid, padded at either end to supress interpolation issues
interp_grid = (
np.arange(-pad, self.samples + pad, dtype=np.float64) / self.samples
)

# Construct regridding matrix for reverse problem
lzf = regrid.lanczos_forward_matrix(
interp_grid, times, self.lanczos_width
).T.copy()

# Make the projected freq-time mask
mp = (weight > 0) @ abs(lzf).T
lzf_thresh = np.mean(np.sum(np.ma.masked_where(lzf == 0, abs(lzf)), axis=1))
mp = mp < self.mask_thresh * lzf_thresh
# Make the signal covariance matrix before reshaping since the
# stack axis will be reduced
Ci = regrid.wiener_noise_covariance(lzf, weight, 2 * self.lanczos_width - 1)
# Store the final shape of the data and flatten across
# frequency and baselines
shape_ = (*vis_data.shape[:-1], interp_grid.shape[0])
# Reconstruct and flatten the weights
weight = (modes[:, :, np.newaxis] @ weight[:, np.newaxis]).reshape(
-1, weight.shape[-1]
)
# Make the dirty projection into signal space
vis_data = vis_data.reshape(-1, vis_data.shape[-1])
xh = regrid.wiener_projection(lzf, vis_data, weight).reshape(shape_)

return xh, Ci, mp


class SiderealGridDeconvolve(SiderealRegridder):
"""Deconvolve a single dirty sidereal day."""

def process(self, data):
"""Deconvolve a dirty sidereal day.
Parameters
----------
data : containers.SiderealDirtyStream
Dirty sidereal data to deconvolve
Returns
-------
sdata
Deconvolved sidereal day with padding removed.
"""
self.log.info(f"Deconvolving dirty grid LSD:{data.attrs['lsd']}")

# Redistribute if needed
data.redistribute("freq")

Ci = data.noise_cov[:].local_array
xh = data.vis[:].local_array
modes = data.modes[:].local_array

pad = data.attrs["pad"]

# Deconvolve the visibilities
xh, nr, samples = self._deconvolve(xh, Ci, modes, pad)

sdata = containers.SiderealStream(axes_from=data, ra=samples)
sdata.redistribute("freq")

# Save out the deconvolved visibilities and noise realization.
# If a `mask` dataset exists, add it to the deconvolved data
if "mask" in data.datasets:
mask = ~data.mask[:].local_array[:, np.newaxis][..., pad:-pad]
else:
mask = np.ones_like(nr, dtype=bool)

sdata.vis[:] = xh * mask
sdata.weight[:] = nr * mask

sdata.attrs["lsd"] = data.attrs["lsd"]
sdata.attrs["tag"] = f"lsd_{data.attrs['lsd']}"

return sdata

def _deconvolve(self, xh, Ci, modes, pad):
"""Deconvolve the dirty visibilities."""
nbaseline = xh.shape[1]
# Get number of RA samples and target shape without padding
samples = xh.shape[-1] - 2 * pad
shape_ = xh.shape[:-1] + (samples,)

# Flatten over frequencies and baselines
xh = xh.reshape(-1, xh.shape[-1])
# Broadcast si to the correct shape
self.si = np.broadcast_to(np.atleast_2d(self.si), xh.shape)
nw = np.zeros_like(xh, dtype=np.float32)
# Massage the modes dataset
modes = np.atleast_3d(modes)
modes = modes.reshape(-1, modes.shape[-1])

# If the last axis of `modes` is length one, use simple
# fast elementwise multiplication
_mult = np.multiply if modes.shape[-1] == 1 else np.matmul

# Iterate over frequency-baseline pairs
for ki in range(xh.shape[0]):
# Get the reconstruction of Ci for this frequency-baseline
Ci_ki = _mult(Ci[ki // nbaseline], modes[ki])
# Set the weights and remove the signal contribution
nw[ki] = Ci_ki[-1]
# Add the signal covariance and solve
Ci_ki[-1] += self.si[ki]

xh[ki] = la.solveh_banded(Ci_ki, xh[ki])

# Remove padding and reshape
xh = xh[:, pad:-pad].reshape(shape_)
nw = nw[:, pad:-pad].reshape(shape_)

return xh, nw, samples


def _search_nearest(x, xeval):
index_next = np.searchsorted(x, xeval, side="left")

Expand Down Expand Up @@ -946,7 +1145,9 @@ def process(self, sdata):
if self.weight == "uniform":
coeff = count.astype(np.float32)
self.stack.weight[:] += (coeff**2) * tools.invert_no_zero(weight)
sum_coeff = self.stack.nsample[:]
# Wrap as MPIArray for consistent behaviour since other datasets
# are distributed
sum_coeff = mpiarray.MPIArray.wrap(self.stack.nsample[:], axis=0)

else:
coeff = weight
Expand Down Expand Up @@ -1241,6 +1442,153 @@ def get_slice_to_broadcast(weight_axis, dataset_axis):
return tuple([slice(None) if ax in weight_axis else None for ax in dataset_axis])


class SiderealStackerDeconvolve(SiderealGridDeconvolve):
"""Stack up a set of dirty sidereal days and deconvolve the final product.
Attributes
----------
tag : str (default: "stack")
The tag to give the stack.
subtract_median : bool
Subtract the median of the final visibilities in RA.
min_sample : int
Minimum number of input samples in the stack to keep the
final sample. Any sample with fewer inputs than this is
flagged. Default is 1.
"""

tag = config.Property(proptype=str, default="stack")
subtract_median = config.Property(proptype=bool, default=True)
min_sample = config.Property(proptype=int, default=1)

stack = None

def process(self, data: containers.SiderealDirtyStream):
"""Stack up the dirty sidereal days and noise matrices.
Parameters
----------
data
Individual sidereal day to add to stack.
"""
data.redistribute("freq")

xh = data.vis[:].local_array
Ci = data.noise_cov[:].local_array
modes = data.modes[:].local_array

if self.stack is None:
# Accumulate stuff here. Assume that all stacked days have
# the same padding
self.xh = np.zeros_like(xh)
self.nsample = np.zeros_like(xh[:, 0], dtype=np.uint16)
self.Ci = []
self.modes = []
self.pad = data.attrs["pad"]
# Make the correct RA axis since input is padded
samples = xh.shape[-1] - 2 * self.pad
ra = np.linspace(0, 360, samples, endpoint=False)
# Don't initialize any datasets for now to save memory
self.stack = containers.SiderealStream(
axes_from=data, ra=ra, skip_datasets=True
)
self.lsd_list = []

# Accumulate the number of unflagged samples in each RA bin.
# Use a mask projection if it exists, otherwise extract
# from the weights directly
if "mask" in data.datasets:
ns = (~data.mask[:].local_array).astype(np.uint16)
else:
ns = np.any(data.weight[:].local_array > 0, axis=1).astype(np.uint16)

# Accumulate the weighted visibilities and
# sample count
self.xh += xh
self.nsample += ns

# Accumulate the signal covariance matrix and modes
self.Ci.append(Ci)
self.modes.append(modes)

# Get the CSD if available
input_lsd = data.attrs.get("lsd", data.attrs.get("csd"))

self.lsd_list += _ensure_list(input_lsd)

def process_finish(self) -> containers.SiderealStream:
"""Deconvolve and return the final stacked sidereal stream.
Returns
-------
stack
Deconvolved stack of sidereal days.
"""
# Log how much data is missing, excluding bands that are entirely
# masked due to persistent RFI
zeros = mpiarray.MPIArray.wrap(
self.nsample < self.min_sample, axis=0, comm=self.stack.comm
)
zeros.local_array[:] &= ~np.all(zeros.local_array, axis=-1)[..., np.newaxis]
n_zeros = zeros.sum().allreduce()
self.log.info(
f"{100 * n_zeros / np.prod(zeros.global_shape):.3f}% of samples missing."
)

# Stack over the last axis
Ci = np.stack(self.Ci, axis=-1)
modes = np.stack(self.modes, axis=-1)

xh, nw, _ = self._deconvolve(self.xh, Ci, modes, self.pad)

# Trim padding from the `nsample` dataset
ns = self.nsample[:, self.pad : -self.pad][:, np.newaxis]
mask = (ns >= self.min_sample).astype(np.float32)

# Multiply in the sample mask
xh *= mask
nw *= mask

# Delete all the larger datasets to save memory
# before initialising the stack datasets
del self.xh
del self.Ci
del self.modes
del self.nsample

if self.subtract_median:
# Use a weighted median to ignore partially filled bands
ww = np.broadcast_to(mask, xh.shape)
xh_h = (
weighted_median.weighted_median(np.ascontiguousarray(xh.real), ww)
+ weighted_median.weighted_median(np.ascontiguousarray(xh.imag), ww)
* 1.0j
)
# Subtract the median
xh -= xh_h[:, :, np.newaxis]
# Make sure that zeros are still zeros
xh *= mask

# Initialize the datasets now
self.stack.add_dataset("vis")
self.stack.add_dataset("vis_weight")
self.stack.add_dataset("input_flags")
self.stack.add_dataset("nsample")
self.stack.redistribute("freq")

self.stack.vis[:].local_array[:] = xh
self.stack.weight[:].local_array[:] = nw
self.stack.nsample[:] = ns
self.stack.input_flags[:] = 0.0

self.stack.attrs["lsd"] = np.array(self.lsd_list)
self.stack.attrs["count"] = len(self.lsd_list)
self.stack.attrs["min_samples"] = self.min_sample
self.stack.attrs["tag"] = self.tag

return self.stack


def _ensure_list(x):
if hasattr(x, "__iter__"):
y = list(x)
Expand Down

0 comments on commit 0139ce9

Please sign in to comment.