Skip to content

Commit

Permalink
implementing feedback on PR
Browse files Browse the repository at this point in the history
  • Loading branch information
arik-shurygin committed Feb 16, 2024
1 parent e8adf86 commit e070161
Show file tree
Hide file tree
Showing 11 changed files with 286 additions and 320 deletions.
87 changes: 34 additions & 53 deletions config/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import datetime
import json
import os
import subprocess
from enum import IntEnum

import git
Expand Down Expand Up @@ -31,41 +30,36 @@ def convert_types(self, config):
takes a dictionary of config parameters, consults the PARAMETERS global list and attempts to convert the type
of each parameter whos name matches.
"""
for param in PARAMETERS:
key = param["name"]
cast_type = param.get("type", False)
for parameter in PARAMETERS:
key = parameter["name"]
# if this validator needs to be cast
if cast_type:
config_val = config.get(key, False)
if "type" in parameter.keys():
cast_type = parameter["type"]
# make sure we actually have the value in our incoming config
if config_val:
config[key] = cast_type(config_val)
if key in config.keys():
config[key] = cast_type(config[key])
return config

def set_downstream_parameters(self):
"""
A parameter that checks if a specific parameter exists, then sets any parameters that depend on it.
A function that checks if a specific parameter exists, then sets any parameters that depend on it.
E.g. `NUM_AGE_GROUPS` = len(`AGE_LIMITS`) if `AGE_LIMITS` exists, set `NUM_AGE_GROUPS`
E.g, `NUM_AGE_GROUPS` = len(`AGE_LIMITS`) if `AGE_LIMITS` exists, set `NUM_AGE_GROUPS`
"""
for validator in PARAMETERS:
key = validator["name"]
downstream_function = validator.get("downstream", False)
for parameter in PARAMETERS:
key = parameter["name"]
# if the key has no downstream functions, do nothing
if downstream_function:
if "downstream" in parameter.keys():
downstream_function = parameter["downstream"]
# turn into list of len(1) if not already
if not isinstance(key, list):
key = [key]
# dont try to create downstream unless config has all necessary keys
if all([hasattr(self, k) for k in key]):
downstream_function(self, key)
# take note of the current git hash for reproducibility reasons
self.GIT_HASH = (
subprocess.check_output(["git", "rev-parse", "HEAD"])
.decode("ascii")
.strip()
)
self.LOCAL_REPO = git.Repo()
self.GIT_HASH = self.LOCAL_REPO.head.object.hexsha

def assert_valid_configuration(self):
"""
Expand All @@ -76,24 +70,16 @@ def assert_valid_configuration(self):
"""
for param in PARAMETERS:
key = param["name"]
# converting to list now makes less if branches, will convert back later
key = key if isinstance(key, list) else [key]
key = make_list_if_not(key)
validator_funcs = param.get("validate", False)
# if there are validators to test, and the key(s) are found in our config, lets test them
if validator_funcs and all([hasattr(self, k) for k in key]):
validator_funcs = (
validator_funcs
if isinstance(validator_funcs, list)
else [validator_funcs]
)
# converting to list now makes less if branches, will convert back later
validator_funcs = make_list_if_not(validator_funcs)
vals = [getattr(self, k) for k in key]
# can not validate a distribution since it does not have 1 fixed value
distribution_involved = False
for val in vals:
val_temp = (
val if isinstance(val, (list, np.ndarray)) else [val]
)
val_temp = make_list_if_not(val)
if any(
[
issubclass(
Expand Down Expand Up @@ -121,6 +107,10 @@ def assert_valid_configuration(self):
]


def make_list_if_not(obj):
return obj if isinstance(obj, (list, np.ndarray)) else [obj]


def distribution_converter(dct):
# a distribution is identified by the "distribution" and "params" keys
if "distribution" in dct.keys() and "params" in dct.keys():
Expand Down Expand Up @@ -201,25 +191,18 @@ def test_not_negative(key, value):


def age_limit_checks(key, age_limits):
assert all(
[
age_limits[idx] > age_limits[idx - 1]
for idx in range(1, len(age_limits))
]
), ("%s must be strictly increasing" % key)
test_ascending(key, age_limits)
assert (
age_limits[-1] < 85
), "age limits can not exceed 84 years of age, the last age bin is implied and does not need to be included"


def compare_geq(keys, vals):
key1, key2 = keys[0], keys[1]
val1, val2 = vals[0], vals[1]
assert val1 >= val2, "%s must be >= %s, however got %d >= %d" % (
key1,
key2,
val1,
val2,
assert vals[0] >= vals[1], "%s must be >= %s, however got %d >= %d" % (
keys[0],
keys[1],
vals[0],
vals[1],
)


Expand All @@ -236,11 +219,9 @@ def test_non_empty(key, val):


def test_len(keys, vals):
key1, key2 = keys[0], keys[1]
len_of_array, array = vals[0], vals[1]
assert len_of_array == len(array), "len(%s) must equal to %s" % (
key2,
key1,
assert vals[0] == len(vals[1]), "len(%s) must equal to %s" % (
keys[1],
keys[0],
)


Expand Down Expand Up @@ -341,11 +322,6 @@ class is accepted to modify/create the downstream parameters.
"validate": test_non_empty,
"type": np.array,
},
{
"name": "NUM_WANING_COMPARTMENTS",
"validate": test_positive,
"downstream": set_wane_enum,
},
{
"name": "WANING_TIMES",
"validate": [
Expand All @@ -355,6 +331,11 @@ class is accepted to modify/create the downstream parameters.
],
"downstream": set_num_waning_compartments,
},
{
"name": "NUM_WANING_COMPARTMENTS",
"validate": test_positive,
"downstream": set_wane_enum,
},
{
"name": "WANING_PROTECTIONS",
"validate": lambda key, vals: [
Expand Down
2 changes: 1 addition & 1 deletion config/config_inferer_covid.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
"transforms": {
"distribution": "AffineTransform",
"params": {
"loc": 1,
"loc": 2.5,
"scale": 1
}
}
Expand Down
9 changes: 5 additions & 4 deletions example_end_to_end_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
from model_odes.seip_model import seip_ode

if __name__ == "__main__":
GLOBAL_CONFIG_PATH = "config/config_global.json"
INITIALIZER_CONFIG_PATH = "config/config_initializer_covid.json"
RUNNER_CONFIG_PATH = "config/config_runner_covid.json"
INTERPRETER_CONFIG_PATH = "config/config_interpreter_covid.json"
config_path = "config/"
GLOBAL_CONFIG_PATH = config_path + "config_global.json"
INITIALIZER_CONFIG_PATH = config_path + "config_initializer_covid.json"
RUNNER_CONFIG_PATH = config_path + "config_runner_covid.json"
INTERPRETER_CONFIG_PATH = config_path + "config_interpreter_covid.json"
# model = build_basic_mechanistic_model(ConfigBase())
initializer = CovidInitializer(INITIALIZER_CONFIG_PATH, GLOBAL_CONFIG_PATH)
static_params = StaticValueParameters(
Expand Down
26 changes: 22 additions & 4 deletions mechanistic_model/abstract_initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,39 @@
mechanistic_initializers will often be tasked with reading, parsing, and combining data sources
to produce an initial state representing some analyzed population
"""

from abc import ABC, abstractmethod

import utils


class MechanisticInitializer(ABC):
@abstractmethod
def __init__(self, initializer_config):
self.INITIAL_STATE
pass

@abstractmethod
def assert_valid_configuration(self):
pass

def get_initial_state(self):
"""
Returns the initial state of the model as defined by the child class in __init__
"""
return self.INITIAL_STATE

def load_initial_population_fractions(self):
"""
a wrapper function which loads age demographics for the US and sets the inital population fraction by age bin.
Updates
----------
`self.config.INITIAL_POPULATION_FRACTIONS` : numpy.ndarray
proportion of the total population that falls into each age group,
length of this array is equal the number of age groups and will sum to 1.0.
"""
populations_path = (
self.config.DEMOGRAPHIC_DATA_PATH
+ "population_rescaled_age_distributions/"
)
# TODO support getting more regions than just 1
self.config.INITIAL_POPULATION_FRACTIONS = utils.load_age_demographics(
populations_path, self.config.REGIONS, self.config.AGE_LIMITS
)[self.config.REGIONS[0]]
Loading

0 comments on commit e070161

Please sign in to comment.