Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/add fit/predict_kwargs argument to historical_forecasts #2050

Merged
merged 35 commits into from
Nov 16, 2023
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
aa8e341
fix: num_loader_workers can be passed to historical_forecasts, only r…
madtoinou Nov 2, 2023
a797b13
feat: added fit/predict_kwargs to historical_forecasts, backtest and …
madtoinou Nov 3, 2023
a3817e7
fix: default value None for dict
madtoinou Nov 3, 2023
7353db2
feat: increased the number of parameters handled by GlobalForecasting…
madtoinou Nov 3, 2023
4566858
Merge branch 'master' into fix/hfc_num_loader_workers
madtoinou Nov 3, 2023
e7ae306
fix: removed obsolete arg/docstring
madtoinou Nov 3, 2023
37fba40
fix: updated docstring
madtoinou Nov 3, 2023
2f780cf
fix: only pass the supported argument to GlobalForecastingModel.predi…
madtoinou Nov 3, 2023
2b2deca
Merge branch 'master' into fix/hfc_num_loader_workers
madtoinou Nov 6, 2023
294e326
Merge branch 'master' into fix/hfc_num_loader_workers
dennisbader Nov 6, 2023
d35b0d6
fix: simplify the logic of the fit/predict wrapper and hist fc sanity…
madtoinou Nov 7, 2023
2ff6422
Merge branch 'master' into fix/hfc_num_loader_workers
dennisbader Nov 8, 2023
8538885
fix: same signature for all _optimized...
madtoinou Nov 8, 2023
6dd2aa5
fix: changed the exception into warning
madtoinou Nov 8, 2023
e29bca9
fix: missing arg
madtoinou Nov 8, 2023
9031fbc
fix: harmonized signatures of optimized_hist, improved kwargs checks
madtoinou Nov 9, 2023
c4ef3fc
feat: added warning when fit_kwargs is set with retrain=False, possib…
madtoinou Nov 9, 2023
986a524
feat: improve fit/predict_kwargs handling, add tests
madtoinou Nov 10, 2023
fcc3767
feat: parametrize the tests to check both optimized and unoptimized m…
madtoinou Nov 10, 2023
c2e6327
Apply suggestions from code review
madtoinou Nov 13, 2023
8a7c260
Merge branch 'master' into fix/hfc_num_loader_workers
madtoinou Nov 13, 2023
3cea611
feat: updated changelog
madtoinou Nov 13, 2023
ed1e442
feat: added tests when fit_kwargs contains invalid arguments and retr…
madtoinou Nov 13, 2023
37e4afc
fix: set self._uses_future_covariates to True in FutureCovariatesLoca…
madtoinou Nov 13, 2023
783ba74
fix: predict_kwargs[trainer] is properly passed to fit_from_dataset
madtoinou Nov 13, 2023
32d8f14
feat: added exception when unsupported covariates are passed to the f…
madtoinou Nov 14, 2023
61d8171
feat: added tests checking that the exception are raised when expected
madtoinou Nov 14, 2023
c18e413
fix: ensemble model pass covariates only to forecasting models suppor…
madtoinou Nov 15, 2023
156ee62
update changelog
dennisbader Nov 16, 2023
30f27ff
Merge branch 'master' into fix/hfc_num_loader_workers
dennisbader Nov 16, 2023
bce4e86
update hist fc tests
dennisbader Nov 16, 2023
a5e21c3
fix failing bt test p2
dennisbader Nov 16, 2023
9f8cee9
uddate fit/predict wrappers
dennisbader Nov 16, 2023
0779d0c
update docs
dennisbader Nov 16, 2023
f6c0074
shorten electricity dataset to avoid source data updates
dennisbader Nov 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 150 additions & 10 deletions darts/models/forecasting/forecasting_model.py
madtoinou marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,12 @@

import numpy as np
import pandas as pd
from sklearn.multioutput import MultiOutputRegressor
madtoinou marked this conversation as resolved.
Show resolved Hide resolved

