Skip to content

Commit

Permalink
Merge v0.2.4 dev branch for release. (#349)
Browse files Browse the repository at this point in the history
* Add experimental combined delay and tec solver. (#339)

* Commit initial attempt at delay and tec solver.

* Add poor test for delay and tec.

* Add missing coerce_literal calls. (#341)

* Make schema consistent with that required by cult-cargo. (#346)

* Update depdendencies  (#345)

* Partiallu update dependencies/lock while waiting for dask-ms release.

* Remove 3.9 from test matrix.

* Bump dependency versions.

* Fix bad python version specification.

* Fix for reindex change. Update Poetry lock.

* Remove 3.12 support for now due to difficulties in tigger-lsm.

* Update lock.

* Add support for Python3.12 (#348)

* Add python3.12 support.

* Update lock.

* Further lock update.

* Serparate versioning for astro-tigger-lsm when using python 3.12.

* EXPERIMENTAL: Add a crosshand phase solver which expolits the zero Stokes V assumption.  (#344)

* Initial commit of null V WIP.

* Seemingly working implementation of solver exploiting the null-v trick.

* Add new term type to allowed gain types.

* Add a parallactic angle gain type.

* Add test for parallactic angle term. Fix minor bugs.

* Remove unused imports.

* Improve removal of coords/attts added by quartical during writes.

* remove question marks when plotting (#347)

* remove question marks when plotting (treated as wildcard by ls in radio-padre)

* Replace question marks with N/A.

---------

Co-authored-by: Jonathan Kenyon <[email protected]>

* Mark test as xfail die to sign ambiguity.

---------

Co-authored-by: Landman Bester <[email protected]>

---------

Co-authored-by: Landman Bester <[email protected]>
  • Loading branch information
JSKenyon and landmanbester authored Nov 29, 2024
1 parent 6c1cd95 commit a971cc6
Show file tree
Hide file tree
Showing 35 changed files with 4,588 additions and 1,694 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
strategy:
matrix:
os: [ubuntu-20.04, ubuntu-22.04]
python-version: ["3.9", "3.10", "3.11"]
python-version: ["3.10", "3.11", "3.12"]

steps:
- name: Set up Python ${{ matrix.python-version }}
Expand Down
3,535 changes: 1,936 additions & 1,599 deletions poetry.lock

Large diffs are not rendered by default.

25 changes: 14 additions & 11 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,24 @@ include = [
]

[tool.poetry.dependencies]
python = ">=3.9, <3.12"
astro-tigger-lsm = ">=1.7.2, <=1.7.3"
codex-africanus = {extras = ["dask", "scipy", "astropy", "python-casacore"], version = ">=0.3.6, <=0.3.6"}
python = ">=3.10, <3.13"
astro-tigger-lsm = [
{ version = ">=1.7.2, <=1.7.3", python = "<3.12" },
{ version = ">=1.7.4, <=1.7.4", python = ">=3.12"}
]
codex-africanus = {extras = ["dask", "scipy", "astropy", "python-casacore"], version = ">=0.4.1, <=0.4.1"}
colorama = ">=0.4.6, <=0.4.6"
columnar = ">=1.4.1, <=1.4.1"
dask = {extras = ["diagnostics"], version = ">=2023.5.0, <=2024.4.2"}
dask-ms = {extras = ["s3", "xarray", "zarr"], version = ">=0.2.20, <=0.2.20"}
distributed = ">=2023.5.0, <=2024.4.2"
dask = {extras = ["diagnostics"], version = ">=2023.5.0, <=2024.10.0"}
dask-ms = {extras = ["s3", "xarray", "zarr"], version = ">=0.2.23, <=0.2.23"}
distributed = ">=2023.5.0, <=2024.10.0"
loguru = ">=0.7.0, <=0.7.2"
matplotlib = ">=3.5.1, <=3.8.2"
matplotlib = ">=3.5.1, <=3.9.2"
omegaconf = ">=2.3.0, <=2.3.0"
pytest = ">=7.3.1, <=7.4.4"
requests = ">=2.31.0, <=2.31.0"
"ruamel.yaml" = ">=0.17.26, <=0.17.40"
stimela = "^2.0rc17" # Volatile - be less strict.
pytest = ">=7.3.1, <=8.3.3"
requests = ">=2.31.0, <=2.32.3"
"ruamel.yaml" = ">=0.17.26, <=0.18.6"
stimela = ">=2.0"
tbump = ">=6.10.0, <=6.11.0"

[tool.poetry.scripts]
Expand Down
3 changes: 3 additions & 0 deletions quartical/apps/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,9 @@ def to_plot_dict(xdsl, iter_attrs):


def _plot(group, xds, args):
# get rid of question marks
qstrip = lambda x: x.replace('?', 'N/A')
group = tuple(map(qstrip, group))

xds = xds.compute(scheduler="single-threaded")

Expand Down
10 changes: 5 additions & 5 deletions quartical/calibration/constructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


term_spec_tup = namedtuple("term_spec_tup", "name type shape pshape")
aux_info_fields = ("SCAN_NUMBER", "FIELD_ID", "DATA_DESC_ID")
log_info_fields = ("SCAN_NUMBER", "FIELD_ID", "DATA_DESC_ID")


def construct_solver(
Expand Down Expand Up @@ -54,9 +54,9 @@ def construct_solver(
corr_mode = data_xds.sizes["corr"]

block_id_arr = get_block_id_arr(data_col)
aux_block_info = {
k: data_xds.attrs.get(k, "?") for k in aux_info_fields
}
data_xds_meta = data_xds.attrs.copy()
for k in log_info_fields:
data_xds_meta[k] = data_xds_meta.get(k, "?")

# Grab the number of input chunks - doing this on the data should be
# safe.
Expand Down Expand Up @@ -87,7 +87,7 @@ def construct_solver(
)
blocker.add_input("term_spec_list", spec_list, ("row", "chan"))
blocker.add_input("corr_mode", corr_mode)
blocker.add_input("aux_block_info", aux_block_info)
blocker.add_input("data_xds_meta", data_xds_meta)
blocker.add_input("solver_opts", solver_opts)
blocker.add_input("chain", chain)

Expand Down
27 changes: 16 additions & 11 deletions quartical/calibration/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def solver_wrapper(
solver_opts,
chain,
block_id_arr,
aux_block_info,
data_xds_meta,
corr_mode,
**kwargs
):
Expand Down Expand Up @@ -108,11 +108,11 @@ def solver_wrapper(
# Perform term specific setup e.g. init gains and params.
if term.is_parameterized:
gains, gain_flags, params, param_flags = term.init_term(
term_spec, ref_ant, ms_kwargs, term_kwargs
term_spec, ref_ant, ms_kwargs, term_kwargs, meta=data_xds_meta
)
else:
gains, gain_flags = term.init_term(
term_spec, ref_ant, ms_kwargs, term_kwargs
term_spec, ref_ant, ms_kwargs, term_kwargs, meta=data_xds_meta
)
# Dummy arrays with standard dtypes - aids compilation.
params = np.empty(term_pshape, dtype=np.float64)
Expand Down Expand Up @@ -190,6 +190,7 @@ def solver_wrapper(
for ind, (term, iters) in enumerate(zip(cycle(chain), iter_recipe)):

active_term = chain.index(term)
active_spec = term_spec_list[term_ind]

ms_fields = term.ms_inputs._fields
ms_inputs = term.ms_inputs(
Expand Down Expand Up @@ -219,13 +220,17 @@ def solver_wrapper(
term.solve_per
)

jhj, conv_iter, conv_perc = term.solver(
ms_inputs,
mapping_inputs,
chain_inputs,
meta_inputs,
corr_mode
)
if term.solver:
jhj, conv_iter, conv_perc = term.solver(
ms_inputs,
mapping_inputs,
chain_inputs,
meta_inputs,
corr_mode
)
else:
jhj = np.zeros(getattr(active_spec, "pshape", active_spec.shape))
conv_iter, conv_perc = 0, 1

# If reweighting is enabled, do it when the epoch changes, except
# for the final epoch - we don't reweight if we won't solve again.
Expand Down Expand Up @@ -269,7 +274,7 @@ def solver_wrapper(
corr_mode
)

log_chisq(presolve_chisq, postsolve_chisq, aux_block_info, block_id)
log_chisq(presolve_chisq, postsolve_chisq, data_xds_meta, block_id)

results_dict["presolve_chisq"] = presolve_chisq
results_dict["postsolve_chisq"] = postsolve_chisq
Expand Down
12 changes: 9 additions & 3 deletions quartical/config/argument_schema.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
input_ms:
path:
required: true
dtype: str
dtype: URI
writable: true
info:
Path to input measurement set.

Expand Down Expand Up @@ -164,15 +165,20 @@ input_model:
output:
gain_directory:
default: gains.qc
dtype: str
dtype: URI
writable: true
must_exist: false
write_parent_dir: true
info:
Name of directory in which QuartiCal gain outputs will be stored.
Accepts both local and s3 paths. QuartiCal will always produce gain
outputs.

log_directory:
default: logs.qc
dtype: str
dtype: Directory
writable: true
must_exist: false
info:
Name of directory in which QuartiCal logging outputs will be stored.
s3 is not currently supported for these outputs.
Expand Down
3 changes: 3 additions & 0 deletions quartical/config/gain_schema.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,15 @@ gain:
- amplitude
- delay
- delay_and_offset
- delay_and_tec
- phase
- tec_and_offset
- rotation_measure
- rotation
- crosshand_phase
- crosshand_phase_null_v
- leakage
- parallactic_angle
info:
Type of gain to solve for.

Expand Down
82 changes: 41 additions & 41 deletions quartical/data_handling/angles.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,46 @@
_thread_local = threading.local()


def assign_parangle_data(ms_path, data_xds_list):

anttab = xds_from_storage_table(ms_path + "::ANTENNA")[0]
feedtab = xds_from_storage_table(ms_path + "::FEED")[0]
fieldtab = xds_from_storage_table(ms_path + "::FIELD")[0]

# We do the following eagerly to reduce graph complexity.
feeds = feedtab.POLARIZATION_TYPE.values
unique_feeds = np.unique(feeds)

if np.all([feed in "XxYy" for feed in unique_feeds]):
feed_type = "linear"
elif np.all([feed in "LlRr" for feed in unique_feeds]):
feed_type = "circular"
else:
raise ValueError("Unsupported feed type/configuration.")

phase_dirs = fieldtab.PHASE_DIR.values

updated_data_xds_list = []
for xds in data_xds_list:
xds = xds.assign(
{
"RECEPTOR_ANGLE": (
("ant", "feed"), clone(feedtab.RECEPTOR_ANGLE.data)
),
"POSITION": (
("ant", "xyz"),
clone(anttab.POSITION.data)
)
}
)
xds.attrs["FEED_TYPE"] = feed_type
xds.attrs["FIELD_CENTRE"] = tuple(phase_dirs[xds.FIELD_ID, 0])

updated_data_xds_list.append(xds)

return updated_data_xds_list


def make_parangle_xds_list(ms_path, data_xds_list):
"""Create a list of xarray.Datasets containing the parallactic angles."""

Expand Down Expand Up @@ -266,7 +306,7 @@ def nb_apply_parangle_rot(data_col, parangles, utime_ind, ant1_col, ant2_col,
v1_imul_v2 = factories.v1_imul_v2_factory(corr_mode)
v1_imul_v2ct = factories.v1_imul_v2ct_factory(corr_mode)
valloc = factories.valloc_factory(corr_mode)
rotmat = rotation_factory(corr_mode, feed_type)
rotmat = factories.rotation_factory(corr_mode, feed_type)

def impl(data_col, parangles, utime_ind, ant1_col, ant2_col,
corr_mode, feed_type):
Expand Down Expand Up @@ -299,43 +339,3 @@ def impl(data_col, parangles, utime_ind, ant1_col, ant2_col,
return data_col

return impl


def rotation_factory(corr_mode, feed_type):

if feed_type.literal_value == "circular":
if corr_mode.literal_value == 4:
def impl(rot0, rot1, out):
out[0] = np.exp(-1j*rot0)
out[1] = 0
out[2] = 0
out[3] = np.exp(1j*rot1)
elif corr_mode.literal_value == 2: # TODO: Is this sensible?
def impl(rot0, rot1, out):
out[0] = np.exp(-1j*rot0)
out[1] = np.exp(1j*rot1)
elif corr_mode.literal_value == 1: # TODO: Is this sensible?
def impl(rot0, rot1, out):
out[0] = np.exp(-1j*rot0)
else:
raise ValueError("Unsupported number of correlations.")
elif feed_type.literal_value == "linear":
if corr_mode.literal_value == 4:
def impl(rot0, rot1, out):
out[0] = np.cos(rot0)
out[1] = np.sin(rot0)
out[2] = -np.sin(rot1)
out[3] = np.cos(rot1)
elif corr_mode.literal_value == 2: # TODO: Is this sensible?
def impl(rot0, rot1, out):
out[0] = np.cos(rot0)
out[1] = np.cos(rot1)
elif corr_mode.literal_value == 1: # TODO: Is this sensible?
def impl(rot0, rot1, out):
out[0] = np.cos(rot0)
else:
raise ValueError("Unsupported number of correlations.")
else:
raise ValueError("Unsupported feed type.")

return factories.qcjit(impl)
26 changes: 19 additions & 7 deletions quartical/data_handling/ms_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@
from quartical.data_handling.selection import filter_xds_list
from quartical.data_handling.angles import apply_parangles

DASKMS_ATTRS = {
"__daskms_partition_schema__",
"SCAN_NUMBER",
"FIELD_ID",
"DATA_DESC_ID"
}


def read_xds_list(model_columns, ms_opts):
"""Reads a measurement set and generates a list of xarray data sets.
Expand Down Expand Up @@ -237,7 +244,8 @@ def write_xds_list(xds_list, ref_xds_list, ms_path, output_opts):

# If the xds has fewer correlations than the measurement set, reindex.
if xds.sizes["corr"] < ms_n_corr:
xds = xds.reindex(corr=corr_types, fill_value=0)
# Note that we have to remove chunks from the reindexed axis.
xds = xds.reindex(corr=corr_types, fill_value=0).chunk({"corr": -1})

# Do some special handling on the flag column if we reindexed -
# we need a value dependent fill value.
Expand Down Expand Up @@ -292,14 +300,18 @@ def write_xds_list(xds_list, ref_xds_list, ms_path, output_opts):

logger.info("Outputs will be written to {}.", ", ".join(output_cols))

# Select only the output columns to simplify datasets.
xds_list = [xds[list(output_cols)] for xds in xds_list]

# Remove all coords bar ROWID so that they do not get written.
xds_list = [
xds.drop_vars(set(xds.coords.keys()) - {"ROWID"}, errors='ignore')
for xds in xds_list
]

# Remove attrs added by QuartiCal so that they do not get written.
for xds in xds_list:
xds.attrs.pop("UTIME_CHUNKS", None)
xds.attrs.pop("FIELD_NAME", None)

# Remove coords added by QuartiCal so that they do not get written.
xds_list = [xds.drop_vars(["chan", "corr"], errors='ignore')
for xds in xds_list]
xds.attrs = {k: v for k, v in xds.attrs.items() if k in DASKMS_ATTRS}

with warnings.catch_warnings(): # We anticipate spurious warnings.
warnings.simplefilter("ignore")
Expand Down
4 changes: 3 additions & 1 deletion quartical/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
preprocess_xds_list,
postprocess_xds_list)
from quartical.data_handling.model_handler import add_model_graph
from quartical.data_handling.angles import make_parangle_xds_list
from quartical.data_handling.angles import (make_parangle_xds_list,
assign_parangle_data)
from quartical.calibration.calibrate import add_calibration_graph
from quartical.statistics.statistics import make_stats_xds_list
from quartical.statistics.logging import log_summary_stats
Expand Down Expand Up @@ -110,6 +111,7 @@ def _execute(exitstack):

# Preprocess the xds_list - initialise some values and fix bad data.
data_xds_list = preprocess_xds_list(data_xds_list, ms_opts)
data_xds_list = assign_parangle_data(ms_opts.path, data_xds_list)

# Make a list of datasets containing the parallactic angles as these
# can be expensive to compute and may be used several times. NOTE: At
Expand Down
9 changes: 7 additions & 2 deletions quartical/gains/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
from quartical.gains.tec_and_offset import TecAndOffset
from quartical.gains.rotation import Rotation
from quartical.gains.rotation_measure import RotationMeasure
from quartical.gains.crosshand_phase import CrosshandPhase
from quartical.gains.crosshand_phase import CrosshandPhase, CrosshandPhaseNullV
from quartical.gains.leakage import Leakage
from quartical.gains.delay_and_tec import DelayAndTec
from quartical.gains.parallactic_angle import ParallacticAngle


TERM_TYPES = {
Expand All @@ -21,5 +23,8 @@
"rotation": Rotation,
"rotation_measure": RotationMeasure,
"crosshand_phase": CrosshandPhase,
"leakage": Leakage
"crosshand_phase_null_v": CrosshandPhaseNullV,
"leakage": Leakage,
"delay_and_tec": DelayAndTec,
"parallactic_angle": ParallacticAngle
}
2 changes: 1 addition & 1 deletion quartical/gains/amplitude/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def make_param_names(cls, correlations):

return [f"amplitude_{c}" for c in param_corr]

def init_term(self, term_spec, ref_ant, ms_kwargs, term_kwargs):
def init_term(self, term_spec, ref_ant, ms_kwargs, term_kwargs, meta=None):
"""Initialise the gains (and parameters)."""

gains, gain_flags, params, param_flags = super().init_term(
Expand Down
Loading

0 comments on commit a971cc6

Please sign in to comment.