Skip to content

Commit

Permalink
feat(flagging): add a new Stokes I RFI mask
Browse files Browse the repository at this point in the history
Apply a high-pass m filter to isolate scattered emission, then apply
a sumthreshold mask to the average low-m power over longer baselines.
  • Loading branch information
ljgray committed May 13, 2024
1 parent 7c051b1 commit 1b09ebb
Showing 1 changed file with 303 additions and 0 deletions.
303 changes: 303 additions & 0 deletions draco/analysis/flagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
import numpy as np
import scipy.signal
from caput import config, mpiarray, weighted_median
from cora.util import units
from skimage.filters import apply_hysteresis_threshold

from ..analysis import delay
from ..core import containers, io, task
from ..util import rfi, tools

Expand Down Expand Up @@ -997,6 +1000,306 @@ def process(
return mask_cont


class RFIStokesIMask(task.SingleTask):
"""Two-stage RFI filter based on Stokes I visibilities.
Tries to independently target transient and persistant RFI.
Stage 1 is applied to each frequency independently. A high-pass
filter is applied in RA to isolate transient RFI. The high-pass
filtered visibilities are beamformed, and a MAD filter is applied
to the resulting map. A time/RA sample is then flagged if some
fraction of beams exceed the MAD threshold for that sample.
Stage 2 is applied across frequencies. A low-pass filter is applied
in RA to reduce transient sky sources. The average visibility power
is taken over 2+ cylinder separation baselines to obtain a single
1D array per frequency. These powers are gathered across all
frequencies and a basic background subtraction is applied. A high-sigma
MAD flag is used during the daytime and bright transits, and the
sumthreshold algorithm is used everywhere else.
Attributes
----------
mad_base_size : list of int, optional
Median absolute deviations base window. Default is [1, 101].
mad_dev_size : list of int, optional
Median absolute deviation median deviation window.
Default is [1, 51].
sigma_high : float, optional
Median absolute deviations sigma threshold. Default is 8.0.
sigma_low : float, optional
Median absolute deviations low sigma threshold. A value above
this threshold is masked only if it is either larger than `sigma_high`
or it is larger than `sigma_low` AND connected to a region larger
than `sigma_high`. Default is 2.0.
frac_samples : float, optional
Fraction of flagged samples in map space above which the entire
time sample will be flagged. Default is 0.01.
st_max_m : int, optional
Maximum size of the SumThreshold window. Default is 32.0.
sigma_day : float, optional
Sigma threshold for the MAD mask applied to bright source transits
and daytime data, which is required to avoid masking out transits.
Generally this should be quite high, as the SumThreshold mask is
applied to the per-frequency median of the data in these regions,
and will catch most bright frequency bands. Default is 10.0.
lowpass_ang : float, optional
Angular cutoff of the ra lowpass filter. Default is 7.5, which
corresponds to about 30 minutes of observation time.
include_multi_channel : bool, optional
If True, include second-stage multi-channel flagging. This should
generally always be included. Default is True.
"""

mad_base_size = config.list_type(int, length=2, default=[1, 101])
mad_dev_size = config.list_type(int, length=2, default=[1, 51])
sigma_high = config.Property(proptype=float, default=8.0)
sigma_low = config.Property(proptype=float, default=2.0)
frac_samples = config.Property(proptype=float, default=0.01)

st_max_m = config.Property(proptype=int, default=32)
sigma_day = config.Property(proptype=float, default=10.0)
lowpass_ang = config.Property(proptype=float, default=7.5)

include_multi_channel = config.Property(proptype=bool, default=True)

def setup(self, telescope):
"""Set up the baseline selections and ordering.
Parameters
----------
telescope : TransitTelescope
The telescope object to use
"""
self.telescope = io.get_telescope(telescope)

def process(self, stream):
"""Make a mask from the data.
Parameters
----------
stream : dcontainers.TimeStream | dcontainers.SiderealStream
Data to use when masking. Axes should be frequency, stack,
and time-like.
Returns
-------
mask : dcontainers.RFIMask | dcontainers.SiderealRFIMask
Time-frequency mask, where values marked `True` are flagged.
"""
stream.redistribute("freq")

csd = stream.attrs.get("lsd", stream.attrs.get("csd"))

if csd is None:
raise ValueError("Dataset does not have a `csd` or `lsd` attribute.")

if "time" in stream.index_map:
times = stream.time
elif "ra" in stream.index_map:
times = self.telescope.lsd_to_unix(csd + stream.ra / 360.0)
else:
raise TypeError(
f"Expected data with `time` or `ra` axis. Got {type(stream)}."
)

ra = 2 * np.pi * (self.telescope.unix_to_lsd(times) - csd)
freq = stream.freq[stream.vis[:].local_bounds]

# Get stokes I and redistribute over frequency. Axes are rearranged
# in order (baseline, freq, time)
vis, weight, baselines = delay.stokes_I(stream, self.telescope)
vis = vis.redistribute(1).local_array
weight = weight.redistribute(1).local_array

# Set up the initial mask
mask = np.all(weight == 0, axis=0)
mask |= self._static_rfi_mask_hook(freq, times[0])[:, np.newaxis]
self.log.debug(f"{100.0 * mask.mean():.2f}% of data initially flagged.")

# Mask scattered transient rfi for each frequency independently
# Also get the average power per frequency after applying a low-pass filter
mask, power = self.mask_single_channel(vis, weight, mask, freq, baselines, ra)

# Gather the entire mask and power arrays
mask = mpiarray.MPIArray.wrap(mask, axis=0).allgather()
power = mpiarray.MPIArray.wrap(power, axis=0).allgather()

if self.include_multi_channel:
# Mask high power across frequencies
mask |= self.mask_multi_channel(power, mask, times)

self.log.debug(f"{100.0 * mask.mean():.2f}% of data flagged.")

if "ra" in stream.index_map:
output = containers.SiderealRFIMask(axes_from=stream, attrs_from=stream)
else:
output = containers.RFIMask(axes_from=stream, attrs_from=stream)

output.mask[:] = mask

return output

def mask_single_channel(self, vis, weight, mask, freq, baselines, ra):
"""Mask scattered rfi."""
# Get the per-frequency high-pass and low-pass cuts
hpf_cut = self._hpf_cut_hook(freq, baselines)
lpf_cut = self._lpf_cut_hook(freq, baselines)
# Select cylinders to include in static power estimation.
# Choose baselines which should not contain much sky structure
bl_sel = baselines[:, 0] > 2.0 * self.telescope.u_width
# Set up an array to store mean power from non-sky sources
power = np.zeros_like(weight[0], dtype=np.float64)

# Iterate over frequencies
for fsel in range(vis.shape[1]):
if np.all(mask[fsel]):
# Frequency is already masked
continue

# Apply a high-pass mmode filter. Scattered emission appears
# similar to an impulse function in time, so its fourier transform
# should extend to high m
v_hpf = self.apply_filter(
vis[:, fsel], weight[:, fsel], ra, hpf_cut[fsel], type_="high"
)

# MAD filter flags scattered emission after beamforming
map_hpf = abs(np.fft.fft(v_hpf, axis=0))
mad_mask = np.zeros_like(v_hpf, dtype=bool) | mask[fsel][np.newaxis]
mad_ = mad(map_hpf, mad_mask, self.mad_base_size, self.mad_dev_size)
# Hysteresis threshold mask flags anything above `sigma_high` or
# anything above `sigma_low` ONLY if it is connected to a region
# above `sigma_high`
mad_mask |= apply_hysteresis_threshold(
mad_, self.sigma_low, self.sigma_high
)
# Collapse over baselines and flag
mean_flagged = np.mean(mad_mask, axis=0)

# Apply a low pass filter
lp_win = (mean_flagged < 0.5)[np.newaxis]
v_lpf = self.apply_filter(
vis[:, fsel], weight[:, fsel] * lp_win, ra, lpf_cut[fsel], type_="low"
)

# Take the average over selected baselines
power[fsel] = np.mean(abs(v_lpf)[bl_sel], axis=0)
# Apply the hp mask
mask[fsel] |= mean_flagged > self.frac_samples

return mask, power

def mask_multi_channel(self, power, mask, times):
"""Mask slow-moving narrow-band RFI."""
# Find times where there are bright sources transiting
source_flag = self._source_flag_hook(times)

# Get a median for each frequency
med = weighted_median.weighted_median(power, (~mask).astype(power.dtype))
# Set power to median when bright sources are in the sky
p = power.copy()
p[:, source_flag] = med[:, np.newaxis]

# Subtract out a background, assuming that the type of
# rfi we're looking for is very localised in frequency
p_med = weighted_median.moving_weighted_median(
p, (~mask).astype(p.dtype), size=(11, 3)
)

# Mask bright data with bright sources removed
summask = rfi.sumthreshold(abs(p - p_med), start_flag=mask, max_m=self.st_max_m)
# Expand the mask in time only. Expanding in frequency generally ends
# up being too aggressive, and the single-channel flagging does a fine
# job at catching broad-spectrum transient rfi
summask |= rfi.sir((summask & ~mask)[:, np.newaxis], only_time=True)[:, 0]

# Extra masking over bright sources
mad_ = mad(power[:, source_flag], summask[:, source_flag])
# Combine with the sumthreshold mask
summask[:, source_flag] |= mad_ > self.sigma_day

# Expand the mask to try to fill small holes in heavily masked areas.
# The values used here are determined experimentally
kf = scipy.signal.windows.gaussian(11, std=5)[:, np.newaxis]
kt = scipy.signal.windows.gaussian(51, std=7)[np.newaxis]
kernel = (kf * kt) ** 0.5

mm = scipy.signal.oaconvolve(summask, kernel, mode="same")
summask |= mm > 0.75 * kernel.sum()

return summask

@staticmethod
def apply_filter(vis, weight, samples, fcut, type_="high"):
"""Apply a high-pass or low-pass mmode filter."""
# Median sampling rate
fs = 1 / np.median(abs(np.diff(samples)))
# Order is sample frequency over cutoff frequency. Ensure order is odd
order = int(np.ceil(fs / fcut) // 2 * 2 + 1)
# Make the window. Flattop seems to work well here
kernel = scipy.signal.firwin(order, fcut, window="flattop", fs=fs)[np.newaxis]

# Low-pass filter the visibilities. `oaconvolve` is significantly
# faster than the standard convolve method
vw_lp = scipy.signal.oaconvolve(vis * weight, kernel, mode="same")
ww_lp = scipy.signal.oaconvolve(weight, kernel, mode="same")
vis_lp = vw_lp * tools.invert_no_zero(ww_lp)

if type_ == "high":
return vis - vis_lp

return vis_lp

def _hpf_cut_hook(self, freq, baselines):
"""Get a high-pass fringe rate cut for each frequency."""
dec = np.deg2rad(self.telescope.latitude)
lambda_inv = freq[:, np.newaxis] * 1e6 / units.c

# Maximum cut per frequency
return lambda_inv * baselines[:, 0].max() / np.cos(dec)

def _lpf_cut_hook(self, freq, baselines):
"""Get a low-pass fringe rate cut for each frequency."""
cut = 1 / np.deg2rad(self.lowpass_ang)

return np.ones(len(freq), dtype=np.float64) * cut

def _static_rfi_mask_hook(self, freq, timestamp=None):
"""Override to mask entire frequency channels.
Parameters
----------
freq : np.ndarray[nfreq]
1D array of frequencies in the data (in MHz).
timestamp : np.array[float]
Start observing time (in unix time)
Returns
-------
mask : np.ndarray[nfreq]
Mask array. True will mask a frequency channel.
"""
return np.zeros_like(freq, dtype=bool)

def _source_flag_hook(self, times):
"""Override to mask out bright point sources.
Parameters
----------
times : np.ndarray[float]
Array of timestamps.
Returns
-------
mask : np.ndarray[float]
Mask array. True will mask out a time sample.
"""
return np.zeros_like(times, dtype=bool)


class RFISensitivityMask(task.SingleTask):
"""Slightly less crappy RFI masking.
Expand Down

0 comments on commit 1b09ebb

Please sign in to comment.