from darts import metrics
from darts.dataprocessing.encoders import SequentialEncoder
from darts.logging import get_logger, raise_if, raise_if_not, raise_log
from darts.models.utils import _check_kwargs_keys
from darts.timeseries import TimeSeries
from darts.utils import _build_tqdm_iterator, _parallel_apply, _with_sanity_checks
from darts.utils.historical_forecasts.utils import (
Expand Down Expand Up @@ -316,6 +318,7 @@ def _fit_wrapper(
series: TimeSeries,
past_covariates: Optional[TimeSeries],
future_covariates: Optional[TimeSeries],
**kwargs,
):
self.fit(series)

Expand All @@ -328,10 +331,21 @@ def _predict_wrapper(
num_samples: int,
verbose: bool = False,
predict_likelihood_parameters: bool = False,
num_loader_workers: int = 0,
madtoinou marked this conversation as resolved.
Show resolved Hide resolved
batch_size: Optional[int] = None,
n_jobs: int = 1,
roll_size: Optional[int] = None,
mc_dropout: bool = False,
) -> TimeSeries:
kwargs = dict()
if self.supports_likelihood_parameter_prediction:
kwargs["predict_likelihood_parameters"] = predict_likelihood_parameters
if getattr(self, "trainer_params", False):
kwargs["num_loader_workers"] = num_loader_workers
kwargs["batch_size"] = batch_size
kwargs["n_jobs"] = n_jobs
kwargs["roll_size"] = roll_size
kwargs["mc_dropout"] = mc_dropout
return self.predict(n, num_samples=num_samples, verbose=verbose, **kwargs)

@property
Expand Down Expand Up @@ -586,6 +600,8 @@ def historical_forecasts(
show_warnings: bool = True,
predict_likelihood_parameters: bool = False,
enable_optimization: bool = True,
fit_kwargs: Optional[Dict[str, Any]] = None,
predict_kwargs: Optional[Dict[str, Any]] = None,
) -> Union[
TimeSeries, List[TimeSeries], Sequence[TimeSeries], Sequence[List[TimeSeries]]
]:
Expand Down Expand Up @@ -692,6 +708,12 @@ def historical_forecasts(
Default: ``False``
enable_optimization
Whether to use the optimized version of historical_forecasts when supported and available.
fit_kwargs
Additional arguments passed to the model `fit()` method, for example `max_samples_per_ts`,
`n_jobs_multiouput_wrapper` or `num_loader_workers`.
predict_kwargs
Additional arguments passed to the model `predict()` method, for example `num_samples`,
`predict_likelihood_parameters` or `num_loader_workers`.

Returns
-------
Expand Down Expand Up @@ -802,6 +824,51 @@ def retrain_func(
logger,
)

if fit_kwargs is None:
fit_kwargs = dict()
if predict_kwargs is None:
predict_kwargs = dict()

# sanity checks of the arguments directly exposed by historical_forecasts
if "predict_likelihood_parameters" not in predict_kwargs:
madtoinou marked this conversation as resolved.
Show resolved Hide resolved
predict_kwargs[
"predict_likelihood_parameters"
] = predict_likelihood_parameters
elif (
predict_kwargs["predict_likelihood_parameters"]
!= predict_likelihood_parameters
):
logger.warning(
"`predict_likelihood_parameters` was provided with contradictory values, "
"retaining the value passed with `predict_kwargs`."
)
if "num_samples" not in predict_kwargs:
predict_kwargs["num_samples"] = num_samples
elif predict_kwargs["num_samples"] != num_samples:
logger.warning(
"`num_samples` was provided with contradictory values, "
"retaining the value passed with `predict_kwargs`."
)

# fit/predict_kwargs cannot be used to pass arguments used by historical_forecast logic
forbiden_args = ["series", "past_covariates", "future_covariates"]
madtoinou marked this conversation as resolved.
Show resolved Hide resolved
fit_invalid_args = forbiden_args + [
madtoinou marked this conversation as resolved.
Show resolved Hide resolved
"val_series",
"val_past_covariates",
"val_future_covariates",
]
_check_kwargs_keys(
param_name="fit_kwargs",
kwargs_dict=fit_kwargs,
invalid_keys=fit_invalid_args,
)
predict_invalid_args = forbiden_args + ["n", "trainer"]
_check_kwargs_keys(
param_name="predict_kwargs",
kwargs_dict=predict_kwargs,
invalid_keys=predict_invalid_args,
)

series = series2seq(series)
past_covariates = series2seq(past_covariates)
future_covariates = series2seq(future_covariates)
Expand All @@ -819,7 +886,6 @@ def retrain_func(
series=series,
past_covariates=past_covariates,
future_covariates=future_covariates,
num_samples=num_samples,
start=start,
start_format=start_format,
forecast_horizon=forecast_horizon,
Expand All @@ -828,7 +894,7 @@ def retrain_func(
last_points_only=last_points_only,
verbose=verbose,
show_warnings=show_warnings,
predict_likelihood_parameters=predict_likelihood_parameters,
predict_kwargs=predict_kwargs,
)

if len(series) == 1:
Expand Down Expand Up @@ -969,6 +1035,7 @@ def retrain_func(
series=train_series,
past_covariates=past_covariates_,
future_covariates=future_covariates_,
**fit_kwargs,
)
else:
# untrained model was not trained on the first trainable timestamp
Expand Down Expand Up @@ -1016,9 +1083,8 @@ def retrain_func(
series=train_series,
past_covariates=past_covariates_,
future_covariates=future_covariates_,
num_samples=num_samples,
verbose=verbose,
predict_likelihood_parameters=predict_likelihood_parameters,
**predict_kwargs,
)
if forecast_components is None:
forecast_components = forecast.columns
Expand Down Expand Up @@ -1076,6 +1142,8 @@ def backtest(
reduction: Union[Callable[[np.ndarray], float], None] = np.mean,
verbose: bool = False,
show_warnings: bool = True,
fit_kwargs: Optional[Dict[str, Any]] = None,
predict_kwargs: Optional[Dict[str, Any]] = None,
) -> Union[float, List[float], Sequence[float], List[Sequence[float]]]:
"""Compute error values that the model would have produced when
used on (potentially multiple) `series`.
Expand Down Expand Up @@ -1185,6 +1253,12 @@ def backtest(
Whether to print progress.
show_warnings
Whether to show warnings related to parameters `start`, and `train_length`.
fit_kwargs
Additional arguments passed to the model `fit()` method, for example `max_samples_per_ts`,
`n_jobs_multiouput_wrapper` or `num_loader_workers`.
predict_kwargs
Additional arguments passed to the model `predict()` method, for example `num_samples`,
`predict_likelihood_parameters` or `num_loader_workers`.

Returns
-------
Expand All @@ -1208,6 +1282,8 @@ def backtest(
last_points_only=last_points_only,
verbose=verbose,
show_warnings=show_warnings,
fit_kwargs=fit_kwargs,
predict_kwargs=predict_kwargs,
)
else:
forecasts = historical_forecasts
Expand Down Expand Up @@ -1261,6 +1337,8 @@ def gridsearch(
verbose=False,
n_jobs: int = 1,
n_random_samples: Optional[Union[int, float]] = None,
fit_kwargs: Optional[Dict[str, Any]] = None,
predict_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple["ForecastingModel", Dict[str, Any], float]:
"""
Find the best hyper-parameters among a given set using a grid search.
Expand Down Expand Up @@ -1374,6 +1452,12 @@ def gridsearch(
must be between `0` and the total number of parameter combinations.
If a float, `n_random_samples` is the ratio of parameter combinations selected from the full grid and must
be between `0` and `1`. Defaults to `None`, for which random selection will be ignored.
fit_kwargs
Additional arguments passed to the model `fit()` method, for example `max_samples_per_ts`,
`n_jobs_multiouput_wrapper` or `num_loader_workers`.
predict_kwargs
Additional arguments passed to the model `predict()` method, for example `predict_likelihood_parameters` or
`num_loader_workers`.

Returns
-------
Expand Down Expand Up @@ -1406,10 +1490,28 @@ def gridsearch(
logger,
)

# TODO: here too I'd say we can leave these checks to the models
# if covariates is not None:
# raise_if_not(series.has_same_time_as(covariates), 'The provided series and covariates must have the '
# 'same time axes.')
if fit_kwargs is None:
fit_kwargs = dict()
if predict_kwargs is None:
predict_kwargs = dict()

forbiden_args = ["series", "past_covariates", "future_covariates"]
madtoinou marked this conversation as resolved.
Show resolved Hide resolved
fit_invalid_args = forbiden_args + [
"val_series",
"val_past_covariates",
"val_future_covariates",
]
_check_kwargs_keys(
param_name="fit_kwargs",
kwargs_dict=fit_kwargs,
invalid_keys=fit_invalid_args,
)
predict_invalid_args = forbiden_args + ["n", "trainer", "num_samples"]
_check_kwargs_keys(
param_name="predict_kwargs",
kwargs_dict=predict_kwargs,
invalid_keys=predict_invalid_args,
)

# compute all hyperparameter combinations from selection
params_cross_product = list(product(*parameters.values()))
Expand Down Expand Up @@ -1437,7 +1539,12 @@ def _evaluate_combination(param_combination) -> float:

model = model_class(**param_combination_dict)
if use_fitted_values: # fitted value mode
model._fit_wrapper(series, past_covariates, future_covariates)
model._fit_wrapper(
series,
past_covariates,
future_covariates,
**fit_kwargs,
)
fitted_values = TimeSeries.from_times_and_values(
series.time_index, model.fitted_values
)
Expand All @@ -1457,16 +1564,20 @@ def _evaluate_combination(param_combination) -> float:
last_points_only=last_points_only,
verbose=verbose,
show_warnings=show_warnings,
predict_kwargs=predict_kwargs,
madtoinou marked this conversation as resolved.
Show resolved Hide resolved
)
else: # split mode
model._fit_wrapper(series, past_covariates, future_covariates)
model._fit_wrapper(
series, past_covariates, future_covariates, **fit_kwargs
)
pred = model._predict_wrapper(
len(val_series),
series,
past_covariates,
future_covariates,
num_samples=1,
verbose=verbose,
**predict_kwargs,
)
error = metric(val_series, pred)

Expand Down Expand Up @@ -2220,10 +2331,21 @@ def _predict_wrapper(
num_samples: int,
verbose: bool = False,
predict_likelihood_parameters: bool = False,
num_loader_workers: int = 0,
batch_size: Optional[int] = None,
n_jobs: int = 1,
roll_size: Optional[int] = None,
mc_dropout: bool = False,
) -> Union[TimeSeries, Sequence[TimeSeries]]:
kwargs = dict()
if self.supports_likelihood_parameter_prediction:
kwargs["predict_likelihood_parameters"] = predict_likelihood_parameters
if getattr(self, "trainer_params", False):
kwargs["num_loader_workers"] = num_loader_workers
kwargs["batch_size"] = batch_size
kwargs["n_jobs"] = n_jobs
kwargs["roll_size"] = roll_size
kwargs["mc_dropout"] = mc_dropout
return self.predict(
n,
series,
Expand All @@ -2239,13 +2361,30 @@ def _fit_wrapper(
series: Union[TimeSeries, Sequence[TimeSeries]],
past_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]],
future_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]],
max_samples_per_ts: Optional[int] = None,
n_jobs_multioutput_wrapper: Optional[int] = None,
trainer=None,
verbose: Optional[bool] = None,
epochs: int = 0,
num_loader_workers: int = 0,
):
kwargs = dict()
if getattr(self, "trainer_params", False):
kwargs["trainer"] = trainer
kwargs["epochs"] = epochs
kwargs["verbose"] = verbose
kwargs["num_loader_workers"] = num_loader_workers
kwargs["max_samples_per_ts"] = max_samples_per_ts
elif isinstance(self, MultiOutputRegressor):
kwargs["n_jobs_multioutput_wrapper"] = n_jobs_multioutput_wrapper
kwargs["max_samples_per_ts"] = max_samples_per_ts
self.fit(
series=series,
past_covariates=past_covariates if self.supports_past_covariates else None,
future_covariates=future_covariates
if self.supports_future_covariates
else None,
**kwargs,
)

