Skip to content

Commit

Permalink
refactor(flagging): updates to ThresholdVisWeightBaseline
Browse files Browse the repository at this point in the history
  • Loading branch information
ljgray committed Jan 17, 2024
1 parent fc3ee24 commit 6584b9e
Showing 1 changed file with 120 additions and 55 deletions.
175 changes: 120 additions & 55 deletions draco/analysis/flagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,14 @@
"""

from typing import Union, overload

import numpy as np
import scipy.signal
from scipy import stats

from caput import config, weighted_median, mpiarray
from caput import config, mpiarray, weighted_median
from caput.tools import invert_no_zero

from ..core import task, containers, io
from ..util import tools
from ..util import rfi
from ..core import containers, io, task
from ..util import rfi, tools


class DayMask(task.SingleTask):
Expand Down Expand Up @@ -838,19 +836,23 @@ class ThresholdVisWeightBaseline(task.SingleTask):
relative_threshold : float, optional
Any weights with values less than this number times the average weight
will be set to zero. Default: 1e-6.
ignore_absolute_threshold : float, optional
Any weights with values less than this number will be ignored when
taking averages and constructing the mask. Default: 0.0.
exclude_zeros : bool, optional
Any weights which are already zero will be ignored when
taking averages and constructing the mask. Default: True
pols_to_flag : string, optional
Which polarizations to flag. "copol" only flags XX and YY baselines,
while "all" flags everything. Default: "all".
per_freq_thresh: bool, optional
If True, make a threshold for each frequency and baseline. Otherwise,
use a single threshold for each baseline.
"""

average_type = config.enum(["median", "mean"], default="median")
absolute_threshold = config.Property(proptype=float, default=1e-7)
relative_threshold = config.Property(proptype=float, default=1e-6)
ignore_absolute_threshold = config.Property(proptype=float, default=0.0)
exclude_zeros = config.Property(proptype=bool, default=True)
pols_to_flag = config.enum(["all", "copol"], default="all")
per_freq_thresh = config.Property(proptype=bool, default=False)

def setup(self, telescope):
"""Set the telescope model.
Expand All @@ -862,10 +864,17 @@ def setup(self, telescope):
"""
self.telescope = io.get_telescope(telescope)

@overload
def process(
self,
stream,
) -> Union[containers.BaselineMask, containers.SiderealBaselineMask]:
self, stream: containers.SiderealContainer
) -> containers.SiderealBaselineMask:
...

@overload
def process(self, stream: containers.TimeStream) -> containers.BaselineMask:
...

def process(self, stream):
"""Construct baseline-dependent mask.
Parameters
Expand All @@ -878,12 +887,6 @@ def process(
out : `BaselineMask` or `SiderealBaselineMask`
The output baseline-dependent mask.
"""
from mpi4py import MPI

# Only redistribute the weight dataset, because CorrData containers will have
# other parallel datasets without a stack axis
stream.weight.redistribute(axis=1)

