Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inference refactor with some config work too #42

Merged
merged 5 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 108 additions & 38 deletions config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

import git
import jax.numpy as jnp
import numpy as np
import numpyro.distributions as distributions
import numpyro.distributions.transforms as transforms


class Config:
Expand All @@ -14,7 +17,9 @@ def __init__(self, config_json_path) -> None:

def add_file(self, config_json_path):
# adds another config to self.__dict__ and resets downstream parameters again
config = json.load(open(config_json_path, "r"))
config = json.load(
open(config_json_path, "r"), object_hook=distribution_converter
)
config = self.convert_types(config)
self.__dict__.update(**config)
self.assert_valid_configuration()
Expand All @@ -26,9 +31,9 @@ def convert_types(self, config):
takes a dictionary of config parameters, consults the PARAMETERS global list and attempts to convert the type
of each parameter whos name matches.
"""
for validator in PARAMETERS:
key = validator["name"]
cast_type = validator.get("type", False)
for param in PARAMETERS:
key = param["name"]
cast_type = param.get("type", False)
# if this validator needs to be cast
if cast_type:
config_val = config.get(key, False)
Expand All @@ -46,16 +51,15 @@ def set_downstream_parameters(self):
for validator in PARAMETERS:
key = validator["name"]
downstream_function = validator.get("downstream", False)
# if the key has no downstream functions, dont bother
# if the key has no downstream functions, do nothing
if downstream_function:
# validator requires multiple params checked against eachother
if isinstance(key, list):
if all([hasattr(self, k) for k in key]):
downstream_function(self, key)
# just one param being tested
else:
if hasattr(self, key):
downstream_function(self, key)
# turn into list of len(1) if not already
if not isinstance(key, list):
key = [key]
# dont try to create downstream unless config has all necessary keys
if all([hasattr(self, k) for k in key]):
downstream_function(self, key)
# take note of the current git hash for reproducibility reasons
self.GIT_HASH = (
subprocess.check_output(["git", "rev-parse", "HEAD"])
.decode("ascii")
Expand All @@ -65,29 +69,78 @@ def set_downstream_parameters(self):

def assert_valid_configuration(self):
"""
checks the soundness of parameters passed into Config. Does not check for the existence of certain key parameters
checks the soundness of parameters passed into Config by referencing the name of parameters passed to the config
with the PARAMETERS global variable. If a distribution is passed instead of a value, blindly accepts the distribution.

Raises assert errors if parameter(s) are incongruent in some way.
"""
for param in PARAMETERS:
key = param["name"]
# converting to list now makes less if branches, will convert back later
key = key if isinstance(key, list) else [key]
validator_funcs = param.get("validate", False)
# check to make sure the key/keys are all in the config
key_in_config = (
all([hasattr(self, k) for k in key])
if isinstance(key, list)
else hasattr(self, key)
)
# if there are validators to test, and the key is found in our config, lets test them
if validator_funcs and key_in_config:
# ensure type is list
if not isinstance(validator_funcs, list):
validator_funcs = [validator_funcs]
# mutiple keys need to be tested against eachother
if isinstance(key, list):
vals = [getattr(self, k) for k in key]
else: # single key being tested
vals = getattr(self, key)
# if there are validators to test, and the key(s) are found in our config, lets test them
if validator_funcs and all([hasattr(self, k) for k in key]):
validator_funcs = (
validator_funcs
if isinstance(validator_funcs, list)
else [validator_funcs]
)
# converting to list now makes less if branches, will convert back later
vals = [getattr(self, k) for k in key]
# can not validate a distribution since it does not have 1 fixed value
distribution_involved = False
for val in vals:
val_temp = (
val if isinstance(val, (list, np.ndarray)) else [val]
)
if any(
[
issubclass(
type(v),
(
distributions.Distribution,
transforms.Transform,
),
)
]
for v in val_temp
):
distribution_involved = True
break
if distribution_involved:
continue
# val_func() throws assert errors if incongruence arrises
[val_func(key, vals) for val_func in validator_funcs]
[
(
val_func(key[0], vals[0])
if len(key) == 1 # convert back to floats if needed
else val_func(key, vals)
)
for val_func in validator_funcs
]


def distribution_converter(dct):
# a distribution is identified by the "distribution" and "params" keys
if "distribution" in dct.keys() and "params" in dct.keys():
try:
if dct["distribution"] in distribution_types.keys():
return distribution_types[dct["distribution"]](**dct["params"])
else:
raise KeyError(
"The distribution name is not found in the available distributions, "
"see distribution names here: https://num.pyro.ai/en/stable/distributions.html#distributions"
)
except Exception as e:
# reraise the error
raise Exception(
"There was an error parsing the following name as a distribution: %s \n "
"see docs to make sure you didnt misspell something: https://num.pyro.ai/en/stable/distributions.html#distributions"
% str(dct)
) from e
else: # do nothing if this isnt a distribution
return dct


#############################################################################
Expand All @@ -107,6 +160,10 @@ def set_downstream_age_variables(conf, _):
conf.AGE_GROUP_IDX = IntEnum("age", conf.AGE_GROUP_STRS, start=0)


def set_num_waning_compartments(conf, _):
conf.NUM_WANING_COMPARTMENTS = len(conf.WANING_TIMES)


def set_num_introduced_strains(conf, _):
"""
given INTRODUCTION_TIMES, set downstream variables from there
Expand Down Expand Up @@ -206,6 +263,18 @@ def test_zero(key, val):
assert val == 0, "value in %s must be zero" % key


