-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pipeline setup step 2: Fill forecast.py #93
Changes from 1 commit
9b7be26
862faf8
857c080
da7fde2
6e07aef
c2cbdd0
a403811
aee047c
f3f56d6
f6a3ff4
99b8a66
de1b426
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,7 +3,10 @@ | |
import polars as pl | ||
import yaml | ||
|
||
def run_all_forecasts() -> pl.DataFrame: | ||
import iup.models | ||
|
||
|
||
def run_all_forecasts(clean_data, config) -> pl.DataFrame: | ||
"""Run all forecasts | ||
|
||
Returns: | ||
|
@@ -18,13 +21,39 @@ def run_all_forecasts() -> pl.DataFrame: | |
models = [getattr(iup.models, model_name) for model_name in config["models"]] | ||
assert all(issubclass(model, iup.models.UptakeModel) for model in models) | ||
|
||
for model in models: | ||
for forecast_date in forecast_dates: | ||
# Get data available as of the forecast date | ||
|
||
all_forecast = pl.DataFrame() | ||
|
||
def run_forecast() -> pl.DataFrame: | ||
for model in models: | ||
for forecast_date in forecast_dates: | ||
# Get data available as of the forecast date | ||
forecast = run_forecast( | ||
model, | ||
clean_data, | ||
grouping_factors=config["groups"], | ||
forecast_start=config["timeframe"]["start"], | ||
forecast_end=forecast_date, | ||
) | ||
|
||
forecast = forecast.with_columns( | ||
forecast_start=config["timeframe"]["start"], | ||
forecast_end=forecast_date, | ||
model=pl.lit(model.__name__), | ||
) | ||
|
||
all_forecast = pl.concat([all_forecast, forecast]) | ||
|
||
return all_forecast | ||
|
||
|
||
def run_forecast( | ||
model, | ||
incident_data, | ||
grouping_factors, | ||
forecast_start, | ||
forecast_end, | ||
) -> pl.DataFrame: | ||
"""Run a single model for a single forecast date""" | ||
|
||
incident_train_data = iup.IncidentUptakeData( | ||
iup.IncidentUptakeData.split_train_test( | ||
incident_data, config["timeframe"]["start"], "train" | ||
|
@@ -35,36 +64,33 @@ def run_forecast() -> pl.DataFrame: | |
fit_model = model().fit(incident_train_data, grouping_factors) | ||
|
||
cumulative_projections = fit_model.predict( | ||
config["timeframe"]["start"], | ||
config["timeframe"]["end"], | ||
forecast_start, | ||
forecast_end, | ||
config["timeframe"]["interval"], | ||
grouping_factors, | ||
) | ||
# save these projections somewhere | ||
|
||
incident_projections = cumulative_projections.to_incident( | ||
grouping_factors | ||
) | ||
# save these projections somewhere | ||
|
||
# Evaluation / Post-processing -------------------------------------------- | ||
incident_projections = cumulative_projections.to_incident(grouping_factors) | ||
|
||
incident_test_data = iup.IncidentUptakeData( | ||
iup.IncidentUptakeData.split_train_test( | ||
incident_data, config["timeframe"]["start"], "test" | ||
) | ||
).filter(pl.col("date") <= config["timeframe"]["end"]) | ||
return pl.concat( | ||
[ | ||
cumulative_projections.with_columns(estimate_type=pl.lit("cumulative")), | ||
incident_projections.with_columns(estimate_type=pl.lit("incident")), | ||
] | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same question I had on #92 - why combine cumulative and incident uptake into a single data frame? We could just return the cumulative data and derive the incident data in a future step if needed. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed in 862faf8 |
||
|
||
|
||
if __name__ == "__main__": | ||
p = argparse.ArgumentParser() | ||
p.add_argument("--config", help="config file", default="scripts/config.yaml") | ||
p.add_argument("--input", help="input data") | ||
p.add_argument("--output", help="output parquet file") | ||
args = p.parse_args() | ||
|
||
with open(args.config, "r") as f: | ||
config = yaml.safe_load(f) | ||
|
||
input_data = pl.scan_parquet(args.input) | ||
input_data = pl.scan_parquet(args.input).collect() | ||
|
||
run_all_forecasts(config=config, cache=args.cache) | ||
all_forecast = run_all_forecasts(config=config, clean_data=input_data) | ||
all_forecast.write_parquet(args.output) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't it be
forecast_start
that is being incremented?As written here, with
forecast_end
being incremented, we are asking questions like:These questions do not require fitting the model multiple times in a loop! Just fit it once for the latest
forecast_end
date (e.g. May 29, 2025) and look up the projections "along the way" for all the intermediate dates (May 22, May 15, May 8, May 1, etc.).But instead, I think we had intended to increment
forecast_start
to ask questions like:In this case, because the starting date for forecasting keeps getting pushed later, the training data set keeps expanding, so the model does need to be refit multiple times in a loop.
So the latter is what we should be doing here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in 857c080