@property
Expand Down Expand Up @@ -2453,6 +2592,7 @@ def _fit_wrapper(
series: TimeSeries,
past_covariates: Optional[TimeSeries],
future_covariates: Optional[TimeSeries],
**kwargs,
):
self.fit(series, future_covariates=future_covariates)

Expand Down
12 changes: 6 additions & 6 deletions darts/models/forecasting/regression_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1094,7 +1094,6 @@ def _optimized_historical_forecasts(
series: Optional[Sequence[TimeSeries]],
past_covariates: Optional[Sequence[TimeSeries]] = None,
future_covariates: Optional[Sequence[TimeSeries]] = None,
num_samples: int = 1,
start: Optional[Union[pd.Timestamp, float, int]] = None,
start_format: Literal["position", "value"] = "value",
forecast_horizon: int = 1,
Expand All @@ -1103,7 +1102,7 @@ def _optimized_historical_forecasts(
last_points_only: bool = True,
verbose: bool = False,
show_warnings: bool = True,
predict_likelihood_parameters: bool = False,
predict_kwargs: Optional[Dict[str, Any]] = None,
madtoinou marked this conversation as resolved.
Show resolved Hide resolved
) -> Union[
TimeSeries, List[TimeSeries], Sequence[TimeSeries], Sequence[List[TimeSeries]]
]:
Expand All @@ -1124,36 +1123,37 @@ def _optimized_historical_forecasts(
allow_autoregression=False,
)

if predict_kwargs is None:
predict_kwargs = dict()

# TODO: move the loop here instead of duplicated code in each sub-routine?
if last_points_only:
return _optimized_historical_forecasts_last_points_only(
model=self,
series=series,
past_covariates=past_covariates,
future_covariates=future_covariates,
num_samples=num_samples,
start=start,
start_format=start_format,
forecast_horizon=forecast_horizon,
stride=stride,
overlap_end=overlap_end,
show_warnings=show_warnings,
predict_likelihood_parameters=predict_likelihood_parameters,
**predict_kwargs,
madtoinou marked this conversation as resolved.
Show resolved Hide resolved
)
else:
return _optimized_historical_forecasts_all_points(
model=self,
series=series,
past_covariates=past_covariates,
future_covariates=future_covariates,
num_samples=num_samples,
start=start,
start_format=start_format,
forecast_horizon=forecast_horizon,
stride=stride,
overlap_end=overlap_end,
show_warnings=show_warnings,
predict_likelihood_parameters=predict_likelihood_parameters,
**predict_kwargs,
)


Expand Down
Loading