Skip to content

Commit

Permalink
Merge pull request #241 from astro-informatics/mmg/iterative-refinement
Browse files Browse the repository at this point in the history
Iterative refinement support for JAX and NumPy forward (spherical) transform implementations
  • Loading branch information
matt-graham authored Dec 18, 2024
2 parents 8f6e4d5 + 23fb7af commit a9e7c0c
Show file tree
Hide file tree
Showing 11 changed files with 273 additions and 76 deletions.
38 changes: 27 additions & 11 deletions s2fft/base_transforms/spherical.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from functools import partial
from warnings import warn

import numpy as np

from s2fft import recursions
from s2fft.sampling import s2_samples as samples
from s2fft.utils import healpix_ffts as hp
from s2fft.utils import quadrature, resampling
from s2fft.utils import iterative_refinement, quadrature, resampling


def inverse(
Expand Down Expand Up @@ -138,6 +139,7 @@ def forward(
nside: int = None,
reality: bool = False,
L_lower: int = 0,
iter: int = 0,
) -> np.ndarray:
r"""
Compute forward spherical harmonic transform.
Expand All @@ -164,20 +166,34 @@ def forward(
L_lower (int, optional): Harmonic lower-bound. Transform will only be computed
for :math:`\texttt{L_lower} \leq \ell < \texttt{L}`. Defaults to 0.
iter (int, optional): Number of iterative refinement iterations to use to
improve accuracy of forward transform (as an inverse of inverse transform).
Primarily of use with HEALPix sampling for which there is not a sampling
theorem, and round-tripping through the forward and inverse transforms will
introduce an error.
Returns:
np.ndarray: Spherical harmonic coefficients.
"""
return _forward(
f,
L,
spin,
sampling,
nside=nside,
method="sov_fft_vectorized",
reality=reality,
L_lower=L_lower,
)
common_kwargs = {
"L": L,
"spin": spin,
"sampling": sampling,
"nside": nside,
"method": "sov_fft_vectorized",
"reality": reality,
"L_lower": L_lower,
}
if iter == 0:
return _forward(f, **common_kwargs)
else:
return iterative_refinement.forward_with_iterative_refinement(
f,
n_iter=iter,
forward_function=partial(_forward, **common_kwargs),
backward_function=partial(_inverse, **common_kwargs),
)


def _forward(
Expand Down
101 changes: 81 additions & 20 deletions s2fft/precompute_transforms/spherical.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,32 @@
from functools import partial
from typing import Optional
from warnings import warn

import jax.numpy as jnp
import numpy as np
import torch
from jax import jit

from s2fft.precompute_transforms import construct
from s2fft.sampling import s2_samples as samples
from s2fft.utils import healpix_ffts as hp
from s2fft.utils import resampling, resampling_jax, resampling_torch
from s2fft.utils import (
iterative_refinement,
resampling,
resampling_jax,
resampling_torch,
)


def inverse(
flm: np.ndarray,
L: int,
spin: int = 0,
kernel: np.ndarray = None,
kernel: Optional[np.ndarray] = None,
sampling: str = "mw",
reality: bool = False,
method: str = "jax",
nside: int = None,
nside: Optional[int] = None,
) -> np.ndarray:
r"""
Compute the inverse spherical harmonic transform via precompute.
Expand Down Expand Up @@ -55,21 +62,28 @@ def inverse(
np.ndarray: Pixel-space coefficients with shape.
"""
if method not in _inverse_functions:
raise ValueError(f"Method {method} not recognised.")
if reality and spin != 0:
reality = False
warn(
"Reality acceleration only supports spin 0 fields. "
+ "Defering to complex transform.",
stacklevel=2,
)
if method == "numpy":
return inverse_transform(flm, kernel, L, sampling, reality, spin, nside)
elif method == "jax":
return inverse_transform_jax(flm, kernel, L, sampling, reality, spin, nside)
elif method == "torch":
return inverse_transform_torch(flm, kernel, L, sampling, reality, spin, nside)
else:
raise ValueError(f"Method {method} not recognised.")
common_kwargs = {
"L": L,
"sampling": sampling,
"reality": reality,
"spin": spin,
"nside": nside,
}
kernel = (
_kernel_functions[method](forward=False, **common_kwargs)
if kernel is None
else kernel
)
return _inverse_functions[method](flm, kernel, **common_kwargs)


def inverse_transform(
Expand Down Expand Up @@ -290,11 +304,12 @@ def forward(
f: np.ndarray,
L: int,
spin: int = 0,
kernel: np.ndarray = None,
kernel: Optional[np.ndarray] = None,
sampling: str = "mw",
reality: bool = False,
method: str = "jax",
nside: int = None,
nside: Optional[int] = None,
iter: int = 0,
) -> np.ndarray:
r"""
Compute the forward spherical harmonic transform via precompute.
Expand All @@ -321,6 +336,12 @@ def forward(
nside (int): HEALPix Nside resolution parameter. Only required
if sampling="healpix".
iter (int, optional): Number of iterative refinement iterations to use to
improve accuracy of forward transform (as an inverse of inverse transform).
Primarily of use with HEALPix sampling for which there is not a sampling
theorem, and round-tripping through the forward and inverse transforms will
introduce an error.
Raises:
ValueError: Transform method not recognised.
Expand All @@ -330,21 +351,41 @@ def forward(
np.ndarray: Spherical harmonic coefficients.
"""
if method not in _forward_functions:
raise ValueError(f"Method {method} not recognised.")
if reality and spin != 0:
reality = False
warn(
"Reality acceleration only supports spin 0 fields. "
+ "Defering to complex transform.",
stacklevel=2,
)
if method == "numpy":
return forward_transform(f, kernel, L, sampling, reality, spin, nside)
elif method == "jax":
return forward_transform_jax(f, kernel, L, sampling, reality, spin, nside)
elif method == "torch":
return forward_transform_torch(f, kernel, L, sampling, reality, spin, nside)
common_kwargs = {
"L": L,
"sampling": sampling,
"reality": reality,
"spin": spin,
"nside": nside,
}
kernel = (
_kernel_functions[method](forward=True, **common_kwargs)
if kernel is None
else kernel
)
if iter == 0:
return _forward_functions[method](f, kernel, **common_kwargs)
else:
raise ValueError(f"Method {method} not recognised.")
inverse_kernel = _kernel_functions[method](forward=False, **common_kwargs)
return iterative_refinement.forward_with_iterative_refinement(
f=f,
n_iter=iter,
forward_function=partial(
_forward_functions[method], kernel=kernel, **common_kwargs
),
backward_function=partial(
_inverse_functions[method], kernel=inverse_kernel, **common_kwargs
),
)


def forward_transform(
Expand Down Expand Up @@ -567,3 +608,23 @@ def forward_transform_torch(
)

return flm * (-1) ** spin


_inverse_functions = {
"numpy": inverse_transform,
"jax": inverse_transform_jax,
"torch": inverse_transform_torch,
}


_forward_functions = {
"numpy": forward_transform,
"jax": forward_transform_jax,
"torch": forward_transform_torch,
}

_kernel_functions = {
"numpy": partial(construct.spin_spherical_kernel, using_torch=False),
"jax": construct.spin_spherical_kernel_jax,
"torch": partial(construct.spin_spherical_kernel, using_torch=True),
}
15 changes: 9 additions & 6 deletions s2fft/transforms/c_backend_spherical.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from functools import partial

import healpy
import jax.numpy as jnp
import numpy as np
Expand All @@ -8,7 +10,7 @@
from jax.interpreters import ad

from s2fft.sampling import reindex
from s2fft.utils import quadrature_jax
from s2fft.utils import iterative_refinement, quadrature_jax


@custom_vjp
Expand Down Expand Up @@ -427,11 +429,12 @@ def healpy_forward(f: jnp.ndarray, L: int, nside: int, iter: int = 3) -> jnp.nda
Astrophysical Journal 622.2 (2005): 759
"""
flm = healpy_map2alm(f, L, nside)
for _ in range(iter):
f_recov = healpy_alm2map(flm, L, nside)
f_error = f - f_recov
flm += healpy_map2alm(f_error, L, nside)
flm = iterative_refinement.forward_with_iterative_refinement(
f=f,
n_iter=iter,
forward_function=partial(healpy_map2alm, L=L, nside=nside),
backward_function=partial(healpy_alm2map, L=L, nside=nside),
)
return reindex.flm_hp_to_2d_fast(flm, L)


Expand Down
6 changes: 6 additions & 0 deletions s2fft/transforms/otf_recursions.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ def inverse_latitudinal_step(
precomps = generate_precomputes(L, -mm, sampling, nside, L_lower)
lrenorm, vsign, cpi, cp2, indices = precomps

# Create copy to prevent in-place updates propagating to caller
lrenorm = lrenorm.copy()

for i in range(2):
if not (reality and i == 0):
m_offset = 1 if sampling in ["mwss", "healpix"] and i == 0 else 0
Expand Down Expand Up @@ -490,6 +493,9 @@ def forward_latitudinal_step(
precomps = generate_precomputes(L, -mm, sampling, nside, True, L_lower)
lrenorm, vsign, cpi, cp2, indices = precomps

# Create copy to prevent in-place updates propagating to caller
lrenorm = lrenorm.copy()

for i in range(2):
if not (reality and i == 0):
m_offset = 1 if sampling in ["mwss", "healpix"] and i == 0 else 0
Expand Down
Loading

0 comments on commit a9e7c0c

Please sign in to comment.