Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge Juno mcmc #142

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions examples/2024-CLi-juno-mwr/inspect_emcee_chain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# import emcee
import corner
# %config InlineBackend.figure_format = "retina"
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
# reader = emcee.backends.HDFBackend("run_mcmc.h5")

# tau = reader.get_autocorr_time()
# burnin = int(2 * np.max(tau))
# thin = int(0.5 * np.min(tau))
# samples = reader.get_chain(discard=burnin, flat=True, thin=thin)
# log_prob_samples = reader.get_log_prob(discard=burnin, flat=True, thin=thin)

import h5py

h5=h5py.File('run_mcmc_5000.h5', 'r')
chain = h5['mcmc']['chain'][180:]
h5.close()

flattened_chain = chain.reshape(-1, 2)

# Create labels for the parameters
labels = ["qNH3 [ppmv]","Temperature [K]"]

# Create the corner plot
fig, ax = plt.subplots(2, 2, figsize=(10, 10))
corner.hist2d(flattened_chain[:, 0], flattened_chain[:, 1], ax=ax[1, 0])
ax[1, 0].set_xlabel("qNH3 [ppmv]")
ax[1, 0].set_ylabel("Temperature [K]")
ax[1, 0].set_xlim([0,700])

# Plot histograms for each parameter
x = np.linspace(1, 700, 700)
for i in range(2):
ax[i, i].hist(flattened_chain[:, i], bins=30, color='blue', alpha=0.7, density=True)
ax[i, i].set_xlabel(labels[i])
ax[i, i].set_ylabel('PDF')
means=np.mean(flattened_chain[:, i])
stdev=np.std(flattened_chain[:, i])
ax[i, i].axvline(means, color='b', linestyle='--', label=f'posterior: ({means:6.2f}, {stdev:5.2f}')

# Plot prior distribution
mean, stddev = [(300, 100), (169, 10)][i]

prior = norm.pdf(x, mean, stddev)
ax[i, i].plot(x, prior, color='red', linestyle='--', label=f'Prior: norm({mean}, {stddev})')
ax[i, i].legend()
x = np.linspace(131, 200, 70)

ax[0, 0].set_xlim([0,700])

ax[0, 1].axis('off') # Disable the axis
ax[0, 1].set_visible(False) # Hide the axis
# Show the plot
plt.tight_layout()
# Show the plot
plt.savefig("emcee_cornerplot_5000.png")
41 changes: 41 additions & 0 deletions examples/2024-CLi-juno-mwr/inspect_emcee_step.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# import emcee

# %config InlineBackend.figure_format = "retina"
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
# reader = emcee.backends.HDFBackend("run_mcmc.h5")

# tau = reader.get_autocorr_time()
# burnin = int(2 * np.max(tau))
# thin = int(0.5 * np.min(tau))
# samples = reader.get_chain(discard=burnin, flat=True, thin=thin)
# log_prob_samples = reader.get_log_prob(discard=burnin, flat=True, thin=thin)

import h5py

h5=h5py.File('run_mcmc_5000.h5', 'r')
chain = h5['mcmc']['chain'][:]
h5.close()

flattened_chain = chain.reshape(-1, 2)

# Create labels for the parameters
labels = ["qNH3 [ppmv]","Temperature [K]"]

# Create the corner plot
fig, ax = plt.subplots(2,1, figsize=(17, 10))
for iw in range(5):
ax[0].plot(range(5000),chain[:,iw,0])

ax[0].set_ylabel("qNH3 [ppmv]")
ax[0].set_xlim([0,5000])

for iw in range(5):
ax[1].plot(range(5000),chain[:,iw,1])
ax[1].set_ylabel("Temperature [K]")
ax[1].set_xlim([0,5000])
ax[1].set_xlabel("step")
plt.tight_layout()
# Show the plot
plt.savefig("emcee_inspect_step_5000.png")
2 changes: 1 addition & 1 deletion examples/2024-CLi-juno-mwr/juno_mwr.inp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ PressureScaleHeight = 30.E3

<hydro>
gamma = 1.42 # gamma = C_p/C_v
grav_acc1 = -23.3
grav_acc1 = -27.01 #-23.3
sfloor = 0.

<species>
Expand Down
24 changes: 0 additions & 24 deletions examples/2024-CLi-juno-mwr/jupiter_atmos.yaml

This file was deleted.

