Skip to content

Commit

Permalink
Add a parallactic angle gain type.
Browse files Browse the repository at this point in the history
  • Loading branch information
JSKenyon committed Nov 7, 2024
1 parent b0fa235 commit ce8e375
Show file tree
Hide file tree
Showing 20 changed files with 265 additions and 70 deletions.
10 changes: 5 additions & 5 deletions quartical/calibration/constructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


term_spec_tup = namedtuple("term_spec_tup", "name type shape pshape")
aux_info_fields = ("SCAN_NUMBER", "FIELD_ID", "DATA_DESC_ID")
log_info_fields = ("SCAN_NUMBER", "FIELD_ID", "DATA_DESC_ID")


def construct_solver(
Expand Down Expand Up @@ -54,9 +54,9 @@ def construct_solver(
corr_mode = data_xds.sizes["corr"]

block_id_arr = get_block_id_arr(data_col)
aux_block_info = {
k: data_xds.attrs.get(k, "?") for k in aux_info_fields
}
data_xds_meta = data_xds.attrs.copy()
for k in log_info_fields:
data_xds_meta[k] = data_xds_meta.get(k, "?")

# Grab the number of input chunks - doing this on the data should be
# safe.
Expand Down Expand Up @@ -87,7 +87,7 @@ def construct_solver(
)
blocker.add_input("term_spec_list", spec_list, ("row", "chan"))
blocker.add_input("corr_mode", corr_mode)
blocker.add_input("aux_block_info", aux_block_info)
blocker.add_input("data_xds_meta", data_xds_meta)
blocker.add_input("solver_opts", solver_opts)
blocker.add_input("chain", chain)

Expand Down
27 changes: 16 additions & 11 deletions quartical/calibration/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def solver_wrapper(
solver_opts,
chain,
block_id_arr,
aux_block_info,
data_xds_meta,
corr_mode,
**kwargs
):
Expand Down Expand Up @@ -108,11 +108,11 @@ def solver_wrapper(
# Perform term specific setup e.g. init gains and params.
if term.is_parameterized:
gains, gain_flags, params, param_flags = term.init_term(
term_spec, ref_ant, ms_kwargs, term_kwargs
term_spec, ref_ant, ms_kwargs, term_kwargs, meta=data_xds_meta
)
else:
gains, gain_flags = term.init_term(
term_spec, ref_ant, ms_kwargs, term_kwargs
term_spec, ref_ant, ms_kwargs, term_kwargs, meta=data_xds_meta
)
# Dummy arrays with standard dtypes - aids compilation.
params = np.empty(term_pshape, dtype=np.float64)
Expand Down Expand Up @@ -190,6 +190,7 @@ def solver_wrapper(
for ind, (term, iters) in enumerate(zip(cycle(chain), iter_recipe)):

active_term = chain.index(term)
active_spec = term_spec_list[term_ind]

ms_fields = term.ms_inputs._fields
ms_inputs = term.ms_inputs(
Expand Down Expand Up @@ -219,13 +220,17 @@ def solver_wrapper(
term.solve_per
)

jhj, conv_iter, conv_perc = term.solver(
ms_inputs,
mapping_inputs,
chain_inputs,
meta_inputs,
corr_mode
)
if term.solver:
jhj, conv_iter, conv_perc = term.solver(
ms_inputs,
mapping_inputs,
chain_inputs,
meta_inputs,
corr_mode
)
else:
jhj = np.zeros(getattr(active_spec, "pshape", active_spec.shape))
conv_iter, conv_perc = 0, 1

# If reweighting is enabled, do it when the epoch changes, except
# for the final epoch - we don't reweight if we won't solve again.
Expand Down Expand Up @@ -269,7 +274,7 @@ def solver_wrapper(
corr_mode
)

log_chisq(presolve_chisq, postsolve_chisq, aux_block_info, block_id)
log_chisq(presolve_chisq, postsolve_chisq, data_xds_meta, block_id)

results_dict["presolve_chisq"] = presolve_chisq
results_dict["postsolve_chisq"] = postsolve_chisq
Expand Down
1 change: 1 addition & 0 deletions quartical/config/gain_schema.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ gain:
- crosshand_phase
- crosshand_phase_null_v
- leakage
- parallactic_angle
info:
Type of gain to solve for.

Expand Down
82 changes: 41 additions & 41 deletions quartical/data_handling/angles.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,46 @@
_thread_local = threading.local()


def assign_parangle_data(ms_path, data_xds_list):

anttab = xds_from_storage_table(ms_path + "::ANTENNA")[0]
feedtab = xds_from_storage_table(ms_path + "::FEED")[0]
fieldtab = xds_from_storage_table(ms_path + "::FIELD")[0]

# We do the following eagerly to reduce graph complexity.
feeds = feedtab.POLARIZATION_TYPE.values
unique_feeds = np.unique(feeds)

if np.all([feed in "XxYy" for feed in unique_feeds]):
feed_type = "linear"
elif np.all([feed in "LlRr" for feed in unique_feeds]):
feed_type = "circular"
else:
raise ValueError("Unsupported feed type/configuration.")

phase_dirs = fieldtab.PHASE_DIR.values

updated_data_xds_list = []
for xds in data_xds_list:
xds = xds.assign(
{
"RECEPTOR_ANGLE": (
("ant", "feed"), clone(feedtab.RECEPTOR_ANGLE.data)
),
"POSITION": (
("ant", "xyz"),
clone(anttab.POSITION.data)
)
}
)
xds.attrs["FEED_TYPE"] = feed_type
xds.attrs["FIELD_CENTRE"] = tuple(phase_dirs[xds.FIELD_ID, 0])

updated_data_xds_list.append(xds)

return updated_data_xds_list


def make_parangle_xds_list(ms_path, data_xds_list):
"""Create a list of xarray.Datasets containing the parallactic angles."""

Expand Down Expand Up @@ -266,7 +306,7 @@ def nb_apply_parangle_rot(data_col, parangles, utime_ind, ant1_col, ant2_col,
v1_imul_v2 = factories.v1_imul_v2_factory(corr_mode)
v1_imul_v2ct = factories.v1_imul_v2ct_factory(corr_mode)
valloc = factories.valloc_factory(corr_mode)
rotmat = rotation_factory(corr_mode, feed_type)
rotmat = factories.rotation_factory(corr_mode, feed_type)

def impl(data_col, parangles, utime_ind, ant1_col, ant2_col,
corr_mode, feed_type):
Expand Down Expand Up @@ -299,43 +339,3 @@ def impl(data_col, parangles, utime_ind, ant1_col, ant2_col,
return data_col

return impl


def rotation_factory(corr_mode, feed_type):

if feed_type.literal_value == "circular":
if corr_mode.literal_value == 4:
def impl(rot0, rot1, out):
out[0] = np.exp(-1j*rot0)
out[1] = 0
out[2] = 0
out[3] = np.exp(1j*rot1)
elif corr_mode.literal_value == 2: # TODO: Is this sensible?
def impl(rot0, rot1, out):
out[0] = np.exp(-1j*rot0)
out[1] = np.exp(1j*rot1)
elif corr_mode.literal_value == 1: # TODO: Is this sensible?
def impl(rot0, rot1, out):
out[0] = np.exp(-1j*rot0)
else:
raise ValueError("Unsupported number of correlations.")
elif feed_type.literal_value == "linear":
if corr_mode.literal_value == 4:
def impl(rot0, rot1, out):
out[0] = np.cos(rot0)
out[1] = np.sin(rot0)
out[2] = -np.sin(rot1)
out[3] = np.cos(rot1)
elif corr_mode.literal_value == 2: # TODO: Is this sensible?
def impl(rot0, rot1, out):
out[0] = np.cos(rot0)
out[1] = np.cos(rot1)
elif corr_mode.literal_value == 1: # TODO: Is this sensible?
def impl(rot0, rot1, out):
out[0] = np.cos(rot0)
else:
raise ValueError("Unsupported number of correlations.")
else:
raise ValueError("Unsupported feed type.")

return factories.qcjit(impl)
4 changes: 3 additions & 1 deletion quartical/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
preprocess_xds_list,
postprocess_xds_list)
from quartical.data_handling.model_handler import add_model_graph
from quartical.data_handling.angles import make_parangle_xds_list
from quartical.data_handling.angles import (make_parangle_xds_list,
assign_parangle_data)
from quartical.calibration.calibrate import add_calibration_graph
from quartical.statistics.statistics import make_stats_xds_list
from quartical.statistics.logging import log_summary_stats
Expand Down Expand Up @@ -110,6 +111,7 @@ def _execute(exitstack):

# Preprocess the xds_list - initialise some values and fix bad data.
data_xds_list = preprocess_xds_list(data_xds_list, ms_opts)
data_xds_list = assign_parangle_data(ms_opts.path, data_xds_list)

# Make a list of datasets containing the parallactic angles as these
# can be expensive to compute and may be used several times. NOTE: At
Expand Down
4 changes: 3 additions & 1 deletion quartical/gains/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from quartical.gains.crosshand_phase import CrosshandPhase, CrosshandPhaseNullV
from quartical.gains.leakage import Leakage
from quartical.gains.delay_and_tec import DelayAndTec
from quartical.gains.parallactic_angle import ParallacticAngle


TERM_TYPES = {
Expand All @@ -24,5 +25,6 @@
"crosshand_phase": CrosshandPhase,
"crosshand_phase_null_v": CrosshandPhaseNullV,
"leakage": Leakage,
"delay_and_tec": DelayAndTec
"delay_and_tec": DelayAndTec,
"parallactic_angle": ParallacticAngle
}
2 changes: 1 addition & 1 deletion quartical/gains/amplitude/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def make_param_names(cls, correlations):

return [f"amplitude_{c}" for c in param_corr]

def init_term(self, term_spec, ref_ant, ms_kwargs, term_kwargs):
def init_term(self, term_spec, ref_ant, ms_kwargs, term_kwargs, meta=None):
"""Initialise the gains (and parameters)."""

gains, gain_flags, params, param_flags = super().init_term(
Expand Down
2 changes: 1 addition & 1 deletion quartical/gains/crosshand_phase/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def make_param_names(cls, correlations):

return [f"crosshand_phase_{c}" for c in param_corr]

def init_term(self, term_spec, ref_ant, ms_kwargs, term_kwargs):
def init_term(self, term_spec, ref_ant, ms_kwargs, term_kwargs, meta=None):
"""Initialise the gains (and parameters)."""

gains, gain_flags, params, param_flags = super().init_term(
Expand Down
2 changes: 1 addition & 1 deletion quartical/gains/delay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def make_param_names(cls, correlations):

return [n.format(c) for c in param_corr for n in template]

def init_term(self, term_spec, ref_ant, ms_kwargs, term_kwargs):
def init_term(self, term_spec, ref_ant, ms_kwargs, term_kwargs, meta=None):
"""Initialise the gains (and parameters)."""

gains, gain_flags, params, param_flags = super().init_term(
Expand Down
2 changes: 1 addition & 1 deletion quartical/gains/delay_and_offset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def make_param_names(cls, correlations):

return [n.format(c) for c in param_corr for n in template]

def init_term(self, term_spec, ref_ant, ms_kwargs, term_kwargs):
def init_term(self, term_spec, ref_ant, ms_kwargs, term_kwargs, meta=None):
"""Initialise the gains (and parameters)."""

gains, gain_flags, params, param_flags = super().init_term(
Expand Down
2 changes: 1 addition & 1 deletion quartical/gains/delay_and_tec/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def make_param_names(cls, correlations):

return [n.format(c) for c in param_corr for n in template]

def init_term(self, term_spec, ref_ant, ms_kwargs, term_kwargs):
def init_term(self, term_spec, ref_ant, ms_kwargs, term_kwargs, meta=None):
"""Initialise the gains (and parameters)."""

gains, gain_flags, params, param_flags = super().init_term(
Expand Down
2 changes: 1 addition & 1 deletion quartical/gains/gain.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def _make_dir_map(cls, n_dir, direction_dependent):

return dir_map

def init_term(self, term_spec, ref_ant, ms_kwargs, term_kwargs):
def init_term(self, term_spec, ref_ant, ms_kwargs, term_kwargs, meta=None):
"""Initialise the gains (and parameters)."""

(_, _, gain_shape, _) = term_spec
Expand Down
40 changes: 40 additions & 0 deletions quartical/gains/general/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,3 +798,43 @@ def impl(a, b, out):
out[3, 3] = a11 * b11

return qcjit(impl)


def rotation_factory(corr_mode, feed_type):

if feed_type.literal_value == "circular":
if corr_mode.literal_value == 4:
def impl(rot0, rot1, out):
out[0] = np.exp(-1j*rot0)
out[1] = 0
out[2] = 0
out[3] = np.exp(1j*rot1)
elif corr_mode.literal_value == 2: # TODO: Is this sensible?
def impl(rot0, rot1, out):
out[0] = np.exp(-1j*rot0)
out[1] = np.exp(1j*rot1)
elif corr_mode.literal_value == 1: # TODO: Is this sensible?
def impl(rot0, rot1, out):
out[0] = np.exp(-1j*rot0)
else:
raise ValueError("Unsupported number of correlations.")
elif feed_type.literal_value == "linear":
if corr_mode.literal_value == 4:
def impl(rot0, rot1, out):
out[0] = np.cos(rot0)
out[1] = np.sin(rot0)
out[2] = -np.sin(rot1)
out[3] = np.cos(rot1)
elif corr_mode.literal_value == 2: # TODO: Is this sensible?
def impl(rot0, rot1, out):
out[0] = np.cos(rot0)
out[1] = np.cos(rot1)
elif corr_mode.literal_value == 1: # TODO: Is this sensible?
def impl(rot0, rot1, out):
out[0] = np.cos(rot0)
else:
raise ValueError("Unsupported number of correlations.")
else:
raise ValueError("Unsupported feed type.")

return qcjit(impl)
Loading

0 comments on commit ce8e375

Please sign in to comment.