diff --git a/draco/analysis/flagging.py b/draco/analysis/flagging.py index 2b57247a7..2c21b1ecb 100644 --- a/draco/analysis/flagging.py +++ b/draco/analysis/flagging.py @@ -1367,6 +1367,13 @@ class RFIMaskChisqHighDelay(task.SingleTask): from the median filtered version is greater than this number of expected standard deviations given the number of degrees of freedom (i.e., number of baselines). + estimate_var : bool + Estimate the variance in the test statistic using the median + absolute deviation over a region defined by the win_t and + win_f parameters. + only_positive : bool + Only mask large postive excursions in the test statistic, + leaving large negative excursions unmasked. """ reg_arpls = config.Property(proptype=float, default=1e5) @@ -1375,6 +1382,8 @@ class RFIMaskChisqHighDelay(task.SingleTask): win_t = config.Property(proptype=int, default=601) win_f = config.Property(proptype=int, default=1) nsigma_2d = config.Property(proptype=float, default=5.0) + estimate_var = config.Property(proptype=bool, default=False) + only_positive = config.Property(proptype=bool, default=False) def setup(self, telescope=None): """Save telescope object for time calculations. @@ -1394,7 +1403,8 @@ def process(self, stream): Parameters ---------- - stream : dcontainers.TimeStream | dcontainers.SiderealStream + stream : dcontainers.TimeStream | dcontainers.SiderealStream | + dcontainers.HybridVisStream | dcontainers.RingMap Container holding a chi-squared test statistic in the visibility dataset. A weighted average will be taken over any axis that is not time/ra or frequency. @@ -1430,19 +1440,29 @@ def process(self, stream): else: timestamp = stream.time + # Expand the weight dataset so that it can broadcast against the data dataset. + # Assumes that weight contains a subset of the axes in data, with the shared + # axes in the same order. This is true for all of the supported input + # containers listed in the docstring. + dax = list(stream.data.attrs["axis"]) + wax = list(stream.weight.attrs["axis"]) + wshp = [stream.weight.shape[wax.index(ax)] if ax in wax else 1 for ax in dax] + wshp[dax.index("freq")] = None + + # Extract the shape of the axes that are missing from the weights dataset, + # so that we can scale the denominator by this factor. + wshp_missing = [sz for sz, ax in zip(stream.data.shape, dax) if ax not in wax] + wfactor = np.prod(wshp_missing) if len(wshp_missing) > 0 else 1.0 + # Sum over any axis that is neither time nor frequency axsum = tuple( - [ - ii - for ii, ax in enumerate(stream.vis.attrs["axis"]) - if ax not in ["freq", "time", "ra"] - ] + [ii for ii, ax in enumerate(dax) if ax not in ["freq", "time", "ra"]] ) - chisq = stream.vis[:].real - weight = stream.weight[:] + chisq = stream.data[:].real + weight = stream.weight[:].reshape(*wshp) - wsum = np.sum(weight, axis=axsum) + wsum = wfactor * np.sum(weight, axis=axsum) chisq = np.sum(weight * chisq, axis=axsum) * tools.invert_no_zero(wsum) # Gather all frequencies on all nodes @@ -1553,7 +1573,21 @@ def mask_2d(self, y, w): # Calculate the deviation from the median, normalized by the # expected standard deviation - dy = np.abs(y - med_y) * np.sqrt(w) + dy = (y - med_y) * np.sqrt(w) + + # If requested, estimate the variance in the test statistic + # using the median absolute deviation. + if self.estimate_var: + f = np.ascontiguousarray((w > 0.0).astype(np.float64)) + mad_y = 1.48625 * weighted_median.moving_weighted_median( + np.abs(dy), f, win_size + ) + dy *= tools.invert_no_zero(mad_y) + + # Take the absolute value of the relative excursion unless + # explicitely requested to only flag positive excursions. + if not self.only_positive: + dy = np.abs(dy) # Flag times and frequencies that deviate by more than some threshold return dy > self.nsigma_2d