Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyRenew Demo Questions #83

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions model/docs/pyrenew_demo.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,13 @@ with seed(rng_seed=np.random.randint(0,1000)):
q_samp = q.sample(duration=100)

plt.plot(np.exp(q_samp[0]))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question, we need to document what is the output of this function. Every sample function from pyrenew returns a tuple or named tuple.

# Damon: why is q_samp multideminsional?
# Damon: Why do we generate a Normal random walk and exponentiate? Should we have a Log-Normal Random walk function?
# Damon: Why do we generate random number to use as the rng seed?
# Damon: Description could be updated to relate our example to a real world scenario. What are we simulating here?
```

Damon: I believe the next section is totally separate from this first example. Perhaps we could make this clearer with # Section Labels.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree

Next, import several additional functions from the `latent` module of the `pyrenew` package to model infections, hospital admissions, initial infections, and hospitalization rate due to infection.

```{python}
Expand Down Expand Up @@ -103,6 +108,8 @@ inf_hosp_int = DeterministicPMF(
(jnp.array([0, 0, 0,0,0,0,0,0,0,0,0,0,0, 0.25, 0.5, 0.1, 0.1, 0.05]),),
)

# These sum to 1, so I assume they are probabilities, but what is the domain of the distribution?
# Regardless, a single array doesn't seem like the appropriate data structure to use.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this is for a deterministic quantity that can be replaced with a probabilistic one. In the example, we assume the generation interval, but it can also be fitted.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm actually more confused now 😅

Two issues:

  1. The doc mentions 18 possible outcomes and corresponding probabilities given by the values in the array. I understand that the array presents 18 probabilities. What values do those probabilities map to? Why are so many of them 0 instead of being omitted as possibilities?
  2. I'm not sure what is meant to be communicated by "deterministic PMF." The docstrings mention a degenerate random variable. If that's the case, I don't understand why there are probabilities involved at all. Based on your description, I think it means that the prior and posterior distribution of this random variable are exactly the same. If that is the case, I think we could come up with a better descriptor. Something like fixed, constant, invariant, unestimated, or known

latent_hospitalizations = HospitalAdmissions(
infection_to_admission_interval=inf_hosp_int,
infect_hosp_rate_dist = InfectHospRate(
Expand All @@ -112,10 +119,12 @@ latent_hospitalizations = HospitalAdmissions(

# 5) An observation process for the hospitalizations
observed_hospitalizations = PoissonObservation()
# Damon: What does it mean that there is a PoissonObservation? What are the parameters of the Poisson distrubtion?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, the pyrenew_demo.qmd needs to include a model description as initial section.


# 6) A random walk process (it could be deterministic using
# pyrenew.process.DeterministicProcess())
Rt_process = RtRandomWalkProcess()
# What's the difference between this and the SimpleRandomWalkProcess used earlier? Why don't we need to specify a distribution this time?
```

The `HospitalizationsModel` is then initialized using the initial conditions just defined:
Expand All @@ -130,6 +139,8 @@ hospmodel = HospitalizationsModel(
latent_infections=latent_infections,
Rt_process=Rt_process
)
# Damon: I don't really get why there is a hospitalizations model as a concept.
# Damon: Maybe the scope of the project is so limited that it makes sense, but one can easily imagine additional data sources (Wastewater, separate hosp and ICU admissions). Would each of those get a separate class?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, that's the whole point. The current state features a handful of models and classes, but needs to be extended. Most of the required features (including other data models) are listed under the repo's issues.

```

Next, we sample from the `hospmodel` for 30 time steps and view the output of a single run:
Expand All @@ -138,6 +149,7 @@ Next, we sample from the `hospmodel` for 30 time steps and view the output of a
with seed(rng_seed=np.random.randint(1, 60)):
x = hospmodel.sample(n_timepoints=30)
x
# Damon: Why do we generate random number to use as the rng seed?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question, @dylanhmorris should know as he wrote the initial code!

```

Visualizations of the single model output show (top) infections over the 30 time steps, (middle) hospitalizations over the 30 time steps, and (bottom)
Expand All @@ -152,9 +164,12 @@ ax[1].plot(x.latent)
ax[2].plot(x.sampled, 'o')
for axis in ax[:-1]:
axis.set_yscale("log")
# Damon: We should label the figures.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree

```

To fit the `hospmodel` to the simulated data, we call `hospmodel.run()`, an MCMC algorithm, with the arguments generated in `hospmodel` object, using 1000 warmup stepts and 1000 samples to draw from the posterior distribution of the model parameters. The model is run for `len(x.sampled)-1` time steps with the seed set by `jax.random.PRNGKey()`
Damon: Which MCMC algorithm is run? Does it come from another module?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NUTS, this should also be explained explicitly (or at least point to where there may be more details).

Damon: Where did we specify the prior distribution?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nowhere explicitly, that's the beauty of numpyro. Priors are embedded in the sample() functions.


```{python}
# from numpyro.infer import MCMC, NUTS
Expand All @@ -166,6 +181,8 @@ hospmodel.run(
rng_key=jax.random.PRNGKey(54),
mcmc_args=dict(progress_bar=False),
)

# Damon: What is the relationship between `n_timepoints` and `observed_hospitalizations`? Is `n_timepoints` always one less than the number of hospitalization observtion times? If so, is this parameter redundant?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great points, that's something that needs to be addressed. When observed hospitalizations are passed, n_timepoints should be directly computed.

```

Print a summary of the model:
Expand All @@ -182,6 +199,8 @@ samps = spread_draws(hospmodel.mcmc.get_samples(), [("Rt", "time")])
```

We visualize these samples below, with individual possible Rt estimates over time shown in light blue, and the overall mean estimate Rt shown in dark blue.
Damon: Phrasing could be improved on "individual possible Rt estimates." These are individual draws from the posterior R_t distribution.
Damon: also, we can use mathjax formatting to format R subscript t.

```{python}
#| label: fig-sampled-rt
Expand All @@ -199,4 +218,6 @@ for samp_id in samp_ids:
ax.set_ylim([0.4, 1/.4])
ax.set_yticks([0.5, 1, 2])
ax.set_yscale("log")

# Can we use ArviZ for visualization? It is included in poetry dependencies.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's that? If you think it improves things, for sure!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://python.arviz.org/en/stable/index.html

It is included here, so it seems that someone else may have been keen to use it.

```