Skip to content

Commit

Permalink
Set up Snakefile; sketch forecast.py
Browse files Browse the repository at this point in the history
  • Loading branch information
swo committed Jan 23, 2025
1 parent 42ba70e commit 2aab75e
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 92 deletions.
26 changes: 0 additions & 26 deletions Makefile

This file was deleted.

62 changes: 62 additions & 0 deletions Snakefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Common filenames and paths -------------------------------------------------
NIS_CACHE = ".cache/nisapi"
RAW_DATA = "data/nis_raw.parquet"
FORECASTS = "data/forecasts.parquet"
SCORES = "data/scores.parquet"

import yaml
import polars as pl

# Read in workflow information from config files ------------------------------
def token():
with open("scripts/socrata_app_token.txt") as f:
return f.read().strip()

with open("scripts/config.yaml") as f:
CONFIG = yaml.safe_load(f)

FORECAST_DATES = pl.date_range(start=CONFIG['forecast_dates']['start'], end=CONFIG['forecast_dates']['end'], interval=CONFIG['forecast_dates']['interval']).to_list()

# Define rules ----------------------------------------------------------------
rule score:
"""Score all forecasts (models & forecast dates) at once, producing a single output"""
input:
expand("data/forecasts/model={model}/forecast_date={forecast_date}/part-0.parquet", model=CONFIG["models"], forecast_date=FORECAST_DATES),
forecasts="data/forecasts",
raw_data=RAW_DATA,
script="scripts/eval.py"
output:
SCORES
shell:
"python {input.script} --forecasts={input.forecasts} --obs={input.raw_data} --output={output}"

rule forecast:
"""Generate forecast for a single model and date"""
input:
raw_data=RAW_DATA,
script="scripts/forecast.py"
output:
"data/forecasts/model={model}/forecast_date={forecast_date}/part-0.parquet"
shell:
"python {input.script} --input={input.raw_data} --model={wildcards.model} --forecast_date={wildcards.forecast_date} --output={output}"

rule raw_data:
"""Preprocess input data"""
input:
cache=".cache/nisapi/clean",
script="scripts/preprocess.py"
output:
RAW_DATA
shell:
"python {input.script} --cache={input.cache}) --output={output}"

rule cache:
"""Cache NIS data"""
output:
".cache/nisapi/status.txt"
params:
token=token,
cache=".cache/nisapi"
shell:
"python -c 'import nisapi; nisapi.cache_all_datasets({params.cache:q}, {params.token:q})'",
"find {params.cache}/clean -type f | xargs sha1sum > {output}"
3 changes: 1 addition & 2 deletions scripts/config_template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ data:
# use these columns as grouping factors
groups: [geography]

# Timeframe for the longest desired forecast
forecast_timeframe:
forecast_dates:
start: 2024-02-03
end: 2024-04-30
interval: 7d
Expand Down
104 changes: 40 additions & 64 deletions scripts/forecast.py
Original file line number Diff line number Diff line change
@@ -1,88 +1,52 @@
import argparse
import datetime
import warnings
from typing import List

import polars as pl
import yaml

import iup.models


def run_all_forecasts(data, config) -> pl.DataFrame:
"""Run all forecasts
Returns:
pl.DataFrame: data frame of forecasts, organized by model and forecast date
"""

models = [getattr(iup.models, model_name) for model_name in config["models"]]
assert all(issubclass(model, iup.models.UptakeModel) for model in models)

if config["evaluation_timeframe"]["interval"] is not None:
forecast_dates = pl.date_range(
config["forecast_timeframe"]["start"],
config["forecast_timeframe"]["end"],
config["evaluation_timeframe"]["interval"],
eager=True,
).to_list()
else:
forecast_dates = [config["forecast_timeframe"]["start"]]

all_forecast = pl.DataFrame()

for model in models:
for forecast_date in forecast_dates:
forecast = run_forecast(
model,
data,
grouping_factors=config["data"]["groups"],
forecast_start=forecast_date,
forecast_end=config["forecast_timeframe"]["end"],
)

forecast = forecast.with_columns(
forecast_start=forecast_date,
forecast_end=config["forecast_timeframe"]["end"],
model=pl.lit(model.__name__),
)

all_forecast = pl.concat([all_forecast, forecast])

return all_forecast


def run_forecast(
model,
data,
dataset_path: str,
model_name: str,
forecast_date: datetime.date,
target_dates: List[datetime.date],
grouping_factors,
forecast_start,
forecast_end,
output_path: str,
) -> pl.DataFrame:
"""Run a single model for a single forecast date"""
# check that target dates are after the forecast date
warnings.warn("not implemented")

# preprocess.py returns cumulative data, so convert to incident for LinearIncidentUptakeModel
incident_data = data.to_incident(grouping_factors)
# get model object from name
model = getattr(iup.models, model_name)
assert issubclass(model, iup.models.UptakeModel)

# Prune to only the training portion
incident_train_data = iup.IncidentUptakeData.split_train_test(
incident_data, forecast_start, "train"
# get data to use for forecast
data = pl.scan_parquet(dataset_path)
training_data = iup.IncidentUptakeData.split_train_test(
data, forecast_date, "train"
)

# Fit models using the training data and make projections
fit_model = model().fit(incident_train_data, grouping_factors)
# check that target dates are not present in the training data
warnings.warn("not implemented")

cumulative_projections = fit_model.predict(
forecast_start,
forecast_end,
config["forecast_timeframe"]["interval"],
grouping_factors,
)
# fit model and run predictions
fit = model().fit(training_data)
pred = fit.predict(target_dates, grouping_factors)

return cumulative_projections
# write output
pred.write_parquet(output_path)


if __name__ == "__main__":
p = argparse.ArgumentParser()
p.add_argument("--config", help="config file", default="scripts/config.yaml")
p.add_argument("--input", help="input data")
p.add_argument("--input", help="input dataset")
p.add_argument("--model", help="model to forecast with")
p.add_argument("--forecast_date", help="forecast date")
p.add_argument("--output", help="output parquet file")
args = p.parse_args()

Expand All @@ -91,4 +55,16 @@ def run_forecast(

input_data = iup.CumulativeUptakeData(pl.scan_parquet(args.input).collect())

run_all_forecasts(input_data, config).write_parquet(args.output)
target_dates = None
warnings.warn("need to figure out target dates")
grouping_factors = None
warnings.warn("need to figure out grouping factors")

run_forecast(
dataset_path=args.input,
model_name=args.model,
forecast_date=datetime.date.fromisoformat(args.forecast_date),
target_dates=target_dates,
grouping_factors=grouping_factors,
output_path=args.output,
)

0 comments on commit 2aab75e

Please sign in to comment.