diff --git a/config/mlr-config.yaml b/config/mlr-config.yaml index 3ded58b..666e2e0 100644 --- a/config/mlr-config.yaml +++ b/config/mlr-config.yaml @@ -15,6 +15,7 @@ model: generation_time: 4.8 pivot: "24A" hierarchical: true + time_varying: true inference: method: "NUTS" diff --git a/scripts/run-mlr-model.py b/scripts/run-mlr-model.py index 6facdfd..82ec9b0 100644 --- a/scripts/run-mlr-model.py +++ b/scripts/run-mlr-model.py @@ -2,13 +2,15 @@ # coding: utf-8 import argparse +import json +import os +from datetime import date + +import evofr as ef import numpy as np import pandas as pd -import os import yaml -import json -import evofr as ef -from datetime import date + def parse_with_default(cf, var, dflt): if var in cf: @@ -86,7 +88,7 @@ def load_data(self, override_seq_path=None): return raw_seq, locations - def load_model(self, override_hier=None): + def load_model(self, override_hier=None, override_time_varying=None): model_cf = self.config["model"] # Processing generation time @@ -96,17 +98,25 @@ def load_model(self, override_hier=None): if override_hier is not None: hier = override_hier + time_varying = parse_with_default(model_cf, "time_varying", dflt=False) + if override_time_varying is not None: + time_varying = override_time_varying + print("hierarchical:", hier) + print("time varying:", time_varying) # Processing likelihoods if hier: ps = parse_pool_scale(model_cf) print("Hierarchical pool scale:", ps) - model = ef.HierMLR(tau=tau, pool_scale=ps) + if time_varying: + model = ef.HierMLRTimeVarying(tau=tau) # TODO: Add options + else: + model = ef.HierMLR(tau=tau, pool_scale=ps) else: model = ef.MultinomialLogisticRegression(tau=tau) model.forecast_L = forecast_L - return model, hier + return model, hier, time_varying def load_optim(self): infer_cf = self.config["inference"] @@ -276,9 +286,9 @@ def make_raw_freq_tidy(data, location): return {"metadata": metadata, "data": entries} -def export_results(multi_posterior, ps, path, data_name, hier): +def export_results(multi_posterior, ps, path, data_name, hier, time_varying): EXPORT_SITES = ["freq", "ga", "freq_forecast"] - EXPORT_DATED = [True, False, True] + EXPORT_DATED = [True, time_varying, True] EXPORT_FORECASTS = [False, False, True] EXPORT_ATTRS = ["pivot"] @@ -383,6 +393,13 @@ def get_group_samples(samples, sites, group): help="Whether to run the model as hierarchical. Overrides model.hierarchical in config. " + "Default is false if unspecified." ) + + parser.add_argument( + "--time-varying", action='store_true', default=False, + help="Whether to run the model as time-varaying. Overrides model.time_varying in config." + + "Default is false if unspecified." + ) + args = parser.parse_args() # Load configuration, data, and create model @@ -396,7 +413,11 @@ def get_group_samples(samples, sites, group): if args.hier: override_hier = args.hier - mlr_model, hier = config.load_model(override_hier=override_hier) + override_time_varying = None + if args.time_varying: + override_time_varying = args.time_varying + + mlr_model, hier, time_varying = config.load_model(override_hier=override_hier, override_time_varying=override_time_varying) print("Model created.") inference_method = config.load_optim() @@ -452,4 +473,4 @@ def get_group_samples(samples, sites, group): config.config["settings"], "ps", dflt=[0.5, 0.8, 0.95] ) data_name = args.data_name or config.config["data"]["name"] - export_results(multi_posterior, ps, export_path, data_name, hier) + export_results(multi_posterior, ps, export_path, data_name, hier, time_varying)