210 changes: 210 additions & 0 deletions examples/2024-CLi-juno-mwr/run_juno_emcee.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
#! /usr/bin/env python3
import numpy as np
import emcee
import sys, os
import matplotlib.pyplot as plt
import h5py

sys.path.append("../python")
sys.path.append(".")

from canoe import def_species, load_configure
from canoe.snap import def_thermo
from canoe.athena import Mesh, ParameterInput, Outputs, MeshBlock
# from canoe.harp import radiation_band, radiation

# forward operater
def run_RT_modify_atmos(mb: MeshBlock,
adlnTdlnP: float=0.,
pmin: float=0. ,
pmax: float =0.
)-> np.array:
# adlnTdlnP=0.0 ## set as insensitive
mb.modify_dlnTdlnP(adlnTdlnP, pmin, pmax)
# adlnNH3dlnP = 0#.25
# mb.modify_dlnNH3dlnP(adlnNH3dlnP, pmin, pmax)

# for k in range(mb.k_st, mb.k_ed + 1):
# for j in range(mb.j_st, mb.j_ed + 1):
rad=mb.get_rad()
rad.cal_radiance(mb, mb.k_st, mb.j_st)

nb=rad.get_num_bands()
tb=np.array([0.]*4*nb)

for ib in range(nb):
print(rad.get_band(ib))
toa=rad.get_band(ib).get_toa()[0]
tb[ib*4:ib*4+4]=toa
return tb

def set_atmos_run_RT(qNH3: float, # ppmv
T0: float=180. # Kelvin
):

mb.construct_atmosphere(pin,qNH3,T0)
rad=mb.get_rad()
rad.cal_radiance(mb, mb.k_st, mb.j_st)

nb=rad.get_num_bands()
tb=np.array([0.]*4*nb)

for ib in range(nb):
# print(rad.get_band(ib))
toa=rad.get_band(ib).get_toa()[0]
tb[ib*4:ib*4+4]=toa
# print(tb[4:])
return tb[4:]

# Define likelihood function
def ln_likelihood(theta, observations, observation_errors):
nh3, temperature = theta

simulations = set_atmos_run_RT(nh3,temperature) # Use your forward operator here
residuals = observations - simulations
print(simulations)
print(observations)
# print(residuals)
chi_squared = np.sum((residuals / observation_errors) ** 2)
# print(chi_squared)
return -0.5 * chi_squared

# Define priors for NH3 and temperature
def ln_prior(theta):
nh3, temperature = theta

nh3_mean = 300 # Mean value for NH3
nh3_stddev = 100 # Standard deviation for NH3
temperature_mean = 169 # Mean value for temperature
temperature_stddev = 10 # Standard deviation for temperature 0.5%

ln_prior_nh3 = -0.5 * ((nh3 - nh3_mean) / nh3_stddev)**2 - np.log(nh3_stddev * np.sqrt(2 * np.pi))
ln_prior_temperature = -0.5 * ((temperature - temperature_mean) / temperature_stddev)**2 - np.log(temperature_stddev * np.sqrt(2 * np.pi))

if (0 < nh3 < 1000):# and (100 < temperature < 200):
return ln_prior_nh3 + ln_prior_temperature
return -np.inf # return negative infinity if parameters are outside allowed range

# Combine likelihood and prior to get posterior
def ln_posterior(theta, observations, observation_errors):
prior = ln_prior(theta)
if not np.isfinite(prior):
return -np.inf
return prior + ln_likelihood(theta, observations, observation_errors)


## main
if __name__ == "__main__":

## extract TB observations from ZZ fitting results
observations=np.zeros((20,))
obs=np.zeros((24,))
pj=51
mu=np.cos(np.array([0.,15.,30.,45.])/180.*np.pi)
print(mu)
for ch in range(6):
tb_file=h5py.File(f"/nfs/nuke/chengcli/JUNOMWR/zzhang/PJ{pj:02d}_Freq{ch}.h5","r")
if ch==0:
c0=tb_file['ModelTypeupdate1_MultiPJ_Mode1/Iter1/c0'][-1] ## the north polar
c1=tb_file['ModelTypeupdate1_MultiPJ_Mode1/Iter1/c1'][-1] ## the north polar
c2=tb_file['ModelTypeupdate1_MultiPJ_Mode1/Iter1/c2'][-1] ## the north polar
else:
c0=tb_file['ModelTypeupdate1_MultiPJ_Mode3/Iter2/c0'][-1] ## the north polar
c1=tb_file['ModelTypeupdate1_MultiPJ_Mode3/Iter2/c1'][-1] ## the north polar
c2=tb_file['ModelTypeupdate1_MultiPJ_Mode3/Iter2/c2'][-1] ## the north polar
tb_file.close()

