Skip to content

Commit

Permalink
Merge branch 'main' into fy_config_eval
Browse files Browse the repository at this point in the history
  • Loading branch information
Fuhan-Yang authored Jan 15, 2025
2 parents f9f1526 + c543676 commit 51b9f6d
Show file tree
Hide file tree
Showing 7 changed files with 105 additions and 96 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ repos:
#####
# Python
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.8.2
rev: v0.9.1
hooks:
- id: ruff
args: ["check", "--select", "I", "--fix"]
Expand Down
20 changes: 10 additions & 10 deletions iup/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,9 @@ def validate(self):
# same validations as UptakeData
super().validate()
# and also require that uptake be a proportion
assert (
self["estimate"].is_between(0.0, 1.0).all()
), "cumulative uptake `estimate` must be a proportion"
assert self["estimate"].is_between(0.0, 1.0).all(), (
"cumulative uptake `estimate` must be a proportion"
)

def to_incident(self, group_cols: List[str,] | None) -> IncidentUptakeData:
"""
Expand All @@ -187,8 +187,8 @@ def to_incident(self, group_cols: List[str,] | None) -> IncidentUptakeData:

return IncidentUptakeData(out)

def insert_rollout(
self, rollout: List[dt.date], group_cols: List[str] | None
def insert_rollouts(
self, rollouts: List[dt.date], group_cols: List[str] | None
) -> pl.DataFrame:
"""
Insert into cumulative uptake data rows with 0 uptake on rollout dates.
Expand All @@ -211,12 +211,12 @@ def insert_rollout(
rollout_rows = (
frame.select(group_cols)
.unique()
.join(pl.DataFrame({"time_end": rollout}), how="cross")
.join(pl.DataFrame({"time_end": rollouts}), how="cross")
.with_columns(estimate=0.0)
)
group_cols = group_cols + ["season"]
else:
rollout_rows = pl.DataFrame({"time_end": rollout, "estimate": 0.0})
rollout_rows = pl.DataFrame({"time_end": rollouts, "estimate": 0.0})
group_cols = ["season"]

frame = frame.vstack(rollout_rows.select(frame.columns)).sort("time_end")
Expand Down Expand Up @@ -252,9 +252,9 @@ def validate(self):
)

# all quantiles should be between 0 and 1
assert (
self["quantile"].is_between(0.0, 1.0).all()
), "quantiles must be between 0 and 1"
assert self["quantile"].is_between(0.0, 1.0).all(), (
"quantiles must be between 0 and 1"
)


class PointForecast(QuantileForecast):
Expand Down
29 changes: 14 additions & 15 deletions scripts/config_template.yaml
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
# The data sets to load and how to interpret them
data:
data_set_1:
rollout: [2023-09-01, 2024-09-01]
filters:
geography_type: nation
domain_type: age
domain: 18+ years
indicator_type: 4-level vaccination and intent
indicator: received a vaccination
time_type: week

# Columns to keep from each data set
keep: [geography, estimate, time_end]

# Grouping factors across data sets
groups: [geography]
rollouts: [2023-09-01, 2024-09-01]
# filter data; eg `vaccine: covid` means filter column "vaccine" for value "covid"
filters:
vaccine: covid
geography_type: nation
domain_type: age
domain: 18+ years
indicator_type: 4-level vaccination and intent
indicator: received a vaccination
time_type: week
# keep only these data columns
keep: [geography, estimate, time_end]
# use these columns as grouping factors
groups: [geography]

# The timeframe over which to generate projections
timeframe:
Expand Down
7 changes: 5 additions & 2 deletions scripts/eval.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

import argparse

import polars as pl
Expand All @@ -24,13 +25,14 @@ def eval_all_forecasts(data, pred, config):
pl.col("model") == model, pl.col("forecast_start") == forecast_start
)

# only 'incident' type is evaluated #
# convert cumulative predictions to incident predictions given certain forecast period and model #
incident_pred = iup.CumulativeUptakeData(this_pred).to_incident(
config["data"]["groups"]
)
# This step is arbitrary, but it is necessary to pass PointForecast validation #
incident_pred = incident_pred.with_columns(quantile=0.5)
incident_pred = iup.PointForecast(incident_pred)
# This step is arbitrary, but it is necessary to pass PointForecast validation #


