Skip to content

Commit

Permalink
Bringing changes from main
Browse files Browse the repository at this point in the history
  • Loading branch information
gvegayon committed Apr 16, 2024
2 parents 338dee0 + 437016d commit b373b17
Show file tree
Hide file tree
Showing 22 changed files with 1,783 additions and 329 deletions.
5 changes: 1 addition & 4 deletions .github/workflows/docs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,7 @@ jobs:
run: pip install poetry

- name: install package
run: poetry install -C model

- name: installing other dependencies
run: poetry run -C model pip install pyyaml nbclient nbformat
run: poetry install --with dev -C model

- name: Render documents
run: |
Expand Down
5 changes: 1 addition & 4 deletions .github/workflows/test_model.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,7 @@ jobs:
run: pip install poetry

- name: install package
run: poetry install -C model

- name: install pytest-cov
run: poetry run -C model pip install pytest-cov
run: poetry install --with dev -C model

- name: run tests
run: |
Expand Down
17 changes: 16 additions & 1 deletion .github/workflows/test_pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,22 @@ jobs:
key: ${{ runner.os }}-poetry
- name: install poetry
run: pip install poetry

- name: install package
run: poetry install -C pipeline

- name: install pytest-cov
run: poetry run -C pipeline pip install pytest-cov

- name: run tests
run: poetry run -C pipeline pytest pipeline
run: poetry run -C pipeline pytest pipeline --cov=pipeline --cov-report term --cov-report xml
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v4
with:
env_vars: OS,PYTHON
fail_ci_if_error: true
flags: unittests
file: coverage.xml
plugin: pycoverage
token: ${{ secrets.CODECOV_TOKEN }}
verbose: true
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -337,3 +337,6 @@ replay_pid*

# VS Code
.vscode

