diff --git a/docs/source/tutorials/periodic_effects.qmd b/docs/source/tutorials/periodic_effects.qmd index 1603ed59..e215dedd 100644 --- a/docs/source/tutorials/periodic_effects.qmd +++ b/docs/source/tutorials/periodic_effects.qmd @@ -24,10 +24,11 @@ from pyrenew import process, deterministic ```{python} # The random process for Rt -rt_proc = process.RtWeeklyDiffARProcess( +rt_proc = process.RtPeriodicDiffARProcess( name="rt_weekly_diff", offset=0, - log_rt_rv=deterministic.DeterministicVariable( + period_size=7, + log_rt_init_rv=deterministic.DeterministicVariable( name="log_rt", value=jnp.array([0.1, 0.2]) ), autoreg_rv=deterministic.DeterministicVariable( @@ -57,7 +58,7 @@ for i in range(0, 30, 7): plt.show() ``` -The implementation of the `RtWeeklyDiffARProcess` (which is an instance of `RtPeriodicDiffARProcess`), uses `repeat_until_n` to repeating values: `repeat_until_n(..., period_size=7)`. The `RtWeeklyDiff` class is a particular case of `RtPeriodicDiff` with a period size of seven. +The implementation of the `RtPeriodicDiffARProcess` uses `repeat_until_n` to repeating values: `repeat_until_n(..., period_size=7)`. The `RtWeeklyDiff` class is a particular case of `RtPeriodicDiff` with a period size of seven. ## Repeated sequences (tiling) diff --git a/pyrenew/process/__init__.py b/pyrenew/process/__init__.py index 638ea45d..45996193 100644 --- a/pyrenew/process/__init__.py +++ b/pyrenew/process/__init__.py @@ -10,10 +10,7 @@ ) from pyrenew.process.periodiceffect import DayOfWeekEffect, PeriodicEffect from pyrenew.process.randomwalk import RandomWalk, StandardNormalRandomWalk -from pyrenew.process.rtperiodicdiffar import ( - RtPeriodicDiffARProcess, - RtWeeklyDiffARProcess, -) +from pyrenew.process.rtperiodicdiffar import RtPeriodicDiffARProcess __all__ = [ "IIDRandomSequence", @@ -25,5 +22,4 @@ "PeriodicEffect", "DayOfWeekEffect", "RtPeriodicDiffARProcess", - "RtWeeklyDiffARProcess", ] diff --git a/pyrenew/process/rtperiodicdiffar.py b/pyrenew/process/rtperiodicdiffar.py index 9186b9ef..2af85c7b 100644 --- a/pyrenew/process/rtperiodicdiffar.py +++ b/pyrenew/process/rtperiodicdiffar.py @@ -54,7 +54,7 @@ def __init__( name: str, offset: int, period_size: int, - log_rt_rv: RandomVariable, + log_rt_init_rv: RandomVariable, autoreg_rv: RandomVariable, periodic_diff_sd_rv: RandomVariable, ar_process_suffix: str = "_first_diff_ar_process_noise", @@ -69,7 +69,7 @@ def __init__( offset : int Relative point at which data starts, must be between 0 and period_size - 1. - log_rt_rv : RandomVariable + log_rt_init_rv : RandomVariable Log Rt prior for the first two observations. autoreg_rv : RandomVariable Autoregressive parameter. @@ -87,7 +87,7 @@ def __init__( """ self.validate( - log_rt_rv=log_rt_rv, + log_rt_init_rv=log_rt_init_rv, autoreg_rv=autoreg_rv, periodic_diff_sd_rv=periodic_diff_sd_rv, ) @@ -95,7 +95,7 @@ def __init__( self.name = name self.period_size = period_size self.offset = offset - self.log_rt_rv = log_rt_rv + self.log_rt_init_rv = log_rt_init_rv self.autoreg_rv = autoreg_rv self.periodic_diff_sd_rv = periodic_diff_sd_rv self.ar_diff = DifferencedProcess( @@ -109,7 +109,7 @@ def __init__( @staticmethod def validate( - log_rt_rv: any, + log_rt_init_rv: any, autoreg_rv: any, periodic_diff_sd_rv: any, ) -> None: @@ -118,7 +118,7 @@ def validate( Parameters ---------- - log_rt_rv : any + log_rt_init_rv : any Log Rt prior for the first two observations. autoreg_rv : any Autoregressive parameter. @@ -130,7 +130,7 @@ def validate( None """ - _assert_sample_and_rtype(log_rt_rv) + _assert_sample_and_rtype(log_rt_init_rv) _assert_sample_and_rtype(autoreg_rv) _assert_sample_and_rtype(periodic_diff_sd_rv) @@ -159,9 +159,9 @@ def sample( """ # Initial sample - log_rt_rv = self.log_rt_rv.sample(**kwargs)[0].value - b = self.autoreg_rv.sample(**kwargs)[0].value - s_r = self.periodic_diff_sd_rv.sample(**kwargs)[0].value + log_rt_init = self.log_rt_init_rv.sample(**kwargs)[0].value + autoreg = self.autoreg_rv.sample(**kwargs)[0].value + noise_sd = self.periodic_diff_sd_rv.sample(**kwargs)[0].value # How many periods to sample? n_periods = (duration + self.period_size - 1) // self.period_size @@ -170,11 +170,11 @@ def sample( log_rt = self.ar_diff( n=n_periods, - init_vals=jnp.array([log_rt_rv[0]]), - autoreg=b, - noise_sd=s_r, + init_vals=jnp.array([log_rt_init[0]]), + autoreg=autoreg, + noise_sd=noise_sd, fundamental_process_init_vals=jnp.array( - [log_rt_rv[1] - log_rt_rv[0]] + [log_rt_init[1] - log_rt_init[0]] ), )[0] @@ -190,49 +190,3 @@ def sample( t_unit=self.t_unit, ), ) - - -class RtWeeklyDiffARProcess(RtPeriodicDiffARProcess): - """ - Weekly Rt with autoregressive first differences. - """ - - def __init__( - self, - name: str, - offset: int, - log_rt_rv: RandomVariable, - autoreg_rv: RandomVariable, - periodic_diff_sd_rv: RandomVariable, - ) -> None: - """ - Default constructor for RtWeeklyDiffARProcess class. - - Parameters - ---------- - name : str - Name of the site. - offset : int - Relative point at which data starts, must be between 0 and 6. - log_rt_rv : RandomVariable - Log Rt prior for the first two observations. - autoreg_rv : RandomVariable - Autoregressive parameter. - periodic_diff_sd_rv : RandomVariable - Standard deviation of the noise. - - Returns - ------- - None - """ - - super().__init__( - name=name, - offset=offset, - period_size=7, - log_rt_rv=log_rt_rv, - autoreg_rv=autoreg_rv, - periodic_diff_sd_rv=periodic_diff_sd_rv, - ) - - return None diff --git a/test/test_rtperiodicdiff.py b/test/test_rtperiodicdiff.py index 8d1ac28a..413a9f9a 100644 --- a/test/test_rtperiodicdiff.py +++ b/test/test_rtperiodicdiff.py @@ -8,7 +8,7 @@ from numpy.testing import assert_array_equal from pyrenew.deterministic import DeterministicVariable -from pyrenew.process import RtWeeklyDiffARProcess +from pyrenew.process import RtPeriodicDiffARProcess def test_rtweeklydiff() -> None: @@ -17,7 +17,7 @@ def test_rtweeklydiff() -> None: params = { "name": "test", "offset": 0, - "log_rt_rv": DeterministicVariable( + "log_rt_init_rv": DeterministicVariable( name="log_rt", value=jnp.array([0.1, 0.2]) ), "autoreg_rv": DeterministicVariable( @@ -26,10 +26,11 @@ def test_rtweeklydiff() -> None: "periodic_diff_sd_rv": DeterministicVariable( name="periodic_diff_sd_rv", value=jnp.array([0.1]) ), + "period_size": 7, } duration = 30 - rtwd = RtWeeklyDiffARProcess(**params) + rtwd = RtPeriodicDiffARProcess(**params) with numpyro.handlers.seed(rng_seed=223): rt = rtwd(duration=duration).rt.value @@ -44,7 +45,7 @@ def test_rtweeklydiff() -> None: # Checking start off a different day of the week params["offset"] = 5 - rtwd = RtWeeklyDiffARProcess(**params) + rtwd = RtPeriodicDiffARProcess(**params) with numpyro.handlers.seed(rng_seed=223): rt2 = rtwd(duration=duration).rt.value @@ -65,7 +66,7 @@ def test_rtweeklydiff_no_autoregressive() -> None: params = { "name": "test", "offset": 0, - "log_rt_rv": DeterministicVariable( + "log_rt_init_rv": DeterministicVariable( name="log_rt", value=jnp.array([0.0, 0.0]) ), # No autoregression! @@ -76,9 +77,10 @@ def test_rtweeklydiff_no_autoregressive() -> None: name="periodic_diff_sd_rv", value=jnp.array([0.1]), ), + "period_size": 7, } - rtwd = RtWeeklyDiffARProcess(**params) + rtwd = RtPeriodicDiffARProcess(**params) duration = 1000 @@ -109,7 +111,7 @@ def test_rtperiodicdiff_smallsample(inits): params = { "name": "test", "offset": 0, - "log_rt_rv": DeterministicVariable( + "log_rt_init_rv": DeterministicVariable( name="log_rt", value=inits, ), @@ -120,9 +122,10 @@ def test_rtperiodicdiff_smallsample(inits): name="periodic_diff_sd_rv", value=jnp.array([0.1]), ), + "period_size": 7, } - rtwd = RtWeeklyDiffARProcess(**params) + rtwd = RtPeriodicDiffARProcess(**params) with numpyro.handlers.seed(rng_seed=223): rt = rtwd(duration=6).rt.value