Skip to content

Commit

Permalink
Merge pull request #245 from astro-informatics/mmg/reindexing-loops
Browse files Browse the repository at this point in the history
Avoid loops in `s2fft.sampling.reindex` functions to reduce compile and run times
  • Loading branch information
matt-graham authored Nov 26, 2024
2 parents 909e6f1 + d2c8f54 commit 5210481
Showing 1 changed file with 22 additions and 34 deletions.
56 changes: 22 additions & 34 deletions s2fft/sampling/reindex.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from functools import partial

import jax.numpy as jnp
import numpy as np
from jax import jit


@partial(jit, static_argnums=(1))
def flm_1d_to_2d_fast(flm_1d: jnp.ndarray, L: int) -> jnp.ndarray:
r"""
Convert from 1D indexed harmnonic coefficients to 2D indexed coefficients (JAX).
Convert from 1D indexed harmonic coefficients to 2D indexed coefficients (JAX).
Note:
Storage conventions for harmonic coefficients :math:`flm_{(\ell,m)}`, for
Expand Down Expand Up @@ -35,13 +36,12 @@ def flm_1d_to_2d_fast(flm_1d: jnp.ndarray, L: int) -> jnp.ndarray:
jnp.ndarray: 2D indexed harmonic coefficients.
"""
flm_2d = jnp.zeros((L, 2 * L - 1), dtype=jnp.complex128)
els = jnp.arange(L)
offset = els**2 + els
for el in range(L):
m_array = jnp.arange(-el, el + 1)
flm_2d = flm_2d.at[el, L - 1 + m_array].set(flm_1d[offset[el] + m_array])
return flm_2d
flm_2d = jnp.zeros((L, 2 * L - 1), dtype=flm_1d.dtype)
row_indices, col_indices = np.arange(L)[:, None], np.arange(2 * L - 1)[None, :]
el_indices, m_indices = np.where(
(row_indices <= col_indices)[::-1, :] & (row_indices <= col_indices)[::-1, ::-1]
)
return flm_2d.at[el_indices, m_indices].set(flm_1d)


@partial(jit, static_argnums=(1))
Expand Down Expand Up @@ -75,13 +75,11 @@ def flm_2d_to_1d_fast(flm_2d: jnp.ndarray, L: int) -> jnp.ndarray:
jnp.ndarray: 1D indexed harmonic coefficients.
"""
flm_1d = jnp.zeros(L**2, dtype=jnp.complex128)
els = jnp.arange(L)
offset = els**2 + els
for el in range(L):
m_array = jnp.arange(-el, el + 1)
flm_1d = flm_1d.at[offset[el] + m_array].set(flm_2d[el, L - 1 + m_array])
return flm_1d
row_indices, col_indices = np.arange(L)[:, None], np.arange(2 * L - 1)[None, :]
el_indices, m_indices = np.where(
(row_indices <= col_indices)[::-1, :] & (row_indices <= col_indices)[::-1, ::-1]
)
return flm_2d[el_indices, m_indices]


@partial(jit, static_argnums=(1))
Expand Down Expand Up @@ -127,17 +125,13 @@ def flm_hp_to_2d_fast(flm_hp: jnp.ndarray, L: int) -> jnp.ndarray:
jnp.ndarray: 2D indexed harmonic coefficients.
"""
flm_2d = jnp.zeros((L, 2 * L - 1), dtype=jnp.complex128)

for el in range(L):
flm_2d = flm_2d.at[el, L - 1].set(flm_hp[el])
m_array = jnp.arange(1, el + 1)
hp_idx = m_array * (2 * L - 1 - m_array) // 2 + el
flm_2d = flm_2d.at[el, L - 1 + m_array].set(flm_hp[hp_idx])
flm_2d = flm_2d.at[el, L - 1 - m_array].set(
(-1) ** m_array * jnp.conj(flm_hp[hp_idx])
)

flm_2d = jnp.zeros((L, 2 * L - 1), dtype=flm_hp.dtype)
m_indices, el_indices = np.triu_indices(n=L, k=1, m=L) + np.array([[1], [0]])
flm_2d = flm_2d.at[:L, L - 1].set(flm_hp[:L])
flm_2d = flm_2d.at[el_indices, L - 1 + m_indices].set(flm_hp[L:])
flm_2d = flm_2d.at[el_indices, L - 1 - m_indices].set(
(-1) ** m_indices * flm_hp[L:].conj()
)
return flm_2d


Expand Down Expand Up @@ -185,11 +179,5 @@ def flm_2d_to_hp_fast(flm_2d: jnp.ndarray, L: int) -> jnp.ndarray:
jnp.ndarray: HEALPix indexed harmonic coefficients.
"""
flm_hp = jnp.zeros(int(L * (L + 1) / 2), dtype=jnp.complex128)

for el in range(L):
m_array = jnp.arange(el + 1)
hp_idx = m_array * (2 * L - 1 - m_array) // 2 + el
flm_hp = flm_hp.at[hp_idx].set(flm_2d[el, L - 1 + m_array])

return flm_hp
m_indices, el_indices = np.triu_indices(n=L + 1, m=L)
return flm_2d[el_indices, L - 1 + m_indices]

0 comments on commit 5210481

Please sign in to comment.