Skip to content

Commit

Permalink
Merge pull request #1 from echaussidon/blinding_edmond
Browse files Browse the repository at this point in the history
Blinding edmond
  • Loading branch information
echaussidon authored Apr 19, 2023
2 parents a0ad0c2 + d748e2f commit fa809e9
Showing 1 changed file with 126 additions and 98 deletions.
224 changes: 126 additions & 98 deletions scripts/main/apply_blinding_main_fromfile_fcomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import sys
import os
import logging
import shutil
import unittest
from datetime import datetime
Expand Down Expand Up @@ -59,6 +60,18 @@
sys.exit('NERSC_HOST not known (code only works on NERSC), not proceeding')


try:
mpicomm = pyrecon.mpi.COMM_WORLD # MPI version
except AttributeError:
mpicomm = None # non-MPI version
sys.exit('The following script need to be run with the MPI version of pyrecon. Please use module swap pyrecon:mpi')
root = mpicomm.rank == 0


# to remove jax warning (from cosmoprimo)
logging.getLogger("jax._src.lib.xla_bridge").addFilter(logging.Filter("No GPU/TPU found, falling back to CPU."))


parser = argparse.ArgumentParser()
parser.add_argument("--type", help="tracer type to be selected")
parser.add_argument("--basedir_in", help="base directory for input, default is location for official catalogs",default='/global/cfs/cdirs/desi/survey/catalogs/')
Expand All @@ -80,110 +93,72 @@
parser.add_argument("--maxr", help="maximum for random files, default is 1",default=1,type=int) #use 2 for abacus mocks
parser.add_argument("--dorecon",help="if y, run the recon needed for RSD blinding",default='n')
parser.add_argument("--rsdblind",help="if y, do the RSD blinding shift",default='n')
parser.add_argument("--fnlblind",help="if y, do the fnl blinding",default='n')

parser.add_argument("--fiducial_f",help="fiducial value for f",default=0.8)

parser.add_argument("--visnz",help="whether to look at the original, blinded, and weighted n(z)",default='n')


#parser.add_argument("--fix_monopole",help="whether to choose f such that the amplitude of the monopole is fixed",default='y')


args = parser.parse_args()

try:
mpicomm = pyrecon.mpi.COMM_WORLD # MPI version
except AttributeError:
mpicomm = None # non-MPI version
root = mpicomm is None or mpicomm.rank == 0


if root:
print(args)
if root: print(args)

type = args.type
version = args.version
specrel = args.verspec

notqso = ''
if args.notqso == 'y':
notqso = 'notqso'

if root:
print('blinding catalogs for tracer type '+type+notqso)


if type[:3] == 'BGS' or type == 'bright' or type == 'MWS_ANY':
prog = 'BRIGHT'

else:
prog = 'DARK'
notqso = 'notqso' if (args.notqso == 'y') else ''
if root: print('blinding catalogs for tracer type ' + type + notqso)

prog = 'BRIGHT' if (type[:3] == 'BGS' or type == 'bright' or type == 'MWS_ANY') else 'DARK'
progl = prog.lower()

mainp = main(args.type)
zmin = mainp.zmin
zmax = mainp.zmax
tsnrcol = mainp.tsnrcol


#share basedir location '/global/cfs/cdirs/desi/survey/catalogs'
if 'mock' not in args.verspec:
maindir = args.basedir_in +'/'+args.survey+'/LSS/'

ldirspec = maindir+specrel+'/'