test = data.filter(
pl.col("time_end") >= forecast_start,
Expand Down Expand Up @@ -68,3 +70,4 @@ def eval_all_forecasts(data, pred, config):

all_scores = eval_all_forecasts(obs_data, pred_data, config)
all_scores.write_parquet(args.output)

84 changes: 52 additions & 32 deletions scripts/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,75 +3,95 @@
import polars as pl
import yaml

import iup
import iup.models


def run_all_forecasts() -> pl.DataFrame:
def run_all_forecasts(clean_data, config) -> pl.DataFrame:

"""Run all forecasts
Returns:
pl.DataFrame: data frame of forecasts, organized by model and forecast date
"""
raise NotImplementedError
models = None

forecast_dates = pl.date_range(
config["timeframe"]["start"],
config["timeframe"]["end"],
config["timeframe"]["interval"],
eager=True,
)

models = [getattr(iup.models, model_name) for model_name in config["models"]]
assert all(issubclass(model, iup.models.UptakeModel) for model in models)

all_forecast = pl.DataFrame()

for model in models:
for forecast_date in forecast_dates:
# Get data available as of the forecast date
pass

forecast = run_forecast(
model,
clean_data,
grouping_factors=config["groups"],
forecast_start=forecast_date,
forecast_end=config["timeframe"]["end"],
)

def run_forecast() -> pl.DataFrame:
"""Run a single model for a single forecast date"""
raise NotImplementedError
forecast = forecast.with_columns(
forecast_start=forecast_date,
forecast_end=config["timeframe"]["end"],
model=pl.lit(model.__name__),
)

# incident_train_data = iup.IncidentUptakeData(
# iup.IncidentUptakeData.split_train_test(
# incident_data, config["timeframe"]["start"], "train"
# )
# )
all_forecast = pl.concat([all_forecast, forecast])

# Fit models using the training data and make projections
# fit_model = model().fit(incident_train_data, grouping_factors)
return all_forecast

# cumulative_projections = fit_model.predict(
# config["timeframe"]["start"],
# config["timeframe"]["end"],
# config["timeframe"]["interval"],
# grouping_factors,
# )
# save these projections somewhere

# incident_projections = cumulative_projections.to_incident(grouping_factors)
# save these projections somewhere
def run_forecast(
model,
observed_data,
grouping_factors,
forecast_start,
forecast_end,
) -> pl.DataFrame:
"""Run a single model for a single forecast date"""

# Evaluation / Post-processing --------------------------------------------
# preprocess.py returns cumulative data, need to convert to incidence for LinearIncidentUptakeModel #
incident_data = iup.CumulativeUptakeData(observed_data).to_incident(
grouping_factors
)

# incident_test_data = iup.IncidentUptakeData(
# iup.IncidentUptakeData.split_train_test(
# incident_data, config["timeframe"]["start"], "test"
# )
# ).filter(pl.col("date") <= config["timeframe"]["end"])
incident_train_data = iup.IncidentUptakeData(
iup.IncidentUptakeData.split_train_test(incident_data, forecast_start, "train")
)

# Fit models using the training data and make projections
fit_model = model().fit(incident_train_data, grouping_factors)

cumulative_projections = fit_model.predict(
forecast_start,
forecast_end,
config["timeframe"]["interval"],
grouping_factors,
)

return cumulative_projections

if __name__ == "__main__":
p = argparse.ArgumentParser()
p.add_argument("--config", help="config file", default="scripts/config.yaml")
p.add_argument("--input", help="input data")
p.add_argument("--output", help="output parquet file")
args = p.parse_args()

with open(args.config, "r") as f:
config = yaml.safe_load(f)

input_data = pl.scan_parquet(args.input)
input_data = pl.scan_parquet(args.input).collect()

input_data = iup.CumulativeUptakeData(input_data)

run_all_forecasts(config=config, cache=args.cache)
all_forecast = run_all_forecasts(config=config, clean_data=input_data)
all_forecast.write_parquet(args.output)
39 changes: 13 additions & 26 deletions scripts/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,29 +14,19 @@ def preprocess(
filters: dict,
keep: List[str],
groups: List[str],
rollout_dates: List[datetime.date],
) -> pl.DataFrame:
rollouts: List[datetime.date],
) -> iup.CumulativeUptakeData:
# Prune data to correct rows and columns
cumulative_data = iup.CumulativeUptakeData(
raw_data.filter(filters).select(keep).sort("time_end").collect()
)
data = raw_data.filter(**filters).select(keep).sort("time_end").collect()

