Skip to content

Commit

Permalink
Add the option to manually define the simulation start when calibrati…
Browse files Browse the repository at this point in the history
…ng a model (#92)
  • Loading branch information
twallema authored Sep 12, 2024
1 parent 674adac commit fba485e
Show file tree
Hide file tree
Showing 9 changed files with 349 additions and 178 deletions.
19 changes: 10 additions & 9 deletions docs/optimization.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,23 @@

### Log posterior probability

***class* log_posterior_probability(model, parameter_names, bounds, data, states, log_likelihood_fnc, log_likelihood_fnc_args, weights=None, log_prior_prob_fnc=None, log_prior_prob_fnc_args=None, initial_states=None, aggregation_function=None, labels=None)**
***class* log_posterior_probability(model, parameter_names, bounds, data, states, log_likelihood_fnc, log_likelihood_fnc_args, start_sim=None, weights=None, log_prior_prob_fnc=None, log_prior_prob_fnc_args=None, initial_states=None, aggregation_function=None, labels=None)**

**Parameters:**

* **model** (object) - An initialized ODE or JumpProcess.
* **model** (object) - An initialized `pySODM.models.base.ODE` or `pySODM.models.base.JumpProcess` model.
* **parameter_names** (list) - Names of model parameters (type: str) to calibrate. Model parameters must be of type float (0D), list containing float (1D), or np.ndarray (nD).
* **bounds** (list) - Lower and upper bound of calibrated parameters provided as lists/tuples containing lower and upper bound: example: `[(lb_1, ub_1), ..., (lb_n, ub_n)]`. Values falling outside these bounds will be restricted to the provided ranges before simulating the model.
* **data** (list) - Contains the datasets (type: pd.Series/pd.DataFrame) the model should be calibrated to. For one dataset use `[dataset,]`. Must contain a time index named `time` or `date`. Additional axes must be implemented using a `pd.Multiindex` and must bear the names/contain the coordinates of a valid model dimension.
* **states** (list) - Names of the model states (type: str) the respective datasets should be matched to.
* **log_likelihood_fnc** (list) - Contains a log likelihood function for every provided dataset.
* **log_likelihood_fnc_args** (list) - Contains the arguments of the log likelihood functions. If the log likelihood function has no arguments (`ll_poisson`), provide an empty list.
* **bounds** (list) - Lower and upper bound of calibrated parameters. Provided as a list or tuple containing lower and upper bound: example: `bounds = [(lb_1, ub_1), ..., (lb_n, ub_n)]`.
* **data** (list) - Contains the datasets (type: pd.Series/pd.DataFrame) the model should be calibrated to. If there is only one dataset use `data = [df,]`. Dataframe must contain an index named `time` or `date`. Stratified data can be incorporated using a `pd.Multiindex`, whose index levels must have names corresponding to valid model dimensions, and whose indices must be valid dimension coordinates.
* **states** (list) - Names of the model states (type: str) the respective datasets should be matched to. Must have the same length as `data`.
* **log_likelihood_fnc** (list) - Contains a log likelihood function for every provided dataset. Must have the same length as `data`.
* **log_likelihood_fnc_args** (list) - Contains the arguments of the log likelihood functions. If the log likelihood function has no arguments (such as `ll_poisson`), provide an empty list. Must have the same length as `data`.
* **start_sim** (int/float or str/datetime) - optional - Can be used to alter the start of the simulation. By default, the start of the simulation is chosen as the earliest time/date found in the datasets.
* **weights** (list) - optional - Contains the weights of every dataset in the final log posterior probability. Defaults to one for every dataset.
* **log_prior_prob_fnc** (list) - optional - Contains a prior probability function for every calibrated parameter. Defaults to a uniform prior using the provided bounds.
* **log_prior_prob_fnc_args** (list) - optional - Contains the arguments of the provided prior probability functions.
* **initial_states** (list) - optional - Contains a dictionary of initial states for every dataset.
* **aggregation_function** (callable function or list) - optional - A user-defined function to manipulate the model output before matching it to data. The function takes as input an `xarray.DataArray`, resulting from selecting the simulation output at the state we wish to match to the dataset (`model_output_xarray_Dataset['state_name']`), as its input. The output of the function must also be an `xarray.DataArray`. No checks are performed on the input or output of the aggregation function, use at your own risk. Illustrative use case: I have a spatially explicit epidemiological model and I desire to simulate it a high spatial resolutioni. However, data is only available on a lower level of spatial resolution. Hence, I use an aggregation function to properly aggregate the spatial levels. I change the coordinates on the spatial dimensions in the model output. Valid inputs for the argument `aggregation_function`are: 1) one callable function --> applied to every dataset. 2) A list containing one callable function --> applied to every dataset. 3) A list containing a callable function for every dataset --> every dataset has its own aggregation function.
* **initial_states** (list) - optional - Contains a dictionary of initial states for every dataset.
* **aggregation_function** (callable function or list) - optional - A user-defined function to manipulate the model output before matching it to data. The function takes as input an `xarray.DataArray`, resulting from selecting the simulation output at the state we wish to match to the dataset (`model_output_xarray_Dataset['state_name']`), as its input. The output of the function must also be an `xarray.DataArray`. No checks are performed on the input or output of the aggregation function, use at your own risk. Illustrative use case: I have a spatially explicit epidemiological model and I desire to simulate it a fine spatial resolution. However, data is only available on a coarser level. Hence, I use an aggregation function to properly aggregate the spatial levels. I change the coordinates on the spatial dimensions in the model output. Valid inputs for the argument `aggregation_function`are: 1) one callable function --> applied to every dataset. 2) A list containing one callable function --> applied to every dataset. 3) A list containing a callable function for every dataset --> every dataset has its own aggregation function.
* **labels** (list) - optional - Contains a custom label for the calibrated parameters. Defaults to the names provided in `parameter_names`.