# macOS
.DS_Store
12 changes: 8 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,14 @@ repos:
- id: isort
args: ['--profile', 'black',
'--line-length', '79']
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.0
hooks:
- id: ruff
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.0
hooks:
- id: ruff
# - repo: https://github.com/numpy/numpydoc
# rev: v1.6.0
# hooks:
# - id: numpydoc-validation
#####
# Secrets
- repo: https://github.com/Yelp/detect-secrets
Expand Down
4 changes: 4 additions & 0 deletions model/docs/example-with-datasets.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ from pyrenew.datasets import load_wastewater
from pyrenew.model import HospitalizationsModel
```

/mnt/c/Users/xrd4/Documents/repos/msr/model/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.

## Model definition

In this section we provide the formal definition of the model. The
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
157 changes: 148 additions & 9 deletions model/docs/pyrenew_demo.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,62 @@
This demo simulates some basic renewal process data and then fits to it
using `pyrenew`.

You’ll need to install `pyrenew` first. You’ll also need working
installations of `matplotlib`, `numpy`, `jax`, `numpyro`, and `polars`
Assuming you’ve already installed Python and pip, you’ll need to first
install `pyrenew`:

``` python
pip install pyrenew
```

Requirement already satisfied: pyrenew in /mnt/c/Users/xrd4/Documents/repos/msr/model/.venv/lib/python3.10/site-packages (0.1.0)
Requirement already satisfied: jax<0.5.0,>=0.4.24 in /mnt/c/Users/xrd4/Documents/repos/msr/model/.venv/lib/python3.10/site-packages (from pyrenew) (0.4.26)
Requirement already satisfied: numpy<2.0.0,>=1.26.4 in /mnt/c/Users/xrd4/Documents/repos/msr/model/.venv/lib/python3.10/site-packages (from pyrenew) (1.26.4)
Requirement already satisfied: numpyro<0.14.0,>=0.13.2 in /mnt/c/Users/xrd4/Documents/repos/msr/model/.venv/lib/python3.10/site-packages (from pyrenew) (0.13.2)
Requirement already satisfied: pillow<11.0.0,>=10.3.0 in /mnt/c/Users/xrd4/Documents/repos/msr/model/.venv/lib/python3.10/site-packages (from pyrenew) (10.3.0)
Requirement already satisfied: polars<0.21.0,>=0.20.13 in /mnt/c/Users/xrd4/Documents/repos/msr/model/.venv/lib/python3.10/site-packages (from pyrenew) (0.20.19)
Requirement already satisfied: ml-dtypes>=0.2.0 in /mnt/c/Users/xrd4/Documents/repos/msr/model/.venv/lib/python3.10/site-packages (from jax<0.5.0,>=0.4.24->pyrenew) (0.4.0)
Requirement already satisfied: opt-einsum in /mnt/c/Users/xrd4/Documents/repos/msr/model/.venv/lib/python3.10/site-packages (from jax<0.5.0,>=0.4.24->pyrenew) (3.3.0)
Requirement already satisfied: scipy>=1.9 in /mnt/c/Users/xrd4/Documents/repos/msr/model/.venv/lib/python3.10/site-packages (from jax<0.5.0,>=0.4.24->pyrenew) (1.13.0)
Requirement already satisfied: jaxlib>=0.4.14 in /mnt/c/Users/xrd4/Documents/repos/msr/model/.venv/lib/python3.10/site-packages (from numpyro<0.14.0,>=0.13.2->pyrenew) (0.4.26)
Requirement already satisfied: multipledispatch in /mnt/c/Users/xrd4/Documents/repos/msr/model/.venv/lib/python3.10/site-packages (from numpyro<0.14.0,>=0.13.2->pyrenew) (1.0.0)
Requirement already satisfied: tqdm in /mnt/c/Users/xrd4/Documents/repos/msr/model/.venv/lib/python3.10/site-packages (from numpyro<0.14.0,>=0.13.2->pyrenew) (4.66.2)
Note: you may need to restart the kernel to use updated packages.

You’ll also need working installations of `matplotlib`, `numpy`, `jax`,
`numpyro`, and `polars`:

``` python
pip install matplotlib numpy jax numpyro polars
```

Requirement already satisfied: matplotlib in /mnt/c/Users/xrd4/Documents/repos/msr/model/.venv/lib/python3.10/site-packages (3.8.3)
Requirement already satisfied: numpy in /mnt/c/Users/xrd4/Documents/repos/msr/model/.venv/lib/python3.10/site-packages (1.26.4)
Requirement already satisfied: jax in /mnt/c/Users/xrd4/Documents/repos/msr/model/.venv/lib/python3.10/site-packages (0.4.26)
Requirement already satisfied: numpyro in /mnt/c/Users/xrd4/Documents/repos/msr/model/.venv/lib/python3.10/site-packages (0.13.2)
Requirement already satisfied: polars in /mnt/c/Users/xrd4/Documents/repos/msr/model/.venv/lib/python3.10/site-packages (0.20.19)
Requirement already satisfied: contourpy>=1.0.1 in /mnt/c/Users/xrd4/Documents/repos/msr/model/.venv/lib/python3.10/site-packages (from matplotlib) (1.2.1)
Requirement already satisfied: cycler>=0.10 in /mnt/c/Users/xrd4/Documents/repos/msr/model/.venv/lib/python3.10/site-packages (from matplotlib) (0.12.1)
Requirement already satisfied: fonttools>=4.22.0 in /mnt/c/Users/xrd4/Documents/repos/msr/model/.venv/lib/python3.10/site-packages (from matplotlib) (4.50.0)
Requirement already satisfied: kiwisolver>=1.3.1 in /mnt/c/Users/xrd4/Documents/repos/msr/model/.venv/lib/python3.10/site-packages (from matplotlib) (1.4.5)
Requirement already satisfied: packaging>=20.0 in /mnt/c/Users/xrd4/Documents/repos/msr/model/.venv/lib/python3.10/site-packages (from matplotlib) (24.0)
Requirement already satisfied: pillow>=8 in /mnt/c/Users/xrd4/Documents/repos/msr/model/.venv/lib/python3.10/site-packages (from matplotlib) (10.3.0)
Requirement already satisfied: pyparsing>=2.3.1 in /mnt/c/Users/xrd4/Documents/repos/msr/model/.venv/lib/python3.10/site-packages (from matplotlib) (3.1.2)
Requirement already satisfied: python-dateutil>=2.7 in /mnt/c/Users/xrd4/Documents/repos/msr/model/.venv/lib/python3.10/site-packages (from matplotlib) (2.9.0.post0)
Requirement already satisfied: ml-dtypes>=0.2.0 in /mnt/c/Users/xrd4/Documents/repos/msr/model/.venv/lib/python3.10/site-packages (from jax) (0.4.0)
Requirement already satisfied: opt-einsum in /mnt/c/Users/xrd4/Documents/repos/msr/model/.venv/lib/python3.10/site-packages (from jax) (3.3.0)
Requirement already satisfied: scipy>=1.9 in /mnt/c/Users/xrd4/Documents/repos/msr/model/.venv/lib/python3.10/site-packages (from jax) (1.13.0)
Requirement already satisfied: jaxlib>=0.4.14 in /mnt/c/Users/xrd4/Documents/repos/msr/model/.venv/lib/python3.10/site-packages (from numpyro) (0.4.26)
Requirement already satisfied: multipledispatch in /mnt/c/Users/xrd4/Documents/repos/msr/model/.venv/lib/python3.10/site-packages (from numpyro) (1.0.0)
Requirement already satisfied: tqdm in /mnt/c/Users/xrd4/Documents/repos/msr/model/.venv/lib/python3.10/site-packages (from numpyro) (4.66.2)
Requirement already satisfied: six>=1.5 in /mnt/c/Users/xrd4/Documents/repos/msr/model/.venv/lib/python3.10/site-packages (from python-dateutil>=2.7->matplotlib) (1.16.0)
Note: you may need to restart the kernel to use updated packages.

To begin, run the following import section to call external modules and
functions necessary to run the `pyrenew` demo. The `import` statement
imports the module and the `as` statement renames the module for use
within this script. The `from` statement imports a specific function
from a module (named after the `.`) within a package (named before the
`.`).

``` python
import matplotlib as mpl
Expand All @@ -23,6 +77,19 @@ from pyrenew.process import SimpleRandomWalkProcess

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.

