From 2aab75e613f76d82c98ff8014af9b070d64c02ef Mon Sep 17 00:00:00 2001 From: Scott Olesen Date: Thu, 23 Jan 2025 18:28:20 -0500 Subject: [PATCH] Set up Snakefile; sketch forecast.py --- Makefile | 26 --------- Snakefile | 62 +++++++++++++++++++++ scripts/config_template.yaml | 3 +- scripts/forecast.py | 104 ++++++++++++++--------------------- 4 files changed, 103 insertions(+), 92 deletions(-) delete mode 100644 Makefile create mode 100644 Snakefile diff --git a/Makefile b/Makefile deleted file mode 100644 index 695ada6..0000000 --- a/Makefile +++ /dev/null @@ -1,26 +0,0 @@ -NIS_CACHE = .cache/nisapi -TOKEN_PATH = scripts/socrata_app_token.txt -TOKEN = $(shell cat $(TOKEN_PATH)) -CONFIG = scripts/config.yaml -RAW_DATA = data/nis_raw.parquet -FORECASTS = data/forecasts.parquet -SCORES = data/scores.parquet - -.PHONY: cache - -all: $(SCORES) - -$(SCORES): scripts/eval.py $(FORECASTS) - python $< --pred=$(FORECASTS) --obs=$(RAW_DATA) --output=$@ - -$(FORECASTS): scripts/forecast.py $(RAW_DATA) - python $< --input=$(RAW_DATA) --output=$@ - -$(RAW_DATA): scripts/preprocess.py cache - python $< --cache=$(NIS_CACHE)/clean --output=$@ - -cache: $(NIS_CACHE)/status.txt - -$(NIS_CACHE)/status.txt $(TOKEN_PATH): - python -c "import nisapi; nisapi.cache_all_datasets('$(NIS_CACHE)', '$(TOKEN)')" - find $(NIS_CACHE)/clean -type f | xargs sha1sum > $@ diff --git a/Snakefile b/Snakefile new file mode 100644 index 0000000..521251b --- /dev/null +++ b/Snakefile @@ -0,0 +1,62 @@ +# Common filenames and paths ------------------------------------------------- +NIS_CACHE = ".cache/nisapi" +RAW_DATA = "data/nis_raw.parquet" +FORECASTS = "data/forecasts.parquet" +SCORES = "data/scores.parquet" + +import yaml +import polars as pl + +# Read in workflow information from config files ------------------------------ +def token(): + with open("scripts/socrata_app_token.txt") as f: + return f.read().strip() + +with open("scripts/config.yaml") as f: + CONFIG = yaml.safe_load(f) + +FORECAST_DATES = pl.date_range(start=CONFIG['forecast_dates']['start'], end=CONFIG['forecast_dates']['end'], interval=CONFIG['forecast_dates']['interval']).to_list() + +# Define rules ---------------------------------------------------------------- +rule score: + """Score all forecasts (models & forecast dates) at once, producing a single output""" + input: + expand("data/forecasts/model={model}/forecast_date={forecast_date}/part-0.parquet", model=CONFIG["models"], forecast_date=FORECAST_DATES), + forecasts="data/forecasts", + raw_data=RAW_DATA, + script="scripts/eval.py" + output: + SCORES + shell: + "python {input.script} --forecasts={input.forecasts} --obs={input.raw_data} --output={output}" + +rule forecast: + """Generate forecast for a single model and date""" + input: + raw_data=RAW_DATA, + script="scripts/forecast.py" + output: + "data/forecasts/model={model}/forecast_date={forecast_date}/part-0.parquet" + shell: + "python {input.script} --input={input.raw_data} --model={wildcards.model} --forecast_date={wildcards.forecast_date} --output={output}" + +rule raw_data: + """Preprocess input data""" + input: + cache=".cache/nisapi/clean", + script="scripts/preprocess.py" + output: + RAW_DATA + shell: + "python {input.script} --cache={input.cache}) --output={output}" + +rule cache: + """Cache NIS data""" + output: + ".cache/nisapi/status.txt" + params: + token=token, + cache=".cache/nisapi" + shell: + "python -c 'import nisapi; nisapi.cache_all_datasets({params.cache:q}, {params.token:q})'", + "find {params.cache}/clean -type f | xargs sha1sum > {output}" diff --git a/scripts/config_template.yaml b/scripts/config_template.yaml index 27e8561..97605dd 100644 --- a/scripts/config_template.yaml +++ b/scripts/config_template.yaml @@ -15,8 +15,7 @@ data: # use these columns as grouping factors groups: [geography] -# Timeframe for the longest desired forecast -forecast_timeframe: +forecast_dates: start: 2024-02-03 end: 2024-04-30 interval: 7d diff --git a/scripts/forecast.py b/scripts/forecast.py index 75b7fa8..4f42519 100644 --- a/scripts/forecast.py +++ b/scripts/forecast.py @@ -1,4 +1,7 @@ import argparse +import datetime +import warnings +from typing import List import polars as pl import yaml @@ -6,83 +9,44 @@ import iup.models -def run_all_forecasts(data, config) -> pl.DataFrame: - """Run all forecasts - - Returns: - pl.DataFrame: data frame of forecasts, organized by model and forecast date - """ - - models = [getattr(iup.models, model_name) for model_name in config["models"]] - assert all(issubclass(model, iup.models.UptakeModel) for model in models) - - if config["evaluation_timeframe"]["interval"] is not None: - forecast_dates = pl.date_range( - config["forecast_timeframe"]["start"], - config["forecast_timeframe"]["end"], - config["evaluation_timeframe"]["interval"], - eager=True, - ).to_list() - else: - forecast_dates = [config["forecast_timeframe"]["start"]] - - all_forecast = pl.DataFrame() - - for model in models: - for forecast_date in forecast_dates: - forecast = run_forecast( - model, - data, - grouping_factors=config["data"]["groups"], - forecast_start=forecast_date, - forecast_end=config["forecast_timeframe"]["end"], - ) - - forecast = forecast.with_columns( - forecast_start=forecast_date, - forecast_end=config["forecast_timeframe"]["end"], - model=pl.lit(model.__name__), - ) - - all_forecast = pl.concat([all_forecast, forecast]) - - return all_forecast - - def run_forecast( - model, - data, + dataset_path: str, + model_name: str, + forecast_date: datetime.date, + target_dates: List[datetime.date], grouping_factors, - forecast_start, - forecast_end, + output_path: str, ) -> pl.DataFrame: """Run a single model for a single forecast date""" + # check that target dates are after the forecast date + warnings.warn("not implemented") - # preprocess.py returns cumulative data, so convert to incident for LinearIncidentUptakeModel - incident_data = data.to_incident(grouping_factors) + # get model object from name + model = getattr(iup.models, model_name) + assert issubclass(model, iup.models.UptakeModel) - # Prune to only the training portion - incident_train_data = iup.IncidentUptakeData.split_train_test( - incident_data, forecast_start, "train" + # get data to use for forecast + data = pl.scan_parquet(dataset_path) + training_data = iup.IncidentUptakeData.split_train_test( + data, forecast_date, "train" ) - # Fit models using the training data and make projections - fit_model = model().fit(incident_train_data, grouping_factors) + # check that target dates are not present in the training data + warnings.warn("not implemented") - cumulative_projections = fit_model.predict( - forecast_start, - forecast_end, - config["forecast_timeframe"]["interval"], - grouping_factors, - ) + # fit model and run predictions + fit = model().fit(training_data) + pred = fit.predict(target_dates, grouping_factors) - return cumulative_projections + # write output + pred.write_parquet(output_path) if __name__ == "__main__": p = argparse.ArgumentParser() - p.add_argument("--config", help="config file", default="scripts/config.yaml") - p.add_argument("--input", help="input data") + p.add_argument("--input", help="input dataset") + p.add_argument("--model", help="model to forecast with") + p.add_argument("--forecast_date", help="forecast date") p.add_argument("--output", help="output parquet file") args = p.parse_args() @@ -91,4 +55,16 @@ def run_forecast( input_data = iup.CumulativeUptakeData(pl.scan_parquet(args.input).collect()) - run_all_forecasts(input_data, config).write_parquet(args.output) + target_dates = None + warnings.warn("need to figure out target dates") + grouping_factors = None + warnings.warn("need to figure out grouping factors") + + run_forecast( + dataset_path=args.input, + model_name=args.model, + forecast_date=datetime.date.fromisoformat(args.forecast_date), + target_dates=target_dates, + grouping_factors=grouping_factors, + output_path=args.output, + )