Skip to content

Commit

Permalink
feat(flagging): more flexible chi-squared masking (#294)
Browse files Browse the repository at this point in the history
* Enable masking of the chi-squared test statistic derived from
  HybridVisStream and RingMap containers.
* Option to estimate the variance of the test statistic using a
  local median absolute deviation.
* Option to only mask positive excursions in the test statistic.
  • Loading branch information
ssiegelx authored Sep 24, 2024
1 parent 533da75 commit 4f72ab4
Showing 1 changed file with 44 additions and 10 deletions.
54 changes: 44 additions & 10 deletions draco/analysis/flagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -1367,6 +1367,13 @@ class RFIMaskChisqHighDelay(task.SingleTask):
from the median filtered version is greater than this number
of expected standard deviations given the number of degrees
of freedom (i.e., number of baselines).
estimate_var : bool
Estimate the variance in the test statistic using the median
absolute deviation over a region defined by the win_t and
win_f parameters.
only_positive : bool
Only mask large postive excursions in the test statistic,
leaving large negative excursions unmasked.
"""

reg_arpls = config.Property(proptype=float, default=1e5)
Expand All @@ -1375,6 +1382,8 @@ class RFIMaskChisqHighDelay(task.SingleTask):
win_t = config.Property(proptype=int, default=601)
win_f = config.Property(proptype=int, default=1)
nsigma_2d = config.Property(proptype=float, default=5.0)
estimate_var = config.Property(proptype=bool, default=False)
only_positive = config.Property(proptype=bool, default=False)

def setup(self, telescope=None):
"""Save telescope object for time calculations.
Expand All @@ -1394,7 +1403,8 @@ def process(self, stream):
Parameters
----------
stream : dcontainers.TimeStream | dcontainers.SiderealStream
stream : dcontainers.TimeStream | dcontainers.SiderealStream |
dcontainers.HybridVisStream | dcontainers.RingMap
Container holding a chi-squared test statistic in the visibility dataset.
A weighted average will be taken over any axis that is not time/ra or frequency.
Expand Down Expand Up @@ -1430,19 +1440,29 @@ def process(self, stream):
else:
timestamp = stream.time

# Expand the weight dataset so that it can broadcast against the data dataset.
# Assumes that weight contains a subset of the axes in data, with the shared
# axes in the same order. This is true for all of the supported input
# containers listed in the docstring.
dax = list(stream.data.attrs["axis"])
wax = list(stream.weight.attrs["axis"])
wshp = [stream.weight.shape[wax.index(ax)] if ax in wax else 1 for ax in dax]
wshp[dax.index("freq")] = None

# Extract the shape of the axes that are missing from the weights dataset,
# so that we can scale the denominator by this factor.
wshp_missing = [sz for sz, ax in zip(stream.data.shape, dax) if ax not in wax]
wfactor = np.prod(wshp_missing) if len(wshp_missing) > 0 else 1.0

# Sum over any axis that is neither time nor frequency
axsum = tuple(
[
ii
for ii, ax in enumerate(stream.vis.attrs["axis"])
if ax not in ["freq", "time", "ra"]
]
[ii for ii, ax in enumerate(dax) if ax not in ["freq", "time", "ra"]]
)

chisq = stream.vis[:].real
weight = stream.weight[:]
chisq = stream.data[:].real
weight = stream.weight[:].reshape(*wshp)

wsum = np.sum(weight, axis=axsum)
wsum = wfactor * np.sum(weight, axis=axsum)
chisq = np.sum(weight * chisq, axis=axsum) * tools.invert_no_zero(wsum)

# Gather all frequencies on all nodes
Expand Down Expand Up @@ -1553,7 +1573,21 @@ def mask_2d(self, y, w):

# Calculate the deviation from the median, normalized by the
# expected standard deviation
dy = np.abs(y - med_y) * np.sqrt(w)
dy = (y - med_y) * np.sqrt(w)

# If requested, estimate the variance in the test statistic
# using the median absolute deviation.
if self.estimate_var:
f = np.ascontiguousarray((w > 0.0).astype(np.float64))
mad_y = 1.48625 * weighted_median.moving_weighted_median(
np.abs(dy), f, win_size
)
dy *= tools.invert_no_zero(mad_y)

# Take the absolute value of the relative excursion unless
# explicitely requested to only flag positive excursions.
if not self.only_positive:
dy = np.abs(dy)

# Flag times and frequencies that deviate by more than some threshold
return dy > self.nsigma_2d
Expand Down

0 comments on commit 4f72ab4

Please sign in to comment.