# look at numpyro.distributions, copy over all the distribution names into the dictionary, along with their class constructors.
distribution_types = {
dist_name: distributions.__dict__.get(dist_name)
for dist_name in distributions.__all__
}
distribution_types.update(
**{
transform_name: transforms.__dict__.get(transform_name)
for transform_name in transforms.__all__
}
)

#############################################################################
###############################PARAMETERS####################################
#############################################################################
Expand Down Expand Up @@ -261,7 +330,7 @@ class is accepted to modify/create the downstream parameters.
},
{
"name": "INFECTIOUS_PERIOD",
"validate": test_not_negative,
# "validate": test_not_negative,
},
{
"name": "EXPOSED_TO_INFECTIOUS",
Expand All @@ -270,7 +339,7 @@ class is accepted to modify/create the downstream parameters.
{
"name": "STRAIN_SPECIFIC_R0",
"validate": test_non_empty,
"type": jnp.array,
"type": np.array,
},
{
"name": "NUM_WANING_COMPARTMENTS",
Expand All @@ -284,13 +353,14 @@ class is accepted to modify/create the downstream parameters.
lambda key, vals: test_zero(key, vals[-1]),
lambda key, vals: [test_type(key, val, int) for val in vals],
],
"downstream": set_num_waning_compartments,
},
{
"name": "WANING_PROTECTIONS",
"validate": lambda key, vals: [
test_not_negative(key, val) for val in vals
],
"type": jnp.array,
"type": np.array,
},
{
"name": ["NUM_WANING_COMPARTMENTS", "WANING_TIMES"],
Expand All @@ -303,7 +373,7 @@ class is accepted to modify/create the downstream parameters.
{
"name": "STRAIN_INTERACTIONS",
"validate": test_non_empty,
"type": jnp.array,
"type": np.array,
},
{
"name": ["NUM_STRAINS", "STRAIN_INTERACTIONS"],
Expand All @@ -326,14 +396,14 @@ class is accepted to modify/create the downstream parameters.
{
"name": "VAX_EFF_MATRIX",
"validate": test_non_empty,
"type": jnp.array,
"type": np.array,
},
{
"name": "BETA_TIMES",
"validate": lambda key, lst: [
test_not_negative(key, beta_time) for beta_time in lst
],
"type": jnp.array,
"type": np.array,
},
{
"name": "BETA_COEFICIENTS",
Expand All @@ -347,7 +417,7 @@ class is accepted to modify/create the downstream parameters.
"validate": lambda key, lst: [
test_not_negative(key, r0) for r0 in lst
],
"type": jnp.array,
"type": np.array,
},
{
"name": ["NUM_STRAINS", "MAX_VAX_COUNT", "VAX_EFF_MATRIX"],
Expand Down
1 change: 0 additions & 1 deletion config/config_global.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
"United States"
],
"INIT_DATE": "2022-02-11",
"MINIMUM_AGE": 0,
"AGE_LIMITS": [
0,
18,
Expand Down
116 changes: 116 additions & 0 deletions config/config_inferer_covid.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
{
"SCENARIO_NAME": "test covid run for testing suite",
"CONTACT_MATRIX_PATH": "data/demographic-data/contact_matrices",
"SAVE_PATH": "output/",
"HOSP_PATH": "data/hospital_220213_220108.csv",
"VAX_MODEL_DATA": "data/spline_fits.csv",
"VAX_MODEL_NUM_KNOTS": 18,
"STRAIN_R0s": [
1.2,
1.8,
{
"distribution": "TransformedDistribution",
"params": {
"base_distribution": {
"distribution": "Beta",
"params": {
"concentration1": 8,
"concentration0": 2
}
},
"transforms": {
"distribution": "AffineTransform",
"params": {
"loc": 1,
"scale": 1
}
}
}
}
],
"INFECTIOUS_PERIOD": {
"distribution": "TruncatedNormal",
"params": {
"loc": 10,
"scale": 2,
"low": 1.0
}
},
"EXPOSED_TO_INFECTIOUS": 3.6,
"INITIAL_INFECTIONS_SCALE": 1.0,
"INTRODUCTION_TIMES": [
{
"distribution": "TruncatedNormal",
"params": {
"loc": 60,
"scale": 20,
"low": 10
}
}
],
"INTRODUCTION_PERCENTAGE": 0.01,
"INTRODUCTION_SCALE": 10,
"INTRODUCTION_AGE_MASK": [
false,
true,
false,
false
],
"WANING_PROTECTIONS": [
1.0,
0.942,
0.942,
0.942,
0.0
],
"STRAIN_INTERACTIONS": [
[
1.0,
0.7,
0.49
],
[
0.7,
1.0,
0.7
],
[
0.49,
0.7,
1.0
]
],
"VAX_EFF_MATRIX": [
[
0,
0.34,
0.68
],
[
0,
0.24,
0.48
],
[
0,
0.14,
0.28
]
],
"BETA_TIMES": [
0.0,
120.0,
150
],
"BETA_COEFICIENTS": [
1.0,
1.0,
1.0
],
"INFERENCE_PRNGKEY": 8675309,
"INFERENCE_NUM_WARMUP": 100,
"INFERENCE_NUM_SAMPLES": 500,
"INFERENCE_NUM_CHAINS": 4,
"INFERENCE_PROGRESS_BAR": true,
"MODEL_RAND_SEED": 8675309
}
Loading
Loading