dirin = ldirspec+'LSScats/'+version+'/'
LSSdir = ldirspec+'LSScats/'
maindir = args.basedir_in + '/' + args.survey + '/LSS/'
ldirspec = maindir + specrel + '/'
dirin = ldirspec + 'LSScats/' + version + '/'
tsnrcut = mainp.tsnrcut
dchi2 = mainp.dchi2
randens = 2500.
nzmd = 'data'
elif 'Y1/mock' in args.verspec: #e.g., use 'mocks/FirstGenMocks/AbacusSummit/Y1/mock1' to get the 1st mock with fiberassign
dirin = args.basedir_in +'/'+args.survey+'/'+args.verspec+'/LSScats/'+version+'/'
LSSdir = args.basedir_in +'/'+args.survey+'/'+args.verspec+'/LSScats/'
dirin = args.basedir_in + '/' + args.survey + '/' + args.verspec + '/LSScats/' + version + '/'
dchi2=None
tsnrcut=0
randens = 10460.
nzmd = 'mock'

else:
sys.exit('verspec '+args.verspec+' not supported')


dirout = args.basedir_out+'/LSScats/'+version+'/blinded/'

sys.exit('verspec ' + args.verspec + ' not supported')

dirout = args.basedir_out + '/LSScats/' + version + '/blinded/'

if root and (not os.path.exists(dirout)):
os.makedirs(dirout)
print('made '+dirout)

tp2z = {'LRG':0.8,'ELG':1.1,'QSO':1.6}
tp2bias = {'LRG':2.,'ELG':1.3,'QSO':2.3}
ztp = tp2z[args.type]
bias = tp2bias[args.type]


if root:
if not os.path.exists(dirout):
os.makedirs(dirout)
print('made '+dirout)

ztp = tp2z[args.type]
bias = tp2bias[args.type]

w0wa = np.loadtxt('/global/cfs/cdirs/desi/survey/catalogs/Y1/LSS/w0wa_initvalues_zeffcombined_1000realisations.txt')

if args.get_par_mode == 'random':
#if args.type != 'LRG':
# sys.exit('Only do LRG in random mode, read from LRG file for other tracers')
if args.type != 'LRG':
sys.exit('Only do LRG in random mode, read from LRG file for other tracers')
ind = int(random()*1000)
[w0_blind,wa_blind] = w0wa[ind]

if args.get_par_mode == 'from_file' and root:
fn = LSSdir + 'filerow.txt'
if not os.path.isfile(fn):
ind_samp = int(random()*1000)
fo = open(fn,'w')
fo.write(str(ind_samp)+'\n')
fo.close()
ind = int(np.loadtxt(fn))
if args.get_par_mode == 'from_file':
hd = fitsio.read_header(dirout+ 'LRG_full.dat.fits',ext='LSS')
ind = hd['FILEROW']
[w0_blind,wa_blind] = w0wa[ind]

#choose f_shift to compensate shift in monopole amplitude
Expand All @@ -196,32 +171,26 @@
DM_shift = cosmo_shift.comoving_angular_distance(ztp)
DH_shift = 1./cosmo_shift.hubble_function(ztp)


vol_fac = (DM_shift**2*DH_shift)/(DM_fid**2*DH_fid)
vol_fac = (DM_shift**2 * DH_shift) / (DM_fid**2 * DH_fid)

#a, b, c for quadratic formula
a = 0.2/bias**2.
b = 2/(3*bias)
c = 1-(1+0.2*(args.fiducial_f/bias)**2.+2/3*args.fiducial_f/bias)/vol_fac

f_shift = (-b+np.sqrt(b**2.-4.*a*c))/(2*a)

dfper = (f_shift-args.fiducial_f)/args.fiducial_f
a = 0.2 / bias**2
b = 2 / (3 * bias)
c = 1 - (1 + 0.2 * (args.fiducial_f / bias)**2. + 2/3 * args.fiducial_f / bias) / vol_fac

f_shift = (-b + np.sqrt(b**2. - 4.*a*c))/(2*a)
dfper = (f_shift - args.fiducial_f)/args.fiducial_f
maxfper = 0.1
if abs(dfper) > maxfper:
dfper = maxfper*dfper/abs(dfper)
f_shift = (1+dfper)*args.fiducial_f

fgrowth_blind = f_shift


#if args.reg_md == 'NS':
regl = ['_S','_N']
#if args.reg_md == 'GC':
gcl = ['_SGC','_NGC']


