Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/add fit/predict_kwargs argument to historical_forecasts #2050

Merged
merged 35 commits into from
Nov 16, 2023
Merged
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
aa8e341
fix: num_loader_workers can be passed to historical_forecasts, only r…
madtoinou Nov 2, 2023
a797b13
feat: added fit/predict_kwargs to historical_forecasts, backtest and …
madtoinou Nov 3, 2023
a3817e7
fix: default value None for dict
madtoinou Nov 3, 2023
7353db2
feat: increased the number of parameters handled by GlobalForecasting…
madtoinou Nov 3, 2023
4566858
Merge branch 'master' into fix/hfc_num_loader_workers
madtoinou Nov 3, 2023
e7ae306
fix: removed obsolete arg/docstring
madtoinou Nov 3, 2023
37fba40
fix: updated docstring
madtoinou Nov 3, 2023
2f780cf
fix: only pass the supported argument to GlobalForecastingModel.predi…
madtoinou Nov 3, 2023
2b2deca
Merge branch 'master' into fix/hfc_num_loader_workers
madtoinou Nov 6, 2023
294e326
Merge branch 'master' into fix/hfc_num_loader_workers
dennisbader Nov 6, 2023
d35b0d6
fix: simplify the logic of the fit/predict wrapper and hist fc sanity…
madtoinou Nov 7, 2023
2ff6422
Merge branch 'master' into fix/hfc_num_loader_workers
dennisbader Nov 8, 2023
8538885
fix: same signature for all _optimized...
madtoinou Nov 8, 2023
6dd2aa5
fix: changed the exception into warning
madtoinou Nov 8, 2023
e29bca9
fix: missing arg
madtoinou Nov 8, 2023
9031fbc
fix: harmonized signatures of optimized_hist, improved kwargs checks
madtoinou Nov 9, 2023
c4ef3fc
feat: added warning when fit_kwargs is set with retrain=False, possib…
madtoinou Nov 9, 2023
986a524
feat: improve fit/predict_kwargs handling, add tests
madtoinou Nov 10, 2023
fcc3767
feat: parametrize the tests to check both optimized and unoptimized m…
madtoinou Nov 10, 2023
c2e6327
Apply suggestions from code review
madtoinou Nov 13, 2023
8a7c260
Merge branch 'master' into fix/hfc_num_loader_workers
madtoinou Nov 13, 2023
3cea611
feat: updated changelog
madtoinou Nov 13, 2023
ed1e442
feat: added tests when fit_kwargs contains invalid arguments and retr…
madtoinou Nov 13, 2023
37e4afc
fix: set self._uses_future_covariates to True in FutureCovariatesLoca…
madtoinou Nov 13, 2023
783ba74
fix: predict_kwargs[trainer] is properly passed to fit_from_dataset
madtoinou Nov 13, 2023
32d8f14
feat: added exception when unsupported covariates are passed to the f…
madtoinou Nov 14, 2023
61d8171
feat: added tests checking that the exception are raised when expected
madtoinou Nov 14, 2023
c18e413
fix: ensemble model pass covariates only to forecasting models suppor…
madtoinou Nov 15, 2023
156ee62
update changelog
dennisbader Nov 16, 2023
30f27ff
Merge branch 'master' into fix/hfc_num_loader_workers
dennisbader Nov 16, 2023
bce4e86
update hist fc tests
dennisbader Nov 16, 2023
a5e21c3
fix failing bt test p2
dennisbader Nov 16, 2023
9f8cee9
uddate fit/predict wrappers
dennisbader Nov 16, 2023
0779d0c
update docs
dennisbader Nov 16, 2023
f6c0074
shorten electricity dataset to avoid source data updates
dennisbader Nov 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
- Improvements to Regression Models:
- `XGBModel` now leverages XGBoost's native Quantile Regression support that was released in version 2.0.0 for improved probabilistic forecasts. [#2051](https://github.com/unit8co/darts/pull/2051) by [Dennis Bader](https://github.com/dennisbader).
- Other improvements:
- Added support for time index time zone conversion with parameter `tz` before generating/computing holidays and datetime attributes. Support was added to all Time Axis Encoders (standalone encoders and forecasting models' `add_encoders`, time series generation utils functions `holidays_timeseries()` and `datetime_attribute_timeseries()`, and `TimeSeries` methods `add_datetime_attribute()` and `add_holidays()`. [#2054](https://github.com/unit8co/darts/pull/2054) by [Dennis Bader](https://github.com/dennisbader).
- Added support for time index time zone conversion with parameter `tz` before generating/computing holidays and datetime attributes. Support was added to all Time Axis Encoders, standalone encoders and forecasting models' `add_encoders`, time series generation utils functions `holidays_timeseries()` and `datetime_attribute_timeseries()`, and `TimeSeries` methods `add_datetime_attribute()` and `add_holidays()`. [#2054](https://github.com/unit8co/darts/pull/2054) by [Dennis Bader](https://github.com/dennisbader).
- Added optional keyword arguments dict `kwargs` to `ExponentialSmoothing` that will be passed to the constructor of the underlying `statsmodels.tsa.holtwinters.ExponentialSmoothing` model. [#2059](https://github.com/unit8co/darts/pull/2059) by [Antoine Madrona](https://github.com/madtoinou).
- Added new dataset `ElectricityConsumptionZurichDataset`: The dataset contains the electricity consumption of households in Zurich, Switzerland from 2015-2022 on different grid levels. We also added weather measurements for Zurich which can be used as covariates for modelling. [#2039](https://github.com/unit8co/darts/pull/2039) by [Antoine Madrona](https://github.com/madtoinou) and [Dennis Bader](https://github.com/dennisbader).
- Added new arguments `fit_kwargs` and `predict_kwargs` to `historical_forecasts()`, `backtest()` and `gridsearch()` that will be passed to the model's `fit()` and / or `predict` methods. E.g., you can now set a batch size, static validation series, ... depending on the model support. [#2050](https://github.com/unit8co/darts/pull/2050) by [Antoine Madrona](https://github.com/madtoinou)

**Fixed**
- Fixed a bug when calling optimized `historical_forecasts()` for a `RegressionModel` trained with unequal component-specific lags. [#2040](https://github.com/unit8co/darts/pull/2040) by [Antoine Madrona](https://github.com/madtoinou).
Expand Down
8 changes: 6 additions & 2 deletions darts/models/forecasting/baselines.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,8 +332,12 @@ def fit(
for model in self.forecasting_models:
model._fit_wrapper(
series=series,
past_covariates=past_covariates,
future_covariates=future_covariates,
past_covariates=past_covariates
if model.supports_past_covariates
else None,
future_covariates=future_covariates
if model.supports_future_covariates
else None,
)

return self
Expand Down
203 changes: 93 additions & 110 deletions darts/models/forecasting/forecasting_model.py
madtoinou marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
_get_historical_forecast_predict_index,
_get_historical_forecast_train_index,
_historical_forecasts_general_checks,
_historical_forecasts_sanitize_kwargs,
_reconciliate_historical_time_indices,
)
from darts.utils.timeseries_generation import (
Expand Down Expand Up @@ -316,23 +317,47 @@ def _fit_wrapper(
series: TimeSeries,
past_covariates: Optional[TimeSeries],
future_covariates: Optional[TimeSeries],
**kwargs,
):
self.fit(series)
supported_params = inspect.signature(self.fit).parameters
kwargs_ = {k: v for k, v in kwargs.items() if k in supported_params}

# handle past and future covariates based on model support
for covs, name in zip([past_covariates, future_covariates], ["past", "future"]):
covs_name = f"{name}_covariates"
if getattr(self, f"supports_{covs_name}"):
kwargs_[covs_name] = covs
elif covs is not None:
raise_log(
ValueError(f"Model cannot be fit/trained with `{covs_name}`."),
logger,
)
self.fit(series, **kwargs_)

def _predict_wrapper(
self,
n: int,
series: TimeSeries,
past_covariates: Optional[TimeSeries],
future_covariates: Optional[TimeSeries],
num_samples: int,
verbose: bool = False,
predict_likelihood_parameters: bool = False,
) -> TimeSeries:
kwargs = dict()
if self.supports_likelihood_parameter_prediction:
kwargs["predict_likelihood_parameters"] = predict_likelihood_parameters
return self.predict(n, num_samples=num_samples, verbose=verbose, **kwargs)
**kwargs,
) -> Union[TimeSeries, Sequence[TimeSeries]]:
supported_params = set(inspect.signature(self.predict).parameters)

# if predict() accepts covariates, the model might not support them at inference
for covs_name in ["past_covariates", "future_covariates"]:
if covs_name in kwargs and not getattr(self, f"supports_{covs_name}"):
if kwargs[covs_name] is None:
supported_params = supported_params - {covs_name}
else:
raise_log(
ValueError(
f"Model prediction does not support `{covs_name}`, either because it "
f"does not support `{covs_name}` in general, or because it was fit/trained "
f"without using `{covs_name}`."
),
logger,
)

kwargs_ = {k: v for k, v in kwargs.items() if k in supported_params}
return self.predict(n, **kwargs_)

@property
def min_train_series_length(self) -> int:
Expand Down Expand Up @@ -586,6 +611,8 @@ def historical_forecasts(
show_warnings: bool = True,
predict_likelihood_parameters: bool = False,
enable_optimization: bool = True,
fit_kwargs: Optional[Dict[str, Any]] = None,
predict_kwargs: Optional[Dict[str, Any]] = None,
) -> Union[
TimeSeries, List[TimeSeries], Sequence[TimeSeries], Sequence[List[TimeSeries]]
]:
Expand Down Expand Up @@ -692,6 +719,10 @@ def historical_forecasts(
Default: ``False``
enable_optimization
Whether to use the optimized version of historical_forecasts when supported and available.
fit_kwargs
Additional arguments passed to the model `fit()` method.
predict_kwargs
Additional arguments passed to the model `predict()` method.

Returns
-------
Expand Down Expand Up @@ -802,6 +833,15 @@ def retrain_func(
logger,
)

# remove unsupported arguments, raise exception if interference with historical forecasts logic
fit_kwargs, predict_kwargs = _historical_forecasts_sanitize_kwargs(
model=model,
fit_kwargs=fit_kwargs,
predict_kwargs=predict_kwargs,
retrain=retrain is not False and retrain != 0,
show_warnings=show_warnings,
)

series = series2seq(series)
past_covariates = series2seq(past_covariates)
future_covariates = series2seq(future_covariates)
Expand Down Expand Up @@ -829,6 +869,7 @@ def retrain_func(
verbose=verbose,
show_warnings=show_warnings,
predict_likelihood_parameters=predict_likelihood_parameters,
**predict_kwargs,
)

if len(series) == 1:
Expand Down Expand Up @@ -969,6 +1010,7 @@ def retrain_func(
series=train_series,
past_covariates=past_covariates_,
future_covariates=future_covariates_,
**fit_kwargs,
)
else:
# untrained model was not trained on the first trainable timestamp
Expand Down Expand Up @@ -1019,6 +1061,7 @@ def retrain_func(
num_samples=num_samples,
verbose=verbose,
predict_likelihood_parameters=predict_likelihood_parameters,
**predict_kwargs,
)
if forecast_components is None:
forecast_components = forecast.columns
Expand Down Expand Up @@ -1076,6 +1119,8 @@ def backtest(
reduction: Union[Callable[[np.ndarray], float], None] = np.mean,
verbose: bool = False,
show_warnings: bool = True,
fit_kwargs: Optional[Dict[str, Any]] = None,
predict_kwargs: Optional[Dict[str, Any]] = None,
) -> Union[float, List[float], Sequence[float], List[Sequence[float]]]:
"""Compute error values that the model would have produced when
used on (potentially multiple) `series`.
Expand Down Expand Up @@ -1185,6 +1230,10 @@ def backtest(
Whether to print progress.
show_warnings
Whether to show warnings related to parameters `start`, and `train_length`.
fit_kwargs
Additional arguments passed to the model `fit()` method.
predict_kwargs
Additional arguments passed to the model `predict()` method.

Returns
-------
Expand All @@ -1208,6 +1257,8 @@ def backtest(
last_points_only=last_points_only,
verbose=verbose,
show_warnings=show_warnings,
fit_kwargs=fit_kwargs,
predict_kwargs=predict_kwargs,
)
else:
forecasts = historical_forecasts
Expand Down Expand Up @@ -1261,6 +1312,8 @@ def gridsearch(
verbose=False,
n_jobs: int = 1,
n_random_samples: Optional[Union[int, float]] = None,
fit_kwargs: Optional[Dict[str, Any]] = None,
predict_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple["ForecastingModel", Dict[str, Any], float]:
"""
Find the best hyper-parameters among a given set using a grid search.
Expand Down Expand Up @@ -1374,6 +1427,10 @@ def gridsearch(
must be between `0` and the total number of parameter combinations.
If a float, `n_random_samples` is the ratio of parameter combinations selected from the full grid and must
be between `0` and `1`. Defaults to `None`, for which random selection will be ignored.
fit_kwargs
Additional arguments passed to the model `fit()` method.
predict_kwargs
Additional arguments passed to the model `predict()` method.

Returns
-------
Expand Down Expand Up @@ -1406,10 +1463,10 @@ def gridsearch(
logger,
)

# TODO: here too I'd say we can leave these checks to the models
# if covariates is not None:
# raise_if_not(series.has_same_time_as(covariates), 'The provided series and covariates must have the '
# 'same time axes.')
if fit_kwargs is None:
fit_kwargs = dict()
if predict_kwargs is None:
predict_kwargs = dict()

# compute all hyperparameter combinations from selection
params_cross_product = list(product(*parameters.values()))
Expand Down Expand Up @@ -1437,7 +1494,12 @@ def _evaluate_combination(param_combination) -> float:

model = model_class(**param_combination_dict)
if use_fitted_values: # fitted value mode
model._fit_wrapper(series, past_covariates, future_covariates)
model._fit_wrapper(
series=series,
past_covariates=past_covariates,
future_covariates=future_covariates,
**fit_kwargs,
)
fitted_values = TimeSeries.from_times_and_values(
series.time_index, model.fitted_values
)
Expand All @@ -1457,16 +1519,24 @@ def _evaluate_combination(param_combination) -> float:
last_points_only=last_points_only,
verbose=verbose,
show_warnings=show_warnings,
fit_kwargs=fit_kwargs,
predict_kwargs=predict_kwargs,
madtoinou marked this conversation as resolved.
Show resolved Hide resolved
)
else: # split mode
model._fit_wrapper(series, past_covariates, future_covariates)
model._fit_wrapper(
series=series,
past_covariates=past_covariates,
future_covariates=future_covariates,
**fit_kwargs,
)
pred = model._predict_wrapper(
len(val_series),
series,
past_covariates,
future_covariates,
n=len(val_series),
series=series,
past_covariates=past_covariates,
future_covariates=future_covariates,
num_samples=1,
verbose=verbose,
**predict_kwargs,
)
error = metric(val_series, pred)

Expand Down Expand Up @@ -2211,43 +2281,6 @@ def predict(
)
)

def _predict_wrapper(
self,
n: int,
series: Union[TimeSeries, Sequence[TimeSeries]],
past_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]],
future_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]],
num_samples: int,
verbose: bool = False,
predict_likelihood_parameters: bool = False,
) -> Union[TimeSeries, Sequence[TimeSeries]]:
kwargs = dict()
if self.supports_likelihood_parameter_prediction:
kwargs["predict_likelihood_parameters"] = predict_likelihood_parameters
return self.predict(
n,
series,
past_covariates=past_covariates,
future_covariates=future_covariates,
num_samples=num_samples,
verbose=verbose,
**kwargs,
)