To understand the simple random walk process underlying the sampling
within the renewal process model, we first examine a single random walk
path. Using the `sample` method from an instance of the
`SimpleRandomWalkProcess` class, we first create an instance of the
`SimpleRandomWalkProcess` class with a normal distribution of mean = 0
and standard deviation = 0.0001 as its input. Next, the `with` statement
sets the seed for the random number generator for the duration of the
block that follows. Inside the `with` block, the
`q_samp = q.sample(duration=100)` generates the sample instance over a
duration of 100 time units. Finally, this single random walk process is
visualized using `matplot.pyplot` to plot the exponential of the sample
instance.

``` python
np.random.seed(3312)
q = SimpleRandomWalkProcess(dist.Normal(0, 0.001))
Expand All @@ -34,47 +101,93 @@ plt.plot(np.exp(q_samp[0]))

![](pyrenew_demo_files/figure-commonmark/fig-randwalk-output-1.png)

Next, import several additional functions from the `latent` module of
the `pyrenew` package to model infections, hospital admissions, initial
infections, and hospitalization rate due to infection.

``` python
from pyrenew.latent import (
Infections, HospitalAdmissions, Infections0, InfectHospRate,
)
```

Additionally, import several classes from Pyrenew, including a Poisson
observation process, determininstic PMF and variable classes, the
Pyrenew hospitalization model, and a renewal modle (Rt) random walk
process:

``` python
from pyrenew.observation import PoissonObservation
from pyrenew.deterministic import DeterministicPMF, DeterministicVariable
from pyrenew.model import HospitalizationsModel
from pyrenew.process import RtRandomWalkProcess
```

To initialize the model, we first define initial conditions, including:

1) deterministic generation time, defined as an instance of the
`DeterministicPMF` class, which gives the probability of each
possible outcome for a discrete random variable given as a JAX NumPy
array of four possible outcomes

2) initial infections at the start of simulation as a log-normal
distribution with mean = 0 and standard deviation = 1

3) latent infections as an instance of the `Infections` class with
default settings

4) latent hospitalization process, modeled by first defining the time
interval from infections to hospitalizations as a `DeterministicPMF`
input with 18 possible outcomes and corresponding probabilities
given by the values in the array. The `HospitalAdmissions` function
then takes in this defined time interval, as well as defining the
rate at which infections are admitted to the hospital due to
infection, modeled as a log-normal distribution with mean =
`jnp.log(0.05)` and standard deviation = 0.05.

5) hospitalization observation process, modeled with a Poisson
distribution

6) an Rt random walk process with default settings

``` python
# Initializing model components:

