diff --git a/s2fft/sampling/reindex.py b/s2fft/sampling/reindex.py index 97d3bb85..15e880d7 100644 --- a/s2fft/sampling/reindex.py +++ b/s2fft/sampling/reindex.py @@ -1,13 +1,14 @@ from functools import partial import jax.numpy as jnp +import numpy as np from jax import jit @partial(jit, static_argnums=(1)) def flm_1d_to_2d_fast(flm_1d: jnp.ndarray, L: int) -> jnp.ndarray: r""" - Convert from 1D indexed harmnonic coefficients to 2D indexed coefficients (JAX). + Convert from 1D indexed harmonic coefficients to 2D indexed coefficients (JAX). Note: Storage conventions for harmonic coefficients :math:`flm_{(\ell,m)}`, for @@ -35,13 +36,12 @@ def flm_1d_to_2d_fast(flm_1d: jnp.ndarray, L: int) -> jnp.ndarray: jnp.ndarray: 2D indexed harmonic coefficients. """ - flm_2d = jnp.zeros((L, 2 * L - 1), dtype=jnp.complex128) - els = jnp.arange(L) - offset = els**2 + els - for el in range(L): - m_array = jnp.arange(-el, el + 1) - flm_2d = flm_2d.at[el, L - 1 + m_array].set(flm_1d[offset[el] + m_array]) - return flm_2d + flm_2d = jnp.zeros((L, 2 * L - 1), dtype=flm_1d.dtype) + row_indices, col_indices = np.arange(L)[:, None], np.arange(2 * L - 1)[None, :] + el_indices, m_indices = np.where( + (row_indices <= col_indices)[::-1, :] & (row_indices <= col_indices)[::-1, ::-1] + ) + return flm_2d.at[el_indices, m_indices].set(flm_1d) @partial(jit, static_argnums=(1)) @@ -75,13 +75,11 @@ def flm_2d_to_1d_fast(flm_2d: jnp.ndarray, L: int) -> jnp.ndarray: jnp.ndarray: 1D indexed harmonic coefficients. """ - flm_1d = jnp.zeros(L**2, dtype=jnp.complex128) - els = jnp.arange(L) - offset = els**2 + els - for el in range(L): - m_array = jnp.arange(-el, el + 1) - flm_1d = flm_1d.at[offset[el] + m_array].set(flm_2d[el, L - 1 + m_array]) - return flm_1d + row_indices, col_indices = np.arange(L)[:, None], np.arange(2 * L - 1)[None, :] + el_indices, m_indices = np.where( + (row_indices <= col_indices)[::-1, :] & (row_indices <= col_indices)[::-1, ::-1] + ) + return flm_2d[el_indices, m_indices] @partial(jit, static_argnums=(1)) @@ -127,17 +125,13 @@ def flm_hp_to_2d_fast(flm_hp: jnp.ndarray, L: int) -> jnp.ndarray: jnp.ndarray: 2D indexed harmonic coefficients. """ - flm_2d = jnp.zeros((L, 2 * L - 1), dtype=jnp.complex128) - - for el in range(L): - flm_2d = flm_2d.at[el, L - 1].set(flm_hp[el]) - m_array = jnp.arange(1, el + 1) - hp_idx = m_array * (2 * L - 1 - m_array) // 2 + el - flm_2d = flm_2d.at[el, L - 1 + m_array].set(flm_hp[hp_idx]) - flm_2d = flm_2d.at[el, L - 1 - m_array].set( - (-1) ** m_array * jnp.conj(flm_hp[hp_idx]) - ) - + flm_2d = jnp.zeros((L, 2 * L - 1), dtype=flm_hp.dtype) + m_indices, el_indices = np.triu_indices(n=L, k=1, m=L) + np.array([[1], [0]]) + flm_2d = flm_2d.at[:L, L - 1].set(flm_hp[:L]) + flm_2d = flm_2d.at[el_indices, L - 1 + m_indices].set(flm_hp[L:]) + flm_2d = flm_2d.at[el_indices, L - 1 - m_indices].set( + (-1) ** m_indices * flm_hp[L:].conj() + ) return flm_2d @@ -185,11 +179,5 @@ def flm_2d_to_hp_fast(flm_2d: jnp.ndarray, L: int) -> jnp.ndarray: jnp.ndarray: HEALPix indexed harmonic coefficients. """ - flm_hp = jnp.zeros(int(L * (L + 1) / 2), dtype=jnp.complex128) - - for el in range(L): - m_array = jnp.arange(el + 1) - hp_idx = m_array * (2 * L - 1 - m_array) // 2 + el - flm_hp = flm_hp.at[hp_idx].set(flm_2d[el, L - 1 + m_array]) - - return flm_hp + m_indices, el_indices = np.triu_indices(n=L + 1, m=L) + return flm_2d[el_indices, L - 1 + m_indices]