diff --git a/CHANGELOG.md b/CHANGELOG.md index cf805cad84..b47b190504 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,8 @@ but cannot always guarantee backwards compatibility. Changes that may **break co [Full Changelog](https://github.com/unit8co/darts/compare/0.32.0...master) +- Fix the bug in [#2579 ](https://github.com/unit8co/darts/issues/2579) that causes an error when `val_sample_weight` is set in the CatBoost and XGBoost models. + ### For users of the library: **Improved** @@ -1440,7 +1442,7 @@ ts: TimeSeries = AirPassengers().load() ```python # Assuming a multivariate TimeSeries named series with 3 columns or variables. # To apply fn to columns with names '0' and '2': - + #old syntax series.map(fn, cols=['0', '2']) # returned a time series with 3 columns #new syntax @@ -1452,13 +1454,13 @@ ts: TimeSeries = AirPassengers().load() ```python #old syntax fillna(series, fill=0) - + #new syntax fill_missing_values(series, fill=0) - + #old syntax auto_fillna(series, **interpolate_kwargs) - + #new syntax fill_missing_values(series, fill='auto', **interpolate_kwargs) fill_missing_values(series, **interpolate_kwargs) # fill='auto' by default @@ -1496,13 +1498,13 @@ ts: TimeSeries = AirPassengers().load() ```python # old syntax: backtest_forecasting(forecasting_model, *args, **kwargs) - + # new syntax: forecasting_model.backtest(*args, **kwargs) - + # old syntax: backtest_regression(regression_model, *args, **kwargs) - + # new syntax: regression_model.backtest(*args, **kwargs) ``` @@ -1511,13 +1513,13 @@ ts: TimeSeries = AirPassengers().load() ```python # old syntax: multivariate_model.fit(multivariate_series, target_indices=[0, 1]) - + # new syntax: multivariate_model.fit(multivariate_series, multivariate_series[["0", "1"]]) - + # old syntax: univariate_model.fit(multivariate_series, component_index=2) - + # new syntax: univariate_model.fit(multivariate_series["2"]) ``` diff --git a/darts/models/forecasting/regression_model.py b/darts/models/forecasting/regression_model.py index b5c76a0f0e..b7f02513d8 100644 --- a/darts/models/forecasting/regression_model.py +++ b/darts/models/forecasting/regression_model.py @@ -588,7 +588,7 @@ def _add_val_set_to_kwargs( val_weights = val_weights or None else: val_sets = [(val_samples, val_labels)] - val_weights = val_weight + val_weights = [val_weight] val_set_name, val_weight_name = self.val_set_params return dict(kwargs, **{val_set_name: val_sets, val_weight_name: val_weights})