Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a scalar mode to diagonal parameterised Jones terms #358

Merged
merged 5 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions quartical/calibration/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"threads",
"robust",
"reference_antenna",
"scalar",
"dd_term",
"pinned_directions",
"solve_per",
Expand Down Expand Up @@ -233,6 +234,7 @@ def solver_wrapper(
solver_opts.threads,
solver_opts.robust,
solver_opts.reference_antenna,
term.scalar,
term.direction_dependent,
term.pinned_directions,
term.solve_per
Expand Down
10 changes: 9 additions & 1 deletion quartical/config/gain_schema.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,17 @@ gain:
Determines whether this term should be solved per antenna (conventional)
or over the entire array (doesn't vary with antenna).

scalar:
dtype: bool
default: false
info:
Determines whether the term is treated as scalar i.e. whether it is
solved for as a single effect over all correlations. This is only
supported for terms which would otherwise be diagonal.

direction_dependent:
dtype: bool
default: False
default: false
info:
Determines whether this term is treated as direction dependent.

Expand Down
7 changes: 6 additions & 1 deletion quartical/gains/amplitude/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
upsampled_itermediaries,
per_array_jhj_jhr,
resample_solints,
downsample_jhj_jhr)
downsample_jhj_jhr,
scalar_jhj_jhr)
from quartical.gains.general.flagging import (flag_intermediaries,
update_gain_flags,
finalize_gain_flags,
Expand Down Expand Up @@ -86,6 +87,7 @@ def impl(
active_term = meta_inputs.active_term
max_iter = meta_inputs.iters
solve_per = meta_inputs.solve_per
scalar = meta_inputs.scalar
dd_term = meta_inputs.dd_term
n_thread = meta_inputs.threads

Expand Down Expand Up @@ -144,6 +146,9 @@ def impl(
if solve_per == "array":
per_array_jhj_jhr(native_imdry)

if scalar and corr_mode != 1:
scalar_jhj_jhr(native_imdry, 1)

if not max_iter: # Non-solvable term, we just want jhj.
conv_perc = 0 # Didn't converge.
loop_idx = -1 # Did zero iterations.
Expand Down
35 changes: 35 additions & 0 deletions quartical/gains/complex/diag_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def impl(
active_term = meta_inputs.active_term
max_iter = meta_inputs.iters
solve_per = meta_inputs.solve_per
scalar = meta_inputs.scalar
dd_term = meta_inputs.dd_term
n_thread = meta_inputs.threads

Expand Down Expand Up @@ -129,6 +130,9 @@ def impl(
if solve_per == "array":
per_array_jhj_jhr(native_imdry)

if scalar and corr_mode != 1:
scalar_jhj_jhr(native_imdry)

if not max_iter: # Non-solvable term, we just want jhj.
conv_perc = 0 # Didn't converge.
loop_idx = -1 # Did zero iterations.
Expand Down Expand Up @@ -677,3 +681,34 @@ def impl(chain_inputs, meta_inputs, mode):
apply_gain_flags_to_gains(gain_flags, gains)

return impl


@njit(**JIT_OPTIONS)
def scalar_jhj_jhr(solver_imdry):
"""This manipulates the entries of jhj and jhr to be scalar."""

# NOTE: This differes from the generic implmenentation in generics.py.

jhj = solver_imdry.jhj
jhr = solver_imdry.jhr

n_tint, n_fint, n_ant, n_dir, n_corr = jhj.shape

for t in range(n_tint):
for f in range(n_fint):
for a in range(n_ant):
for d in range(n_dir):

jhr_sel = jhr[t, f, a, d]
jhj_sel = jhj[t, f, a, d]

# Sum to a single scalar element.
for p in range(1, n_corr):
jhr_sel[0] += jhr_sel[p]
jhr_sel[p] = 0
jhj_sel[0] += jhj_sel[p]
jhj_sel[p] = 0

# Repopluate appropriate zeroed values from scalar sum.
jhr_sel[-1] = jhr_sel[0]
jhj_sel[-1] = jhj_sel[0]
4 changes: 4 additions & 0 deletions quartical/gains/complex/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def impl(
active_term = meta_inputs.active_term
max_iter = meta_inputs.iters
solve_per = meta_inputs.solve_per
scalar = meta_inputs.scalar
dd_term = meta_inputs.dd_term
n_thread = meta_inputs.threads

Expand Down Expand Up @@ -132,6 +133,9 @@ def impl(
if solve_per == "array":
per_array_jhj_jhr(native_imdry)

if scalar:
raise ValueError("Scalar mode not supported for complex terms.")

if not max_iter: # Non-solvable term, we just want jhj.
conv_perc = 0 # Didn't converge.
loop_idx = -1 # Did zero iterations.
Expand Down
6 changes: 6 additions & 0 deletions quartical/gains/crosshand_phase/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def impl(
active_term = meta_inputs.active_term
max_iter = meta_inputs.iters
solve_per = meta_inputs.solve_per
scalar = meta_inputs.scalar
dd_term = meta_inputs.dd_term
n_thread = meta_inputs.threads

Expand Down Expand Up @@ -141,6 +142,11 @@ def impl(
if solve_per == "array":
per_array_jhj_jhr(native_imdry)

if scalar:
raise ValueError(
"Scalar mode not supported for crosshand phase terms."
)

if not max_iter: # Non-solvable term, we just want jhj.
conv_perc = 0 # Didn't converge.
loop_idx = -1 # Did zero iterations.
Expand Down
6 changes: 6 additions & 0 deletions quartical/gains/crosshand_phase/null_v_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def impl(
active_term = meta_inputs.active_term
max_iter = meta_inputs.iters
solve_per = meta_inputs.solve_per
scalar = meta_inputs.scalar
dd_term = meta_inputs.dd_term
n_thread = meta_inputs.threads

Expand Down Expand Up @@ -149,6 +150,11 @@ def impl(
if solve_per == "array":
per_array_jhj_jhr(native_imdry)

if scalar:
raise ValueError(
"Scalar mode not supported for crosshand phase terms."
)

if not max_iter: # Non-solvable term, we just want jhj.
conv_perc = 0 # Didn't converge.
loop_idx = -1 # Did zero iterations.
Expand Down
7 changes: 6 additions & 1 deletion quartical/gains/delay/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
upsampled_itermediaries,
per_array_jhj_jhr,
resample_solints,
downsample_jhj_jhr
downsample_jhj_jhr,
scalar_jhj_jhr
)
from quartical.gains.general.flagging import (
flag_intermediaries,
Expand Down Expand Up @@ -94,6 +95,7 @@ def impl(
active_term = meta_inputs.active_term
max_iter = meta_inputs.iters
solve_per = meta_inputs.solve_per
scalar = meta_inputs.scalar
dd_term = meta_inputs.dd_term
n_thread = meta_inputs.threads

Expand Down Expand Up @@ -157,6 +159,9 @@ def impl(
if solve_per == "array":
per_array_jhj_jhr(native_imdry)

if scalar and corr_mode != 1:
scalar_jhj_jhr(native_imdry, 1)

if not max_iter: # Non-solvable term, we just want jhj.
conv_perc = 0 # Didn't converge.
loop_idx = -1 # Did zero iterations.
Expand Down
50 changes: 31 additions & 19 deletions quartical/gains/delay_and_offset/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,34 @@
import numpy as np
from numba import prange, njit
from numba.extending import overload
from quartical.utils.numba import (coerce_literal,
JIT_OPTIONS,
PARALLEL_JIT_OPTIONS)
from quartical.gains.general.generics import (native_intermediaries,
upsampled_itermediaries,
per_array_jhj_jhr,
resample_solints,
downsample_jhj_jhr)
from quartical.gains.general.flagging import (flag_intermediaries,
update_gain_flags,
finalize_gain_flags,
apply_gain_flags_to_flag_col,
update_param_flags,
apply_gain_flags_to_gains,
apply_param_flags_to_params)
from quartical.gains.general.convenience import (get_row,
get_extents)
from quartical.utils.numba import (
coerce_literal,
JIT_OPTIONS,
PARALLEL_JIT_OPTIONS
)
from quartical.gains.general.generics import (
native_intermediaries,
upsampled_itermediaries,
per_array_jhj_jhr,
resample_solints,
downsample_jhj_jhr,
scalar_jhj_jhr
)
from quartical.gains.general.flagging import (
flag_intermediaries,
update_gain_flags,
finalize_gain_flags,
apply_gain_flags_to_flag_col,
update_param_flags,
apply_gain_flags_to_gains,
apply_param_flags_to_params
)
from quartical.gains.general.convenience import get_row, get_extents
import quartical.gains.general.factories as factories
from quartical.gains.general.inversion import (invert_factory,
inversion_buffer_factory)
from quartical.gains.general.inversion import (
invert_factory,
inversion_buffer_factory
)


def get_identity_params(corr_mode):
Expand Down Expand Up @@ -88,6 +96,7 @@ def impl(
active_term = meta_inputs.active_term
max_iter = meta_inputs.iters
solve_per = meta_inputs.solve_per
scalar = meta_inputs.scalar
dd_term = meta_inputs.dd_term
n_thread = meta_inputs.threads

Expand Down Expand Up @@ -151,6 +160,9 @@ def impl(
if solve_per == "array":
per_array_jhj_jhr(native_imdry)

if scalar and corr_mode != 1:
scalar_jhj_jhr(native_imdry, 2)

if not max_iter: # Non-solvable term, we just want jhj.
conv_perc = 0 # Didn't converge.
loop_idx = -1 # Did zero iterations.
Expand Down
7 changes: 6 additions & 1 deletion quartical/gains/delay_and_tec/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
upsampled_itermediaries,
per_array_jhj_jhr,
resample_solints,
downsample_jhj_jhr)
downsample_jhj_jhr,
scalar_jhj_jhr)
from quartical.gains.general.flagging import (flag_intermediaries,
update_gain_flags,
finalize_gain_flags,
Expand Down Expand Up @@ -88,6 +89,7 @@ def impl(
active_term = meta_inputs.active_term
max_iter = meta_inputs.iters
solve_per = meta_inputs.solve_per
scalar = meta_inputs.scalar
dd_term = meta_inputs.dd_term
n_thread = meta_inputs.threads

Expand Down Expand Up @@ -155,6 +157,9 @@ def impl(
if solve_per == "array":
per_array_jhj_jhr(native_imdry)

if scalar and corr_mode != 1:
scalar_jhj_jhr(native_imdry, 2)

if not max_iter: # Non-solvable term, we just want jhj.
conv_perc = 0 # Didn't converge.
loop_idx = -1 # Did zero iterations.
Expand Down
1 change: 1 addition & 0 deletions quartical/gains/gain.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def __init__(self, term_name, term_opts):
self.name = term_name
self.type = term_opts.type
self.solve_per = term_opts.solve_per
self.scalar = term_opts.scalar
self.direction_dependent = term_opts.direction_dependent
self.pinned_directions = term_opts.pinned_directions
self.time_interval = term_opts.time_interval
Expand Down
40 changes: 40 additions & 0 deletions quartical/gains/general/generics.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,46 @@ def per_array_jhj_jhr(solver_imdry):
jhr[t, f, a] = jhr[t, f, 0]


@njit(**JIT_OPTIONS)
def scalar_jhj_jhr(solver_imdry, values_per_correlation):
"""This manipulates the entries of jhj and jhr to be scalar."""

jhj = solver_imdry.jhj
jhr = solver_imdry.jhr

vpc = values_per_correlation # For brevity.

n_tint, n_fint, n_ant, n_dir, n_par, _ = jhj.shape

for t in range(n_tint):
for f in range(n_fint):
for a in range(n_ant):
for d in range(n_dir):

jhr_sel = jhr[t, f, a, d]
jhj_sel = jhj[t, f, a, d]

# Sum bottom half into top half.
for p in range(vpc, n_par):
jhr_sel[p % vpc] += jhr_sel[p]

# Sum right half into left half and zero.
for p0 in range(n_par):
for p1 in range(vpc, n_par):
jhj_sel[p0, p1 % vpc] += jhj_sel[p0, p1]
jhj_sel[p0, p1] = 0

# Sum bottom half into top half and zero.
for p0 in range(vpc, n_par):
for p1 in range(vpc):
jhj_sel[p0 % vpc, p1] += jhj_sel[p0, p1]
jhj_sel[p0, p1] = 0

# Repopluate zeroed values from scalar sum.
jhr_sel[vpc:] = jhr_sel[:vpc]
jhj_sel[vpc:, vpc:] = jhj_sel[:vpc, :vpc]


@njit(**JIT_OPTIONS)
def resample_solints(native_map, native_shape, n_thread):

Expand Down
4 changes: 4 additions & 0 deletions quartical/gains/leakage/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def impl(
active_term = meta_inputs.active_term
max_iter = meta_inputs.iters
solve_per = meta_inputs.solve_per
scalar = meta_inputs.scalar
dd_term = meta_inputs.dd_term
n_thread = meta_inputs.threads

Expand Down Expand Up @@ -129,6 +130,9 @@ def impl(
if solve_per == "array":
per_array_jhj_jhr(native_imdry)

if scalar:
raise ValueError("Scalar mode not supported for leakage terms.")

if not max_iter: # Non-solvable term, we just want jhj.
conv_perc = 0 # Didn't converge.
loop_idx = -1 # Did zero iterations.
Expand Down
Loading