diff --git a/iup/__init__.py b/iup/__init__.py index e586372..1bb758e 100644 --- a/iup/__init__.py +++ b/iup/__init__.py @@ -162,9 +162,9 @@ def validate(self): # same validations as UptakeData super().validate() # and also require that uptake be a proportion - assert ( - self["estimate"].is_between(0.0, 1.0).all() - ), "cumulative uptake `estimate` must be a proportion" + assert self["estimate"].is_between(0.0, 1.0).all(), ( + "cumulative uptake `estimate` must be a proportion" + ) def to_incident(self, group_cols: List[str,] | None) -> IncidentUptakeData: """ @@ -252,9 +252,9 @@ def validate(self): ) # all quantiles should be between 0 and 1 - assert ( - self["quantile"].is_between(0.0, 1.0).all() - ), "quantiles must be between 0 and 1" + assert self["quantile"].is_between(0.0, 1.0).all(), ( + "quantiles must be between 0 and 1" + ) class PointForecast(QuantileForecast): diff --git a/iup/eval.py b/iup/eval.py index 2170d6d..ee4c27b 100644 --- a/iup/eval.py +++ b/iup/eval.py @@ -30,9 +30,9 @@ def check_date_match(data: IncidentUptakeData, pred: PointForecast): (data["time_end"] == pred["time_end"]).all() # 2. There should not be any duplicated date in either data or prediction. - assert not ( - any(data["time_end"].is_duplicated()) - ), "Duplicated dates are found in data and prediction." + assert not (any(data["time_end"].is_duplicated())), ( + "Duplicated dates are found in data and prediction." + ) def score( diff --git a/scripts/forecast.py b/scripts/forecast.py index f722035..488407e 100644 --- a/scripts/forecast.py +++ b/scripts/forecast.py @@ -3,17 +3,15 @@ import polars as pl import yaml -import iup +import iup.models -def run_all_forecasts() -> pl.DataFrame: +def run_all_forecasts(clean_data, config) -> pl.DataFrame: """Run all forecasts Returns: pl.DataFrame: data frame of forecasts, organized by model and forecast date """ - raise NotImplementedError - models = None forecast_dates = pl.date_range( config["timeframe"]["start"], @@ -24,54 +22,74 @@ def run_all_forecasts() -> pl.DataFrame: models = [getattr(iup.models, model_name) for model_name in config["models"]] assert all(issubclass(model, iup.models.UptakeModel) for model in models) + all_forecast = pl.DataFrame() + for model in models: for forecast_date in forecast_dates: # Get data available as of the forecast date - pass - - -def run_forecast() -> pl.DataFrame: + forecast = run_forecast( + model, + clean_data, + grouping_factors=config["groups"], + forecast_start=forecast_date, + forecast_end=config["timeframe"]["end"], + ) + + forecast = forecast.with_columns( + forecast_start=forecast_date, + forecast_end=config["timeframe"]["end"], + model=pl.lit(model.__name__), + ) + + all_forecast = pl.concat([all_forecast, forecast]) + + return all_forecast + + +def run_forecast( + model, + observed_data, + grouping_factors, + forecast_start, + forecast_end, +) -> pl.DataFrame: """Run a single model for a single forecast date""" - raise NotImplementedError - # incident_train_data = iup.IncidentUptakeData( - # iup.IncidentUptakeData.split_train_test( - # incident_data, config["timeframe"]["start"], "train" - # ) - # ) - - # Fit models using the training data and make projections - # fit_model = model().fit(incident_train_data, grouping_factors) + # preprocess.py returns cumulative data, need to convert to incidence for LinearIncidentUptakeModel # + incident_data = iup.CumulativeUptakeData(observed_data).to_incident( + grouping_factors + ) - # cumulative_projections = fit_model.predict( - # config["timeframe"]["start"], - # config["timeframe"]["end"], - # config["timeframe"]["interval"], - # grouping_factors, - # ) - # save these projections somewhere + incident_train_data = iup.IncidentUptakeData( + iup.IncidentUptakeData.split_train_test(incident_data, forecast_start, "train") + ) - # incident_projections = cumulative_projections.to_incident(grouping_factors) - # save these projections somewhere + # Fit models using the training data and make projections + fit_model = model().fit(incident_train_data, grouping_factors) - # Evaluation / Post-processing -------------------------------------------- + cumulative_projections = fit_model.predict( + forecast_start, + forecast_end, + config["timeframe"]["interval"], + grouping_factors, + ) - # incident_test_data = iup.IncidentUptakeData( - # iup.IncidentUptakeData.split_train_test( - # incident_data, config["timeframe"]["start"], "test" - # ) - # ).filter(pl.col("date") <= config["timeframe"]["end"]) + return cumulative_projections if __name__ == "__main__": p = argparse.ArgumentParser() p.add_argument("--config", help="config file", default="scripts/config.yaml") p.add_argument("--input", help="input data") + p.add_argument("--output", help="output parquet file") args = p.parse_args() with open(args.config, "r") as f: config = yaml.safe_load(f) - input_data = pl.scan_parquet(args.input) + input_data = pl.scan_parquet(args.input).collect() + + input_data = iup.CumulativeUptakeData(input_data) - run_all_forecasts(config=config, cache=args.cache) + all_forecast = run_all_forecasts(config=config, clean_data=input_data) + all_forecast.write_parquet(args.output)