Skip to content

Commit

Permalink
add Spergel in the changelog, 2) use _shootxnorm Galsim variable inst…
Browse files Browse the repository at this point in the history
…ead of _Nnu & polish
  • Loading branch information
jecampagne committed Jan 15, 2024
1 parent 89fb1d8 commit 15508ea
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 21 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
* `makePhot`
* `drawPhot`
* Added implementation of simple light profiles:
* `Gaussian`, `Exponential`, `Pixel`, `Box`, `Moffat`, `DeltaFunction`
* `Gaussian`, `Exponential`, `Pixel`, `Box`, `Moffat`, `Spergel`, `DeltaFunction`
* Added implementation of simple WCS:
* `PixelScale`, `OffsetWCS`, `JacobianWCS`, `AffineTransform`, `ShearWCS`, `OffsetShearWCS`, `GSFitsWCS`, `FitsWCS`, `TanWCS`
* Added automated suite of tests using the reference GalSim and LSSTDESC-Coord test suites
Expand Down
30 changes: 10 additions & 20 deletions jax_galsim/spergel.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,23 +148,21 @@ def fnu(z, nu):

@jax.jit
def fz_nup1(z, nu):
"""Return z^(nu+1) K_{nu+1}(z)
Spergel index nu in [-0.85, 4.]
"""
"""z^(nu+1) K_{nu+1}(z)"""
return jnp.where(
z <= 1.0e-10, fsmallz_nup1(z, nu), jnp.power(z, nu + 1.0) * _Knu(nu + 1.0, z)
)


@jax.jit
def fluxfractionFunc(z, nu, alpha):
"""Return 1 - z^(nu+1) K_{nu+1}(z) / (2^nu Gamma(nu+1)) - alpha"""
"""1 - z^(nu+1) K_{nu+1}(z) / (2^nu Gamma(nu+1)) - alpha"""
return 1.0 - fz_nup1(z, nu) / (jnp.power(2.0, nu) * _gammap1(nu)) - alpha


@jax.jit
def reducedfluxfractionFunc(z, nu, norm):
"""Return (1 - z^(nu+1) K_{nu+1}(z) / (2^nu Gamma(nu+1)))/norm"""
"""(1 - z^(nu+1) K_{nu+1}(z) / (2^nu Gamma(nu+1)))/norm"""
return fluxfractionFunc(z, nu, alpha=0.0) / norm


Expand Down Expand Up @@ -261,11 +259,6 @@ def nu(self):
"""The Spergel index, nu"""
return self._params["nu"]

@property
def _Nnu(self):
"""2^nu Gamma(nu+1)"""
return jnp.power(2.0, self.nu) * _gammap1(self.nu)

@property
def scale_radius(self):
"""The scale radius of this `Spergel` profile."""
Expand Down Expand Up @@ -406,9 +399,10 @@ def _shoot_pos_cdf(self):
cdf = preducedfluxfractionFunc(z_cdf)
return z_cdf, cdf

def _shoot_pos(self, u, dum1=None, dum2=None):
def _shoot_pos(self, u):
# shoot r in case of nu>0
z_cdf, cdf = self._shoot_pos_cdf
z = jnp.interp(u, cdf, z_cdf) # inversion of the CDF
z = jnp.interp(u, cdf, z_cdf) # linear inversion of the CDF
r = z * self._r0
return r

Expand All @@ -434,7 +428,7 @@ def _shoot_neg_cdf(self):
shoot_rmin = calculateFluxRadius(flux_target, self.nu)
knur = fz_nu(shoot_rmin, self.nu)

corrFact = 1.0 / (2 * jnp.pi * self._Nnu) # this factor correct
corrFact = self._shootxnorm # this is the correct normalisation
b = knur - flux_target / (jnp.pi * shoot_rmin * shoot_rmin * corrFact)
b = 3.0 * b / shoot_rmin
a = knur - shoot_rmin * b
Expand All @@ -458,22 +452,18 @@ def cumulflux(z, a, b, zmin, nu, norm=1.0):
cdf = preducedfluxfractionFunc(z_cdf)
return z_cdf, cdf

def _shoot_neg(self, u, dum=None):
# for nu<=0
# computes as in Galsim even if suspicious code
def _shoot_neg(self, u):
# shoot r in case of nu<=0
z_cdf, cdf = self._shoot_neg_cdf
z = jnp.interp(u, cdf, z_cdf) # inversion of the CDF
z = jnp.interp(u, cdf, z_cdf) # linear inversion of the CDF
r = z * self._r0
return r

@_wraps(_galsim.Spergel._shoot)
def _shoot(self, photons, rng):
ud = UniformDeviate(rng)

u = ud.generate(photons.x)

r = jax.lax.select(self.nu > 0, self._shoot_pos(u), self._shoot_neg(u))

ang = ud.generate(photons.x) * 2.0 * jnp.pi
photons.x = r * jnp.cos(ang)
photons.y = r * jnp.sin(ang)
Expand Down

0 comments on commit 15508ea

Please sign in to comment.