-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PyRenew Demo Questions #83
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -49,8 +49,13 @@ with seed(rng_seed=np.random.randint(0,1000)): | |
q_samp = q.sample(duration=100) | ||
|
||
plt.plot(np.exp(q_samp[0])) | ||
# Damon: why is q_samp multideminsional? | ||
# Damon: Why do we generate a Normal random walk and exponentiate? Should we have a Log-Normal Random walk function? | ||
# Damon: Why do we generate random number to use as the rng seed? | ||
# Damon: Description could be updated to relate our example to a real world scenario. What are we simulating here? | ||
``` | ||
|
||
Damon: I believe the next section is totally separate from this first example. Perhaps we could make this clearer with # Section Labels. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agree |
||
Next, import several additional functions from the `latent` module of the `pyrenew` package to model infections, hospital admissions, initial infections, and hospitalization rate due to infection. | ||
|
||
```{python} | ||
|
@@ -103,6 +108,8 @@ inf_hosp_int = DeterministicPMF( | |
(jnp.array([0, 0, 0,0,0,0,0,0,0,0,0,0,0, 0.25, 0.5, 0.1, 0.1, 0.05]),), | ||
) | ||
|
||
# These sum to 1, so I assume they are probabilities, but what is the domain of the distribution? | ||
# Regardless, a single array doesn't seem like the appropriate data structure to use. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note that this is for a deterministic quantity that can be replaced with a probabilistic one. In the example, we assume the generation interval, but it can also be fitted. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm actually more confused now 😅 Two issues:
|
||
latent_hospitalizations = HospitalAdmissions( | ||
infection_to_admission_interval=inf_hosp_int, | ||
infect_hosp_rate_dist = InfectHospRate( | ||
|
@@ -112,10 +119,12 @@ latent_hospitalizations = HospitalAdmissions( | |
|
||
# 5) An observation process for the hospitalizations | ||
observed_hospitalizations = PoissonObservation() | ||
# Damon: What does it mean that there is a PoissonObservation? What are the parameters of the Poisson distrubtion? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point, the |
||
|
||
# 6) A random walk process (it could be deterministic using | ||
# pyrenew.process.DeterministicProcess()) | ||
Rt_process = RtRandomWalkProcess() | ||
# What's the difference between this and the SimpleRandomWalkProcess used earlier? Why don't we need to specify a distribution this time? | ||
``` | ||
|
||
The `HospitalizationsModel` is then initialized using the initial conditions just defined: | ||
|
@@ -130,6 +139,8 @@ hospmodel = HospitalizationsModel( | |
latent_infections=latent_infections, | ||
Rt_process=Rt_process | ||
) | ||
# Damon: I don't really get why there is a hospitalizations model as a concept. | ||
# Damon: Maybe the scope of the project is so limited that it makes sense, but one can easily imagine additional data sources (Wastewater, separate hosp and ICU admissions). Would each of those get a separate class? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Indeed, that's the whole point. The current state features a handful of models and classes, but needs to be extended. Most of the required features (including other data models) are listed under the repo's issues. |
||
``` | ||
|
||
Next, we sample from the `hospmodel` for 30 time steps and view the output of a single run: | ||
|
@@ -138,6 +149,7 @@ Next, we sample from the `hospmodel` for 30 time steps and view the output of a | |
with seed(rng_seed=np.random.randint(1, 60)): | ||
x = hospmodel.sample(n_timepoints=30) | ||
x | ||
# Damon: Why do we generate random number to use as the rng seed? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good question, @dylanhmorris should know as he wrote the initial code! |
||
``` | ||
|
||
Visualizations of the single model output show (top) infections over the 30 time steps, (middle) hospitalizations over the 30 time steps, and (bottom) | ||
|
@@ -152,9 +164,12 @@ ax[1].plot(x.latent) | |
ax[2].plot(x.sampled, 'o') | ||
for axis in ax[:-1]: | ||
axis.set_yscale("log") | ||
# Damon: We should label the figures. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agree |
||
``` | ||
|
||
To fit the `hospmodel` to the simulated data, we call `hospmodel.run()`, an MCMC algorithm, with the arguments generated in `hospmodel` object, using 1000 warmup stepts and 1000 samples to draw from the posterior distribution of the model parameters. The model is run for `len(x.sampled)-1` time steps with the seed set by `jax.random.PRNGKey()` | ||
Damon: Which MCMC algorithm is run? Does it come from another module? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. NUTS, this should also be explained explicitly (or at least point to where there may be more details). |
||
Damon: Where did we specify the prior distribution? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nowhere explicitly, that's the beauty of |
||
|
||
```{python} | ||
# from numpyro.infer import MCMC, NUTS | ||
|
@@ -166,6 +181,8 @@ hospmodel.run( | |
rng_key=jax.random.PRNGKey(54), | ||
mcmc_args=dict(progress_bar=False), | ||
) | ||
|
||
# Damon: What is the relationship between `n_timepoints` and `observed_hospitalizations`? Is `n_timepoints` always one less than the number of hospitalization observtion times? If so, is this parameter redundant? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great points, that's something that needs to be addressed. When observed hospitalizations are passed, |
||
``` | ||
|
||
Print a summary of the model: | ||
|
@@ -182,6 +199,8 @@ samps = spread_draws(hospmodel.mcmc.get_samples(), [("Rt", "time")]) | |
``` | ||
|
||
We visualize these samples below, with individual possible Rt estimates over time shown in light blue, and the overall mean estimate Rt shown in dark blue. | ||
Damon: Phrasing could be improved on "individual possible Rt estimates." These are individual draws from the posterior R_t distribution. | ||
Damon: also, we can use mathjax formatting to format R subscript t. | ||
|
||
```{python} | ||
#| label: fig-sampled-rt | ||
|
@@ -199,4 +218,6 @@ for samp_id in samp_ids: | |
ax.set_ylim([0.4, 1/.4]) | ||
ax.set_yticks([0.5, 1, 2]) | ||
ax.set_yscale("log") | ||
|
||
# Can we use ArviZ for visualization? It is included in poetry dependencies. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's that? If you think it improves things, for sure! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. https://python.arviz.org/en/stable/index.html It is included here, so it seems that someone else may have been keen to use it. |
||
``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good question, we need to document what is the output of this function. Every sample function from
pyrenew
returns a tuple or named tuple.