Skip to content

Commit

Permalink
FIX-2633: Use correct time indices when running historical forecasts …
Browse files Browse the repository at this point in the history
…on regression models with 'output_chunk_shift > 0' and 'output_chunk_length == 1'. Extended unit tests to cover this
  • Loading branch information
MattiasDC committed Dec 26, 2024
1 parent aad1440 commit 5abcce3
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,20 @@
from darts.utils.likelihood_models import GaussianLikelihood, QuantileRegression

models = [LinearRegressionModel, NaiveDrift]
models_reg_no_cov_cls_kwargs = [(LinearRegressionModel, {"lags": 8}, {}, (8, 1))]
models_reg_no_cov_cls_kwargs = [
(LinearRegressionModel, {"lags": 8}, {}, (8, 1)),
# output_chunk_length only
(LinearRegressionModel, {"lags": 5, "output_chunk_length": 2}, {}, (5, 1)),
# output_chunk_shift only
(LinearRegressionModel, {"lags": 5, "output_chunk_shift": 1}, {}, (5, 1)),
# output_chunk_shift + output_chunk_length only
(
LinearRegressionModel,
{"lags": 5, "output_chunk_shift": 1, "output_chunk_length": 2},
{},
(5, 1),
),
]
if not isinstance(CatBoostModel, NotImportedModule):
models_reg_no_cov_cls_kwargs.append((
CatBoostModel,
Expand Down Expand Up @@ -656,6 +669,19 @@ def test_historical_forecasts(self, config):
model_cls, kwargs, model_kwarg, bounds = config
model = model_cls(**kwargs, **model_kwarg)

if model.output_chunk_shift > 0:
with pytest.raises(ValueError):
forecasts = model.historical_forecasts(
series=self.ts_pass_val,
forecast_horizon=forecast_horizon,
stride=1,
train_length=train_length,
retrain=True,
overlap_end=False,
)
# continue the test without autogregression if we are using shifts
forecast_horizon = model.output_chunk_length

# time index
forecasts = model.historical_forecasts(
series=self.ts_pass_val,
Expand Down Expand Up @@ -1153,6 +1179,19 @@ def test_regression_auto_start_multiple_no_cov(self, config):
)
model.fit(self.ts_pass_train)

if model.output_chunk_shift > 0:
with pytest.raises(ValueError):
forecasts = model.historical_forecasts(
series=[self.ts_pass_val, self.ts_pass_val],
forecast_horizon=forecast_horizon,
train_length=train_length,
stride=1,
retrain=True,
overlap_end=False,
)
# continue the test without autogregression if we are using shifts
forecast_horizon = model.output_chunk_length

forecasts = model.historical_forecasts(
series=[self.ts_pass_val, self.ts_pass_val],
forecast_horizon=forecast_horizon,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,19 +162,24 @@ def _optimized_historical_forecasts_last_points_only(
else:
forecast = forecast[:, 0]

if (
stride == 1
and model.output_chunk_length == 1
and model.output_chunk_shift == 0
):
times = times[0]
else:
times = generate_index(
start=hist_fct_start
+ (forecast_horizon + model.output_chunk_shift - 1) * freq,
length=forecast.shape[0],
freq=freq * stride,
name=series_.time_index.name,
)

forecasts_list.append(
TimeSeries.from_times_and_values(
times=(
times[0]
if stride == 1 and model.output_chunk_length == 1
else generate_index(
start=hist_fct_start
+ (forecast_horizon + model.output_chunk_shift - 1) * freq,
length=forecast.shape[0],
freq=freq * stride,
name=series_.time_index.name,
)
),
times=times,
values=forecast,
columns=forecast_components,
static_covariates=series_.static_covariates,
Expand Down

0 comments on commit 5abcce3

Please sign in to comment.