From 1b09ebb10872e2d4e844e54dfe01a3dac0566ef7 Mon Sep 17 00:00:00 2001 From: Liam Gray Date: Fri, 12 Apr 2024 12:26:57 -0700 Subject: [PATCH] feat(flagging): add a new Stokes I RFI mask Apply a high-pass m filter to isolate scattered emission, then apply a sumthreshold mask to the average low-m power over longer baselines. --- draco/analysis/flagging.py | 303 +++++++++++++++++++++++++++++++++++++ 1 file changed, 303 insertions(+) diff --git a/draco/analysis/flagging.py b/draco/analysis/flagging.py index a2fbb346..bd5ed044 100644 --- a/draco/analysis/flagging.py +++ b/draco/analysis/flagging.py @@ -13,7 +13,10 @@ import numpy as np import scipy.signal from caput import config, mpiarray, weighted_median +from cora.util import units +from skimage.filters import apply_hysteresis_threshold +from ..analysis import delay from ..core import containers, io, task from ..util import rfi, tools @@ -997,6 +1000,306 @@ def process( return mask_cont +class RFIStokesIMask(task.SingleTask): + """Two-stage RFI filter based on Stokes I visibilities. + + Tries to independently target transient and persistant RFI. + + Stage 1 is applied to each frequency independently. A high-pass + filter is applied in RA to isolate transient RFI. The high-pass + filtered visibilities are beamformed, and a MAD filter is applied + to the resulting map. A time/RA sample is then flagged if some + fraction of beams exceed the MAD threshold for that sample. + + Stage 2 is applied across frequencies. A low-pass filter is applied + in RA to reduce transient sky sources. The average visibility power + is taken over 2+ cylinder separation baselines to obtain a single + 1D array per frequency. These powers are gathered across all + frequencies and a basic background subtraction is applied. A high-sigma + MAD flag is used during the daytime and bright transits, and the + sumthreshold algorithm is used everywhere else. + + Attributes + ---------- + mad_base_size : list of int, optional + Median absolute deviations base window. Default is [1, 101]. + mad_dev_size : list of int, optional + Median absolute deviation median deviation window. + Default is [1, 51]. + sigma_high : float, optional + Median absolute deviations sigma threshold. Default is 8.0. + sigma_low : float, optional + Median absolute deviations low sigma threshold. A value above + this threshold is masked only if it is either larger than `sigma_high` + or it is larger than `sigma_low` AND connected to a region larger + than `sigma_high`. Default is 2.0. + frac_samples : float, optional + Fraction of flagged samples in map space above which the entire + time sample will be flagged. Default is 0.01. + st_max_m : int, optional + Maximum size of the SumThreshold window. Default is 32.0. + sigma_day : float, optional + Sigma threshold for the MAD mask applied to bright source transits + and daytime data, which is required to avoid masking out transits. + Generally this should be quite high, as the SumThreshold mask is + applied to the per-frequency median of the data in these regions, + and will catch most bright frequency bands. Default is 10.0. + lowpass_ang : float, optional + Angular cutoff of the ra lowpass filter. Default is 7.5, which + corresponds to about 30 minutes of observation time. + include_multi_channel : bool, optional + If True, include second-stage multi-channel flagging. This should + generally always be included. Default is True. + """ + + mad_base_size = config.list_type(int, length=2, default=[1, 101]) + mad_dev_size = config.list_type(int, length=2, default=[1, 51]) + sigma_high = config.Property(proptype=float, default=8.0) + sigma_low = config.Property(proptype=float, default=2.0) + frac_samples = config.Property(proptype=float, default=0.01) + + st_max_m = config.Property(proptype=int, default=32) + sigma_day = config.Property(proptype=float, default=10.0) + lowpass_ang = config.Property(proptype=float, default=7.5) + + include_multi_channel = config.Property(proptype=bool, default=True) + + def setup(self, telescope): + """Set up the baseline selections and ordering. + + Parameters + ---------- + telescope : TransitTelescope + The telescope object to use + """ + self.telescope = io.get_telescope(telescope) + + def process(self, stream): + """Make a mask from the data. + + Parameters + ---------- + stream : dcontainers.TimeStream | dcontainers.SiderealStream + Data to use when masking. Axes should be frequency, stack, + and time-like. + + Returns + ------- + mask : dcontainers.RFIMask | dcontainers.SiderealRFIMask + Time-frequency mask, where values marked `True` are flagged. + """ + stream.redistribute("freq") + + csd = stream.attrs.get("lsd", stream.attrs.get("csd")) + + if csd is None: + raise ValueError("Dataset does not have a `csd` or `lsd` attribute.") + + if "time" in stream.index_map: + times = stream.time + elif "ra" in stream.index_map: + times = self.telescope.lsd_to_unix(csd + stream.ra / 360.0) + else: + raise TypeError( + f"Expected data with `time` or `ra` axis. Got {type(stream)}." + ) + + ra = 2 * np.pi * (self.telescope.unix_to_lsd(times) - csd) + freq = stream.freq[stream.vis[:].local_bounds] + + # Get stokes I and redistribute over frequency. Axes are rearranged + # in order (baseline, freq, time) + vis, weight, baselines = delay.stokes_I(stream, self.telescope) + vis = vis.redistribute(1).local_array + weight = weight.redistribute(1).local_array + + # Set up the initial mask + mask = np.all(weight == 0, axis=0) + mask |= self._static_rfi_mask_hook(freq, times[0])[:, np.newaxis] + self.log.debug(f"{100.0 * mask.mean():.2f}% of data initially flagged.") + + # Mask scattered transient rfi for each frequency independently + # Also get the average power per frequency after applying a low-pass filter + mask, power = self.mask_single_channel(vis, weight, mask, freq, baselines, ra) + + # Gather the entire mask and power arrays + mask = mpiarray.MPIArray.wrap(mask, axis=0).allgather() + power = mpiarray.MPIArray.wrap(power, axis=0).allgather() + + if self.include_multi_channel: + # Mask high power across frequencies + mask |= self.mask_multi_channel(power, mask, times) + + self.log.debug(f"{100.0 * mask.mean():.2f}% of data flagged.") + + if "ra" in stream.index_map: + output = containers.SiderealRFIMask(axes_from=stream, attrs_from=stream) + else: + output = containers.RFIMask(axes_from=stream, attrs_from=stream) + + output.mask[:] = mask + + return output + + def mask_single_channel(self, vis, weight, mask, freq, baselines, ra): + """Mask scattered rfi.""" + # Get the per-frequency high-pass and low-pass cuts + hpf_cut = self._hpf_cut_hook(freq, baselines) + lpf_cut = self._lpf_cut_hook(freq, baselines) + # Select cylinders to include in static power estimation. + # Choose baselines which should not contain much sky structure + bl_sel = baselines[:, 0] > 2.0 * self.telescope.u_width + # Set up an array to store mean power from non-sky sources + power = np.zeros_like(weight[0], dtype=np.float64) + + # Iterate over frequencies + for fsel in range(vis.shape[1]): + if np.all(mask[fsel]): + # Frequency is already masked + continue + + # Apply a high-pass mmode filter. Scattered emission appears + # similar to an impulse function in time, so its fourier transform + # should extend to high m + v_hpf = self.apply_filter( + vis[:, fsel], weight[:, fsel], ra, hpf_cut[fsel], type_="high" + ) + + # MAD filter flags scattered emission after beamforming + map_hpf = abs(np.fft.fft(v_hpf, axis=0)) + mad_mask = np.zeros_like(v_hpf, dtype=bool) | mask[fsel][np.newaxis] + mad_ = mad(map_hpf, mad_mask, self.mad_base_size, self.mad_dev_size) + # Hysteresis threshold mask flags anything above `sigma_high` or + # anything above `sigma_low` ONLY if it is connected to a region + # above `sigma_high` + mad_mask |= apply_hysteresis_threshold( + mad_, self.sigma_low, self.sigma_high + ) + # Collapse over baselines and flag + mean_flagged = np.mean(mad_mask, axis=0) + + # Apply a low pass filter + lp_win = (mean_flagged < 0.5)[np.newaxis] + v_lpf = self.apply_filter( + vis[:, fsel], weight[:, fsel] * lp_win, ra, lpf_cut[fsel], type_="low" + ) + + # Take the average over selected baselines + power[fsel] = np.mean(abs(v_lpf)[bl_sel], axis=0) + # Apply the hp mask + mask[fsel] |= mean_flagged > self.frac_samples + + return mask, power + + def mask_multi_channel(self, power, mask, times): + """Mask slow-moving narrow-band RFI.""" + # Find times where there are bright sources transiting + source_flag = self._source_flag_hook(times) + + # Get a median for each frequency + med = weighted_median.weighted_median(power, (~mask).astype(power.dtype)) + # Set power to median when bright sources are in the sky + p = power.copy() + p[:, source_flag] = med[:, np.newaxis] + + # Subtract out a background, assuming that the type of + # rfi we're looking for is very localised in frequency + p_med = weighted_median.moving_weighted_median( + p, (~mask).astype(p.dtype), size=(11, 3) + ) + + # Mask bright data with bright sources removed + summask = rfi.sumthreshold(abs(p - p_med), start_flag=mask, max_m=self.st_max_m) + # Expand the mask in time only. Expanding in frequency generally ends + # up being too aggressive, and the single-channel flagging does a fine + # job at catching broad-spectrum transient rfi + summask |= rfi.sir((summask & ~mask)[:, np.newaxis], only_time=True)[:, 0] + + # Extra masking over bright sources + mad_ = mad(power[:, source_flag], summask[:, source_flag]) + # Combine with the sumthreshold mask + summask[:, source_flag] |= mad_ > self.sigma_day + + # Expand the mask to try to fill small holes in heavily masked areas. + # The values used here are determined experimentally + kf = scipy.signal.windows.gaussian(11, std=5)[:, np.newaxis] + kt = scipy.signal.windows.gaussian(51, std=7)[np.newaxis] + kernel = (kf * kt) ** 0.5 + + mm = scipy.signal.oaconvolve(summask, kernel, mode="same") + summask |= mm > 0.75 * kernel.sum() + + return summask + + @staticmethod + def apply_filter(vis, weight, samples, fcut, type_="high"): + """Apply a high-pass or low-pass mmode filter.""" + # Median sampling rate + fs = 1 / np.median(abs(np.diff(samples))) + # Order is sample frequency over cutoff frequency. Ensure order is odd + order = int(np.ceil(fs / fcut) // 2 * 2 + 1) + # Make the window. Flattop seems to work well here + kernel = scipy.signal.firwin(order, fcut, window="flattop", fs=fs)[np.newaxis] + + # Low-pass filter the visibilities. `oaconvolve` is significantly + # faster than the standard convolve method + vw_lp = scipy.signal.oaconvolve(vis * weight, kernel, mode="same") + ww_lp = scipy.signal.oaconvolve(weight, kernel, mode="same") + vis_lp = vw_lp * tools.invert_no_zero(ww_lp) + + if type_ == "high": + return vis - vis_lp + + return vis_lp + + def _hpf_cut_hook(self, freq, baselines): + """Get a high-pass fringe rate cut for each frequency.""" + dec = np.deg2rad(self.telescope.latitude) + lambda_inv = freq[:, np.newaxis] * 1e6 / units.c + + # Maximum cut per frequency + return lambda_inv * baselines[:, 0].max() / np.cos(dec) + + def _lpf_cut_hook(self, freq, baselines): + """Get a low-pass fringe rate cut for each frequency.""" + cut = 1 / np.deg2rad(self.lowpass_ang) + + return np.ones(len(freq), dtype=np.float64) * cut + + def _static_rfi_mask_hook(self, freq, timestamp=None): + """Override to mask entire frequency channels. + + Parameters + ---------- + freq : np.ndarray[nfreq] + 1D array of frequencies in the data (in MHz). + + timestamp : np.array[float] + Start observing time (in unix time) + + Returns + ------- + mask : np.ndarray[nfreq] + Mask array. True will mask a frequency channel. + """ + return np.zeros_like(freq, dtype=bool) + + def _source_flag_hook(self, times): + """Override to mask out bright point sources. + + Parameters + ---------- + times : np.ndarray[float] + Array of timestamps. + + Returns + ------- + mask : np.ndarray[float] + Mask array. True will mask out a time sample. + """ + return np.zeros_like(times, dtype=bool) + + class RFISensitivityMask(task.SingleTask): """Slightly less crappy RFI masking.