# A deterministic generation time
# 1) A deterministic generation time
gen_int = DeterministicPMF(jnp.array([0.25, 0.25, 0.25, 0.25]))

# Initial infections
# 2) Initial infections
I0 = Infections0(I0_dist=dist.LogNormal(0, 1))

# The latent infections process
# 3) The latent infections process
latent_infections = Infections()

# A deterministic infection to hosp pmf
# 4) The latent hospitalization process:

# First, define a deterministic infection to hosp pmf
inf_hosp_int = DeterministicPMF(
jnp.array([0, 0, 0,0,0,0,0,0,0,0,0,0,0, 0.25, 0.5, 0.1, 0.1, 0.05]),
)

# The latent hospitalization process
latent_hospitalizations = HospitalAdmissions(
infection_to_admission_interval=inf_hosp_int,
infect_hosp_rate_dist = InfectHospRate(
dist=dist.LogNormal(jnp.log(0.05), 0.05),
),
)

# And observation process for the hospitalizations
# 5) An observation process for the hospitalizations
observed_hospitalizations = PoissonObservation()

# And a random walk process (it could be deterministic using
# 6) A random walk process (it could be deterministic using
# pyrenew.process.DeterministicProcess())
Rt_process = RtRandomWalkProcess()
```

The `HospitalizationsModel` is then initialized using the initial
conditions just defined:

``` python
# Initializing the model
hospmodel = HospitalizationsModel(
gen_int=gen_int,
Expand All @@ -86,6 +199,9 @@ hospmodel = HospitalizationsModel(
)
```

Next, we sample from the `hospmodel` for 30 time steps and view the
output of a single run:

``` python
with seed(rng_seed=np.random.randint(1, 60)):
x = hospmodel.sample(n_timepoints=30)
Expand All @@ -112,6 +228,10 @@ x
0.01345188], dtype=float32), sampled=Array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 1, 0, 0, 0], dtype=int32))

Visualizations of the single model output show (top) infections over the
30 time steps, (middle) hospitalizations over the 30 time steps, and
(bottom)

``` python
fig, ax = plt.subplots(nrows=3, sharex=True)
ax[0].plot(x.infections)
Expand All @@ -124,6 +244,13 @@ for axis in ax[:-1]:

![](pyrenew_demo_files/figure-commonmark/fig-hosp-output-1.png)

To fit the `hospmodel` to the simulated data, we call `hospmodel.run()`,
an MCMC algorithm, with the arguments generated in `hospmodel` object,
using 1000 warmup stepts and 1000 samples to draw from the posterior
distribution of the model parameters. The model is run for
`len(x.sampled)-1` time steps with the seed set by
`jax.random.PRNGKey()`

``` python
# from numpyro.infer import MCMC, NUTS
hospmodel.run(
Expand All @@ -136,6 +263,8 @@ hospmodel.run(
)
```

Print a summary of the model:

``` python
hospmodel.print_summary()
```
Expand Down Expand Up @@ -178,11 +307,21 @@ hospmodel.print_summary()

Number of divergences: 0

Next, we will use the `spread_draws` function from the
`pyrenew.mcmcutils` module to process the MCMC samples. The
`spread_draws` function reformats the samples drawn from the
`mcmc.get_samples()` from the `hospmodel`. The samples are simulated Rt
values over time.

``` python
from pyrenew.mcmcutils import spread_draws
samps = spread_draws(hospmodel.mcmc.get_samples(), [("Rt", "time")])
```

We visualize these samples below, with individual possible Rt estimates
over time shown in light blue, and the overall mean estimate Rt shown in
dark blue.

``` python
import numpy as np
import polars as pl
Expand Down
Loading

0 comments on commit b373b17

Please sign in to comment.