diff --git a/signal.py b/signal.py index d275552..9a44611 100644 --- a/signal.py +++ b/signal.py @@ -9,7 +9,7 @@ from draco.util import tools from draco.core.containers import FrequencyStackByPol, Powerspec1D, Powerspec2D -from draco.analysis.powerspec import get_1d_ps +from draco.analysis.powerspec import get_1d_ps, get_1d_ps_mcmc from . import utils @@ -584,6 +584,8 @@ def __init__( nbins: int = 10, logbins: bool = True, ): + # Add a new instance variable to store the power spectrum calculator + self._ps_calculator = None if derivs is None: derivs = { @@ -595,6 +597,7 @@ def __init__( self._aliases = aliases if aliases is not None else {} self._nbins = nbins self._logbins = logbins + self._mcmc_binning_cache = None logger.debug(f"Using deriv modes: {self._derivs}") logger.debug(f"Using aliases: {self._aliases}") logger.debug(f"Using factor: {self._factor}") @@ -603,6 +606,43 @@ def __init__( f"{'log-spaced' if self._logbins else 'linearly-spaced'} bins" ) + + def _cache_mcmc_binning(self): + """Cache the binning calculations from get_1d_ps.""" + # This will store all the pre-computed values needed for each polarization + cache = {} + + for ipol in range(self._signal_mask.shape[0]): + kpp, kll = np.meshgrid(self._kperp, self._kpar) + k = np.sqrt(kpp**2 + kll**2) + + # Apply signal window if present + if self._signal_mask is not None: + k = k[self._signal_mask[ipol]] + + # Flatten arrays + k1D = k.flatten() + + # Calculate bin edges + kmin = k1D[k1D > 0].min() + kmax = k1D.max() + + if self._logbins: + kbins = np.logspace(np.log10(kmin), np.log10(kmax), self._nbins + 1) + else: + kbins = np.linspace(kmin, kmax, self._nbins + 1) + + indices = np.digitize(k1D, kbins) + + # Store everything needed for this polarization + cache[ipol] = { + 'indices': indices, + 'kbins': kbins + } + + self._mcmc_binning_cache = cache + logger.debug("MCMC binning calculations cached") + @classmethod def load_from_ps2Dfiles( cls, @@ -714,14 +754,13 @@ def _interpret_ps2Ds( compterms = [k.split("-")[1] for k in ps2Ds.keys() if k.startswith("0")] ps2D_modes = {} - # Get the first kpar, kperp axes as references self._kpar = next(iter(ps2Ds.values())).kpar[:].copy() self._kperp = next(iter(ps2Ds.values())).kperp[:].copy() self._kpar.flags.writeable = False self._kperp.flags.writeable = False - print(self., self.kpar) + def _check_load_ps2D(key): # Validate the 2D power spectrum and extract the template and its variance @@ -866,8 +905,11 @@ def _combine(vec): return signal + + """ + def signal_1D(self, *, omega: float, b_HI: float, **kwargs: float) -> np.ndarray: - """Return the 1D power spectrum template, binned from 2D template. + Return the 1D power spectrum template, binned from 2D template. Parameters ---------- @@ -883,7 +925,7 @@ def signal_1D(self, *, omega: float, b_HI: float, **kwargs: float) -> np.ndarray ------- signal Signal template for the given parameters. An array of [pol, k]. - """ + _signal_2D = self.signal_2D(omega=omega, b_HI=b_HI, **kwargs) @@ -891,7 +933,7 @@ def signal_1D(self, *, omega: float, b_HI: float, **kwargs: float) -> np.ndarray for ipol in range(_signal_2D.shape[0]): - _, signal_1D[ipol], _, _ = get_1d_ps( + signal_1D[ipol]= get_1d_ps_mcmc( _signal_2D[ipol], self._kperp, self._kpar, @@ -900,9 +942,45 @@ def signal_1D(self, *, omega: float, b_HI: float, **kwargs: float) -> np.ndarray self._nbins + 1, self._logbins, ) + #print(f"\nFinal signal_1D values old:\n{signal_1D}") return signal_1D + """ + + def signal_1D(self, *, omega: float, b_HI: float, **kwargs: float) -> np.ndarray: + """Return the 1D power spectrum template with cached binning schemes.""" + + _signal_2D = self.signal_2D(omega=omega, b_HI=b_HI, **kwargs) + signal_1D = np.zeros((_signal_2D.shape[0], self._nbins)) + + # Cache binning calculations if not already cached + if self._mcmc_binning_cache is None: + self._cache_mcmc_binning() + + # Use cached binning for each polarization + for ipol in range(_signal_2D.shape[0]): + cache = self._mcmc_binning_cache[ipol] + indices = cache['indices'] + + if self._signal_mask is not None: + p1D = _signal_2D[ipol][self._signal_mask[ipol]].flatten() + w1D = self._ps2D_weight[ipol][self._signal_mask[ipol]].flatten() + else: + p1D = _signal_2D[ipol].flatten() + w1D = self._ps2D_weight[ipol].flatten() + + with np.errstate(divide="ignore", invalid="ignore"): + for i in np.arange(len(cache['kbins']) - 1) + 1: + w_b = w1D[indices == i] + p = np.nansum(w_b * p1D[indices == i]) / np.sum(w_b) + signal_1D[ipol, i-1] = p + #print(f"\nFinal signal_1D values old:\n{signal_1D}") + + return signal_1D + + + def multiply_pre_noncomp(self, signal: np.ndarray, **kwargs) -> np.ndarray: """Override in subclass to multiply signal by function pre-non-components.""" return signal