diff --git a/benchmark.py b/benchmark.py deleted file mode 100644 index 18c1b538..00000000 --- a/benchmark.py +++ /dev/null @@ -1,36 +0,0 @@ -# %% -import timeit - -import jax.config -import jax.numpy as jnp -import numpyro - -from config.config_base import ConfigBase -from mechanistic_compartments import build_basic_mechanistic_model -from model_odes.seir_model_v5 import seirw_ode, seirw_ode2 - -# Use 4 cores -numpyro.set_host_device_count(4) -jax.config.update("jax_enable_x64", True) - -cb = ConfigBase() -cb.STRAIN_SPECIFIC_R0 = jnp.array([1.5, 2.5, 3.5]) -model = build_basic_mechanistic_model(cb) - - -def func1(): - return model.run(seirw_ode, tf=100, plot=False, save=False) - - -def func2(): - return model.run(seirw_ode2, tf=100, plot=False, save=False) - - -num_runs = 20 -duration1 = timeit.Timer(func1).timeit(number=num_runs) -avg_duration1 = duration1 / num_runs -duration2 = timeit.Timer(func2).timeit(number=num_runs) -avg_duration2 = duration2 / num_runs - -print(f"func1: On average it took {avg_duration1} seconds") -print(f"func2: On average it took {avg_duration2} seconds") diff --git a/R/abm_resampling.R b/data_manipulation_scripts/abm_resampling.R similarity index 100% rename from R/abm_resampling.R rename to data_manipulation_scripts/abm_resampling.R diff --git a/R/hhs_hospitalization_formatter.R b/data_manipulation_scripts/hhs_hospitalization_formatter.R similarity index 100% rename from R/hhs_hospitalization_formatter.R rename to data_manipulation_scripts/hhs_hospitalization_formatter.R diff --git a/R/seroprevalence.R b/data_manipulation_scripts/seroprevalence.R similarity index 100% rename from R/seroprevalence.R rename to data_manipulation_scripts/seroprevalence.R diff --git a/sim_data_to_sero_generator.py b/data_manipulation_scripts/sim_data_to_sero_generator.py similarity index 100% rename from sim_data_to_sero_generator.py rename to data_manipulation_scripts/sim_data_to_sero_generator.py diff --git a/example_end_to_end_run.py b/example_end_to_end_run.py index a45b5557..d6acc1ac 100644 --- a/example_end_to_end_run.py +++ b/example_end_to_end_run.py @@ -2,28 +2,78 @@ from mechanistic_model.mechanistic_runner import MechanisticRunner from mechanistic_model.solution_iterpreter import SolutionInterpreter from mechanistic_model.static_value_parameters import StaticValueParameters +from mechanistic_model.mechanistic_inferer import MechanisticInferer from model_odes.seip_model import seip_ode +import matplotlib.pyplot as plt +import sys +import jax.numpy as jnp +import numpy as np if __name__ == "__main__": + # step 1: define your paths config_path = "config/" + # global_config include definitions such as age bin bounds and strain definitions + # Any value or data structure that needs context to be interpretted is here. GLOBAL_CONFIG_PATH = config_path + "config_global.json" + # defines the init conditions of the scenario: pop size, initial infections etc. INITIALIZER_CONFIG_PATH = config_path + "config_initializer_covid.json" + # defines the running variables, strain R0s, external strain introductions etc. RUNNER_CONFIG_PATH = config_path + "config_runner_covid.json" + # defines __distributions__ that act like priors to infer runner variables. + INFERER_CONFIG_PATH = config_path + "config_inferer_covid.json" + # defines how the solution should be viewed, what slices examined, how to save. INTERPRETER_CONFIG_PATH = config_path + "config_interpreter_covid.json" - # model = build_basic_mechanistic_model(ConfigBase()) + # sets up the initial conditions, initializer.get_initial_state() passed to runner initializer = CovidInitializer(INITIALIZER_CONFIG_PATH, GLOBAL_CONFIG_PATH) + # reads and interprets values from config, sets up downstream parameters + # like beta = STRAIN_R0s / INFECTIOUS_PERIOD static_params = StaticValueParameters( initializer.get_initial_state(), RUNNER_CONFIG_PATH, GLOBAL_CONFIG_PATH, - ).get_parameters() + ) + # A runner that does ODE solving of a single run. runner = MechanisticRunner(seip_ode) + # run for 200 days, using init state and parameters from StaticValueParameters solution = runner.run( - initializer.get_initial_state(), tf=200, args=static_params - ) - interpreter = SolutionInterpreter( - solution, INTERPRETER_CONFIG_PATH, GLOBAL_CONFIG_PATH - ) - fig, ax = interpreter.summarize_solution( - plot_commands=["S[0, :, :, :]", "E[0, :, :, :]", "I[0, :, :, :]"] + initializer.get_initial_state(), + tf=200, + args=static_params.get_parameters(), ) + if "-infer" in sys.argv: + # for an example inference, lets jumble our solution up a bit and attempt to fit back to it + ihr = [0.002, 0.004, 0.008, 0.06] + model_incidence = jnp.sum(solution.ys[3], axis=(2, 3, 4)) + model_incidence = jnp.diff(model_incidence, axis=0) + rng = np.random.default_rng(seed=8675399) + m = np.asarray(model_incidence) * ihr + k = 10.0 + p = k / (k + m) + fake_obs = rng.negative_binomial(k, p) + inferer = MechanisticInferer( + GLOBAL_CONFIG_PATH, + INFERER_CONFIG_PATH, + runner, + initializer.get_initial_state(), + ) + # artificially shortening inference since this is a toy example + inferer.config.INFERENCE_NUM_WARMUP = 30 + inferer.config.INFERENCE_NUM_SAMPLES = 30 + inferer.set_infer_algo() + # this will print a summary of the inferred variables + # those distributions in the Config are now posteriors + inferer.infer(fake_obs) + print( + "Toy inference finished, see the distributions of posteriors above, " + "in only 60 samples how well do they match with the actual parameters " + "used to generate the fake data? \n" + ) + else: + # interpret the solution object in a variety of ways + interpreter = SolutionInterpreter( + solution, INTERPRETER_CONFIG_PATH, GLOBAL_CONFIG_PATH + ) + # plot the 4 compartments summed across all age bins and immunity status + fig, ax = interpreter.summarize_solution() + print("Please see output/example_end_to_end_run.png for your plot!") + plt.savefig("output/example_end_to_end_run.png") diff --git a/gen_fake_dat_and_fit.py b/gen_fake_dat_and_fit.py deleted file mode 100644 index 97a9a0ee..00000000 --- a/gen_fake_dat_and_fit.py +++ /dev/null @@ -1,76 +0,0 @@ -# %% -import jax.config -import jax.numpy as jnp -import matplotlib.pyplot as plt -import numpy as np -import numpyro - -from config.config_base import ConfigBase -from mechanistic_compartments import build_basic_mechanistic_model -from model_odes.seip_model import seip_ode # , seirw_ode - -# from jax.random import PRNGKey -# from numpyro.infer import MCMC, NUTS -# from code_fragments_deprecated.inference import infer_model - -# Use 4 cores -numpyro.set_host_device_count(4) -jax.config.update("jax_enable_x64", True) - -# %% -# True model -cb = ConfigBase( - POP_SIZE=1e8, - INITIAL_INFECTIONS=1e6, - INTRODUCTION_PERCENTAGE=0.01, - INTRODUCTION_TIMES=[60], -) -cb.STRAIN_SPECIFIC_R0 = jnp.array([1.5, 2.5, 2.5], dtype=jnp.float32) # R0s -model = build_basic_mechanistic_model(cb) -ihr = [0.002, 0.004, 0.008, 0.06] - -solution = model.run(seip_ode, tf=300, show=True, save=False) -model_incidence = jnp.sum(solution.ys[3], axis=(2, 3, 4)) -model_incidence = jnp.diff(model_incidence, axis=0) -rng = np.random.default_rng(seed=8675399) -m = np.asarray(model_incidence) * ihr -k = 10.0 -p = k / (k + m) -fake_obs = rng.negative_binomial(k, p) - -# %% -fig, ax = plt.subplots(1) -ax.plot(m, label=[1, 2, 3, 4]) -fig.legend() -plt.show() - -# %% -# Perform inference -# Reducing max_tree_depth here to reduce fitting time -# The new way of doing it shown below! -model.MCMC_NUM_WARMUP = 100 -model.MCMC_NUM_SAMPLES = 100 -model.MCMC_NUM_CHAINS = 1 -model.infer( - seip_ode, - fake_obs, - negbin=True, -) - -# mcmc = MCMC( -# NUTS(infer_model, dense_mass=True, max_tree_depth=5), -# num_warmup=500, -# num_samples=500, -# thinning=1, -# num_chains=4, -# progress_bar=True, -# ) -# mcmc.run( -# PRNGKey(8675328), -# times=np.linspace(0.0, 100.0, 101), -# incidence=fake_obs, -# model=model, -# ) -# mcmc.print_summary() - -# %%