Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/wrapper model gridsearch #2594

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
**Improved**

- New model: `StatsForecastAutoTBATS`. This model offers the [AutoTBATS](https://nixtlaverse.nixtla.io/statsforecast/src/core/models.html#autotbats) model from Nixtla's `statsforecasts` library. [#2611](https://github.com/unit8co/darts/pull/2611) by [He Weilin](https://github.com/cnhwl).
- Improvement to `gridsearch()`, now supports optimization of models wrapped in `RegressionModel`. [#2594](https://github.com/unit8co/darts/pull/2594) by [Andrés Sandoval](https://github.com/andresliszt)

**Fixed**

Expand Down
136 changes: 131 additions & 5 deletions darts/models/forecasting/forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1670,6 +1670,69 @@ def gridsearch(
Currently this method only supports deterministic predictions (i.e. when models' predictions
have only 1 sample).

Some darts models wrap scikit-learn like models (See, for instance :class:`RegressionModel`), i.e
they accept an argument ``model``. With the purpose of including 'model' in the grid search,
there are two possible options:

1. Give the key ``model`` in parameter as a valid list of instances of the scikit-learn like model used.

Example
.. highlight:: python
.. code-block:: python

from sklearn.ensemble import RandomForestRegressor

from darts.models import RegressionModel
from darts.utils import timeseries_generation as tg

parameters = {
"model": [
RandomForestRegressor(min_samples_split=2, min_samples_leaf=1),
RandomForestRegressor(min_samples_split=3, min_samples_leaf=2),
],
"lags": [1,2,3],
}
series = tg.sine_timeseries(length=100)

RegressionModel.gridsearch(
parameters=parameters, series=series, forecast_horizon=1
)
..

2. Give the key ``model`` in parameter as dictionary containing a special key
``model_class`` which is the scikit-learn like model class that will be used
to pass arguments from the grid. The other keys/values are arguments passed to the wrapped
model class and they behave as an inner parameters dictionary

Example
.. highlight:: python
.. code-block:: python

from sklearn.ensemble import RandomForestRegressor

from darts.models import RegressionModel
from darts.utils import timeseries_generation as tg

parameters = {
"model": {
"model_class": RandomForestRegressor,
"min_samples_split": [2,3],
"min_samples_leaf": [1,2],
},
"lags": [1,2,3],
}
series = tg.sine_timeseries(length=100)

RegressionModel.gridsearch(
parameters=parameters, series=series, forecast_horizon=1
)
..

In order to keep consistency in the best-performing hyper-parameters returned in this method,
wrapped model arguments are returned with a suffix containing the name of the wrapped model class
and a dot separator. For example, the parameter ``min_samples_split`` in the example above will be
returned as ``RandomForestRegressor.min_samples_split``

Parameters
----------
model_class
Expand Down Expand Up @@ -1796,12 +1859,34 @@ def gridsearch(
)
)

if "model" in parameters:
valid_model_list = isinstance(parameters["model"], list)
valid_nested_params = (
not valid_model_list
and parameters["model"].get("wrapped_model_class")
and all(
isinstance(params, (list, np.ndarray))
for p_name, params in parameters["model"].items()
if p_name != "wrapped_model_class"
)
)
if not (valid_model_list or valid_nested_params):
raise_log(
ValueError(
"When the 'model' key is set as a dictionary, it must contain the 'wrapped_model_class' key, "
"which represents the class of the model to be wrapped.",
logger,
)
)

if not all(
isinstance(params, (list, np.ndarray)) for params in parameters.values()
isinstance(params, (list, np.ndarray))
for p_name, params in parameters.items()
if p_name != "model"
):
raise_log(
ValueError(
"Every value in the `parameters` dictionary should be a list or a np.ndarray."
"Every hyper-parameter value in the `parameters` dictionary should be a list or a np.ndarray."
),
logger,
)
Expand Down Expand Up @@ -1832,6 +1917,24 @@ def gridsearch(
if predict_kwargs is None:
predict_kwargs = dict()

# Used if the darts model wraps a scikit-learn like model
wrapped_model_class = None

if "model" in parameters:
# Ask if model has been passed as a dictionary. This implies that the arguments
# of the wrapped model must be passed to the grid.
if (
isinstance(parameters["model"], dict)
and "wrapped_model_class" in parameters["model"]
):
wrapped_model_class = parameters["model"].pop("wrapped_model_class")
# Create a flat dictionary by adding a suffix to the arguments of the wrapped model in
# order to distinguish them from the other arguments of the Darts model
parameters.update({
f"{wrapped_model_class.__name__}.{k}": v
for k, v in parameters.pop("model").items()
})

# compute all hyperparameter combinations from selection
params_cross_product = list(product(*parameters.values()))

Expand All @@ -1849,6 +1952,25 @@ def gridsearch(
desc="gridsearch",
)

def _init_model_from_combination(param_combination_dict):
if wrapped_model_class is None:
return model_class(**param_combination_dict)

# Decode new keys created with the suffix.
wrapped_model_kwargs = {}
darts_model_kwargs = {}
for k, v in param_combination_dict.items():
if k.startswith(f"{wrapped_model_class.__name__}."):
wrapped_model_kwargs[
k.replace(f"{wrapped_model_class.__name__}.", "")
] = v
else:
darts_model_kwargs[k] = v
return model_class(
model=wrapped_model_class(**wrapped_model_kwargs),
**darts_model_kwargs,
)

def _evaluate_combination(param_combination) -> float:
param_combination_dict = dict(
list(zip(parameters.keys(), param_combination))
Expand All @@ -1859,7 +1981,8 @@ def _evaluate_combination(param_combination) -> float:
f"{current_time}_{param_combination_dict['model_name']}"
)

model = model_class(**param_combination_dict)
model = _init_model_from_combination(param_combination_dict)

if use_fitted_values: # fitted value mode
if data_transformers:
series_, past_covariates_, future_covariates_ = (
Expand Down Expand Up @@ -1968,8 +2091,11 @@ def _evaluate_combination(param_combination) -> float:
)

logger.info("Chosen parameters: " + str(best_param_combination))

return model_class(**best_param_combination), best_param_combination, min_error
return (
_init_model_from_combination(best_param_combination),
best_param_combination,
min_error,
)

def residuals(
self,
Expand Down
57 changes: 57 additions & 0 deletions darts/tests/models/forecasting/test_regression_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3243,6 +3243,63 @@ def test_lgbm_categorical_features_passed_to_fit_correctly(self, lgb_fit_patch):
) = self.lgbm_w_categorical_covariates._categorical_fit_param
assert kwargs[cat_param_name] == [2, 3, 5]

def test_grid_search(self):
# Create grid over wrapped model parameters too
parameters = {
"model": {
"wrapped_model_class": RandomForestRegressor,
"min_samples_split": [2, 3],
},
"lags": [1],
}
result = RegressionModel.gridsearch(
parameters=parameters, series=self.sine_multivariate1, forecast_horizon=1
)
assert isinstance(result[0], RegressionModel)
assert {
"lags",
"RandomForestRegressor.min_samples_split",
} == set(result[1])
assert isinstance(result[2], float)

# Use model as instances of RandomForestRegressor directly
parameters = {
"model": [
RandomForestRegressor(min_samples_split=2),
RandomForestRegressor(min_samples_split=3),
],
"lags": [1],
}

result = RegressionModel.gridsearch(
parameters=parameters, series=self.sine_multivariate1, forecast_horizon=1
)

assert isinstance(result[0], RegressionModel)
assert {
"lags",
"model",
} == set(result[1])
assert isinstance(result[1]["model"], RandomForestRegressor)
assert isinstance(result[2], float)

def test_grid_search_invalid_wrapped_model_dict(self):
parameters = {
"model": {"fit_intercept": [True, False]},
"lags": [1, 2, 3],
}
with pytest.raises(
ValueError,
match="When the 'model' key is set as a dictionary, it must contain "
"the 'wrapped_model_class' key, which represents the class of the model "
"to be wrapped.",
):
RegressionModel.gridsearch(
parameters=parameters,
series=self.sine_multivariate1,
forecast_horizon=1,
)

def helper_create_LinearModel(self, multi_models=True, extreme_lags=False):
if not extreme_lags:
lags, lags_pc, lags_fc = 3, 3, [-3, -2, -1, 0]
Expand Down
Loading