Skip to content

Commit

Permalink
chore: code clean up (#117)
Browse files Browse the repository at this point in the history
  • Loading branch information
beckermr authored Sep 13, 2024
1 parent 047cf19 commit 746975b
Show file tree
Hide file tree
Showing 15 changed files with 169 additions and 436 deletions.
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Copyright (c) 2012-2022 by the GalSim developers team on GitHub
Copyright (c) 2012-2024 by the GalSim developers team on GitHub
https://github.com/GalSim-developers

Redistribution and use in source and binary forms, with or without
Expand Down
10 changes: 3 additions & 7 deletions jax_galsim/angle.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# original source license:
#
# Copyright (c) 2013-2017 LSST Dark Energy Science Collaboration (DESC)
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
Expand Down Expand Up @@ -176,21 +178,15 @@ def __sub__(self, other):
return _Angle(self._rad - other._rad)

def __mul__(self, other):
# if other != float(other):
# raise TypeError("Cannot multiply Angle by %s of type %s" % (other, type(other)))
return _Angle(self._rad * other)

__rmul__ = __mul__

def __div__(self, other):
if isinstance(other, AngleUnit):
return self._rad / other.value
elif other == float(other):
return _Angle(self._rad / other)
else:
raise TypeError(
"Cannot divide Angle by %s of type %s" % (other, type(other))
)
return _Angle(self._rad / other)

__truediv__ = __div__

Expand Down
152 changes: 152 additions & 0 deletions jax_galsim/bessel.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,155 @@ def kv(nu, x):
nu = 1.0 * nu
x = 1.0 * x
return _tfp_bessel_kve(nu, x) / jnp.exp(jnp.abs(x))


@jax.jit
def _R(z, num, denom):
return jnp.polyval(num, z) / jnp.polyval(denom, z)


@jax.jit
def _evaluate_rational(z, num, denom):
return _R(z, num[::-1], denom[::-1])


# jitted & vectorized version
_v_rational = jax.jit(jax.vmap(_evaluate_rational, in_axes=(0, None, None)))


@implements(
_galsim.bessel.j0,
lax_description="""\
The JAX-GalSim implementation of ``j0`` is a vectorized version of the Boost C++
algorith for the Bessel function of the first kind J0(x).""",
)
@jax.jit
def j0(x):
orig_shape = x.shape

x = jnp.atleast_1d(x)

P1 = jnp.array(
[
-4.1298668500990866786e11,
2.7282507878605942706e10,
-6.2140700423540120665e08,
6.6302997904833794242e06,
-3.6629814655107086448e04,
1.0344222815443188943e02,
-1.2117036164593528341e-01,
]
)
Q1 = jnp.array(
[
2.3883787996332290397e12,
2.6328198300859648632e10,
1.3985097372263433271e08,
4.5612696224219938200e05,
9.3614022392337710626e02,
1.0,
0.0,
]
)

P2 = jnp.array(
[
-1.8319397969392084011e03,
-1.2254078161378989535e04,
-7.2879702464464618998e03,
1.0341910641583726701e04,
1.1725046279757103576e04,
4.4176707025325087628e03,
7.4321196680624245801e02,
4.8591703355916499363e01,
]
)
Q2 = jnp.array(
[
-3.5783478026152301072e05,
2.4599102262586308984e05,
-8.4055062591169562211e04,
1.8680990008359188352e04,
-2.9458766545509337327e03,
3.3307310774649071172e02,
-2.5258076240801555057e01,
1.0,
]
)

PC = jnp.array(
[
2.2779090197304684302e04,
4.1345386639580765797e04,
2.1170523380864944322e04,
3.4806486443249270347e03,
1.5376201909008354296e02,
8.8961548424210455236e-01,
]
)
QC = jnp.array(
[
2.2779090197304684318e04,
4.1370412495510416640e04,
2.1215350561880115730e04,
3.5028735138235608207e03,
1.5711159858080893649e02,
1.0,
]
)

PS = jnp.array(
[
-8.9226600200800094098e01,
-1.8591953644342993800e02,
-1.1183429920482737611e02,
-2.2300261666214198472e01,
-1.2441026745835638459e00,
-8.8033303048680751817e-03,
]
)
QS = jnp.array(
[
5.7105024128512061905e03,
1.1951131543434613647e04,
7.2642780169211018836e03,
1.4887231232283756582e03,
9.0593769594993125859e01,
1.0,
]
)

x1 = 2.4048255576957727686e00
x2 = 5.5200781102863106496e00
x11 = 6.160e02
x12 = -1.42444230422723137837e-03
x21 = 1.4130e03
x22 = 5.46860286310649596604e-04
one_div_root_pi = 5.641895835477562869480794515607725858e-01

def t1(x): # x<=4
y = x * x
r = _v_rational(y, P1, Q1)
factor = (x + x1) * ((x - x11 / 256) - x12)
return factor * r

def t2(x): # x<=8
y = 1 - (x * x) / 64
r = _v_rational(y, P2, Q2)
factor = (x + x2) * ((x - x21 / 256) - x22)
return factor * r

def t3(x): # x>8
y = 8 / x
y2 = y * y
rc = _v_rational(y2, PC, QC)
rs = _v_rational(y2, PS, QS)
factor = one_div_root_pi / jnp.sqrt(x)
sx = jnp.sin(x)
cx = jnp.cos(x)
return factor * (rc * (cx + sx) - y * rs * (sx - cx))

x = jnp.abs(x)
return jnp.select(
[x == 0, x <= 4, x <= 8, x > 8], [1, t1(x), t2(x), t3(x)], default=x
).reshape(orig_shape)
1 change: 0 additions & 1 deletion jax_galsim/bounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
"""


# The reason for avoid these tests is that they are not easy to do for jitted code.
@implements(_galsim.Bounds, lax_description=BOUNDS_LAX_DESCR)
@register_pytree_node_class
class Bounds(_galsim.Bounds):
Expand Down
2 changes: 1 addition & 1 deletion jax_galsim/box.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def tree_unflatten(cls, aux_data, children):
**aux_data,
)

@implements(_galsim.Box._shoot)
def _shoot(self, photons, rng):
ud = UniformDeviate(rng)

Expand All @@ -135,7 +136,6 @@ def __init__(self, scale, flux=1.0, gsparams=None):
@property
@implements(_galsim.Pixel.scale)
def scale(self):
"""The linear scale size of the `Pixel`."""
return self.width

def __repr__(self):
Expand Down
7 changes: 2 additions & 5 deletions jax_galsim/celestial.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# original source license:
#
# Copyright (c) 2013-2017 LSST Dark Energy Science Collaboration (DESC)
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
Expand Down Expand Up @@ -71,8 +73,6 @@ def __init__(self, ra, dec=None):
raise TypeError("ra must be a galsim.Angle")
elif not isinstance(dec, Angle):
raise TypeError("dec must be a galsim.Angle")
# elif dec/degrees > 90. or dec/degrees < -90.:
# raise ValueError("dec must be between -90 deg and +90 deg.")
else:
# Normal case
self._ra = ra
Expand Down Expand Up @@ -130,9 +130,6 @@ def get_xyz(self):
)
def from_xyz(x, y, z):
norm = jnp.sqrt(x * x + y * y + z * z)
# JAX cannot check this condition
# if norm == 0.:
# raise ValueError("CelestialCoord for position (0,0,0) is undefined.")
ret = CelestialCoord.__new__(CelestialCoord)
ret._x = x / norm
ret._y = y / norm
Expand Down
5 changes: 1 addition & 4 deletions jax_galsim/convolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ def _kValue(self, kpos):
def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0):
raise NotImplementedError("Real-space convolutions are not implemented")

@implements(_galsim.Convolution._shoot)
def _shoot(self, photons, rng):
self.obj_list[0]._shoot(photons, rng)
# It may be necessary to shuffle when convolving because we do not have a
Expand Down Expand Up @@ -342,10 +343,6 @@ def tree_unflatten(cls, aux_data, children):
lax_description="Does not support ChromaticDeconvolution",
)
def Deconvolve(obj, gsparams=None, propagate_gsparams=True):
# from .chromatic import ChromaticDeconvolution
# if isinstance(obj, ChromaticObject):
# return ChromaticDeconvolution(obj, gsparams=gsparams, propagate_gsparams=propagate_gsparams)
# elif isinstance(obj, GSObject):
if isinstance(obj, GSObject):
return Deconvolution(
obj, gsparams=gsparams, propagate_gsparams=propagate_gsparams
Expand Down
Loading

0 comments on commit 746975b

Please sign in to comment.