def _fit_wrapper(
self,
series: Union[TimeSeries, Sequence[TimeSeries]],
past_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]],
future_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]],
):
self.fit(
series=series,
past_covariates=past_covariates if self.supports_past_covariates else None,
future_covariates=future_covariates
if self.supports_future_covariates
else None,
)

@property
def _supports_non_retrainable_historical_forecasts(self) -> bool:
"""GlobalForecastingModel supports historical forecasts without retraining the model"""
Expand Down Expand Up @@ -2340,6 +2373,7 @@ def fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = None
logger=logger,
)
self._expect_future_covariates = True
self._uses_future_covariates = True

self.encoders = self.initialize_encoders()
if self.encoders.encoding_available:
Expand Down Expand Up @@ -2448,35 +2482,6 @@ def _predict(
"""
pass

def _fit_wrapper(
self,
series: TimeSeries,
past_covariates: Optional[TimeSeries],
future_covariates: Optional[TimeSeries],
):
self.fit(series, future_covariates=future_covariates)

def _predict_wrapper(
self,
n: int,
series: TimeSeries,
past_covariates: Optional[TimeSeries],
future_covariates: Optional[TimeSeries],
num_samples: int,
verbose: bool = False,
predict_likelihood_parameters: bool = False,
) -> TimeSeries:
kwargs = dict()
if self.supports_likelihood_parameter_prediction:
kwargs["predict_likelihood_parameters"] = predict_likelihood_parameters
return self.predict(
n,
future_covariates=future_covariates,
num_samples=num_samples,
verbose=verbose,
**kwargs,
)

@property
def _model_encoder_settings(
self,
Expand Down Expand Up @@ -2673,28 +2678,6 @@ def _predict(
"""
pass

def _predict_wrapper(
self,
n: int,
series: TimeSeries,
past_covariates: Optional[TimeSeries],
future_covariates: Optional[TimeSeries],
num_samples: int,
verbose: bool = False,
predict_likelihood_parameters: bool = False,
) -> TimeSeries:
kwargs = dict()
if self.supports_likelihood_parameter_prediction:
kwargs["predict_likelihood_parameters"] = predict_likelihood_parameters
return self.predict(
n=n,
series=series,
future_covariates=future_covariates,
num_samples=num_samples,
verbose=verbose,
**kwargs,
)

@property
def _supports_non_retrainable_historical_forecasts(self) -> bool:
return True
Expand Down
Loading
Loading