# Make the output container, depending on input type
if "ra" in stream.axes:
mask_cont = containers.SiderealBaselineMask(
Expand All @@ -896,30 +899,67 @@ def process(
f"Task requires TimeStream, SiderealStream, or CorrData. Got {type(stream)}"
)

# Redistribute output container along stack axis
# Redistribute output container and input weights along stack axis. No need to
# redistribute the rest of the input container
mask_cont.redistribute("stack")

# Get local section of weights
stream.weight.redistribute(axis=1)
# Get the local view of the weights
local_weight = stream.weight[:].local_array

# For each baseline (axis=1), take average over non-ignored time/freq samples
average_func = np.ma.median if self.average_type == "median" else np.ma.mean
average_weight = average_func(
np.ma.array(
local_weight, mask=(local_weight <= self.ignore_absolute_threshold)
),
axis=(0, 2),
).data
# If a per-frequency threshold is desired, only average over the time-like axis
axis_ = (2,) if self.per_freq_thresh else (0, 2)

if self.average_type == "mean":
if self.exclude_zeros:
# Using a masked array is faster than using the standard
# mean with `where` kwarg. It also avoids a nasty warning
# when there is no data in a slice
average_weight = np.ma.mean(
np.ma.masked_array(local_weight, mask=(local_weight == 0.0)),
axis=axis_,
).data
else:
average_weight = np.mean(local_weight, axis=axis_)
else:
# Use the median. This is a bit more complicated since it uses a fast,
# parallel weighted median across the last axis
dtype_ = local_weight.dtype
if len(axis_) == 2:
# We're taking the median across axes 0 and 2
# Create the output array
average_weight = np.zeros(local_weight.shape[1], dtype=dtype_)
# Make a default weight mask if nothing is excluded
if not self.exclude_zeros:
where_ = np.ones(local_weight[:, 0].size, dtype=dtype_)
# Iterate over baselines. This is faster and more memory efficient
# than reshaping and flattening the entire array
for ki in range(len(average_weight)):
lw_ = local_weight[:, ki].ravel()
if self.exclude_zeros:
where_ = (lw_ > 0.0).astype(dtype_)
average_weight[ki] = weighted_median.weighted_median(lw_, where_)
else:
# Only taking the median across the last axis, so no need
# to do any reshaping
where_ = (
(local_weight > 0.0).astype(dtype_)
if self.exclude_zeros
else np.ones_like(local_weight)
)
average_weight = weighted_median.weighted_median(local_weight, where_)

# Figure out which entries to keep
threshold = np.maximum(
self.absolute_threshold, self.relative_threshold * average_weight
)

# Compute the mask, excluding samples that we want to ignore
local_mask = (local_weight < threshold[np.newaxis, :, np.newaxis]) & (
local_weight > self.ignore_absolute_threshold
)
# Compute the mask and save it out
local_mask = mask_cont.mask[:].local_array
local_mask[:] = local_weight < np.expand_dims(threshold, axis=axis_)

# Exclude zeros from the mask
if self.exclude_zeros:
local_mask[:] &= local_weight > 0.0

# If only flagging co-pol baselines, make separate mask to select those,
# and multiply into low-weight mask
Expand All @@ -938,20 +978,16 @@ def process(
local_pol_mask = (pol_a == pol_b)[np.newaxis, :, np.newaxis]

# Apply pol mask to low-weight mask
local_mask *= local_pol_mask
local_mask[:] *= local_pol_mask

# Compute the fraction of data that will be masked
local_mask_sum = np.sum(local_mask)
global_mask_total = np.zeros_like(local_mask_sum)
stream.comm.Allreduce(local_mask_sum, global_mask_total, op=MPI.SUM)
mask_frac = global_mask_total / float(np.prod(stream.weight.global_shape))
# Log the additional fraction of data that will be masked. This
# does not include data that was already flagged
# MPI.SUM is the default operation for `allreduce`
local_flagged = local_mask & (local_weight != 0)
mask_sum = mpiarray.MPIArray.wrap(local_flagged, axis=1).sum().allreduce()
mask_frac = mask_sum / float(np.prod(stream.weight.global_shape))

self.log.info(
"%0.5f%% of data is below the weight threshold" % (100.0 * mask_frac)
)

# Save mask to output container
mask_cont.mask[:] = mpiarray.MPIArray.wrap(local_mask, axis=1)
self.log.info(f"Additional {100 * mask_frac:.5f}% of data will be flagged.")

# Distribute back across frequency
mask_cont.redistribute("freq")
Expand All @@ -965,12 +1001,32 @@ class CollapseBaselineMask(task.SingleTask):
The output is a frequency/time mask that is True for any freq/time sample
for which any baseline is masked in the input mask.
Attributes
----------
method : string
Which method to use when reducing. 'any' means that a sample will be masked
if it is flagged for any baseline. 'all' means that a sample will only be
masked if it is flagged for all baselines.
frac_flagged : float
If `method` is `frac`, this sets the threshold for the number of baselines
which must be flagged in a sample in order to mask the entire sample.
"""

method = config.enum(["all", "any", "frac"], default="any")
frac_flagged = config.Property(proptype=float, default=0.5)

@overload
def process(self, baseline_mask: containers.BaselineMask) -> containers.RFIMask:
...

@overload
def process(
self,
baseline_mask: Union[containers.BaselineMask, containers.SiderealBaselineMask],
) -> Union[containers.RFIMask, containers.SiderealRFIMask]:
self, baseline_mask: containers.SiderealBaselineMask
) -> containers.SiderealRFIMask:
...

def process(self, baseline_mask):
"""Collapse input mask over baseline axis.
Parameters
Expand All @@ -996,21 +1052,29 @@ def process(
axes_from=baseline_mask, attrs_from=baseline_mask
)

# Log the amount of data currently masked
mask_sum = baseline_mask.mask[:].sum().allreduce()
mask_frac = mask_sum / float(np.prod(baseline_mask.mask.global_shape))
self.log.info(f"{100.0 * mask_frac:.5f}% of data masked before reduction.")

# Get local section of baseline-dependent mask
local_mask = baseline_mask.mask[:].local_array

# Collapse along stack axis
local_mask = np.any(local_mask, axis=1)
if self.method == "frac":
local_mask = (
np.sum(local_mask, axis=1) > self.frac_flagged * local_mask.shape[1]
)
else:
reduce_ = getattr(np, self.method)
local_mask = reduce_(local_mask, axis=1)

# Gather full mask on each rank
full_mask = mpiarray.MPIArray.wrap(local_mask, axis=0).allgather()

# Log the percent of freq/time samples masked
drop_frac = np.sum(full_mask) / np.prod(full_mask.shape)
self.log.info(
f"After baseline collapse: {100.0 * drop_frac:.1f}%% of data"
" is below the weight threshold"
)
mask_frac = np.sum(full_mask) / np.prod(full_mask.shape)
self.log.info(f"{100.0 * mask_frac:.5f}% of data masked after reduction.")

mask_cont.mask[:] = full_mask

Expand Down Expand Up @@ -1746,6 +1810,7 @@ class BlendStack(task.SingleTask):

def setup(self, data_stack):
"""Set the stacked data.
Parameters
----------
data_stack : VisContainer
Expand Down

0 comments on commit 6584b9e

Please sign in to comment.