Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup rtperiodicdiffar #419

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions docs/source/tutorials/periodic_effects.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@ from pyrenew import process, deterministic

```{python}
# The random process for Rt
rt_proc = process.RtWeeklyDiffARProcess(
rt_proc = process.RtPeriodicDiffARProcess(
name="rt_weekly_diff",
offset=0,
log_rt_rv=deterministic.DeterministicVariable(
period_size=7,
log_rt_init_rv=deterministic.DeterministicVariable(
name="log_rt", value=jnp.array([0.1, 0.2])
),
autoreg_rv=deterministic.DeterministicVariable(
Expand Down Expand Up @@ -57,7 +58,7 @@ for i in range(0, 30, 7):
plt.show()
```

The implementation of the `RtWeeklyDiffARProcess` (which is an instance of `RtPeriodicDiffARProcess`), uses `repeat_until_n` to repeating values: `repeat_until_n(..., period_size=7)`. The `RtWeeklyDiff` class is a particular case of `RtPeriodicDiff` with a period size of seven.
The implementation of the `RtPeriodicDiffARProcess` uses `repeat_until_n` to repeating values: `repeat_until_n(..., period_size=7)`. The `RtWeeklyDiff` class is a particular case of `RtPeriodicDiff` with a period size of seven.

## Repeated sequences (tiling)

Expand Down
6 changes: 1 addition & 5 deletions pyrenew/process/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,7 @@
)
from pyrenew.process.periodiceffect import DayOfWeekEffect, PeriodicEffect
from pyrenew.process.randomwalk import RandomWalk, StandardNormalRandomWalk
from pyrenew.process.rtperiodicdiffar import (
RtPeriodicDiffARProcess,
RtWeeklyDiffARProcess,
)
from pyrenew.process.rtperiodicdiffar import RtPeriodicDiffARProcess