fb_in = dirin+type+notqso
fcr_in = fb_in+'_1_full.ran.fits'
fcd_in = fb_in+'_full.dat.fits'
Expand All @@ -238,7 +207,7 @@
dz = 0.01
#zmin = 0.01
#zmax = 1.6

if type[:3] == 'LRG':
P0 = 10000
#zmin = 0.4
Expand Down Expand Up @@ -286,7 +255,8 @@
fd['WEIGHT_SYS'] *= wl
common.write_LSS(fd,fcd_out)

if args.visnz == 'y':

if nzmd == 'mock':
print('min/max of weights for nz:')
print(np.min(wl),np.max(wl))
fdin = fitsio.read(fcd_in)
Expand All @@ -295,8 +265,7 @@
c = plt.hist(fd['Z'][gz],bins=100,range=(zmin,zmax),histtype='step',weights=fd['WEIGHT_SYS'][gz],label='blinded+reweight')
plt.legend()
plt.show()




if args.type == 'LRG':
hdul = fits.open(fcd_out,mode='update')
Expand All @@ -305,12 +274,10 @@
hdtest = fitsio.read_header(dirout+ 'LRG_full.dat.fits', ext='LSS')['FILEROW']
if hdtest != ind:
sys.exit('ERROR writing/reading row from blind file')




if args.mkclusdat == 'y':
ct.mkclusdat(dirout+type+notqso,tp=type,dchi2=dchi2,tsnrcut=tsnrcut,zmin=zmin,zmax=zmax)
ct.mkclusdat(dirout + type + notqso, tp=type, dchi2=dchi2, tsnrcut=tsnrcut, zmin=zmin, zmax=zmax)


if args.mkclusran == 'y':
Expand All @@ -327,27 +294,24 @@
ranfm = dirout+args.type+notqso+reg+'_'+str(rannum-1)+'_clustering.ran.fits'
os.system('mv '+ranf+' '+ranfm)

reg_md = args.reg_md
if args.split_GC == 'y':
fb = dirout+args.type+notqso+'_'
ct.clusNStoGC(fb,args.maxr-args.minr)

if args.split_GC == 'y' and root:
fb = dirout+args.type+notqso+'_'
ct.clusNStoGC(fb,args.maxr-args.minr)

sys.stdout.flush()
sys.stdout.flush()

if args.dorecon == 'y':
nran = args.maxr-args.minr

nran = args.maxr - args.minr

if root: print('on est la')

distance = TabulatedDESI().comoving_radial_distance

f, bias = rectools.get_f_bias(args.type)
from pyrecon import MultiGridReconstruction
Reconstruction = MultiGridReconstruction

setup_logging()
Reconstruction = MultiGridReconstruction


if reg_md == 'NS':
if args.reg_md == 'NS':
regions = ['N','S']
else:
regions = ['NGC','SGC']
Expand All @@ -358,22 +322,86 @@
randoms_fn = catalog_fn(**catalog_kwargs, cat_dir=dirout, name='randoms')
data_rec_fn = catalog_fn(**catalog_kwargs, cat_dir=dirout, rec_type='MGrsd', name='data')
randoms_rec_fn = catalog_fn(**catalog_kwargs, cat_dir=dirout, rec_type='MGrsd', name='randoms')
rectools.run_reconstruction(Reconstruction, distance, data_fn, randoms_fn, data_rec_fn, randoms_rec_fn, f=f, bias=bias, convention='rsd', dtype='f8', zlim=(zmin, zmax),mpicomm=mpicomm)
rectools.run_reconstruction(Reconstruction, distance, data_fn, randoms_fn, data_rec_fn, randoms_rec_fn, f=f, bias=bias, convention='rsd', dtype='f8', zlim=(zmin, zmax), mpicomm=mpicomm)

if root and (args.rsdblind == 'y'):