**Methods:**
Expand Down
8 changes: 6 additions & 2 deletions docs/quickstart.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,19 @@ To initialize the model, provide a dictionary containing the initial values of t
model = SIR(states={'S': 1000, 'I': 1}, parameters={'beta': 0.35, 'gamma': 5})
```

Simulate the model using its `sim()` method. pySODM supports the use of `datetime` types as timesteps.
Simulate the model using its `sim()` method. pySODM supports the use of dates to index simulations, string representations of dates with the format `'yyyy-mm-dd'` as well as `datetime.datetime()` can be used.

```python
# Timesteps
out = model.sim(121)

# Dates
# String representation of dates: 'yyyy-mm-dd' only
out = model.sim(['2022-12-01', '2023-05-01'])

# Datetime representation of time + date
from datetime import datetime as datetime
out = model.sim([datetime(2022, 12, 1), datetime(2023, 5, 1)])

# Tailor method and tolerance of integrator
out = model.sim(121, method='RK45', rtol='1e-4')
```
Expand Down
47 changes: 9 additions & 38 deletions src/pySODM/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from pySODM.models.validation import merge_parameter_names_parameter_stratified_names, validate_draw_function, validate_simulation_time, validate_dimensions, \
validate_time_dependent_parameters, validate_integrate, check_duplicates, build_state_sizes_dimensions, validate_dimensions_per_state, \
validate_initial_states, validate_integrate_or_compute_rates_signature, validate_provided_parameters, validate_parameter_stratified_sizes, \
validate_apply_transitionings_signature, validate_compute_rates, validate_apply_transitionings
validate_apply_transitionings_signature, validate_compute_rates, validate_apply_transitionings, validate_solution_methods_ODE, validate_solution_methods_JumpProcess

class JumpProcess:
"""
Expand Down Expand Up @@ -375,7 +375,7 @@ def _mp_sim_single(self, drawn_parameters, seed, time, actual_start_date, method
np.random.seed(seed)
return self._sim_single(time, actual_start_date, method, tau, output_timestep)

def sim(self, time, warmup=0, N=1, draw_function=None, draw_function_kwargs={}, processes=None, method='tau_leap', tau=0.5, output_timestep=1):
def sim(self, time, warmup=0, N=1, draw_function=None, draw_function_kwargs={}, processes=None, method='tau_leap', tau=1, output_timestep=1):

"""
Simulate a model during a given time period.
Expand Down Expand Up @@ -422,27 +422,15 @@ def sim(self, time, warmup=0, N=1, draw_function=None, draw_function_kwargs={},
output: xarray.Dataset
Simulation output
"""

# Input checks on solution method and timestep
if not isinstance(method, str):
raise TypeError(
"solver method 'method' must be of type string"
)
if not isinstance(tau, (int,float)):
raise TypeError(
"discrete timestep 'tau' must be of type int or float"
)

# Input checks on solution settings
validate_solution_methods_JumpProcess(method, tau)
# Input checks on supplied simulation time
time, actual_start_date = validate_simulation_time(time, warmup)
# Input checks related to draw functions
if draw_function:
# validate function
validate_draw_function(draw_function, draw_function_kwargs, self.parameters)
# function provided but N=1 --> user likely forgot 'N'
if N == 1:
raise ValueError(
"you specified a draw function but N=1, have you forgotten 'N'?"
)

# Copy parameter dictionary --> dict is global
cp = copy.deepcopy(self.parameters)
Expand Down Expand Up @@ -682,7 +670,7 @@ def _mp_sim_single(self, drawn_parameters, time, actual_start_date, method, rtol
out = self._sim_single(time, actual_start_date, method, rtol, output_timestep, tau)
return out

def sim(self, time, warmup=0, N=1, draw_function=None, draw_function_kwargs={}, processes=None, method='RK23', rtol=1e-3, tau=None, output_timestep=1):
def sim(self, time, warmup=0, N=1, draw_function=None, draw_function_kwargs={}, processes=None, method='RK23', rtol=1e-4, tau=None, output_timestep=1):
"""
Simulate a model during a given time period.
Can optionally perform `N` repeated simulations with sampling of model parameters using a function `draw_function`.
Expand Down Expand Up @@ -732,32 +720,15 @@ def sim(self, time, warmup=0, N=1, draw_function=None, draw_function_kwargs={},
output: xarray.Dataset
Simulation output
"""

# Input checks on solver settings
if not isinstance(rtol, float):
raise TypeError(
"relative solver tolerance 'rtol' must be of type float"
)
if not isinstance(method, str):
raise TypeError(
"solver method 'method' must be of type string"
)
if tau != None:
if not isinstance(tau, (int,float)):
raise TypeError(
"discrete timestep 'tau' must be of type int or float"
)

# Input checks on solution settings
validate_solution_methods_ODE(rtol, method, tau)
# Input checks on supplied simulation time
time, actual_start_date = validate_simulation_time(time, warmup)
# Input checks related to draw functions
if draw_function:
# validate function
validate_draw_function(draw_function, draw_function_kwargs, self.parameters)
# function provided but N=1 --> user likely forgot 'N'
if N == 1:
raise ValueError(
"you specified a draw function but N=1, have you forgotten 'N'?"
)
# provinding 'N' but no draw function: wasteful of resources
if ((N != 1) & (draw_function==None)):
raise ValueError('attempting to perform N={0} repeated simulations without using a draw function'.format(N))
Expand Down
78 changes: 68 additions & 10 deletions src/pySODM/models/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def validate_simulation_time(time, warmup):
time = [0-warmup, time]
elif isinstance(time, list):
if not len(time) == 2:
raise ValueError(f"'Time' must be of format: time=[start, stop]. You have supplied: time={time}.")
raise ValueError(f"wrong length of list-like simulation start and stop (length: {len(time)}). correct format: time=[start, stop] (length: 2).")
else:
# If they are all int or flat (or commonly occuring np.int64/np.float64)
if all([isinstance(item, (int,float,np.int32,np.float32,np.int64,np.float64)) for item in time]):
Expand All @@ -37,29 +37,87 @@ def validate_simulation_time(time, warmup):
actual_start_date = time[0] - timedelta(days=warmup)
time = [0, date_to_diff(actual_start_date, time[1])]
else:
types = [type(t) for t in time]
raise ValueError(
f"List-like input of simulation start and stop must contain either all int/float or all str/datetime, not a combination of the two "
"simulation start and stop must have the format: time=[start, stop]."
" 'start' and 'stop' must have the same datatype: int/float, str ('yyyy-mm-dd'), or datetime."
f" mixing of types is not allowed. you supplied: {types} "
)
elif isinstance(time, (str,datetime)):
raise TypeError(
"You have only provided one date as input 'time', how am I supposed to know when to start/end this simulation?"
)
else:
raise TypeError(
"Input argument 'time' must be a single number (int or float), a list of format: time=[start, stop], a string representing of a timestamp, or a timestamp"
"'time' must be 1) a single int/float representing the end of the simulation, 2) a list of format: time=[start, stop]."
)

if time[1] < time[0]:
raise ValueError(
"Start of simulation is chronologically after end of simulation"
"start of simulation is chronologically after end of simulation"
)
elif time[0] == time[1]:
# TODO: Might be usefull to just return the initial condition in this case?
raise ValueError(
"Start of simulation is the same as the end of simulation"
"start of simulation is the same as the end of simulation"
)
return time, actual_start_date

def validate_solution_methods_ODE(rtol, method, tau):
"""
Validates the input arguments of the ODE.sim() function
input
-----
rtol: float
Relative solver tolerance
method: str
Solver method: 'RK23', 'RK45', 'DOP853', 'Radau', 'BDF', 'LSODA'
tau: int/float
Discrete integration size of timestep.
"""

if not isinstance(rtol, float):
raise TypeError(
"relative solver tolerance 'rtol' must be of type float"
)
if not isinstance(method, str):
raise TypeError(
"solver method 'method' must be of type string"
)
if method not in ['RK23', 'RK45', 'DOP853', 'Radau', 'BDF', 'LSODA']:
raise ValueError(
f"invalid solution method '{method}'. valid methods: 'RK23', 'RK45', 'DOP853', 'Radau', 'BDF', 'LSODA'"
)
if tau != None:
if not isinstance(tau, (int,float)):
raise TypeError(
"discrete timestep 'tau' must be of type int or float"
)

def validate_solution_methods_JumpProcess(method, tau):
"""
Validates the input arguments of the JumpProcess.sim() function
method: str
Solver method: 'SSA' or 'tau_leap'
tau: int/float
If method == 'tau_leap' --> leap size
"""

# Input checks on solution method and timestep
if not isinstance(method, str):
raise TypeError(
"solver method 'method' must be of type string"
)
if method not in ['tau_leap', 'SSA']:
raise ValueError(
f"invalid solution method '{method}'. valid methods: 'SSA' and 'tau_leap'"
)
if not isinstance(tau, (int,float)):
raise TypeError(
"discrete timestep 'tau' must be of type int or float"
)

def validate_draw_function(draw_function, draw_function_kwargs, parameters):
"""Validates the draw function's input and output. For use in the sim() functions of the ODE and JumpProcess classes (base.py).
Expand Down
Loading

0 comments on commit fba485e

Please sign in to comment.