Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mcmc test #2

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 84 additions & 6 deletions signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from draco.util import tools
from draco.core.containers import FrequencyStackByPol, Powerspec1D, Powerspec2D
from draco.analysis.powerspec import get_1d_ps
from draco.analysis.powerspec import get_1d_ps, get_1d_ps_mcmc

from . import utils

Expand Down Expand Up @@ -584,6 +584,8 @@ def __init__(
nbins: int = 10,
logbins: bool = True,
):
# Add a new instance variable to store the power spectrum calculator
self._ps_calculator = None

if derivs is None:
derivs = {
Expand All @@ -595,6 +597,7 @@ def __init__(
self._aliases = aliases if aliases is not None else {}
self._nbins = nbins
self._logbins = logbins
self._mcmc_binning_cache = None
logger.debug(f"Using deriv modes: {self._derivs}")
logger.debug(f"Using aliases: {self._aliases}")
logger.debug(f"Using factor: {self._factor}")
Expand All @@ -603,6 +606,43 @@ def __init__(
f"{'log-spaced' if self._logbins else 'linearly-spaced'} bins"
)


def _cache_mcmc_binning(self):
"""Cache the binning calculations from get_1d_ps."""
# This will store all the pre-computed values needed for each polarization
cache = {}

for ipol in range(self._signal_mask.shape[0]):
kpp, kll = np.meshgrid(self._kperp, self._kpar)
k = np.sqrt(kpp**2 + kll**2)

# Apply signal window if present
if self._signal_mask is not None:
k = k[self._signal_mask[ipol]]

# Flatten arrays
k1D = k.flatten()

# Calculate bin edges
kmin = k1D[k1D > 0].min()
kmax = k1D.max()

if self._logbins:
kbins = np.logspace(np.log10(kmin), np.log10(kmax), self._nbins + 1)
else:
kbins = np.linspace(kmin, kmax, self._nbins + 1)

indices = np.digitize(k1D, kbins)

# Store everything needed for this polarization
cache[ipol] = {
'indices': indices,
'kbins': kbins
}

self._mcmc_binning_cache = cache
logger.debug("MCMC binning calculations cached")

@classmethod
def load_from_ps2Dfiles(
cls,
Expand Down Expand Up @@ -714,14 +754,13 @@ def _interpret_ps2Ds(
compterms = [k.split("-")[1] for k in ps2Ds.keys() if k.startswith("0")]

ps2D_modes = {}

# Get the first kpar, kperp axes as references
self._kpar = next(iter(ps2Ds.values())).kpar[:].copy()
self._kperp = next(iter(ps2Ds.values())).kperp[:].copy()
self._kpar.flags.writeable = False
self._kperp.flags.writeable = False

print(self., self.kpar)


def _check_load_ps2D(key):
# Validate the 2D power spectrum and extract the template and its variance
Expand Down Expand Up @@ -866,8 +905,11 @@ def _combine(vec):

return signal


"""

def signal_1D(self, *, omega: float, b_HI: float, **kwargs: float) -> np.ndarray:
"""Return the 1D power spectrum template, binned from 2D template.
Return the 1D power spectrum template, binned from 2D template.

Parameters
----------
Expand All @@ -883,15 +925,15 @@ def signal_1D(self, *, omega: float, b_HI: float, **kwargs: float) -> np.ndarray
-------
signal
Signal template for the given parameters. An array of [pol, k].
"""


_signal_2D = self.signal_2D(omega=omega, b_HI=b_HI, **kwargs)

signal_1D = np.zeros((_signal_2D.shape[0], self._nbins))

for ipol in range(_signal_2D.shape[0]):

_, signal_1D[ipol], _, _ = get_1d_ps(
signal_1D[ipol]= get_1d_ps_mcmc(
_signal_2D[ipol],
self._kperp,
self._kpar,
Expand All @@ -900,9 +942,45 @@ def signal_1D(self, *, omega: float, b_HI: float, **kwargs: float) -> np.ndarray
self._nbins + 1,
self._logbins,
)
#print(f"\nFinal signal_1D values old:\n{signal_1D}")

return signal_1D

"""

def signal_1D(self, *, omega: float, b_HI: float, **kwargs: float) -> np.ndarray:
"""Return the 1D power spectrum template with cached binning schemes."""

_signal_2D = self.signal_2D(omega=omega, b_HI=b_HI, **kwargs)
signal_1D = np.zeros((_signal_2D.shape[0], self._nbins))

# Cache binning calculations if not already cached
if self._mcmc_binning_cache is None:
self._cache_mcmc_binning()

# Use cached binning for each polarization
for ipol in range(_signal_2D.shape[0]):
cache = self._mcmc_binning_cache[ipol]
indices = cache['indices']

if self._signal_mask is not None:
p1D = _signal_2D[ipol][self._signal_mask[ipol]].flatten()
w1D = self._ps2D_weight[ipol][self._signal_mask[ipol]].flatten()
else:
p1D = _signal_2D[ipol].flatten()
w1D = self._ps2D_weight[ipol].flatten()

with np.errstate(divide="ignore", invalid="ignore"):
for i in np.arange(len(cache['kbins']) - 1) + 1:
w_b = w1D[indices == i]
p = np.nansum(w_b * p1D[indices == i]) / np.sum(w_b)
signal_1D[ipol, i-1] = p
#print(f"\nFinal signal_1D values old:\n{signal_1D}")

return signal_1D



def multiply_pre_noncomp(self, signal: np.ndarray, **kwargs) -> np.ndarray:
"""Override in subclass to multiply signal by function pre-non-components."""
return signal
Expand Down