if args.rsdblind == 'y' and root:
if reg_md == 'NS':
if root: print('on est ici')

if args.reg_md == 'NS':
cl = regl
if reg_md == 'GC':
if args.reg_md == 'GC':
cl = gcl
for reg in cl:
fnd = dirout+type+notqso+reg+'_clustering.dat.fits'
fndr = dirout+type+notqso+reg+'_clustering.MGrsd.dat.fits'
fnd = dirout + type + notqso + reg + '_clustering.dat.fits'
fndr = dirout + type + notqso + reg + '_clustering.MGrsd.dat.fits'
data = Table(fitsio.read(fnd))
data_real = Table(fitsio.read(fndr))

out_file = fnd
blind.apply_zshift_RSD(data,data_real,out_file,
blind.apply_zshift_RSD(data, data_real, out_file,
fgrowth_fid=args.fiducial_f,
fgrowth_blind=fgrowth_blind)#,
#comments=f"f_blind: {fgrowth_blind}, w0_blind: {w0_blind}, wa_blind: {wa_blind}")

if args.fnlblind == 'y':
from mockfactory.blinding import get_cosmo_blind, CutskyCatalogBlinding

if root: print('on est ici')

if root:
f_blind = fgrowth_blind
# generate blinding value from the choosen index above
np.random.seed(ind)
fnl_blind = np.random.uniform(low=-15, high=15, size=1)[0]
if not root:
w0_blind, wa_blind, f_blind, fnl_blind = None, None, None, None
w0_blind = mpicomm.bcast(w0_blind, root=0)
wa_blind = mpicomm.bcast(wa_blind, root=0)
f_blind = mpicomm.bcast(f_blind, root=0)
fnl_blind = mpicomm.bcast(fnl_blind, root=0)

# collect effective redshift and bias for the considered tracer
zeff = tp2z[args.type]
bias = tp2bias[args.type]

# build blinding cosmology
cosmo_blind = get_cosmo_blind('DESI', z=zeff)
cosmo_blind.params['w0_fld'] = w0_blind
cosmo_blind.params['wa_fld'] = wa_blind
cosmo_blind._derived['f'] = f_blind
cosmo_blind._derived['fnl'] = fnl_blind # on fixe la valeur pour de bon
blinding = CutskyCatalogBlinding(cosmo_fid='DESI', cosmo_blind=cosmo_blind, bias=bias, z=zeff, position_type='rdz', mpicomm=mpicomm, mpiroot=0)

# loop over the different region of the sky
if args.reg_md == 'NS':
cl = regl
if args.reg_md == 'GC':
cl = gcl
for reg in cl:
# path of data and randoms:
catalog_kwargs = dict(tracer=args.type, region=region, ctype='clustering', nrandoms=nran)
data_fn = catalog_fn(**catalog_kwargs, cat_dir=dirout, name='data')
randoms_fn = catalog_fn(**catalog_kwargs, cat_dir=dirout, name='randoms')
if np.ndim(randoms_fn) == 0: randoms_fn = [randoms_fn]

data_positions, data_weights = None, None
randoms_positions, randoms_weights = None, None
if root:
print('Loading {}.'.format(data_fn))
data = Table.read(data_fn)
data_positions, data_weights = [data['RA'], data['DEC'], data['Z']], data['WEIGHT']

print('Loading {}'.format(randoms_fn))
randoms = vstack([Table.read(fn) for fn in randoms_fn])
randoms_positions, randoms_weights = [randoms['RA'], randoms['DEC'], randoms['Z']], randoms['WEIGHT']

# add fnl blinding weight to the data weight
new_data_weights = blinding.png(data_positions, data_weights=data_weights,
randoms_positions=randoms_positions, randoms_weights=randoms_weights,
method='data_weights', shotnoise_correction=True)

# overwrite the data!
if root:
data['WEIGHT'] = new_data_weights
common.write_LSS(data, data_fn)

0 comments on commit fa809e9

Please sign in to comment.