Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding time-varying hierarchical model #109

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions config/mlr-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ model:
generation_time: 4.8
pivot: "24A"
hierarchical: true
time_varying: true

inference:
method: "NUTS"
Expand Down
43 changes: 32 additions & 11 deletions scripts/run-mlr-model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
# coding: utf-8

import argparse
import json
import os
from datetime import date

import evofr as ef
import numpy as np
import pandas as pd
import os
import yaml
import json
import evofr as ef
from datetime import date


def parse_with_default(cf, var, dflt):
if var in cf:
Expand Down Expand Up @@ -86,7 +88,7 @@ def load_data(self, override_seq_path=None):

return raw_seq, locations

def load_model(self, override_hier=None):
def load_model(self, override_hier=None, override_time_varying=None):
model_cf = self.config["model"]

# Processing generation time
Expand All @@ -96,17 +98,25 @@ def load_model(self, override_hier=None):
if override_hier is not None:
hier = override_hier

time_varying = parse_with_default(model_cf, "time_varying", dflt=False)
if override_time_varying is not None:
time_varying = override_time_varying

print("hierarchical:", hier)
print("time varying:", time_varying)

# Processing likelihoods
if hier:
ps = parse_pool_scale(model_cf)
print("Hierarchical pool scale:", ps)
model = ef.HierMLR(tau=tau, pool_scale=ps)
if time_varying:
model = ef.HierMLRTimeVarying(tau=tau) # TODO: Add options
else:
model = ef.HierMLR(tau=tau, pool_scale=ps)
else:
model = ef.MultinomialLogisticRegression(tau=tau)
model.forecast_L = forecast_L
return model, hier
return model, hier, time_varying

def load_optim(self):
infer_cf = self.config["inference"]
Expand Down Expand Up @@ -276,9 +286,9 @@ def make_raw_freq_tidy(data, location):
return {"metadata": metadata, "data": entries}


def export_results(multi_posterior, ps, path, data_name, hier):
def export_results(multi_posterior, ps, path, data_name, hier, time_varying):
EXPORT_SITES = ["freq", "ga", "freq_forecast"]
EXPORT_DATED = [True, False, True]
EXPORT_DATED = [True, time_varying, True]
EXPORT_FORECASTS = [False, False, True]
EXPORT_ATTRS = ["pivot"]

Expand Down Expand Up @@ -383,6 +393,13 @@ def get_group_samples(samples, sites, group):
help="Whether to run the model as hierarchical. Overrides model.hierarchical in config. "
+ "Default is false if unspecified."
)

parser.add_argument(
"--time-varying", action='store_true', default=False,
help="Whether to run the model as time-varaying. Overrides model.time_varying in config."
+ "Default is false if unspecified."
)

args = parser.parse_args()

# Load configuration, data, and create model
Expand All @@ -396,7 +413,11 @@ def get_group_samples(samples, sites, group):
if args.hier:
override_hier = args.hier

mlr_model, hier = config.load_model(override_hier=override_hier)
override_time_varying = None
if args.time_varying:
override_time_varying = args.time_varying

mlr_model, hier, time_varying = config.load_model(override_hier=override_hier, override_time_varying=override_time_varying)
print("Model created.")

inference_method = config.load_optim()
Expand Down Expand Up @@ -452,4 +473,4 @@ def get_group_samples(samples, sites, group):
config.config["settings"], "ps", dflt=[0.5, 0.8, 0.95]
)
data_name = args.data_name or config.config["data"]["name"]
export_results(multi_posterior, ps, export_path, data_name, hier)
export_results(multi_posterior, ps, export_path, data_name, hier, time_varying)
Loading