Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove spectral index and allow input of stokes varying by source, time and channel. #244

Merged
merged 46 commits into from
Apr 21, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
a593498
Remove spi and input stokes per (src,time,channel)
sjperkins Jan 12, 2018
0c182cf
Merge branch 'master' into ddfacet
sjperkins Feb 14, 2018
4b5f507
Merge branch 'master' into ddfacet
sjperkins Jun 11, 2018
06b91c4
Install tensorflow CPU by default
sjperkins Jul 13, 2018
0e42a40
Update the standalone.py example
sjperkins Jul 13, 2018
700e78c
RadecToLm operator
sjperkins Jul 26, 2018
f4c9ed6
Swap translation and rotation in the beam
sjperkins Jul 27, 2018
dbad6d5
Makefile fixes for later tensorflow versions
sjperkins Jul 27, 2018
3e50b32
Convert all source position inputs to radec
sjperkins Jul 27, 2018
8127ba4
Offset source from phase centre in beam
sjperkins Jul 27, 2018
04f7d08
Adds pointing offsets to PA computation
ratt-priv-ci Jul 31, 2018
50bf33b
Swap ant and time axis in PA computation
ratt-priv-ci Jul 31, 2018
af32df3
Typos on previous commit
ratt-priv-ci Jul 31, 2018
1811be8
use lm coordinates to look up positions in the beam (this is may be w…
ratt-priv-ci Aug 28, 2018
cee1231
Fix GPU sgn brightness index in sum coherency operator (#263)
sjperkins Nov 19, 2018
bfa80dd
Constrain python2 astropy deps to ">= 2, < 3" (#264)
sjperkins Nov 19, 2018
0720c95
Support disjoint antenna UVW decompositions in antenna_uvw (#268)
sjperkins Jan 24, 2019
858a6c5
Add feed angle to parallactic angle in feed rotation matrix (#269)
sjperkins Jan 25, 2019
98928f1
Update cuda.py
bennahugo May 23, 2019
073ce49
Update cuda.py
bennahugo May 23, 2019
23eb2a0
Update cub.py
bennahugo May 23, 2019
0e0e012
Update cub.py
bennahugo May 23, 2019
12a095d
Update tensorflow_ops_ext.py
bennahugo May 23, 2019
aaf2380
py3 fixes
ratt-priv-ci May 23, 2019
c01a613
merge beamfix branch into main ddf branch
ratt-priv-ci May 24, 2019
986756d
Fix #270
ratt-priv-ci May 24, 2019
78a976e
Semi-automatic py2 to py3 conversion
ratt-priv-ci May 27, 2019
cc39f4e
residual py3 issues
ratt-priv-ci May 27, 2019
a7ab593
Fix py2 backwards compat
ratt-priv-ci May 27, 2019
3f095ce
depend on hypercube py3 fixes
ratt-priv-ci May 27, 2019
02da662
py3 fixes
ratt-priv-ci May 27, 2019
36e0550
Use six to write py2-3 compatible metaclasses
ratt-priv-ci May 27, 2019
64567df
py2-3 install issue
ratt-priv-ci May 27, 2019
f167e01
remove metadata class printing
ratt-priv-ci May 27, 2019
6c4dad5
py23 agnostic __builtin__ import
sjperkins May 28, 2019
878ce89
Fix method inspection within MetaClass constructor
sjperkins May 28, 2019
03afcc9
Fix residual py3 issues
ratt-priv-ci May 28, 2019
916c430
merge master into ddfacet branch
ratt-priv-ci May 29, 2019
5a456cd
Fix variable usage in setup.py
sjperkins May 31, 2019
85684b2
Make a proper Sersic Shape test
sjperkins May 31, 2019
a99dcf7
Make a proper gaussians shape test
sjperkins May 31, 2019
4f0af23
Feed rotation test spacing
sjperkins May 31, 2019
2061238
Fix radec_to_lm test library load
sjperkins May 31, 2019
65d7ec2
Make a proper complex phase test case
sjperkins May 31, 2019
77e501c
Remove py3 only constructs
ratt-priv-ci May 31, 2019
547008f
Merge branch 'ddfacet' of https://github.com/ska-sa/montblanc into dd…
ratt-priv-ci May 31, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions montblanc/impl/rime/tensorflow/RimeSolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,7 +951,7 @@ def _construct_tensorflow_expression(slvr_cfg, feed_data, device, shard):
feed_rotation = rime.feed_rotation(pa_sin, pa_cos, CT=CT,
feed_type=polarisation_type)

def antenna_jones(lm, stokes, alpha, ref_freq):
def antenna_jones(lm, stokes):
"""
Compute the jones terms for each antenna.

Expand All @@ -972,8 +972,7 @@ def antenna_jones(lm, stokes, alpha, ref_freq):

# Compute the square root of the brightness matrix
# (as well as the sign)
bsqrt, sgn_brightness = rime.b_sqrt(stokes, alpha,
D.frequency, ref_freq, CT=CT,
bsqrt, sgn_brightness = rime.b_sqrt(stokes, CT=CT,
polarisation_type=polarisation_type)

# Check for nans/infs in the bsqrt
Expand Down Expand Up @@ -1023,7 +1022,7 @@ def point_body(coherencies, npsrc, src_count):
npsrc += nsrc

ant_jones, sgn_brightness = antenna_jones(S.point_lm,
S.point_stokes, S.point_alpha, S.point_ref_freq)
S.point_stokes)
shape = tf.ones(shape=[nsrc,ntime,nbl,nchan], dtype=FT)
coherencies = rime.sum_coherencies(D.antenna1, D.antenna2,
shape, ant_jones, sgn_brightness, coherencies)
Expand All @@ -1040,7 +1039,7 @@ def gaussian_body(coherencies, ngsrc, src_count):
ngsrc += nsrc

ant_jones, sgn_brightness = antenna_jones(S.gaussian_lm,
S.gaussian_stokes, S.gaussian_alpha, S.gaussian_ref_freq)
S.gaussian_stokes)
gauss_shape = rime.gauss_shape(D.uvw, D.antenna1, D.antenna2,
D.frequency, S.gaussian_shape)
coherencies = rime.sum_coherencies(D.antenna1, D.antenna2,
Expand All @@ -1058,7 +1057,7 @@ def sersic_body(coherencies, nssrc, src_count):
nssrc += nsrc

ant_jones, sgn_brightness = antenna_jones(S.sersic_lm,
S.sersic_stokes, S.sersic_alpha, S.sersic_ref_freq)
S.sersic_stokes)
sersic_shape = rime.sersic_shape(D.uvw, D.antenna1, D.antenna2,
D.frequency, S.sersic_shape)
coherencies = rime.sum_coherencies(D.antenna1, D.antenna2,
Expand Down
42 changes: 3 additions & 39 deletions montblanc/impl/rime/tensorflow/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,24 +366,12 @@ def test_sersic_shape(self, context):
tags = "input, constant",
description = LM_DESCRIPTION.format(st="point"),
units = RADIANS),
array_dict('point_stokes', ('npsrc','ntime', 4), 'ft',
array_dict('point_stokes', ('npsrc','ntime', 'nchan', 4), 'ft',
default = default_stokes,
test = rand_stokes,
tags = "input, constant",
description = STOKES_DESCRIPTION,
units = JANSKYS),
array_dict('point_alpha', ('npsrc','ntime'), 'ft',
default = lambda s, c: np.zeros(c.shape, c.dtype),
test = lambda s, c: rf(c.shape, c.dtype)*0.1,
tags = "input, constant",
description = ALPHA_DESCRIPTION,
units = DIMENSIONLESS),
array_dict('point_ref_freq', ('npsrc',), 'ft',
default = lambda s, c: np.full(c.shape, _ref_freq, c.dtype),
test = lambda s, c: np.full(c.shape, _ref_freq, c.dtype),
tags = "input, constant",
description = REF_FREQ_DESCRIPTION.format(st="point"),
units = HERTZ),

# Gaussian Source Definitions
array_dict('gaussian_lm', ('ngsrc',2), 'ft',
Expand All @@ -392,24 +380,12 @@ def test_sersic_shape(self, context):
tags = "input, constant",
description = LM_DESCRIPTION.format(st="gaussian"),
units = RADIANS),
array_dict('gaussian_stokes', ('ngsrc','ntime', 4), 'ft',
array_dict('gaussian_stokes', ('ngsrc','ntime', 'nchan', 4), 'ft',
default = default_stokes,
test = rand_stokes,
tags = "input, constant",
description = STOKES_DESCRIPTION,
units = JANSKYS),
array_dict('gaussian_alpha', ('ngsrc','ntime'), 'ft',
default = lambda s, c: np.zeros(c.shape, c.dtype),
test = lambda s, c: rf(c.shape, c.dtype)*0.1,
tags = "input, constant",
description = ALPHA_DESCRIPTION,
units = DIMENSIONLESS),
array_dict('gaussian_ref_freq', ('ngsrc',), 'ft',
default = lambda s, c: np.full(c.shape, _ref_freq, c.dtype),
test = lambda s, c: np.full(c.shape, _ref_freq, c.dtype),
tags = "input, constant",
description = REF_FREQ_DESCRIPTION.format(st="gaussian"),
units = HERTZ),
array_dict('gaussian_shape', (3, 'ngsrc'), 'ft',
default = default_gaussian_shape,
test = rand_gaussian_shape,
Expand All @@ -427,24 +403,12 @@ def test_sersic_shape(self, context):
tags = "input, constant",
description = LM_DESCRIPTION.format(st="sersic"),
units = "Radians"),
array_dict('sersic_stokes', ('nssrc','ntime', 4), 'ft',
array_dict('sersic_stokes', ('nssrc','ntime', 'nchan', 4), 'ft',
default = default_stokes,
test = rand_stokes,
tags = "input, constant",
description = STOKES_DESCRIPTION,
units = JANSKYS),
array_dict('sersic_alpha', ('nssrc','ntime'), 'ft',
default = lambda s, c: np.zeros(c.shape, c.dtype),
test = lambda s, c: rf(c.shape, c.dtype)*0.1,
tags = "input, constant",
description = ALPHA_DESCRIPTION,
units = DIMENSIONLESS),
array_dict('sersic_ref_freq', ('nssrc',), 'ft',
default = lambda s, c: np.full(c.shape, _ref_freq, c.dtype),
test = lambda s, c: np.full(c.shape, _ref_freq, c.dtype),
tags = "input, constant",
description = REF_FREQ_DESCRIPTION.format(st="sersic"),
units = HERTZ),
array_dict('sersic_shape', (3, 'nssrc'), 'ft',
default = default_sersic_shape,
test = test_sersic_shape,
Expand Down
31 changes: 9 additions & 22 deletions montblanc/impl/rime/tensorflow/rime_ops/b_sqrt_op_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,35 +16,25 @@ auto bsqrt_shape_function = [](InferenceContext* c) {
DimensionHandle d;

ShapeHandle stokes = c->input(0);
ShapeHandle alpha = c->input(1);
ShapeHandle frequency = c->input(2);
ShapeHandle ref_freq = c->input(3);

TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(stokes, 3, &input),
"stokes shape must be [nsrc, ntime, 4] but is " + c->DebugString(stokes));
TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithValue(c->Dim(stokes, 2), 4, &d),
"stokes shape must be [nsrc, ntime, 4] but is " + c->DebugString(stokes));

TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(alpha, 2, &input),
"alpha shape must be [nsrc, ntime] but is " + c->DebugString(alpha));

TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(frequency, 1, &input),
"frequency shape must be [nchan,] but is " + c->DebugString(frequency));

TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(ref_freq, 1, &input),
"ref_freq shape must be [nsrc,] but is " + c->DebugString(ref_freq));
TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(stokes, 4, &input),
"stokes shape must be [nsrc, ntime, nchan, 4] but is " + c->DebugString(stokes));
TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithValue(c->Dim(stokes, 3), 4, &d),
"stokes shape must be [nsrc, ntime, nchan, 4] but is " + c->DebugString(stokes));

// bsqrt output is (nsrc, ntime, nchan, 4)
ShapeHandle bsqrt = c->MakeShape({
c->Dim(stokes, 0),
c->Dim(stokes, 1),
c->Dim(frequency, 0),
c->Dim(stokes, 2),
4});

// sgn_brightness output is (nsrc, ntime)
// sgn_brightness output is (nsrc, ntime, nchan)
ShapeHandle sgn_brightness = c->MakeShape({
c->Dim(stokes, 0),
c->Dim(stokes, 1)});
c->Dim(stokes, 1),
c->Dim(stokes, 2),
});

// Set the output shape
c->set_output(0, bsqrt);
Expand All @@ -55,9 +45,6 @@ auto bsqrt_shape_function = [](InferenceContext* c) {

REGISTER_OP("BSqrt")
.Input("stokes: FT")
.Input("alpha: FT")
.Input("frequency: FT")
.Input("ref_freq: FT")
.Output("b_sqrt: CT")
.Output("sgn_brightness: int8")
.Attr("FT: {float, double} = DT_FLOAT")
Expand Down
107 changes: 48 additions & 59 deletions montblanc/impl/rime/tensorflow/rime_ops/b_sqrt_op_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,11 @@ class BSqrt<CPUDevice, FT, CT> : public tensorflow::OpKernel

// Sanity check the input tensors
const tf::Tensor & in_stokes = context->input(0);
const tf::Tensor & in_alpha = context->input(1);
const tf::Tensor & in_frequency = context->input(2);
const tf::Tensor & in_ref_freq = context->input(3);

// Extract problem dimensions
int nsrc = in_stokes.dim_size(0);
int ntime = in_stokes.dim_size(1);
int nchan = in_frequency.dim_size(0);
int nchan = in_stokes.dim_size(2);

// Reason about the shape of the b_sqrt tensor and
// create a pointer to it
Expand All @@ -56,20 +53,17 @@ class BSqrt<CPUDevice, FT, CT> : public tensorflow::OpKernel
if (b_sqrt_ptr->NumElements() == 0)
{ return; }

// Reason about shape of the invert tensor
// Reason about shape of the sgn_brightness tensor
// and create a pointer to it
tf::TensorShape invert_shape({nsrc, ntime});
tf::Tensor * invert_ptr = nullptr;
tf::TensorShape sgn_brightness_shape({nsrc, ntime, nchan});
tf::Tensor * sgn_brightness_ptr = nullptr;

OP_REQUIRES_OK(context, context->allocate_output(
1, invert_shape, &invert_ptr));
1, sgn_brightness_shape, &sgn_brightness_ptr));

auto stokes = in_stokes.tensor<FT, 3>();
auto alpha = in_alpha.tensor<FT, 2>();
auto frequency = in_frequency.tensor<FT, 1>();
auto ref_freq = in_ref_freq.tensor<FT, 1>();
auto stokes = in_stokes.tensor<FT, 4>();
auto b_sqrt = b_sqrt_ptr->tensor<CT, 4>();
auto sgn_brightness = invert_ptr->tensor<tf::int8, 2>();
auto sgn_brightness = sgn_brightness_ptr->tensor<tf::int8, 3>();

// Linear polarisation or circular polarisation
bool linear = (polarisation_type == "linear");
Expand All @@ -89,58 +83,53 @@ class BSqrt<CPUDevice, FT, CT> : public tensorflow::OpKernel
{
for(int time=0; time < ntime; ++time)
{
// Reference stokes parameters.
// Input order of stokes parameters differs
// depending on whether linear or circular polarisation
// is used, but the rest of the calculation is the same...
FT I = stokes(src, time, iI);
FT Q = stokes(src, time, iQ);
FT U = stokes(src, time, iU);
FT V = stokes(src, time, iV);

// sgn variable, used to indicate whether
// brightness matrix is negative, zero or positive
// and a valid Cholesky decomposition
FT IQ = I + Q;
FT sgn = (zero < IQ) - (IQ < zero);
// I *= sign;
// Q *= sign;
U *= sgn;
V *= sgn;
IQ *= sgn;

// Indicate negative, zero or positive brightness matrix
sgn_brightness(src, time) = sgn;

// Compute cholesky decomposition
CT L00 = std::sqrt(CT(IQ, zero));
// Store L00 as a divisor of L10
CT div = L00;

// Gracefully handle zero matrices
if(IQ == zero)
{
div = CT(one, zero);
IQ = one;
}

CT L10 = CT(U, -V) / div;
FT L11_real = (I*I - Q*Q - U*U - V*V)/IQ;
CT L11 = std::sqrt(CT(L11_real, zero));

for(int chan=0; chan < nchan; ++chan)
{
// Compute square root of spectral index
FT psqrt = std::pow(
frequency(chan)/ref_freq(src),
alpha(src, time)*0.5);
// Reference stokes parameters.
// Input order of stokes parameters differs
// depending on whether linear or circular polarisation
// is used, but the rest of the calculation is the same...
FT I = stokes(src, time, chan, iI);
FT Q = stokes(src, time, chan, iQ);
FT U = stokes(src, time, chan, iU);
FT V = stokes(src, time, chan, iV);

// sgn variable, used to indicate whether
// brightness matrix is negative, zero or positive
// and a valid Cholesky decomposition
FT IQ = I + Q;
FT sgn = (zero < IQ) - (IQ < zero);
// I *= sign;
// Q *= sign;
U *= sgn;
V *= sgn;
IQ *= sgn;

// Indicate negative, zero or positive brightness matrix
sgn_brightness(src, time, chan) = sgn;

// Compute cholesky decomposition
CT L00 = std::sqrt(CT(IQ, zero));
// Store L00 as a divisor of L10
CT div = L00;

// Gracefully handle zero matrices
if(IQ == zero)
{
div = CT(one, zero);
IQ = one;
}

CT L10 = CT(U, -V) / div;
FT L11_real = (I*I - Q*Q - U*U - V*V)/IQ;
CT L11 = std::sqrt(CT(L11_real, zero));

// Assign square root of the brightness matrix,
// computed via cholesky decomposition
b_sqrt(src, time, chan, XX) = L00*psqrt;
b_sqrt(src, time, chan, XX) = L00;
b_sqrt(src, time, chan, XY) = 0.0;
b_sqrt(src, time, chan, YX) = L10*psqrt;
b_sqrt(src, time, chan, YY) = L11*psqrt;
b_sqrt(src, time, chan, YX) = L10;
b_sqrt(src, time, chan, YY) = L11;
}
}
}
Expand Down
Loading