Skip to content

Commit

Permalink
various minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
daviesje committed Mar 22, 2024
1 parent 4561c5b commit 48ca369
Show file tree
Hide file tree
Showing 7 changed files with 331 additions and 190 deletions.
339 changes: 269 additions & 70 deletions docs/tutorials/halosampler.ipynb

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions src/py21cmfast/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,7 +839,7 @@ def USE_TS_FLUCT(self):
def PHOTON_CONS_TYPE(self):
"""Automatically setting PHOTON_CONS to False if USE_MINI_HALOS."""
if (self.USE_MINI_HALOS or self.USE_HALO_FIELD) and self._PHOTON_CONS_TYPE == 1:
logger.warn(
warnings.warn(
"USE_MINI_HALOS and USE_HALO_FIELD are not compatible with the redshift-based"
" photon conservation corrections (PHOTON_CONS_TYPE==1)! "
" Automatically setting PHOTON_CONS_TYPE to zero."
Expand All @@ -854,7 +854,7 @@ def PHOTON_CONS_TYPE(self):
def HALO_STOCHASTICITY(self):
"""Automatically setting HALO_STOCHASTICITY to False if not USE_HALO_FIELD."""
if not self.USE_HALO_FIELD and self._HALO_STOCHASTICITY:
logger.warning(
warnings.warn(
"HALO_STOCHASTICITY must be used with USE_HALO_FIELD"
"Turning off Stochastic Halos..."
)
Expand All @@ -866,7 +866,7 @@ def HALO_STOCHASTICITY(self):
def CELL_RECOMB(self):
"""Automatically setting CELL_RECOMB if USE_EXP_FILTER is active."""
if self.USE_EXP_FILTER:
logger.warning(
warnings.warn(
"CELL_RECOMB is automatically set to True if USE_EXP_FILTER is True."
)
return True
Expand Down
5 changes: 5 additions & 0 deletions src/py21cmfast/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1436,6 +1436,8 @@ def _write_particulars(self, fname):

f["node_redshifts"] = self.node_redshifts
f["distances"] = self.lightcone_distances
f["log10_mturnovers"] = self.log10_mturnovers
f["log10_mturnovers_mini"] = self.log10_mturnovers_mini

def make_checkpoint(self, fname, index: int, redshift: float):
"""Write updated lightcone data to file."""
Expand Down Expand Up @@ -1491,6 +1493,9 @@ def _read_particular(cls, fname):
kwargs["node_redshifts"] = fl["node_redshifts"][...]
kwargs["distances"] = fl["distances"][...]

kwargs["log10_mturnovers"] = fl["log10_mturnovers"][...]
kwargs["log10_mturnovers_mini"] = fl["log10_mturnovers_mini"][...]

return kwargs

def __eq__(self, other):
Expand Down
1 change: 1 addition & 0 deletions src/py21cmfast/src/21cmFAST.h
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ double Nion_General(double z, double lnM_Min, double lnM_Max, double MassTurnove
double Fesc10, double Mlim_Fstar, double Mlim_Fesc);
double Nion_General_MINI(double z, double lnM_Min, double lnM_Max, double MassTurnover, double MassTurnover_upper, double Alpha_star,
double Alpha_esc, double Fstar7_MINI, double Fesc7_MINI, double Mlim_Fstar, double Mlim_Fesc);
double FgtrM_General(double z, double M);
double unconditional_mf(double growthf, double lnM, double z, int HMF);
double conditional_mf(double growthf, double lnM, double delta_cond, double sigma_cond, int HMF);
double atomic_cooling_threshold(float z);
10 changes: 5 additions & 5 deletions src/py21cmfast/src/SpinTemperatureBox.c
Original file line number Diff line number Diff line change
Expand Up @@ -620,12 +620,12 @@ int UpdateXraySourceBox(struct UserParams *user_params, struct CosmoParams *cosm
double sfr_avg = 0;
double sfr_avg_mini = 0;

#pragma omp parallel private(i,j,k) num_threads(user_params->N_THREADS)
#pragma omp parallel private(i,j,k) num_threads(user_params->N_THREADS) reduction(+:sfr_avg,sfr_avg_mini)
{
#pragma omp for
for (i=0; i<user_params->HII_DIM; i++){
for (j=0; j<user_params->HII_DIM; j++){
for (k=0; k<user_params->HII_DIM; k++){
for (k=0; k<HII_D_PARA; k++){
*((float *)unfiltered_box + HII_R_FFT_INDEX(i,j,k)) = halobox->halo_sfr[HII_R_INDEX(i,j,k)];
*((float *)unfiltered_box_mini + HII_R_FFT_INDEX(i,j,k)) = halobox->halo_sfr_mini[HII_R_INDEX(i,j,k)];
sfr_avg += halobox->halo_sfr[HII_R_INDEX(i,j,k)];
Expand Down Expand Up @@ -669,13 +669,13 @@ int UpdateXraySourceBox(struct UserParams *user_params, struct CosmoParams *cosm
dft_c2r_cube(user_params->USE_FFTW_WISDOM, user_params->HII_DIM, HII_D_PARA, user_params->N_THREADS, filtered_box_mini);

// copy over the values
#pragma omp parallel private(i,j,k) num_threads(user_params->N_THREADS) reduction(+:fsfr_avg)
#pragma omp parallel private(i,j,k) num_threads(user_params->N_THREADS) reduction(+:fsfr_avg,fsfr_avg_mini)
{
float curr,curr_mini;
#pragma omp for
for (i=0;i<user_params->HII_DIM; i++){
for (j=0;j<user_params->HII_DIM; j++){
for (k=0;k<user_params->HII_DIM; k++){
for (k=0;k<HII_D_PARA; k++){
curr = *((float *)filtered_box + HII_R_FFT_INDEX(i,j,k));
curr_mini = *((float *)filtered_box_mini + HII_R_FFT_INDEX(i,j,k));
// correct for aliasing in the filtering step
Expand All @@ -700,7 +700,7 @@ int UpdateXraySourceBox(struct UserParams *user_params, struct CosmoParams *cosm
if(R_ct == global_params.NUM_FILTER_STEPS_FOR_Ts - 1){
LOG_DEBUG("finished XraySourceBox");
}
LOG_SUPER_DEBUG("R = %8.3f | mean sfr = %10.3e (%10.3e MINI) Unfiltered %10.3e (%10.3e MINI) mean log10McritLW %.4e",
LOG_SUPER_DEBUG("R = %8.3f | mean filtered sfr = %10.3e (%10.3e MINI) unfiltered %10.3e (%10.3e MINI) mean log10McritLW %.4e",
R_outer,fsfr_avg,fsfr_avg_mini,sfr_avg,sfr_avg_mini,source_box->mean_log10_Mcrit_LW[R_ct]);

fftwf_free(filtered_box);
Expand Down
33 changes: 17 additions & 16 deletions src/py21cmfast/src/ps.c
Original file line number Diff line number Diff line change
Expand Up @@ -1394,10 +1394,13 @@ double MFIntegral_Approx(double lnM_lo, double lnM_hi, struct parameters_gsl_MF_
double sigma_hi_limit = EvaluateSigma(lnM_hi_limit);

//These nu use the CMF delta (subtracted the condition delta), but not the condition sigma
double nu_pivot1 = delta_arg / (sigma_pivot1*sigma_pivot1);
double nu_pivot2 = delta_arg / (sigma_pivot2*sigma_pivot2);
double nu_pivot1_umf = delta_arg / (sigma_pivot1*sigma_pivot1);
double nu_pivot2_umf = delta_arg / (sigma_pivot2*sigma_pivot2);
double nu_condition = delta_arg / (sigma_c*sigma_c);

double nu_pivot1 = delta_arg / (sigma_pivot1*sigma_pivot1 - sigma_c*sigma_c);
double nu_pivot2 = delta_arg / (sigma_pivot2*sigma_pivot2 - sigma_c*sigma_c);

//These nu subtract the condition sigma as in the CMF
double nu_lo_limit = delta_arg / (sigma_lo_limit*sigma_lo_limit - sigma_c*sigma_c);
double nu_hi_limit = delta_arg / (sigma_hi_limit*sigma_hi_limit - sigma_c*sigma_c);
Expand All @@ -1411,38 +1414,36 @@ double MFIntegral_Approx(double lnM_lo, double lnM_hi, struct parameters_gsl_MF_
if(fabs(type) == 4){
// re-written for further speedups
if (nu_hi_limit <= nu_pivot2){ //if both are below pivot2 don't bother adding and subtracting the high contribution
fcoll += (Fcollapprox(nu_lo_limit,beta3))*pow(nu_pivot2,-beta3);
fcoll -= (Fcollapprox(nu_hi_limit,beta3))*pow(nu_pivot2,-beta3);
fcoll += (Fcollapprox(nu_lo_limit,beta3))*pow(nu_pivot2_umf,-beta3);
fcoll -= (Fcollapprox(nu_hi_limit,beta3))*pow(nu_pivot2_umf,-beta3);
}
else {
fcoll -= (Fcollapprox(nu_hi_limit,beta2))*pow(nu_pivot1,-beta2);
fcoll -= (Fcollapprox(nu_hi_limit,beta2))*pow(nu_pivot1_umf,-beta2);
if (nu_lo_limit > nu_pivot2){
fcoll += (Fcollapprox(nu_lo_limit,beta2))*pow(nu_pivot1,-beta2);
fcoll += (Fcollapprox(nu_lo_limit,beta2))*pow(nu_pivot1_umf,-beta2);
}
else {
fcoll += (Fcollapprox(nu_pivot2,beta2))*pow(nu_pivot1,-beta2);
fcoll += (Fcollapprox(nu_lo_limit,beta3)-Fcollapprox(nu_pivot2,beta3) )*pow(nu_pivot2,-beta3);
fcoll += (Fcollapprox(nu_pivot2,beta2))*pow(nu_pivot1_umf,-beta2);
fcoll += (Fcollapprox(nu_lo_limit,beta3)-Fcollapprox(nu_pivot2,beta3) )*pow(nu_pivot2_umf,-beta3);
}
}
}
else{
if(nu_lo_limit >= nu_condition){ //fully in the flat part of sigma(nu), M^alpha is nu-independent.
return 1e-40;
}
else{ //we subtract the contribution from high nu, since the HMF is set to 0 if sigma2>sigma1
fcoll -= Fcollapprox(nu_condition,beta1)*pow(nu_pivot1,-beta1);
}

if(nu_lo_limit >= nu_pivot1){
fcoll += Fcollapprox(nu_lo_limit,beta1)*pow(nu_pivot1,-beta1);
fcoll += Fcollapprox(nu_lo_limit,beta1)*pow(nu_pivot1_umf,-beta1);
}
else{
fcoll += Fcollapprox(nu_pivot1,beta1)*pow(nu_pivot1,-beta1);
fcoll += Fcollapprox(nu_pivot1,beta1)*pow(nu_pivot1_umf,-beta1);
if (nu_lo_limit > nu_pivot2){
fcoll += (Fcollapprox(nu_lo_limit,beta2)-Fcollapprox(nu_pivot1,beta2))*pow(nu_pivot1,-beta2);
fcoll += (Fcollapprox(nu_lo_limit,beta2)-Fcollapprox(nu_pivot1,beta2))*pow(nu_pivot1_umf,-beta2);
}
else {
fcoll += (Fcollapprox(nu_pivot2,beta2)-Fcollapprox(nu_pivot1,beta2) )*pow(nu_pivot1,-beta2);
fcoll += (Fcollapprox(nu_lo_limit,beta3)-Fcollapprox(nu_pivot2,beta3) )*pow(nu_pivot2,-beta3);
fcoll += (Fcollapprox(nu_pivot2,beta2)-Fcollapprox(nu_pivot1,beta2) )*pow(nu_pivot1_umf,-beta2);
fcoll += (Fcollapprox(nu_lo_limit,beta3)-Fcollapprox(nu_pivot2,beta3) )*pow(nu_pivot2_umf,-beta3);
}
}
}
Expand Down
127 changes: 31 additions & 96 deletions src/py21cmfast/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1786,12 +1786,13 @@ def xray_source(

# call the box the initialize the memory, since I give some values before computing
box()

final_box_computed = False
for i in range(global_params.NUM_FILTER_STEPS_FOR_Ts):
R_inner = R_range[i - 1].to("Mpc").value if i > 0 else 0
R_outer = R_range[i].to("Mpc").value

if zpp_avg[i] >= z_max:
logger.debug(f"ignoring Radius {i} which is above Z_HEAT_MAX")
box.filtered_sfr[i, ...] = 0
continue

Expand Down Expand Up @@ -1824,12 +1825,21 @@ def xray_source(
R_ct=i,
hooks=hooks_in,
)
if i == global_params.NUM_FILTER_STEPS_FOR_Ts - 1:
final_box_computed = True

# HACK: sometimes we don't compute (if the first zpp > z_max or there are no halos)
# HACK: sometimes we don't compute on the last step
# (if the first zpp > z_max or there are no halos at max R)
# in which case the array is not marked as computed
for k, state in box._array_state.items():
if state.initialized:
state.computed_in_mem = True
if not final_box_computed:
# we need to pass the memory to C, mark it as computed and call the hooks
box()

for k, state in box._array_state.items():
if state.initialized:
state.computed_in_mem = True

box._call_hooks(hooks)

return box

Expand Down Expand Up @@ -2150,54 +2160,6 @@ def ionize_box(
"Automatic generation of halo boxes not yet implemented, \
Use run_coeval, run_lightcone or explicitly generate the box"
)
if not flag_options.FIXED_HALO_GRIDS:
# determine_halo_list will generate the descendant fields
# perturb and box only generate the current redshift
halo_field = determine_halo_list(
redshift=redshift,
init_boxes=init_boxes,
cosmo_params=cosmo_params,
user_params=user_params,
astro_params=astro_params,
flag_options=flag_options,
regenerate=regen_halos,
hooks=hooks,
direc=direc,
)
pt_halos = perturb_halo_list(
redshift=redshift,
init_boxes=init_boxes,
cosmo_params=cosmo_params,
user_params=user_params,
astro_params=astro_params,
flag_options=flag_options,
halo_field=halo_field,
regenerate=regenerate,
hooks=hooks,
direc=direc,
)
else:
pt_halos = PerturbHaloField(
redshift=0,
user_params=user_params,
cosmo_params=cosmo_params,
astro_params=astro_params,
flag_options=flag_options,
dummy=True,
)

halobox = halo_box(
redshift=redshift,
init_boxes=init_boxes,
astro_params=astro_params,
flag_options=flag_options,
cosmo_params=cosmo_params,
user_params=user_params,
regenerate=regenerate,
pt_halos=pt_halos,
perturbed_field=perturbed_field,
previous_ionize_box=previous_ionize_box,
)

# Set empty spin temp box if necessary.
if not flag_options.USE_TS_FLUCT:
Expand Down Expand Up @@ -2646,7 +2608,6 @@ def run_coeval(
init_box=None,
perturb=None,
use_interp_perturb_field=False,
halobox=None,
random_seed=None,
cleanup=True,
hooks=None,
Expand Down Expand Up @@ -2689,9 +2650,6 @@ def run_coeval(
to determine all spin temperature fields. If so, this field is interpolated in
the underlying C-code to the correct redshift. This is less accurate (and no more
efficient), but provides compatibility with older versions of 21cmFAST.
halobox : list of :class: `~HaloBox`, optional
If given, must be compatible with init_box. It will merely negate the necessity
of re-calculating the halo fields.
cleanup : bool, optional
A flag to specify whether the C routine cleans up its memory before returning.
Typically, if `spin_temperature` is called directly, you will want this to be
Expand Down Expand Up @@ -2730,19 +2688,6 @@ def run_coeval(
perturb = [perturb]
singleton = True

# Ensure perturbed halo field is a list of boxes, not just one.
if flag_options is None or halobox is None:
halobox = []

elif (
flag_options["USE_HALO_FIELD"]
if isinstance(flag_options, dict)
else flag_options.USE_HALO_FIELD
):
halobox = [halobox] if not hasattr(halobox, "__len__") else []
else:
halobox = []

(
random_seed,
user_params,
Expand Down Expand Up @@ -2787,13 +2732,6 @@ def run_coeval(
else:
redshift = [p.redshift for p in perturb]

if (
flag_options.USE_HALO_FIELD
and halobox
and any(p.redshift != z for p, z in zip(halobox, redshift))
):
raise ValueError("Input redshifts do not match the halo field redshifts")

kw = {
**{
"astro_params": astro_params,
Expand Down Expand Up @@ -2846,9 +2784,7 @@ def run_coeval(
pass

# get the halos (reverse redshift order)
generate_halobox = False
if flag_options.USE_HALO_FIELD and not halobox:
generate_halobox = True
if flag_options.USE_HALO_FIELD:
halos_desc = None

pt_halos = []
Expand Down Expand Up @@ -2913,22 +2849,18 @@ def run_coeval(
pf2.load_all()

if flag_options.USE_HALO_FIELD:
if generate_halobox:
ph2 = pt_halos[iz]
hb2 = halo_box(
redshift=z,
pt_halos=ph2,
perturbed_field=pf2,
previous_ionize_box=ib,
previous_spin_temp=st,
**kw,
)
# append the halo redshift array so we have all halo boxes [z,zmax]
z_halos += [z]
hbox_arr += [hb2]
# if haloboxes have been provided with correct redshifts...
else:
z_halos = redshifts
ph2 = pt_halos[iz]
hb2 = halo_box(
redshift=z,
pt_halos=ph2,
perturbed_field=pf2,
previous_ionize_box=ib,
previous_spin_temp=st,
**kw,
)
# append the halo redshift array so we have all halo boxes [z,zmax]
z_halos += [z]
hbox_arr += [hb2]

if flag_options.USE_TS_FLUCT:
if flag_options.USE_HALO_FIELD:
Expand Down Expand Up @@ -3675,6 +3607,9 @@ def run_lightcone(
"pt_halos": pth_files,
}

lightcone.log10_mturnovers = log10_mturnovers
lightcone.log10_mturnovers_mini = log10_mturnovers_mini

if coeval_callback is None:
return lightcone
else:
Expand Down

0 comments on commit 48ca369

Please sign in to comment.