Skip to content

Commit

Permalink
feat(interpolate): add an optional mask to select different inpaintin…
Browse files Browse the repository at this point in the history
…g samples
  • Loading branch information
ljgray committed Jan 22, 2025
1 parent f744c5b commit 46c94db
Showing 1 changed file with 33 additions and 5 deletions.
38 changes: 33 additions & 5 deletions draco/analysis/interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,18 @@ class DPSSInpaint(task.SingleTask):
flag_above_cutoff = config.Property(proptype=bool, default=False)
copy = config.Property(proptype=bool, default=True)

def setup(self, mask=None):
"""Use an optional mask dataset.
Parameters
----------
mask : containers.RFIMask, optional
Container used to select samples to inpaint. If
not provided, inpaint samples where the data
weights are zero.
"""
self.mask = mask

def process(self, data):
"""Inpaint visibility data.
Expand Down Expand Up @@ -97,6 +109,11 @@ def inpaint(self, vis, weight, samples):
vobs, vaxind = _flatten_axes(vis, (*self.iter_axes, self.axis))
wobs, waxind = _flatten_axes(weight, (*self.iter_axes, self.axis))

if self.mask is not None:
mobs, _ = _flatten_axes(self.mask.mask, (*self.iter_axes, self.axis))
# Invert the mask to avoid doing it every loop
mobs = ~mobs

# Pre-allocate the full output array
vinp = np.zeros_like(vobs)
winp = np.zeros_like(wobs)
Expand All @@ -112,13 +129,16 @@ def inpaint(self, vis, weight, samples):
for ii in range(vobs.shape[0]):
# Get the correct basis for each slice
A = modes[amap[ii]]
# Write to the preallocated output array
W = wobs[ii] > 0

# Get a selection for data to keep
M = wobs[ii] > 0
W = mobs if self.mask is not None else M

vinp[ii], winp[ii] = dpss.inpaint(vobs[ii], wobs[ii], A, W, self.snr_cov)

# Re-flag gaps above the cutoff width
if self.flag_above_cutoff:
winp[ii] *= dpss.flag_above_cutoff(W, cutoff)
winp[ii] *= dpss.flag_above_cutoff(M, cutoff)

# Reshape and move the interpolation axis back
vinp = _inv_move_front(vinp, vaxind, vis.local_shape)
Expand Down Expand Up @@ -170,7 +190,7 @@ class DPSSInpaintDelay(DPSSInpaint):
extra_cut = config.Property(proptype=float, default=0.0)
telescope_orientation = config.enum(["NS", "EW", "none"], default="NS")

def setup(self, telescope):
def setup(self, telescope, mask=None):
"""Load a telescope object.
This is required to establish baseline-dependent
Expand All @@ -180,8 +200,14 @@ def setup(self, telescope):
----------
telescope : TransitTelescope
Telescope object with baseline information.
mask : containers.RFIMask, optional
Container used to select samples to inpaint. If
not provided, inpaint samples where the data
weights are zero.
"""
self.telescope = io.get_telescope(telescope)
# Pass the mask to the parent class
super().setup(mask)

def _set_sel(self, data):
"""Set the local baselines."""
Expand Down Expand Up @@ -253,7 +279,9 @@ def _flatten_axes(data, axes):
f"but axes {axes} were requested."
)

return _move_front(data[:].local_array, axind, data.local_shape), axind
ds = data[:].view(np.ndarray)

return _move_front(ds, axind, ds.shape), axind


def _move_front(arr: np.ndarray, axis: int | list, shape: tuple) -> np.ndarray:
Expand Down

0 comments on commit 46c94db

Please sign in to comment.