From 8e244191766a7cc7a1dc8372732bb4c035b57808 Mon Sep 17 00:00:00 2001 From: Jonathan Kenyon Date: Thu, 19 Dec 2024 14:27:19 +0200 Subject: [PATCH 1/4] Checkpoint WIP on generic scalar terms. --- quartical/calibration/solver.py | 2 + quartical/config/gain_schema.yaml | 10 ++++- quartical/gains/delay_and_offset/kernel.py | 50 ++++++++++++++-------- quartical/gains/gain.py | 1 + quartical/gains/general/generics.py | 40 +++++++++++++++++ 5 files changed, 83 insertions(+), 20 deletions(-) diff --git a/quartical/calibration/solver.py b/quartical/calibration/solver.py index ddf3452f..aeb6ccfc 100644 --- a/quartical/calibration/solver.py +++ b/quartical/calibration/solver.py @@ -18,6 +18,7 @@ "threads", "robust", "reference_antenna", + "scalar", "dd_term", "pinned_directions", "solve_per", @@ -215,6 +216,7 @@ def solver_wrapper( solver_opts.threads, solver_opts.robust, solver_opts.reference_antenna, + term.scalar, term.direction_dependent, term.pinned_directions, term.solve_per diff --git a/quartical/config/gain_schema.yaml b/quartical/config/gain_schema.yaml index 266c23b7..4d42b2e5 100644 --- a/quartical/config/gain_schema.yaml +++ b/quartical/config/gain_schema.yaml @@ -28,9 +28,17 @@ gain: Determines whether this term should be solved per antenna (conventional) or over the entire array (doesn't vary with antenna). + scalar: + dtype: bool + default: false + info: + Determines whether the term is treated as scalar i.e. whether it is + solved for as a single effect over all correlations. This is only + supported for terms which would otherwise be diagonal. + direction_dependent: dtype: bool - default: False + default: false info: Determines whether this term is treated as direction dependent. diff --git a/quartical/gains/delay_and_offset/kernel.py b/quartical/gains/delay_and_offset/kernel.py index 70e41d95..5c02f3ad 100644 --- a/quartical/gains/delay_and_offset/kernel.py +++ b/quartical/gains/delay_and_offset/kernel.py @@ -2,26 +2,34 @@ import numpy as np from numba import prange, njit from numba.extending import overload -from quartical.utils.numba import (coerce_literal, - JIT_OPTIONS, - PARALLEL_JIT_OPTIONS) -from quartical.gains.general.generics import (native_intermediaries, - upsampled_itermediaries, - per_array_jhj_jhr, - resample_solints, - downsample_jhj_jhr) -from quartical.gains.general.flagging import (flag_intermediaries, - update_gain_flags, - finalize_gain_flags, - apply_gain_flags_to_flag_col, - update_param_flags, - apply_gain_flags_to_gains, - apply_param_flags_to_params) -from quartical.gains.general.convenience import (get_row, - get_extents) +from quartical.utils.numba import ( + coerce_literal, + JIT_OPTIONS, + PARALLEL_JIT_OPTIONS +) +from quartical.gains.general.generics import ( + native_intermediaries, + upsampled_itermediaries, + per_array_jhj_jhr, + resample_solints, + downsample_jhj_jhr, + scalar_jhj_jhr +) +from quartical.gains.general.flagging import ( + flag_intermediaries, + update_gain_flags, + finalize_gain_flags, + apply_gain_flags_to_flag_col, + update_param_flags, + apply_gain_flags_to_gains, + apply_param_flags_to_params +) +from quartical.gains.general.convenience import get_row, get_extents import quartical.gains.general.factories as factories -from quartical.gains.general.inversion import (invert_factory, - inversion_buffer_factory) +from quartical.gains.general.inversion import ( + invert_factory, + inversion_buffer_factory +) def get_identity_params(corr_mode): @@ -88,6 +96,7 @@ def impl( active_term = meta_inputs.active_term max_iter = meta_inputs.iters solve_per = meta_inputs.solve_per + scalar = meta_inputs.scalar dd_term = meta_inputs.dd_term n_thread = meta_inputs.threads @@ -151,6 +160,9 @@ def impl( if solve_per == "array": per_array_jhj_jhr(native_imdry) + if scalar: + scalar_jhj_jhr(native_imdry, 2) + if not max_iter: # Non-solvable term, we just want jhj. conv_perc = 0 # Didn't converge. loop_idx = -1 # Did zero iterations. diff --git a/quartical/gains/gain.py b/quartical/gains/gain.py index 0b18d4f2..2d103102 100644 --- a/quartical/gains/gain.py +++ b/quartical/gains/gain.py @@ -85,6 +85,7 @@ def __init__(self, term_name, term_opts): self.name = term_name self.type = term_opts.type self.solve_per = term_opts.solve_per + self.scalar = term_opts.scalar self.direction_dependent = term_opts.direction_dependent self.pinned_directions = term_opts.pinned_directions self.time_interval = term_opts.time_interval diff --git a/quartical/gains/general/generics.py b/quartical/gains/general/generics.py index 277e9ec0..f0028c56 100644 --- a/quartical/gains/general/generics.py +++ b/quartical/gains/general/generics.py @@ -705,6 +705,46 @@ def per_array_jhj_jhr(solver_imdry): jhr[t, f, a] = jhr[t, f, 0] +@njit(**JIT_OPTIONS) +def scalar_jhj_jhr(solver_imdry, values_per_correlation): + """This manipulates the entries of jhj and jhr to be scalar.""" + + jhj = solver_imdry.jhj + jhr = solver_imdry.jhr + + vpc = values_per_correlation # For brevity. + + n_tint, n_fint, n_ant, n_dir, n_par, _ = jhj.shape + + for t in range(n_tint): + for f in range(n_fint): + for a in range(n_ant): + for d in range(n_dir): + + jhr_sel = jhr[t, f, a, d] + jhj_sel = jhj[t, f, a, d] + + # Sum bottom half into top half. + for p in range(vpc, n_par): + jhr_sel[p % vpc] += jhr_sel[p] + + # Sum right half into left half and zero. + for p0 in range(n_par): + for p1 in range(vpc, n_par): + jhj_sel[p0, p1 % vpc] += jhj_sel[p0, p1] + jhj_sel[p0, p1] = 0 + + # Sum bottom half into top half and zero. + for p0 in range(vpc, n_par): + for p1 in range(vpc): + jhj_sel[p0 % vpc, p1] += jhj_sel[p0, p1] + jhj_sel[p0, p1] = 0 + + # Repopluate zeroed values from scalar sum. + jhr_sel[vpc:] = jhr_sel[:vpc] + jhj_sel[vpc:, vpc:] = jhj_sel[:vpc, :vpc] + + @njit(**JIT_OPTIONS) def resample_solints(native_map, native_shape, n_thread): From 1a53bb160497f4cebac1730f3c4b366441f92d4c Mon Sep 17 00:00:00 2001 From: Jonathan Kenyon Date: Thu, 19 Dec 2024 14:59:28 +0200 Subject: [PATCH 2/4] Update all terms to accept/error out appropriately when invoked in scalar mode. --- quartical/gains/amplitude/kernel.py | 7 ++++++- quartical/gains/complex/diag_kernel.py | 7 +++++++ quartical/gains/complex/kernel.py | 4 ++++ quartical/gains/crosshand_phase/kernel.py | 6 ++++++ quartical/gains/crosshand_phase/null_v_kernel.py | 6 ++++++ quartical/gains/delay/kernel.py | 7 ++++++- quartical/gains/delay_and_tec/kernel.py | 7 ++++++- quartical/gains/leakage/kernel.py | 4 ++++ quartical/gains/phase/kernel.py | 7 ++++++- quartical/gains/rotation/kernel.py | 6 ++++++ quartical/gains/rotation_measure/kernel.py | 6 ++++++ quartical/gains/tec_and_offset/kernel.py | 7 ++++++- 12 files changed, 69 insertions(+), 5 deletions(-) diff --git a/quartical/gains/amplitude/kernel.py b/quartical/gains/amplitude/kernel.py index b197d601..84d9fcda 100644 --- a/quartical/gains/amplitude/kernel.py +++ b/quartical/gains/amplitude/kernel.py @@ -9,7 +9,8 @@ upsampled_itermediaries, per_array_jhj_jhr, resample_solints, - downsample_jhj_jhr) + downsample_jhj_jhr, + scalar_jhj_jhr) from quartical.gains.general.flagging import (flag_intermediaries, update_gain_flags, finalize_gain_flags, @@ -86,6 +87,7 @@ def impl( active_term = meta_inputs.active_term max_iter = meta_inputs.iters solve_per = meta_inputs.solve_per + scalar = meta_inputs.scalar dd_term = meta_inputs.dd_term n_thread = meta_inputs.threads @@ -144,6 +146,9 @@ def impl( if solve_per == "array": per_array_jhj_jhr(native_imdry) + if scalar: + scalar_jhj_jhr(native_imdry, 1) + if not max_iter: # Non-solvable term, we just want jhj. conv_perc = 0 # Didn't converge. loop_idx = -1 # Did zero iterations. diff --git a/quartical/gains/complex/diag_kernel.py b/quartical/gains/complex/diag_kernel.py index 8bd6ba59..f1657761 100644 --- a/quartical/gains/complex/diag_kernel.py +++ b/quartical/gains/complex/diag_kernel.py @@ -74,6 +74,7 @@ def impl( active_term = meta_inputs.active_term max_iter = meta_inputs.iters solve_per = meta_inputs.solve_per + scalar = meta_inputs.scalar dd_term = meta_inputs.dd_term n_thread = meta_inputs.threads @@ -129,6 +130,12 @@ def impl( if solve_per == "array": per_array_jhj_jhr(native_imdry) + if scalar: + raise ValueError( + "Scalar mode not (yet) supported for diag complex terms." + "Please raise an issue if you require this functionality." + ) + if not max_iter: # Non-solvable term, we just want jhj. conv_perc = 0 # Didn't converge. loop_idx = -1 # Did zero iterations. diff --git a/quartical/gains/complex/kernel.py b/quartical/gains/complex/kernel.py index f8ca97dc..2e96bd8a 100644 --- a/quartical/gains/complex/kernel.py +++ b/quartical/gains/complex/kernel.py @@ -75,6 +75,7 @@ def impl( active_term = meta_inputs.active_term max_iter = meta_inputs.iters solve_per = meta_inputs.solve_per + scalar = meta_inputs.scalar dd_term = meta_inputs.dd_term n_thread = meta_inputs.threads @@ -132,6 +133,9 @@ def impl( if solve_per == "array": per_array_jhj_jhr(native_imdry) + if scalar: + raise ValueError("Scalar mode not supported for complex terms.") + if not max_iter: # Non-solvable term, we just want jhj. conv_perc = 0 # Didn't converge. loop_idx = -1 # Did zero iterations. diff --git a/quartical/gains/crosshand_phase/kernel.py b/quartical/gains/crosshand_phase/kernel.py index 53adfff4..dc264f8c 100644 --- a/quartical/gains/crosshand_phase/kernel.py +++ b/quartical/gains/crosshand_phase/kernel.py @@ -84,6 +84,7 @@ def impl( active_term = meta_inputs.active_term max_iter = meta_inputs.iters solve_per = meta_inputs.solve_per + scalar = meta_inputs.scalar dd_term = meta_inputs.dd_term n_thread = meta_inputs.threads @@ -141,6 +142,11 @@ def impl( if solve_per == "array": per_array_jhj_jhr(native_imdry) + if scalar: + raise ValueError( + "Scalar mode not supported for crosshand phase terms." + ) + if not max_iter: # Non-solvable term, we just want jhj. conv_perc = 0 # Didn't converge. loop_idx = -1 # Did zero iterations. diff --git a/quartical/gains/crosshand_phase/null_v_kernel.py b/quartical/gains/crosshand_phase/null_v_kernel.py index 33fb4f78..130ee2d8 100644 --- a/quartical/gains/crosshand_phase/null_v_kernel.py +++ b/quartical/gains/crosshand_phase/null_v_kernel.py @@ -91,6 +91,7 @@ def impl( active_term = meta_inputs.active_term max_iter = meta_inputs.iters solve_per = meta_inputs.solve_per + scalar = meta_inputs.scalar dd_term = meta_inputs.dd_term n_thread = meta_inputs.threads @@ -149,6 +150,11 @@ def impl( if solve_per == "array": per_array_jhj_jhr(native_imdry) + if scalar: + raise ValueError( + "Scalar mode not supported for crosshand phase terms." + ) + if not max_iter: # Non-solvable term, we just want jhj. conv_perc = 0 # Didn't converge. loop_idx = -1 # Did zero iterations. diff --git a/quartical/gains/delay/kernel.py b/quartical/gains/delay/kernel.py index 209da746..a45952c9 100644 --- a/quartical/gains/delay/kernel.py +++ b/quartical/gains/delay/kernel.py @@ -12,7 +12,8 @@ upsampled_itermediaries, per_array_jhj_jhr, resample_solints, - downsample_jhj_jhr + downsample_jhj_jhr, + scalar_jhj_jhr ) from quartical.gains.general.flagging import ( flag_intermediaries, @@ -94,6 +95,7 @@ def impl( active_term = meta_inputs.active_term max_iter = meta_inputs.iters solve_per = meta_inputs.solve_per + scalar = meta_inputs.scalar dd_term = meta_inputs.dd_term n_thread = meta_inputs.threads @@ -157,6 +159,9 @@ def impl( if solve_per == "array": per_array_jhj_jhr(native_imdry) + if scalar: + scalar_jhj_jhr(native_imdry, 1) + if not max_iter: # Non-solvable term, we just want jhj. conv_perc = 0 # Didn't converge. loop_idx = -1 # Did zero iterations. diff --git a/quartical/gains/delay_and_tec/kernel.py b/quartical/gains/delay_and_tec/kernel.py index 3b0b9202..a90d0bd0 100644 --- a/quartical/gains/delay_and_tec/kernel.py +++ b/quartical/gains/delay_and_tec/kernel.py @@ -9,7 +9,8 @@ upsampled_itermediaries, per_array_jhj_jhr, resample_solints, - downsample_jhj_jhr) + downsample_jhj_jhr, + scalar_jhj_jhr) from quartical.gains.general.flagging import (flag_intermediaries, update_gain_flags, finalize_gain_flags, @@ -88,6 +89,7 @@ def impl( active_term = meta_inputs.active_term max_iter = meta_inputs.iters solve_per = meta_inputs.solve_per + scalar = meta_inputs.scalar dd_term = meta_inputs.dd_term n_thread = meta_inputs.threads @@ -155,6 +157,9 @@ def impl( if solve_per == "array": per_array_jhj_jhr(native_imdry) + if scalar: + scalar_jhj_jhr(native_imdry, 2) + if not max_iter: # Non-solvable term, we just want jhj. conv_perc = 0 # Didn't converge. loop_idx = -1 # Did zero iterations. diff --git a/quartical/gains/leakage/kernel.py b/quartical/gains/leakage/kernel.py index 8fa6921a..1cb06874 100644 --- a/quartical/gains/leakage/kernel.py +++ b/quartical/gains/leakage/kernel.py @@ -73,6 +73,7 @@ def impl( active_term = meta_inputs.active_term max_iter = meta_inputs.iters solve_per = meta_inputs.solve_per + scalar = meta_inputs.scalar dd_term = meta_inputs.dd_term n_thread = meta_inputs.threads @@ -129,6 +130,9 @@ def impl( if solve_per == "array": per_array_jhj_jhr(native_imdry) + if scalar: + raise ValueError("Scalar mode not supported for leakage terms.") + if not max_iter: # Non-solvable term, we just want jhj. conv_perc = 0 # Didn't converge. loop_idx = -1 # Did zero iterations. diff --git a/quartical/gains/phase/kernel.py b/quartical/gains/phase/kernel.py index 7fa76da0..2ab9ec80 100644 --- a/quartical/gains/phase/kernel.py +++ b/quartical/gains/phase/kernel.py @@ -9,7 +9,8 @@ upsampled_itermediaries, per_array_jhj_jhr, resample_solints, - downsample_jhj_jhr) + downsample_jhj_jhr, + scalar_jhj_jhr) from quartical.gains.general.flagging import (flag_intermediaries, update_gain_flags, finalize_gain_flags, @@ -88,6 +89,7 @@ def impl( active_term = meta_inputs.active_term max_iter = meta_inputs.iters solve_per = meta_inputs.solve_per + scalar = meta_inputs.scalar dd_term = meta_inputs.dd_term n_thread = meta_inputs.threads @@ -145,6 +147,9 @@ def impl( if solve_per == "array": per_array_jhj_jhr(native_imdry) + if scalar: + scalar_jhj_jhr(native_imdry, 1) + if not max_iter: # Non-solvable term, we just want jhj. conv_perc = 0 # Didn't converge. loop_idx = -1 # Did zero iterations. diff --git a/quartical/gains/rotation/kernel.py b/quartical/gains/rotation/kernel.py index 66cc58c2..95f98821 100644 --- a/quartical/gains/rotation/kernel.py +++ b/quartical/gains/rotation/kernel.py @@ -84,6 +84,7 @@ def impl( active_term = meta_inputs.active_term max_iter = meta_inputs.iters solve_per = meta_inputs.solve_per + scalar = meta_inputs.scalar dd_term = meta_inputs.dd_term n_thread = meta_inputs.threads @@ -142,6 +143,11 @@ def impl( if solve_per == "array": per_array_jhj_jhr(native_imdry) + if scalar: + raise ValueError( + "Scalar mode not supported for rotation terms." + ) + if not max_iter: # Non-solvable term, we just want jhj. conv_perc = 0 # Didn't converge. loop_idx = -1 # Did zero iterations. diff --git a/quartical/gains/rotation_measure/kernel.py b/quartical/gains/rotation_measure/kernel.py index 3c935f23..348501a1 100644 --- a/quartical/gains/rotation_measure/kernel.py +++ b/quartical/gains/rotation_measure/kernel.py @@ -84,6 +84,7 @@ def impl( active_term = meta_inputs.active_term max_iter = meta_inputs.iters solve_per = meta_inputs.solve_per + scalar = meta_inputs.scalar dd_term = meta_inputs.dd_term n_thread = meta_inputs.threads @@ -145,6 +146,11 @@ def impl( if solve_per == "array": per_array_jhj_jhr(native_imdry) + if scalar: + raise ValueError( + "Scalar mode not supported for rotation measure terms." + ) + if not max_iter: # Non-solvable term, we just want jhj. conv_perc = 0 # Didn't converge. loop_idx = -1 # Did zero iterations. diff --git a/quartical/gains/tec_and_offset/kernel.py b/quartical/gains/tec_and_offset/kernel.py index e41ab75c..b855e68d 100644 --- a/quartical/gains/tec_and_offset/kernel.py +++ b/quartical/gains/tec_and_offset/kernel.py @@ -9,7 +9,8 @@ upsampled_itermediaries, per_array_jhj_jhr, resample_solints, - downsample_jhj_jhr) + downsample_jhj_jhr, + scalar_jhj_jhr) from quartical.gains.general.flagging import (flag_intermediaries, update_gain_flags, finalize_gain_flags, @@ -90,6 +91,7 @@ def impl( active_term = meta_inputs.active_term max_iter = meta_inputs.iters solve_per = meta_inputs.solve_per + scalar = meta_inputs.scalar dd_term = meta_inputs.dd_term n_thread = meta_inputs.threads @@ -153,6 +155,9 @@ def impl( if solve_per == "array": per_array_jhj_jhr(native_imdry) + if scalar: + scalar_jhj_jhr(native_imdry, 2) + if not max_iter: # Non-solvable term, we just want jhj. conv_perc = 0 # Didn't converge. loop_idx = -1 # Did zero iterations. From dcdc0504d9bbcc54495ecb2288ab223ac6ac6173 Mon Sep 17 00:00:00 2001 From: Jonathan Kenyon Date: Thu, 19 Dec 2024 15:16:44 +0200 Subject: [PATCH 3/4] Add scalar mode support for diag_complex terms. --- quartical/gains/complex/diag_kernel.py | 36 +++++++++++++++++++++++--- 1 file changed, 32 insertions(+), 4 deletions(-) diff --git a/quartical/gains/complex/diag_kernel.py b/quartical/gains/complex/diag_kernel.py index f1657761..6343d1cf 100644 --- a/quartical/gains/complex/diag_kernel.py +++ b/quartical/gains/complex/diag_kernel.py @@ -131,10 +131,7 @@ def impl( per_array_jhj_jhr(native_imdry) if scalar: - raise ValueError( - "Scalar mode not (yet) supported for diag complex terms." - "Please raise an issue if you require this functionality." - ) + scalar_jhj_jhr(native_imdry) if not max_iter: # Non-solvable term, we just want jhj. conv_perc = 0 # Didn't converge. @@ -684,3 +681,34 @@ def impl(chain_inputs, meta_inputs, mode): apply_gain_flags_to_gains(gain_flags, gains) return impl + + +@njit(**JIT_OPTIONS) +def scalar_jhj_jhr(solver_imdry): + """This manipulates the entries of jhj and jhr to be scalar.""" + + # NOTE: This differes from the generic implmenentation in generics.py. + + jhj = solver_imdry.jhj + jhr = solver_imdry.jhr + + n_tint, n_fint, n_ant, n_dir, n_corr = jhj.shape + + for t in range(n_tint): + for f in range(n_fint): + for a in range(n_ant): + for d in range(n_dir): + + jhr_sel = jhr[t, f, a, d] + jhj_sel = jhj[t, f, a, d] + + # Sum to a single scalar element. + for p in range(1, n_corr): + jhr_sel[0] += jhr_sel[p] + jhr_sel[p] = 0 + jhj_sel[0] += jhj_sel[p] + jhj_sel[p] = 0 + + # Repopluate appropriate zeroed values from scalar sum. + jhr_sel[-1] = jhr_sel[0] + jhj_sel[-1] = jhj_sel[0] From a40fe3a3ed9cd46a62d8f6f53abc9fe1cb616933 Mon Sep 17 00:00:00 2001 From: Jonathan Kenyon Date: Thu, 16 Jan 2025 15:51:40 +0200 Subject: [PATCH 4/4] Add some tests for scalar mode. Does not test all terms, but probes all functionality. --- quartical/gains/amplitude/kernel.py | 2 +- quartical/gains/complex/diag_kernel.py | 2 +- quartical/gains/delay/kernel.py | 2 +- quartical/gains/delay_and_offset/kernel.py | 2 +- quartical/gains/delay_and_tec/kernel.py | 2 +- quartical/gains/phase/kernel.py | 2 +- quartical/gains/tec_and_offset/kernel.py | 2 +- testing/fixtures/gains.py | 5 +++++ testing/tests/gains/test_amplitude.py | 8 ++++++-- testing/tests/gains/test_delay_and_offset.py | 8 ++++++-- testing/tests/gains/test_diag_complex.py | 8 ++++++-- testing/tests/gains/test_phase.py | 8 ++++++-- 12 files changed, 36 insertions(+), 15 deletions(-) diff --git a/quartical/gains/amplitude/kernel.py b/quartical/gains/amplitude/kernel.py index 84d9fcda..16c70962 100644 --- a/quartical/gains/amplitude/kernel.py +++ b/quartical/gains/amplitude/kernel.py @@ -146,7 +146,7 @@ def impl( if solve_per == "array": per_array_jhj_jhr(native_imdry) - if scalar: + if scalar and corr_mode != 1: scalar_jhj_jhr(native_imdry, 1) if not max_iter: # Non-solvable term, we just want jhj. diff --git a/quartical/gains/complex/diag_kernel.py b/quartical/gains/complex/diag_kernel.py index 6343d1cf..066fb828 100644 --- a/quartical/gains/complex/diag_kernel.py +++ b/quartical/gains/complex/diag_kernel.py @@ -130,7 +130,7 @@ def impl( if solve_per == "array": per_array_jhj_jhr(native_imdry) - if scalar: + if scalar and corr_mode != 1: scalar_jhj_jhr(native_imdry) if not max_iter: # Non-solvable term, we just want jhj. diff --git a/quartical/gains/delay/kernel.py b/quartical/gains/delay/kernel.py index a45952c9..a4bc35e0 100644 --- a/quartical/gains/delay/kernel.py +++ b/quartical/gains/delay/kernel.py @@ -159,7 +159,7 @@ def impl( if solve_per == "array": per_array_jhj_jhr(native_imdry) - if scalar: + if scalar and corr_mode != 1: scalar_jhj_jhr(native_imdry, 1) if not max_iter: # Non-solvable term, we just want jhj. diff --git a/quartical/gains/delay_and_offset/kernel.py b/quartical/gains/delay_and_offset/kernel.py index 5c02f3ad..384c1e48 100644 --- a/quartical/gains/delay_and_offset/kernel.py +++ b/quartical/gains/delay_and_offset/kernel.py @@ -160,7 +160,7 @@ def impl( if solve_per == "array": per_array_jhj_jhr(native_imdry) - if scalar: + if scalar and corr_mode != 1: scalar_jhj_jhr(native_imdry, 2) if not max_iter: # Non-solvable term, we just want jhj. diff --git a/quartical/gains/delay_and_tec/kernel.py b/quartical/gains/delay_and_tec/kernel.py index a90d0bd0..6d2cf61b 100644 --- a/quartical/gains/delay_and_tec/kernel.py +++ b/quartical/gains/delay_and_tec/kernel.py @@ -157,7 +157,7 @@ def impl( if solve_per == "array": per_array_jhj_jhr(native_imdry) - if scalar: + if scalar and corr_mode != 1: scalar_jhj_jhr(native_imdry, 2) if not max_iter: # Non-solvable term, we just want jhj. diff --git a/quartical/gains/phase/kernel.py b/quartical/gains/phase/kernel.py index 2ab9ec80..b3607535 100644 --- a/quartical/gains/phase/kernel.py +++ b/quartical/gains/phase/kernel.py @@ -147,7 +147,7 @@ def impl( if solve_per == "array": per_array_jhj_jhr(native_imdry) - if scalar: + if scalar and corr_mode != 1: scalar_jhj_jhr(native_imdry, 1) if not max_iter: # Non-solvable term, we just want jhj. diff --git a/quartical/gains/tec_and_offset/kernel.py b/quartical/gains/tec_and_offset/kernel.py index b855e68d..a867d153 100644 --- a/quartical/gains/tec_and_offset/kernel.py +++ b/quartical/gains/tec_and_offset/kernel.py @@ -155,7 +155,7 @@ def impl( if solve_per == "array": per_array_jhj_jhr(native_imdry) - if scalar: + if scalar and corr_mode != 1: scalar_jhj_jhr(native_imdry, 2) if not max_iter: # Non-solvable term, we just want jhj. diff --git a/testing/fixtures/gains.py b/testing/fixtures/gains.py index ab03669d..575cd4e2 100644 --- a/testing/fixtures/gains.py +++ b/testing/fixtures/gains.py @@ -25,3 +25,8 @@ def cmp_post_solve_data_xds_list(cmp_calibration_graph_outputs): @pytest.fixture(params=["antenna", "array"], scope="module") def solve_per(request): return request.param + + +@pytest.fixture(params=[True, False], scope="module") +def scalar_mode(request): + return request.param diff --git a/testing/tests/gains/test_amplitude.py b/testing/tests/gains/test_amplitude.py index 25bfd0e4..2ed8cbba 100644 --- a/testing/tests/gains/test_amplitude.py +++ b/testing/tests/gains/test_amplitude.py @@ -7,7 +7,7 @@ @pytest.fixture(scope="module") -def opts(base_opts, select_corr, solve_per): +def opts(base_opts, select_corr, solve_per, scalar_mode): # Don't overwrite base config - instead create a copy and update. @@ -22,6 +22,7 @@ def opts(base_opts, select_corr, solve_per): _opts.solver.threads = 2 _opts.G.type = "amplitude" _opts.G.solve_per = solve_per + _opts.G.scalar = scalar_mode return _opts @@ -33,7 +34,7 @@ def raw_xds_list(read_xds_list_output): @pytest.fixture(scope="module") -def true_gain_list(predicted_xds_list, solve_per): +def true_gain_list(predicted_xds_list, solve_per, scalar_mode): gain_list = [] @@ -66,6 +67,9 @@ def true_gain_list(predicted_xds_list, solve_per): gains = amp*da.exp(1j*phase) + if scalar_mode: + gains[..., -1] = gains[..., 0] + if solve_per == "array": gains = da.broadcast_to(gains[:, :, :1], gains.shape) diff --git a/testing/tests/gains/test_delay_and_offset.py b/testing/tests/gains/test_delay_and_offset.py index ed58db58..48785ce1 100644 --- a/testing/tests/gains/test_delay_and_offset.py +++ b/testing/tests/gains/test_delay_and_offset.py @@ -7,7 +7,7 @@ @pytest.fixture(scope="module") -def opts(base_opts, select_corr): +def opts(base_opts, select_corr, scalar_mode): # Don't overwrite base config - instead create a copy and update. @@ -23,6 +23,7 @@ def opts(base_opts, select_corr): _opts.G.type = "delay_and_offset" _opts.G.freq_interval = 0 _opts.G.initial_estimate = True + _opts.G.scalar = scalar_mode return _opts @@ -34,7 +35,7 @@ def raw_xds_list(read_xds_list_output): @pytest.fixture(scope="module") -def true_gain_list(predicted_xds_list): +def true_gain_list(predicted_xds_list, scalar_mode): gain_list = [] @@ -82,6 +83,9 @@ def true_gain_list(predicted_xds_list): phase = 2*np.pi*delays*origin_chan_freq + offsets gains = amp*da.exp(1j*phase) + if scalar_mode: + gains[..., -1] = gains[..., 0] + gain_list.append(gains) return gain_list diff --git a/testing/tests/gains/test_diag_complex.py b/testing/tests/gains/test_diag_complex.py index 4f193c5d..90f87d98 100644 --- a/testing/tests/gains/test_diag_complex.py +++ b/testing/tests/gains/test_diag_complex.py @@ -12,7 +12,7 @@ def solver_type(request): @pytest.fixture(scope="module") -def opts(base_opts, solver_type, select_corr, solve_per): +def opts(base_opts, solver_type, select_corr, solve_per, scalar_mode): # Don't overwrite base config - instead create a copy and update. @@ -27,6 +27,7 @@ def opts(base_opts, solver_type, select_corr, solve_per): _opts.solver.threads = 2 _opts.G.type = solver_type _opts.G.solve_per = solve_per + _opts.G.scalar = scalar_mode return _opts @@ -38,7 +39,7 @@ def raw_xds_list(read_xds_list_output): @pytest.fixture(scope="module") -def true_gain_list(predicted_xds_list, solve_per): +def true_gain_list(predicted_xds_list, solve_per, scalar_mode): gain_list = [] @@ -69,6 +70,9 @@ def true_gain_list(predicted_xds_list, solve_per): gains = amp*da.exp(1j*phase) + if scalar_mode: + gains[..., -1] = gains[..., 0] + if solve_per == "array": gains = da.broadcast_to(gains[:, :, :1], gains.shape) diff --git a/testing/tests/gains/test_phase.py b/testing/tests/gains/test_phase.py index 12e38192..500f673d 100644 --- a/testing/tests/gains/test_phase.py +++ b/testing/tests/gains/test_phase.py @@ -7,7 +7,7 @@ @pytest.fixture(scope="module") -def opts(base_opts, select_corr, solve_per): +def opts(base_opts, select_corr, solve_per, scalar_mode): # Don't overwrite base config - instead create a copy and update. @@ -22,6 +22,7 @@ def opts(base_opts, select_corr, solve_per): _opts.solver.threads = 2 _opts.G.type = "phase" _opts.G.solve_per = solve_per + _opts.G.scalar = scalar_mode return _opts @@ -33,7 +34,7 @@ def raw_xds_list(read_xds_list_output): @pytest.fixture(scope="module") -def true_gain_list(predicted_xds_list, solve_per): +def true_gain_list(predicted_xds_list, solve_per, scalar_mode): gain_list = [] @@ -64,6 +65,9 @@ def true_gain_list(predicted_xds_list, solve_per): gains = amp*da.exp(1j*phase) + if scalar_mode: + gains[..., -1] = gains[..., 0] + if solve_per == "array": gains = da.broadcast_to(gains[:, :, :1], gains.shape)