__all__ = [
"IIDRandomSequence",
Expand All @@ -25,5 +22,4 @@
"PeriodicEffect",
"DayOfWeekEffect",
"RtPeriodicDiffARProcess",
"RtWeeklyDiffARProcess",
]
74 changes: 14 additions & 60 deletions pyrenew/process/rtperiodicdiffar.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(
name: str,
offset: int,
period_size: int,
log_rt_rv: RandomVariable,
log_rt_init_rv: RandomVariable,
autoreg_rv: RandomVariable,
periodic_diff_sd_rv: RandomVariable,
ar_process_suffix: str = "_first_diff_ar_process_noise",
Expand All @@ -69,7 +69,7 @@ def __init__(
offset : int
Relative point at which data starts, must be between 0 and
period_size - 1.
log_rt_rv : RandomVariable
log_rt_init_rv : RandomVariable
Log Rt prior for the first two observations.
autoreg_rv : RandomVariable
Autoregressive parameter.
Expand All @@ -87,15 +87,15 @@ def __init__(
"""

self.validate(
log_rt_rv=log_rt_rv,
log_rt_init_rv=log_rt_init_rv,
autoreg_rv=autoreg_rv,
periodic_diff_sd_rv=periodic_diff_sd_rv,
)

self.name = name
self.period_size = period_size
self.offset = offset
self.log_rt_rv = log_rt_rv
self.log_rt_init_rv = log_rt_init_rv
self.autoreg_rv = autoreg_rv
self.periodic_diff_sd_rv = periodic_diff_sd_rv
self.ar_diff = DifferencedProcess(
Expand All @@ -109,7 +109,7 @@ def __init__(

@staticmethod
def validate(
log_rt_rv: any,
log_rt_init_rv: any,
autoreg_rv: any,
periodic_diff_sd_rv: any,
) -> None:
Expand All @@ -118,7 +118,7 @@ def validate(

Parameters
----------
log_rt_rv : any
log_rt_init_rv : any
Log Rt prior for the first two observations.
autoreg_rv : any
Autoregressive parameter.
Expand All @@ -130,7 +130,7 @@ def validate(
None
"""

_assert_sample_and_rtype(log_rt_rv)
_assert_sample_and_rtype(log_rt_init_rv)
_assert_sample_and_rtype(autoreg_rv)
_assert_sample_and_rtype(periodic_diff_sd_rv)

Expand Down Expand Up @@ -159,9 +159,9 @@ def sample(
"""

# Initial sample
log_rt_rv = self.log_rt_rv.sample(**kwargs)[0].value
b = self.autoreg_rv.sample(**kwargs)[0].value
s_r = self.periodic_diff_sd_rv.sample(**kwargs)[0].value
log_rt_init = self.log_rt_init_rv.sample(**kwargs)[0].value
autoreg = self.autoreg_rv.sample(**kwargs)[0].value
noise_sd = self.periodic_diff_sd_rv.sample(**kwargs)[0].value

# How many periods to sample?
n_periods = (duration + self.period_size - 1) // self.period_size
Expand All @@ -170,11 +170,11 @@ def sample(

log_rt = self.ar_diff(
n=n_periods,
init_vals=jnp.array([log_rt_rv[0]]),
autoreg=b,
noise_sd=s_r,
init_vals=jnp.array([log_rt_init[0]]),
autoreg=autoreg,
noise_sd=noise_sd,
fundamental_process_init_vals=jnp.array(
[log_rt_rv[1] - log_rt_rv[0]]
[log_rt_init[1] - log_rt_init[0]]
),
)[0]

Expand All @@ -190,49 +190,3 @@ def sample(
t_unit=self.t_unit,
),
)


class RtWeeklyDiffARProcess(RtPeriodicDiffARProcess):
"""
Weekly Rt with autoregressive first differences.
"""

def __init__(
self,
name: str,
offset: int,
log_rt_rv: RandomVariable,
autoreg_rv: RandomVariable,
periodic_diff_sd_rv: RandomVariable,
) -> None:
"""
Default constructor for RtWeeklyDiffARProcess class.

Parameters
----------
name : str
Name of the site.
offset : int
Relative point at which data starts, must be between 0 and 6.
log_rt_rv : RandomVariable
Log Rt prior for the first two observations.
autoreg_rv : RandomVariable
Autoregressive parameter.
periodic_diff_sd_rv : RandomVariable
Standard deviation of the noise.

Returns
-------
None
"""

super().__init__(
name=name,
offset=offset,
period_size=7,
log_rt_rv=log_rt_rv,
autoreg_rv=autoreg_rv,
periodic_diff_sd_rv=periodic_diff_sd_rv,
)

return None
19 changes: 11 additions & 8 deletions test/test_rtperiodicdiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from numpy.testing import assert_array_equal

from pyrenew.deterministic import DeterministicVariable
from pyrenew.process import RtWeeklyDiffARProcess
from pyrenew.process import RtPeriodicDiffARProcess


def test_rtweeklydiff() -> None:
Expand All @@ -17,7 +17,7 @@ def test_rtweeklydiff() -> None:
params = {
"name": "test",
"offset": 0,
"log_rt_rv": DeterministicVariable(
"log_rt_init_rv": DeterministicVariable(
name="log_rt", value=jnp.array([0.1, 0.2])
),
"autoreg_rv": DeterministicVariable(
Expand All @@ -26,10 +26,11 @@ def test_rtweeklydiff() -> None:
"periodic_diff_sd_rv": DeterministicVariable(
name="periodic_diff_sd_rv", value=jnp.array([0.1])
),
"period_size": 7,
}
duration = 30

rtwd = RtWeeklyDiffARProcess(**params)
rtwd = RtPeriodicDiffARProcess(**params)

with numpyro.handlers.seed(rng_seed=223):
rt = rtwd(duration=duration).rt.value
Expand All @@ -44,7 +45,7 @@ def test_rtweeklydiff() -> None:

# Checking start off a different day of the week
params["offset"] = 5
rtwd = RtWeeklyDiffARProcess(**params)
rtwd = RtPeriodicDiffARProcess(**params)

with numpyro.handlers.seed(rng_seed=223):
rt2 = rtwd(duration=duration).rt.value
Expand All @@ -65,7 +66,7 @@ def test_rtweeklydiff_no_autoregressive() -> None:
params = {
"name": "test",
"offset": 0,
"log_rt_rv": DeterministicVariable(
"log_rt_init_rv": DeterministicVariable(
name="log_rt", value=jnp.array([0.0, 0.0])
),
# No autoregression!
Expand All @@ -76,9 +77,10 @@ def test_rtweeklydiff_no_autoregressive() -> None:
name="periodic_diff_sd_rv",
value=jnp.array([0.1]),
),
"period_size": 7,
}

rtwd = RtWeeklyDiffARProcess(**params)
rtwd = RtPeriodicDiffARProcess(**params)

duration = 1000

Expand Down Expand Up @@ -109,7 +111,7 @@ def test_rtperiodicdiff_smallsample(inits):
params = {
"name": "test",
"offset": 0,
"log_rt_rv": DeterministicVariable(
"log_rt_init_rv": DeterministicVariable(
name="log_rt",
value=inits,
),
Expand All @@ -120,9 +122,10 @@ def test_rtperiodicdiff_smallsample(inits):
name="periodic_diff_sd_rv",
value=jnp.array([0.1]),
),
"period_size": 7,
}

rtwd = RtWeeklyDiffARProcess(**params)
rtwd = RtPeriodicDiffARProcess(**params)

with numpyro.handlers.seed(rng_seed=223):
rt = rtwd(duration=6).rt.value
Expand Down