# Xr=1.0 ## \mu >0.6
obs[(ch)*4:(ch+1)*4]=c0-c1*5.*(1-mu)+c2/0.04*0.5*(mu-0.8)*(1-mu)
## discard CH1
observations=obs[4:]
print(observations)

# [740.51932939 732.39178625 708.02076917 667.58562359 474.58510281
# 469.42513666 454.00808555 428.59559118 338.13016122 335.65949356
# 328.01197674 314.60534003 251.9730167 250.46642377 245.71888005
# 237.15115289 194.47971955 193.67185714 191.10407859 186.40702401
# 141.18445694 141.06723252 140.59821156 139.46693178]

## initialize Canoe
global pin
pin = ParameterInput()
pin.load_from_file("juno_mwr.inp")

vapors = pin.get_string("species", "vapor").split(", ")
clouds = pin.get_string("species", "cloud").split(", ")
tracers = pin.get_string("species", "tracer").split(", ")

def_species(vapors=vapors, clouds=clouds, tracers=tracers)
def_thermo(pin)

config = load_configure("juno_mwr.yaml")
# print(pin.get_real("problem", "qH2O.ppmv"))

mesh = Mesh(pin)
mesh.initialize(pin)

global mb
mb = mesh.meshblock(0)

## run MCMC

# Generate synthetic observations and errors (replace with your real data)
# observations = np.random.normal(size=20) # Replace with your observations
observation_errors_stddev = 0.03 * observations ## error = 5%

# Initialize walkers
n_walkers = 5
n_dimensions = 2 # nh3, temperature
initial_guess = [200.0, 150.0] # Initial guess for NH3 and temperature
initial_guesses = [initial_guess + 10*np.random.randn(n_dimensions) for _ in range(n_walkers)]

filename = "run_mcmc_5000.h5"
backend = emcee.backends.HDFBackend(filename)
backend.reset(n_walkers, n_dimensions)

# Set up the sampler
sampler = emcee.EnsembleSampler(n_walkers, n_dimensions, ln_posterior, args=(observations, observation_errors_stddev), backend=backend)

# Run MCMC
n_steps = 5000
sampler.run_mcmc(initial_guesses, n_steps, progress=True)

# Extract samples
samples = sampler.get_chain()

# Compute mean and standard deviation of samples
mean_nh3 = np.mean(samples[:,:,0])
std_nh3 = np.std(samples[:,:,0])
mean_temperature = np.mean(samples[:,:,1])
std_temperature = np.std(samples[:,:,1])

print("NH3: Mean =", mean_nh3, "Standard Deviation =", std_nh3)
print("Temperature: Mean =", mean_temperature, "Standard Deviation =", std_temperature)


# Plot convergence diagnostics
fig, axes = plt.subplots(2, figsize=(10, 7), sharex=True)
for iwalk in range(n_walkers):
axes[0].plot(range(n_steps),sampler.get_chain()[:,iwalk,0].T, alpha=0.4)
axes[0].set_ylabel("NH3")
for iwalk in range(n_walkers):
axes[1].plot(range(n_steps),sampler.get_chain()[:,iwalk,1].T, alpha=0.4)
axes[1].plot(sampler.get_chain()[:,:,1].T, alpha=0.4)
axes[1].set_ylabel("Temperature")
axes[1].set_xlabel("Step")
plt.savefig("MCMC-5000Steps.png")

# Flatten samples for posterior distribution plotting
flat_samples = sampler.get_chain(discard=100, thin=15, flat=True)

# Plot posterior distributions
labels = ["NH3", "Temperature"]
fig, axes = plt.subplots(2, figsize=(10, 7), sharex=True)
for i in range(2):
ax = axes[i]
ax.hist(flat_samples[:, i], bins=50, color="skyblue", histtype="stepfilled", alpha=0.7)
ax.set_title(labels[i])
ax.grid(True)
# plt.tight_layout()
plt.savefig("posterior-distributions-5000.png")
Loading
Loading