# Ensure that the desired grouping factors are found in all data sets
assert set(cumulative_data.columns).issuperset(groups)
assert set(data.columns).issuperset(groups)

# Insert rollout dates into the data
cumulative_data = iup.CumulativeUptakeData(
cumulative_data.insert_rollout(rollout_dates, groups)
)

# Convert to incident data
incident_data = cumulative_data.to_incident(groups)

return pl.concat(
[
cumulative_data.with_columns(estimate_type="cumulative"),
incident_data.with_columns(estimate_type="incident"),
]
# note the awkward wrapping with the class, because insert_rollouts returns
# a normal data frame
return iup.CumulativeUptakeData(
iup.CumulativeUptakeData(data).insert_rollouts(rollouts, groups)
)


Expand All @@ -46,23 +36,20 @@ def preprocess(
p.add_argument(
"--cache", help="NIS cache directory", default=".cache/nisapi/clean/"
)
p.add_argument("--cache", help="clean cache directory")
p.add_argument("--output", help="output parquet file")
p.add_argument("--output", help="output parquet file", required=True)
args = p.parse_args()

with open(args.config, "r") as f:
config = yaml.safe_load(f)

assert len(config["data"]) == 1, "Don't know how to preprocess multiple data sets"

raw_data = nisapi.get_nis(path=args.cache)

clean_data = preprocess(
raw_data,
filters=config["data"][0]["filters"],
keep=config["data"][0]["keep"],
groups=config["groups"],
rollout_dates=config["data"][0]["rollout"],
filters=config["data"]["filters"],
keep=config["data"]["keep"],
groups=config["data"]["groups"],
rollouts=config["data"]["rollouts"],
)

clean_data.write_parquet(args.output)
20 changes: 10 additions & 10 deletions tests/test_data_cleaning.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,49 +24,49 @@ def frame():
return frame


def test_insert_rollout_handles_groups(frame):
def test_insert_rollouts_handles_groups(frame):
"""
If grouping columns are given to insert_rollout, a separate rollout is inserted for each group.
If grouping columns are given to insert_rollouts, a separate rollout is inserted for each group.
"""
frame = frame.with_columns(
time_end=pl.col("time_end").str.strptime(pl.Date, "%Y-%m-%d")
)
rollout = [dt.date(2020, 1, 1), dt.date(2021, 1, 1)]
rollouts = [dt.date(2020, 1, 1), dt.date(2021, 1, 1)]
group_cols = [
"geography",
]
frame = iup.CumulativeUptakeData(frame.drop("indicator"))

output = frame.insert_rollout(rollout, group_cols)
output = frame.insert_rollouts(rollouts, group_cols)

assert output.shape[0] == 7
assert (
output["time_end"]
.value_counts()
.filter(pl.col("time_end") == rollout[0])["count"][0]
.filter(pl.col("time_end") == rollouts[0])["count"][0]
== 2
)
assert output["time_end"].is_sorted()


def test_insert_rollout_handles_no_groups(frame):
def test_insert_rollouts_handles_no_groups(frame):
"""
If no grouping columns are given to insert_rollout, only one of each rollout is inserted.
If no grouping columns are given to insert_rollouts, only one of each rollout is inserted.
"""
frame = frame.with_columns(
time_end=pl.col("time_end").str.strptime(pl.Date, "%Y-%m-%d")
)
rollout = [dt.date(2020, 1, 1), dt.date(2021, 1, 1)]
rollouts = [dt.date(2020, 1, 1), dt.date(2021, 1, 1)]
group_cols = None
frame = iup.CumulativeUptakeData(frame.drop(["indicator", "geography"]))

output = frame.insert_rollout(rollout, group_cols)
output = frame.insert_rollouts(rollouts, group_cols)

assert output.shape[0] == 5
assert (
output["time_end"]
.value_counts()
.filter(pl.col("time_end") == rollout[0])["count"][0]
.filter(pl.col("time_end") == rollouts[0])["count"][0]
== 1
)
assert output["time_end"].is_sorted()
Expand Down

0 comments on commit 51b9f6d

Please sign in to comment.