Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/add fit/predict_kwargs argument to historical_forecasts #2050

Merged
merged 35 commits into from
Nov 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
aa8e341
fix: num_loader_workers can be passed to historical_forecasts, only r…
madtoinou Nov 2, 2023
a797b13
feat: added fit/predict_kwargs to historical_forecasts, backtest and …
madtoinou Nov 3, 2023
a3817e7
fix: default value None for dict
madtoinou Nov 3, 2023
7353db2
feat: increased the number of parameters handled by GlobalForecasting…
madtoinou Nov 3, 2023
4566858
Merge branch 'master' into fix/hfc_num_loader_workers
madtoinou Nov 3, 2023
e7ae306
fix: removed obsolete arg/docstring
madtoinou Nov 3, 2023
37fba40
fix: updated docstring
madtoinou Nov 3, 2023
2f780cf
fix: only pass the supported argument to GlobalForecastingModel.predi…
madtoinou Nov 3, 2023
2b2deca
Merge branch 'master' into fix/hfc_num_loader_workers
madtoinou Nov 6, 2023
294e326
Merge branch 'master' into fix/hfc_num_loader_workers
dennisbader Nov 6, 2023
d35b0d6
fix: simplify the logic of the fit/predict wrapper and hist fc sanity…
madtoinou Nov 7, 2023
2ff6422
Merge branch 'master' into fix/hfc_num_loader_workers
dennisbader Nov 8, 2023
8538885
fix: same signature for all _optimized...
madtoinou Nov 8, 2023
6dd2aa5
fix: changed the exception into warning
madtoinou Nov 8, 2023
e29bca9
fix: missing arg
madtoinou Nov 8, 2023
9031fbc
fix: harmonized signatures of optimized_hist, improved kwargs checks
madtoinou Nov 9, 2023
c4ef3fc
feat: added warning when fit_kwargs is set with retrain=False, possib…
madtoinou Nov 9, 2023
986a524
feat: improve fit/predict_kwargs handling, add tests
madtoinou Nov 10, 2023
fcc3767
feat: parametrize the tests to check both optimized and unoptimized m…
madtoinou Nov 10, 2023
c2e6327
Apply suggestions from code review
madtoinou Nov 13, 2023
8a7c260
Merge branch 'master' into fix/hfc_num_loader_workers
madtoinou Nov 13, 2023
3cea611
feat: updated changelog
madtoinou Nov 13, 2023
ed1e442
feat: added tests when fit_kwargs contains invalid arguments and retr…
madtoinou Nov 13, 2023
37e4afc
fix: set self._uses_future_covariates to True in FutureCovariatesLoca…
madtoinou Nov 13, 2023
783ba74
fix: predict_kwargs[trainer] is properly passed to fit_from_dataset
madtoinou Nov 13, 2023
32d8f14
feat: added exception when unsupported covariates are passed to the f…
madtoinou Nov 14, 2023
61d8171
feat: added tests checking that the exception are raised when expected
madtoinou Nov 14, 2023
c18e413
fix: ensemble model pass covariates only to forecasting models suppor…
madtoinou Nov 15, 2023
156ee62
update changelog
dennisbader Nov 16, 2023
30f27ff
Merge branch 'master' into fix/hfc_num_loader_workers
dennisbader Nov 16, 2023
bce4e86
update hist fc tests
dennisbader Nov 16, 2023
a5e21c3
fix failing bt test p2
dennisbader Nov 16, 2023
9f8cee9
uddate fit/predict wrappers
dennisbader Nov 16, 2023
0779d0c
update docs
dennisbader Nov 16, 2023
f6c0074
shorten electricity dataset to avoid source data updates
dennisbader Nov 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
- Improvements to Regression Models:
- `XGBModel` now leverages XGBoost's native Quantile Regression support that was released in version 2.0.0 for improved probabilistic forecasts. [#2051](https://github.com/unit8co/darts/pull/2051) by [Dennis Bader](https://github.com/dennisbader).
- Other improvements:
- Added support for time index time zone conversion with parameter `tz` before generating/computing holidays and datetime attributes. Support was added to all Time Axis Encoders (standalone encoders and forecasting models' `add_encoders`, time series generation utils functions `holidays_timeseries()` and `datetime_attribute_timeseries()`, and `TimeSeries` methods `add_datetime_attribute()` and `add_holidays()`. [#2054](https://github.com/unit8co/darts/pull/2054) by [Dennis Bader](https://github.com/dennisbader).
- Added support for time index time zone conversion with parameter `tz` before generating/computing holidays and datetime attributes. Support was added to all Time Axis Encoders, standalone encoders and forecasting models' `add_encoders`, time series generation utils functions `holidays_timeseries()` and `datetime_attribute_timeseries()`, and `TimeSeries` methods `add_datetime_attribute()` and `add_holidays()`. [#2054](https://github.com/unit8co/darts/pull/2054) by [Dennis Bader](https://github.com/dennisbader).
- Added optional keyword arguments dict `kwargs` to `ExponentialSmoothing` that will be passed to the constructor of the underlying `statsmodels.tsa.holtwinters.ExponentialSmoothing` model. [#2059](https://github.com/unit8co/darts/pull/2059) by [Antoine Madrona](https://github.com/madtoinou).
- Added new dataset `ElectricityConsumptionZurichDataset`: The dataset contains the electricity consumption of households in Zurich, Switzerland from 2015-2022 on different grid levels. We also added weather measurements for Zurich which can be used as covariates for modelling. [#2039](https://github.com/unit8co/darts/pull/2039) by [Antoine Madrona](https://github.com/madtoinou) and [Dennis Bader](https://github.com/dennisbader).
- Added new arguments `fit_kwargs` and `predict_kwargs` to `historical_forecasts()`, `backtest()` and `gridsearch()` that will be passed to the model's `fit()` and / or `predict` methods. E.g., you can now set a batch size, static validation series, ... depending on the model support. [#2050](https://github.com/unit8co/darts/pull/2050) by [Antoine Madrona](https://github.com/madtoinou)

**Fixed**
- Fixed a bug when calling optimized `historical_forecasts()` for a `RegressionModel` trained with unequal component-specific lags. [#2040](https://github.com/unit8co/darts/pull/2040) by [Antoine Madrona](https://github.com/madtoinou).
Expand Down
9 changes: 5 additions & 4 deletions darts/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,7 +829,8 @@ class ElectricityConsumptionZurichDataset(DatasetLoaderCSV):
To simplify the dataset, the measurements from the Zch_Schimmelstrasse and Zch_Rosengartenstrasse weather
stations are discarded to keep only the data recorded in the Zch_Stampfenbachstrasse station.

Both dataset sources are updated continuously, but this dataset only retrains values between 2015 and 2022.
Both dataset sources are updated continuously, but this dataset only retrains values between 2015-01-01 and
2022-08-31.
The time index was converted from CET time zone to UTC.

Components Descriptions:
Expand Down Expand Up @@ -864,7 +865,7 @@ def pre_process_dataset(dataset_path):
# extract pre-determined period
df = df.loc[
(pd.Timestamp("2015-01-01") <= df.index)
& (df.index <= pd.Timestamp("2022-12-31"))
& (df.index <= pd.Timestamp("2022-08-31"))
]
# download and preprocess the weather information
df_weather = self._download_weather_data()
Expand Down Expand Up @@ -894,7 +895,7 @@ def pre_process_dataset(dataset_path):
"ewz_stromabgabe_netzebenen_stadt_zuerich/"
"download/ewz_stromabgabe_netzebenen_stadt_zuerich.csv"
),
hash="c2fea1a0974611ff1c276abcc1d34619",
hash="a019125b7f9c1afeacb0ae60ce7455ef",
header_time="Timestamp",
freq="15min",
pre_process_csv_fn=pre_process_dataset,
Expand All @@ -919,6 +920,6 @@ def _download_weather_data():
)
df = df.loc[
(pd.Timestamp("2015-01-01") <= df.index)
& (df.index <= pd.Timestamp("2022-12-31"))
& (df.index <= pd.Timestamp("2022-08-31"))
]
return df
8 changes: 6 additions & 2 deletions darts/models/forecasting/baselines.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,8 +332,12 @@ def fit(
for model in self.forecasting_models:
model._fit_wrapper(
series=series,
past_covariates=past_covariates,
future_covariates=future_covariates,
past_covariates=past_covariates
if model.supports_past_covariates
else None,
future_covariates=future_covariates
if model.supports_future_covariates
else None,
)

return self
Expand Down
Loading
Loading