diff --git a/CHANGELOG.md b/CHANGELOG.md index 66f6f90538..fd6d9cdcd9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co **Improved** - Improvements to `ForecastingModel`: Improved `start` handling for historical forecasts, backtest, residuals, and gridsearch. If `start` is not within the trainable / forecastable points, uses the closest valid start point that is a round multiple of `stride` ahead of start. Raises a ValueError, if no valid start point exists. This guarantees that all historical forecasts are `n * stride` points away from start, and will simplify many downstream tasks. [#2560](https://github.com/unit8co/darts/issues/2560) by [Dennis Bader](https://github.com/dennisbader). +- 🚀🚀 New forecasting model: `TimesNetModel` as proposed in [this paper](https://arxiv.org/abs/2210.02186). [#2538](https://github.com/unit8co/darts/pull/2538) by [Greg DeVosNouri](https://github.com/gdevos010). **Fixed** @@ -548,7 +549,7 @@ Patch release [#1256](https://github.com/unit8co/darts/pull/1256) by [Julien Adda](https://github.com/julien12234) and [Julien Herzen](https://github.com/hrzn). - New forecasting models: `DLinearModel` and `NLinearModel` as proposed in [this paper](https://arxiv.org/pdf/2205.13504.pdf). - [#1139](https://github.com/unit8co/darts/pull/1139) by [Julien Herzen](https://github.com/hrzn) and [Greg DeVos](https://github.com/gdevos010). + [#1139](https://github.com/unit8co/darts/pull/1139) by [Julien Herzen](https://github.com/hrzn) and [Greg DeVosNouri](https://github.com/gdevos010). - New forecasting model: `XGBModel` implementing XGBoost. [#1405](https://github.com/unit8co/darts/pull/1405) by [Julien Herzen](https://github.com/hrzn). - New `multi_models` option for all `RegressionModel`s: when set to False, uses only a single underlying @@ -619,13 +620,13 @@ Patch release - Added support for past and future covariates to `residuals()` function. [#1223](https://github.com/unit8co/darts/pull/1223) by [Eliane Maalouf](https://github.com/eliane-maalouf). - Added support for retraining model(s) every `n` iteration and on custom conditions in `historical_forecasts` method of `ForecastingModel`s. [#1139](https://github.com/unit8co/darts/pull/1139) by [Francesco Bruzzesi](https://github.com/fbruzzesi). - Added support for beta-NLL in `GaussianLikelihood`s, as proposed in [this paper](https://arxiv.org/abs/2203.09168). [#1162](https://github.com/unit8co/darts/pull/1162) by [Julien Herzen](https://github.com/hrzn). -- New LayerNorm alternatives, RMSNorm and LayerNormNoBias [#1113](https://github.com/unit8co/darts/pull/1113) by [Greg DeVos](https://github.com/gdevos010). +- New LayerNorm alternatives, RMSNorm and LayerNormNoBias [#1113](https://github.com/unit8co/darts/issues/1113) by [Greg DeVosNouri](https://github.com/gdevos010). - 🔴 Improvements to encoders: improve fitting behavior of encoders' transformers and solve a couple of issues. Remove support for absolute index encoding. [#1257](https://github.com/unit8co/darts/pull/1257) by [Dennis Bader](https://github.com/dennisbader). - Overwrite min_train_series_length for Catboost and LightGBM [#1214](https://github.com/unit8co/darts/pull/1214) by [Anne de Vries](https://github.com/anne-devries). - New example notebook showcasing and end-to-end example of hyperparameter optimization with Optuna [#1242](https://github.com/unit8co/darts/pull/1242) by [Julien Herzen](https://github.com/hrzn). - New user guide section on hyperparameter optimization with Optuna and Ray Tune [#1242](https://github.com/unit8co/darts/pull/1242) by [Julien Herzen](https://github.com/hrzn). - Documentation on model saving and loading. [#1210](https://github.com/unit8co/darts/pull/1210) by [Amadej Kocbek](https://github.com/amadejkocbek). -- 🔴 `torch_device_str` has been removed from all torch models in favor of Pytorch Lightning's `pl_trainer_kwargs` method [#1244](https://github.com/unit8co/darts/pull/1244) by [Greg DeVos](https://github.com/gdevos010). +- 🔴 `torch_device_str` has been removed from all torch models in favor of Pytorch Lightning's `pl_trainer_kwargs` method [#1244](https://github.com/unit8co/darts/pull/1244) by [Greg DeVosNouri](https://github.com/gdevos010). **Fixed** @@ -682,16 +683,16 @@ Patch release - New Reconciliation transformers for forecast reconciliation: bottom up, top down and MinT. [#1012](https://github.com/unit8co/darts/pull/1012) by [Julien Herzen](https://github.com/hrzn). - Added support for Monte Carlo Dropout, as a way to capture model uncertainty with torch models at inference time. [#1013](https://github.com/unit8co/darts/pull/1013) by [Julien Herzen](https://github.com/hrzn). - New datasets: ETT and Electricity. [#617](https://github.com/unit8co/darts/pull/617) - by [Greg DeVos](https://github.com/gdevos010) -- New dataset, [Uber TLC](https://github.com/fivethirtyeight/uber-tlc-foil-response). [#1003](https://github.com/unit8co/darts/pull/1003) by [Greg DeVos](https://github.com/gdevos010). -- Model Improvements: Option for changing activation function for NHiTs and NBEATS. NBEATS support for dropout. NHiTs Support for AvgPooling1d. [#955](https://github.com/unit8co/darts/pull/955) by [Greg DeVos](https://github.com/gdevos010). -- Implemented ["GLU Variants Improve Transformer"](https://arxiv.org/abs/2002.05202) for transformer based models (transformer and TFT). [#968](https://github.com/unit8co/darts/pull/968) by [Greg DeVos](https://github.com/gdevos010). -- Added support for torch metrics during training and validation. [#996](https://github.com/unit8co/darts/pull/996) by [Greg DeVos](https://github.com/gdevos010). + by [Greg DeVosNouri](https://github.com/gdevos010) +- New dataset, [Uber TLC](https://github.com/fivethirtyeight/uber-tlc-foil-response). [#1003](https://github.com/unit8co/darts/pull/1003) by [Greg DeVosNouri](https://github.com/gdevos010). +- Model Improvements: Option for changing activation function for NHiTs and NBEATS. NBEATS support for dropout. NHiTs Support for AvgPooling1d. [#955](https://github.com/unit8co/darts/pull/955) by [Greg DeVosNouri](https://github.com/gdevos010). +- Implemented ["GLU Variants Improve Transformer"](https://arxiv.org/abs/2002.05202) for transformer based models (transformer and TFT). [#959](https://github.com/unit8co/darts/issues/959) by [Greg DeVosNouri](https://github.com/gdevos010). +- Added support for torch metrics during training and validation. [#996](https://github.com/unit8co/darts/pull/996) by [Greg DeVosNouri](https://github.com/gdevos010). - Better handling of logging [#1010](https://github.com/unit8co/darts/pull/1010) by [Dustin Brunner](https://github.com/brunnedu). - Better support for Python 3.10, and dropping `prophet` as a dependency (`Prophet` model still works if `prophet` package is installed separately) [#1023](https://github.com/unit8co/darts/pull/1023) by [Julien Herzen](https://github.com/hrzn). - Option to avoid global matplotlib configuration changes. [#924](https://github.com/unit8co/darts/pull/924) by [Mike Richman](https://github.com/zgana). -- 🔴 `HNiTSModel` renamed to `HNiTS` [#1000](https://github.com/unit8co/darts/pull/1000) by [Greg DeVos](https://github.com/gdevos010). +- 🔴 `HNiTSModel` renamed to `HNiTS` [#1000](https://github.com/unit8co/darts/pull/1000) by [Greg DeVosNouri](https://github.com/gdevos010). **Fixed** diff --git a/darts/models/__init__.py b/darts/models/__init__.py index 17640b195d..0836d9cb7d 100644 --- a/darts/models/__init__.py +++ b/darts/models/__init__.py @@ -33,6 +33,7 @@ from darts.models.forecasting.regression_model import RegressionModel from darts.models.forecasting.tbats_model import BATS, TBATS from darts.models.forecasting.theta import FourTheta, Theta +from darts.models.forecasting.times_net_model import TimesNetModel from darts.models.forecasting.varima import VARIMA try: @@ -50,6 +51,7 @@ from darts.models.forecasting.tcn_model import TCNModel from darts.models.forecasting.tft_model import TFTModel from darts.models.forecasting.tide_model import TiDEModel + from darts.models.forecasting.times_net_model import TimesNetModel from darts.models.forecasting.transformer_model import TransformerModel from darts.models.forecasting.tsmixer_model import TSMixerModel except ModuleNotFoundError: @@ -71,6 +73,7 @@ TFTModel = NotImportedModule(module_name="(Py)Torch", warn=False) TiDEModel = NotImportedModule(module_name="(Py)Torch", warn=False) TransformerModel = NotImportedModule(module_name="(Py)Torch", warn=False) + TimesNetModel = NotImportedModule(module_name="(Py)Torch", warn=False) TSMixerModel = NotImportedModule(module_name="(Py)Torch", warn=False) try: @@ -151,6 +154,7 @@ "TFTModel", "TiDEModel", "TransformerModel", + "TimesNetModel", "TSMixerModel", "Prophet", "CatBoostModel", diff --git a/darts/models/components/embed.py b/darts/models/components/embed.py new file mode 100644 index 0000000000..b59040822c --- /dev/null +++ b/darts/models/components/embed.py @@ -0,0 +1,237 @@ +""" +TimesNet Model +------- +The implementation is built upon the Time Series Library's TimesNet model + + +------- +MIT License + +Copyright (c) 2021 THUML @ Tsinghua University + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" + +import math + +import torch +import torch.nn as nn + +from darts.utils.torch import MonteCarloDropout + + +class PositionalEmbedding(nn.Module): + def __init__(self, d_model, max_len=5000): + super().__init__() + # Compute the positional encodings once in log space. + pe = torch.zeros(max_len, d_model).float() + pe.require_grad = False + + position = torch.arange(0, max_len).float().unsqueeze(1) + div_term = ( + torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model) + ).exp() + + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + + pe = pe.unsqueeze(0) + self.register_buffer("pe", pe) + + def forward(self, x): + return self.pe[:, : x.size(1)] + + +class TokenEmbedding(nn.Module): + def __init__(self, c_in, d_model): + super().__init__() + padding = 1 if torch.__version__ >= "1.5.0" else 2 + self.tokenConv = nn.Conv1d( + in_channels=c_in, + out_channels=d_model, + kernel_size=3, + padding=padding, + padding_mode="circular", + bias=False, + ) + for m in self.modules(): + if isinstance(m, nn.Conv1d): + nn.init.kaiming_normal_( + m.weight, mode="fan_in", nonlinearity="leaky_relu" + ) + + def forward(self, x): + x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2) + return x + + +class FixedEmbedding(nn.Module): + def __init__(self, c_in, d_model): + super().__init__() + + w = torch.zeros(c_in, d_model).float() + w.require_grad = False + + position = torch.arange(0, c_in).float().unsqueeze(1) + div_term = ( + torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model) + ).exp() + + w[:, 0::2] = torch.sin(position * div_term) + w[:, 1::2] = torch.cos(position * div_term) + + self.emb = nn.Embedding(c_in, d_model) + self.emb.weight = nn.Parameter(w, requires_grad=False) + + def forward(self, x): + return self.emb(x).detach() + + +class TemporalEmbedding(nn.Module): + def __init__(self, d_model, embed_type="fixed", freq="h"): + super().__init__() + + minute_size = 4 + hour_size = 24 + weekday_size = 7 + day_size = 32 + month_size = 13 + + Embed = FixedEmbedding if embed_type == "fixed" else nn.Embedding + if freq == "t": + self.minute_embed = Embed(minute_size, d_model) + self.hour_embed = Embed(hour_size, d_model) + self.weekday_embed = Embed(weekday_size, d_model) + self.day_embed = Embed(day_size, d_model) + self.month_embed = Embed(month_size, d_model) + + def forward(self, x): + x = x.long() + minute_x = ( + self.minute_embed(x[:, :, 4]) if hasattr(self, "minute_embed") else 0.0 + ) + hour_x = self.hour_embed(x[:, :, 3]) + weekday_x = self.weekday_embed(x[:, :, 2]) + day_x = self.day_embed(x[:, :, 1]) + month_x = self.month_embed(x[:, :, 0]) + + return hour_x + weekday_x + day_x + month_x + minute_x + + +class TimeFeatureEmbedding(nn.Module): + def __init__(self, d_model, embed_type="timeF", freq="h"): + super().__init__() + + freq_map = {"h": 4, "t": 5, "s": 6, "m": 1, "a": 1, "w": 2, "d": 3, "b": 3} + d_inp = freq_map[freq] + self.embed = nn.Linear(d_inp, d_model, bias=False) + + def forward(self, x): + return self.embed(x) + + +class DataEmbedding(nn.Module): + def __init__(self, c_in, d_model, embed_type="fixed", freq="h", dropout=0.1): + super().__init__() + + self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) + self.position_embedding = PositionalEmbedding(d_model=d_model) + self.temporal_embedding = ( + TemporalEmbedding(d_model=d_model, embed_type=embed_type, freq=freq) + if embed_type != "timeF" + else TimeFeatureEmbedding(d_model=d_model, embed_type=embed_type, freq=freq) + ) + self.dropout = MonteCarloDropout(p=dropout) + + def forward(self, x, x_mark): + if x_mark is None: + x = self.value_embedding(x) + self.position_embedding(x) + else: + x = ( + self.value_embedding(x) + + self.temporal_embedding(x_mark) + + self.position_embedding(x) + ) + return self.dropout(x) + + +class DataEmbedding_inverted(nn.Module): + def __init__(self, c_in, d_model, embed_type="fixed", freq="h", dropout=0.1): + super().__init__() + self.value_embedding = nn.Linear(c_in, d_model) + self.dropout = MonteCarloDropout(p=dropout) + + def forward(self, x, x_mark): + x = x.permute(0, 2, 1) + # x: [Batch Variate Time] + if x_mark is None: + x = self.value_embedding(x) + else: + x = self.value_embedding(torch.cat([x, x_mark.permute(0, 2, 1)], 1)) + # x: [Batch Variate d_model] + return self.dropout(x) + + +class DataEmbedding_wo_pos(nn.Module): + def __init__(self, c_in, d_model, embed_type="fixed", freq="h", dropout=0.1): + super().__init__() + + self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) + self.position_embedding = PositionalEmbedding(d_model=d_model) + self.temporal_embedding = ( + TemporalEmbedding(d_model=d_model, embed_type=embed_type, freq=freq) + if embed_type != "timeF" + else TimeFeatureEmbedding(d_model=d_model, embed_type=embed_type, freq=freq) + ) + self.dropout = MonteCarloDropout(p=dropout) + + def forward(self, x, x_mark): + if x_mark is None: + x = self.value_embedding(x) + else: + x = self.value_embedding(x) + self.temporal_embedding(x_mark) + return self.dropout(x) + + +class PatchEmbedding(nn.Module): + def __init__(self, d_model, patch_len, stride, padding, dropout): + super().__init__() + # Patching + self.patch_len = patch_len + self.stride = stride + self.padding_patch_layer = nn.ReplicationPad1d((0, padding)) + + # Backbone, Input encoding: projection of feature vectors onto a d-dim vector space + self.value_embedding = nn.Linear(patch_len, d_model, bias=False) + + # Positional embedding + self.position_embedding = PositionalEmbedding(d_model) + + # Residual dropout + self.dropout = MonteCarloDropout(p=dropout) + + def forward(self, x): + # do patching + n_vars = x.shape[1] + x = self.padding_patch_layer(x) + x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride) + x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3])) + # Input encoding + x = self.value_embedding(x) + self.position_embedding(x) + return self.dropout(x), n_vars diff --git a/darts/models/forecasting/__init__.py b/darts/models/forecasting/__init__.py index 37a50aa4bc..3c1cb01306 100644 --- a/darts/models/forecasting/__init__.py +++ b/darts/models/forecasting/__init__.py @@ -42,6 +42,7 @@ - :class:`~darts.models.forecasting.nhits.NHiTSModel` - :class:`~darts.models.forecasting.tcn_model.TCNModel` - :class:`~darts.models.forecasting.transformer_model.TransformerModel` + - :class:`~darts.models.forecasting.time_net_model.TimesNetModel` - :class:`~darts.models.forecasting.tft_model.TFTModel` - :class:`~darts.models.forecasting.dlinear.DLinearModel` - :class:`~darts.models.forecasting.nlinear.NLinearModel` diff --git a/darts/models/forecasting/times_net_model.py b/darts/models/forecasting/times_net_model.py new file mode 100644 index 0000000000..883ceda945 --- /dev/null +++ b/darts/models/forecasting/times_net_model.py @@ -0,0 +1,472 @@ +""" +TimesNet Model +------- +The implementation is built upon the Time Series Library's TimesNet model + + +------- +MIT License + +Copyright (c) 2021 THUML @ Tsinghua University + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" + +from typing import Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from packaging import version + +from darts.models.components.embed import DataEmbedding +from darts.models.forecasting.pl_forecasting_module import ( + PLPastCovariatesModule, + io_processor, +) +from darts.models.forecasting.torch_forecasting_model import PastCovariatesTorchModel + + +class Inception_Block_V1(nn.Module): + def __init__(self, in_channels, out_channels, num_kernels=6, init_weight=True): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.num_kernels = num_kernels + + self.kernels = nn.ModuleList([ + nn.Conv2d(in_channels, out_channels, kernel_size=2 * i + 1, padding=i) + for i in range(self.num_kernels) + ]) + + if init_weight: + self._initialize_weights() + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + res = torch.stack([kernel(x) for kernel in self.kernels], dim=-1) + return torch.mean(res, dim=-1) + + +def FFT_for_Period(x, k: int = 2): + # [B, T, C] + xf = torch.fft.rfft(x, dim=1) + # find period by amplitudes + frequency_list = abs(xf).mean(0).mean(-1) + frequency_list[0] = 0 + _, top_list = torch.topk(frequency_list, k) + + T = torch.tensor(x.shape[1], dtype=torch.int64) + period = (T / top_list).to(torch.int64) + return period, abs(xf).mean(-1)[:, top_list] + + +if version.parse(torch.__version__) >= version.parse("2.0.0"): + import torch.jit + + FFT_for_Period = torch.jit.script(FFT_for_Period) + + +class TimesBlock(nn.Module): + def __init__(self, seq_len, pred_len, top_k, d_model, d_ff, num_kernels): + super().__init__() + self.seq_len = seq_len + self.pred_len = pred_len + self.k = top_k + self.conv = nn.Sequential( + Inception_Block_V1(d_model, d_ff, num_kernels=num_kernels), + nn.GELU(), + Inception_Block_V1(d_ff, d_model, num_kernels=num_kernels), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, T, N = x.size() + period_list, period_weight = FFT_for_Period(x, self.k) + + res = [] + for period in period_list: + # padding + if (self.seq_len + self.pred_len) % period != 0: + length = (((self.seq_len + self.pred_len) // period) + 1) * period + padding = torch.zeros([ + x.shape[0], + (length - (self.seq_len + self.pred_len)), + x.shape[2], + ]).to(x.device) + out = torch.cat([x, padding], dim=1) + else: + length = self.seq_len + self.pred_len + out = x + # reshape + out = ( + out.reshape(B, length // period, period, N) + .permute(0, 3, 1, 2) + .contiguous() + ) + # 2D conv: from 1d Variation to 2d Variation + out = self.conv(out) + # reshape back + out = out.permute(0, 2, 3, 1).reshape(B, -1, N) + res.append(out[:, : (self.seq_len + self.pred_len), :]) + res = torch.stack(res, dim=-1) + # adaptive aggregation + period_weight = F.softmax(period_weight, dim=1) + period_weight = period_weight.unsqueeze(1).unsqueeze(1).repeat(1, T, N, 1) + res = torch.sum(res * period_weight, -1) + # residual connection + res = res + x + return res + + +class _TimesNetModule(PLPastCovariatesModule): + def __init__( + self, + input_dim: int, + output_dim: int, + nr_params: int, + hidden_size: int, + num_layers: int, + num_kernels: int, + top_k: int, + embed_type: str = "fixed", + freq: str = "h", + **kwargs, + ): + """ + input_size + The dimensionality of the TimeSeries instances that will be fed to the the fit and predict functions. + output_size + The dimensionality of the output time series. + nr_params + The number of parameters of the likelihood (or 1 if no likelihood is used). + hidden_size : int + The size of the hidden layers in the model. + num_layers : int + The number of TimesBlock layers in the model. + num_kernels : int + The number of kernels in each Inception block within the TimesBlock. + top_k : int + The number of top frequencies to consider in the FFT analysis. + embed_type : str, optional + The type of embedding to use. Default is "fixed". + freq : str, optional + The frequency of the time series. Default is "h" (hourly). + **kwargs + Additional keyword arguments passed to the parent PLPastCovariatesModule. + + Notes + ----- + - The `embed_type` and `freq` parameters are currently placeholders and are not fully utilized + in the current implementation. + """ + super().__init__(**kwargs) + + self.input_dim = input_dim + self.output_dim = output_dim + self.nr_params = nr_params + + # embed_type and freq are placeholders and are not used until the futures + # covariate in the forward method are figured out + self.embedding = DataEmbedding( + input_dim, hidden_size, embed_type=embed_type, freq=freq, dropout=0.1 + ) + + self.model = nn.ModuleList([ + TimesBlock( + seq_len=self.input_chunk_length, + pred_len=self.output_chunk_length, + top_k=top_k, + d_model=hidden_size, + d_ff=hidden_size * 4, + num_kernels=num_kernels, + ) + for _ in range(num_layers) + ]) + self.layer_norm = nn.LayerNorm(hidden_size) + self.predict_linear = nn.Linear( + self.input_chunk_length, self.output_chunk_length + self.input_chunk_length + ) + self.projection = nn.Linear(hidden_size, output_dim * nr_params) + + @io_processor + def forward(self, x_in: Tuple) -> torch.Tensor: + x, _ = x_in + + # Embedding + x = self.embedding(x, None) # TODO: future covariate would go here + x = self.predict_linear(x.transpose(1, 2)).transpose(1, 2) + + # TimesNet + for layer in self.model: + x = self.layer_norm(layer(x)) + + y = self.projection(x) + + y = y[:, -self.output_chunk_length :, :] + y = y.view( + y.shape[0], self.output_chunk_length, self.output_dim, self.nr_params + ) + + return y + + +class TimesNetModel(PastCovariatesTorchModel): + def __init__( + self, + input_chunk_length: int, + output_chunk_length: int, + output_chunk_shift: int = 0, + hidden_size: int = 32, + num_layers: int = 2, + num_kernels: int = 6, + top_k: int = 5, + **kwargs, + ): + """ + TimesNet model for time series forecasting. + + This model is based on the paper "TimesNet: Temporal 2D-Variation Modeling for General Time Series Analysis" + by Haixu Wu et al. (2023). TimesNet uses a combination of 2D convolutions and frequency domain analysis + to capture both temporal patterns and periodic variations in time series data. + https://arxiv.org/abs/2210.02186 + + Parameters + ---------- + input_chunk_length + Number of time steps in the past to take as a model input (per chunk). Applies to the target + series, and past and/or future covariates (if the model supports it). + output_chunk_length + Number of time steps predicted at once (per chunk) by the internal model. Also, the number of future values + from future covariates to use as a model input (if the model supports future covariates). It is not the same + as forecast horizon `n` used in `predict()`, which is the desired number of prediction points generated + using either a one-shot- or autoregressive forecast. Setting `n <= output_chunk_length` prevents + auto-regression. This is useful when the covariates don't extend far enough into the future, or to prohibit + the model from using future values of past and / or future covariates for prediction (depending on the + model's covariate support). + output_chunk_shift + Optionally, the number of steps to shift the start of the output chunk into the future (relative to the + input chunk end). This will create a gap between the input and output. If the model supports + `future_covariates`, the future values are extracted from the shifted output chunk. Predictions will start + `output_chunk_shift` steps after the end of the target `series`. If `output_chunk_shift` is set, the model + cannot generate autoregressive predictions (`n > output_chunk_length`). + hidden_size : int, optional + The hidden size of the model, controlling the dimensionality of the internal representations. + Default: 32. + num_layers : int, optional + The number of TimesBlock layers in the model. Each layer processes the input sequence + using 2D convolutions and frequency domain analysis. Default: 2. + num_kernels : int, optional + The number of kernels in each Inception block within the TimesBlock. This controls + the variety of convolution operations applied to the input. Default: 6. + top_k : int, optional + The number of top frequencies to consider in the FFT analysis. This parameter influences + how many periodic components are extracted from the input sequence. Default: 5. + **kwargs + Optional arguments to initialize the pytorch_lightning.Module, pytorch_lightning.Trainer, and + Darts' :class:`TorchForecastingModel`. + + loss_fn + PyTorch loss function used for training. + This parameter will be ignored for probabilistic models if the ``likelihood`` parameter is specified. + Default: ``torch.nn.MSELoss()``. + likelihood + One of Darts' :meth:`Likelihood ` models to be used for + probabilistic forecasts. Default: ``None``. + torch_metrics + A torch metric or a ``MetricCollection`` used for evaluation. A full list of available metrics can be found + at https://torchmetrics.readthedocs.io/en/latest/. Default: ``None``. + optimizer_cls + The PyTorch optimizer class to be used. Default: ``torch.optim.Adam``. + optimizer_kwargs + Optionally, some keyword arguments for the PyTorch optimizer (e.g., ``{'lr': 1e-3}`` + for specifying a learning rate). Otherwise, the default values of the selected ``optimizer_cls`` + will be used. Default: ``None``. + lr_scheduler_cls + Optionally, the PyTorch learning rate scheduler class to be used. Specifying ``None`` corresponds + to using a constant learning rate. Default: ``None``. + lr_scheduler_kwargs + Optionally, some keyword arguments for the PyTorch learning rate scheduler. Default: ``None``. + use_reversible_instance_norm + Whether to use reversible instance normalization `RINorm` against distribution shift as shown in [3]_. + It is only applied to the features of the target series and not the covariates. + batch_size + Number of time series (input and output sequences) used in each training pass. Default: ``32``. + n_epochs + Number of epochs over which to train the model. Default: ``100``. + model_name + Name of the model. Used for creating checkpoints and saving torch.Tensorboard data. If not specified, + defaults to the following string ``"YYYY-mm-dd_HH_MM_SS_torch_model_run_PID"``, where the initial part + of the name is formatted with the local date and time, while PID is the processed ID (preventing models + spawned at the same time by different processes to share the same model_name). E.g., + ``"2021-06-14_09_53_32_torch_model_run_44607"``. + work_dir + Path of the working directory, where to save checkpoints and torch.Tensorboard summaries. + Default: current working directory. + log_torch.Tensorboard + If set, use torch.Tensorboard to log the different parameters. The logs will be located in: + ``"{work_dir}/darts_logs/{model_name}/logs/"``. Default: ``False``. + nr_epochs_val_period + Number of epochs to wait before evaluating the validation loss (if a validation + ``TimeSeries`` is passed to the :func:`fit()` method). Default: ``1``. + force_reset + If set to ``True``, any previously-existing model with the same name will be reset (all checkpoints will + be discarded). Default: ``False``. + save_checkpoints + Whether to automatically save the untrained model and checkpoints from training. + To load the model from checkpoint, call :func:`MyModelClass.load_from_checkpoint()`, where + :class:`MyModelClass` is the :class:`TorchForecastingModel` class that was used (such as :class:`TFTModel`, + :class:`NBEATSModel`, etc.). If set to ``False``, the model can still be manually saved using + :func:`save()` and loaded using :func:`load()`. Default: ``False``. + add_encoders + A large number of past and future covariates can be automatically generated with `add_encoders`. + This can be done by adding multiple pre-defined index encoders and/or custom user-made functions that + will be used as index encoders. Additionally, a transformer such as Darts' :class:`Scaler` can be added to + transform the generated covariates. This happens all under one hood and only needs to be specified at + model creation. + Read :meth:`SequentialEncoder ` to find out more about + ``add_encoders``. Default: ``None``. An example showing some of ``add_encoders`` features: + + .. highlight:: python + .. code-block:: python + + def encode_year(idx): + return (idx.year - 1950) / 50 + + add_encoders={ + 'cyclic': {'future': ['month']}, + 'datetime_attribute': {'future': ['hour', 'dayofweek']}, + 'position': {'past': ['relative'], 'future': ['relative']}, + 'custom': {'past': [encode_year]}, + 'transformer': Scaler(), + 'tz': 'CET' + } + .. + random_state + Control the randomness of the weight's initialization. Check this + `link `_ for more details. + Default: ``None``. + pl_trainer_kwargs + By default :class:`TorchForecastingModel` creates a PyTorch Lightning Trainer with several useful presets + that performs the training, validation and prediction processes. These presets include automatic + checkpointing, torch.Tensorboard logging, setting the torch device and more. + With ``pl_trainer_kwargs`` you can add additional kwargs to instantiate the PyTorch Lightning trainer + object. Check the `PL Trainer documentation + `_ for more information about the + supported kwargs. Default: ``None``. + Running on GPU(s) is also possible using ``pl_trainer_kwargs`` by specifying keys ``"accelerator", + "devices", and "auto_select_gpus"``. Some examples for setting the devices inside the ``pl_trainer_kwargs`` + dict: + + - ``{"accelerator": "cpu"}`` for CPU, + - ``{"accelerator": "gpu", "devices": [i]}`` to use only GPU ``i`` (``i`` must be an integer), + - ``{"accelerator": "gpu", "devices": -1, "auto_select_gpus": True}`` to use all available GPUS. + + For more info, see here: + https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html#trainer-flags , and + https://pytorch-lightning.readthedocs.io/en/stable/accelerators/gpu_basic.html#train-on-multiple-gpus + + With parameter ``"callbacks"`` you can add custom or PyTorch-Lightning built-in callbacks to Darts' + :class:`TorchForecastingModel`. Below is an example for adding EarlyStopping to the training process. + The model will stop training early if the validation loss `val_loss` does not improve beyond + specifications. For more information on callbacks, visit: + `PyTorch Lightning Callbacks + `_ + + .. highlight:: python + .. code-block:: python + + from pytorch_lightning.callbacks.early_stopping import EarlyStopping + + # stop training when validation loss does not decrease more than 0.05 (`min_delta`) over + # a period of 5 epochs (`patience`) + my_stopper = EarlyStopping( + monitor="val_loss", + patience=5, + min_delta=0.05, + mode='min', + ) + + pl_trainer_kwargs={"callbacks": [my_stopper]} + .. + + Note that you can also use a custom PyTorch Lightning Trainer for training and prediction with optional + parameter ``trainer`` in :func:`fit()` and :func:`predict()`. + show_warnings + whether to show warnings raised from PyTorch Lightning. Useful to detect potential issues of + your forecasting use case. Default: ``False``. + Examples + -------- + >>> from darts.datasets import WeatherDataset + >>> from darts.models import TimesNetModel + >>> series = WeatherDataset().load() + >>> # predicting atmospheric pressure + >>> target = series['p (mbar)'][:100] + >>> # optionally, use past observed rainfall (pretending to be unknown beyond index 100) + >>> past_cov = series['rain (mm)'][:100] + >>> model = TimesNetModel( + >>> input_chunk_length=6, + >>> output_chunk_length=6, + >>> n_epochs=20 + >>> ) + >>> model.fit(target, past_covariates=past_cov) + >>> pred = model.predict(6) + >>> pred.values() + array([[5.40498034], + [5.36561899], + [5.80616883], + [6.48695488], + [7.63158655], + [5.65417736]]) + """ + super().__init__(**self._extract_torch_model_params(**self.model_params)) + + # extract pytorch lightning module kwargs + self.pl_module_params = self._extract_pl_module_params(**self.model_params) + self.hidden_size = hidden_size + self.num_layers = num_layers + self.num_kernels = num_kernels + self.top_k = top_k + + def _create_model(self, train_sample: Tuple[torch.Tensor]) -> torch.nn.Module: + input_dim = train_sample[0].shape[1] + ( + train_sample[1].shape[1] if train_sample[1] is not None else 0 + ) + output_dim = train_sample[-1].shape[1] + nr_params = 1 if self.likelihood is None else self.likelihood.num_parameters + + return _TimesNetModule( + input_dim=input_dim, + output_dim=output_dim, + nr_params=nr_params, + hidden_size=self.hidden_size, + num_layers=self.num_layers, + num_kernels=self.num_kernels, + top_k=self.top_k, + **self.pl_module_params, + ) + + @property + def supports_multivariate(self) -> bool: + return True diff --git a/darts/models/forecasting/transformer_model.py b/darts/models/forecasting/transformer_model.py index 59870e4975..4511f23cbe 100644 --- a/darts/models/forecasting/transformer_model.py +++ b/darts/models/forecasting/transformer_model.py @@ -480,7 +480,7 @@ def encode_year(idx): } .. random_state - Control the randomness of the weights initialization. Check this + Control the randomness of the weight's initialization. Check this `link `_ for more details. Default: ``None``. pl_trainer_kwargs diff --git a/darts/models/forecasting/tsmixer_model.py b/darts/models/forecasting/tsmixer_model.py index 56100ed4c6..f8b165521e 100644 --- a/darts/models/forecasting/tsmixer_model.py +++ b/darts/models/forecasting/tsmixer_model.py @@ -475,7 +475,7 @@ def forward( Returns ------- torch.torch.Tensor - The output Tensorof shape `(batch_size, output_chunk_length, output_dim, nr_params)`. + The output Tensor of shape `(batch_size, output_chunk_length, output_dim, nr_params)`. """ # B: batch size # L: input chunk length diff --git a/darts/tests/models/components/test_embed.py b/darts/tests/models/components/test_embed.py new file mode 100644 index 0000000000..8fff02e13f --- /dev/null +++ b/darts/tests/models/components/test_embed.py @@ -0,0 +1,198 @@ +import pytest +import torch + +from darts.models.components.embed import ( + DataEmbedding, + DataEmbedding_inverted, + DataEmbedding_wo_pos, + FixedEmbedding, + PatchEmbedding, + PositionalEmbedding, + TemporalEmbedding, + TimeFeatureEmbedding, + TokenEmbedding, +) + + +class TestEmbedding: + def test_PositionalEmbedding(self): + d_model = 64 + max_len = 500 + embedding = PositionalEmbedding(d_model=d_model, max_len=max_len) + x = torch.randn(10, 100, d_model) # batch_size=10, seq_len=100, d_model + pe = embedding(x) + assert pe.shape == ( + 1, + 100, + d_model, + ), "PositionalEmbedding output shape mismatch." + # Test that pe does not require gradient + assert ( + not pe.requires_grad + ), "PositionalEmbedding output should not require grad." + + def test_TokenEmbedding(self): + c_in = 10 + d_model = 64 + batch_size = 32 + seq_len = 100 + embedding = TokenEmbedding(c_in=c_in, d_model=d_model) + x = torch.randn(batch_size, seq_len, c_in) # [B, L, C_in] + output = embedding(x) + assert output.shape == ( + batch_size, + seq_len, + d_model, + ), "TokenEmbedding output shape mismatch." + + def test_FixedEmbedding(self): + c_in = 32 + d_model = 64 + embedding = FixedEmbedding(c_in, d_model) + x = torch.arange(0, c_in).unsqueeze(0) # [1, c_in] + output = embedding(x) + assert output.shape == ( + 1, + c_in, + d_model, + ), "FixedEmbedding output shape mismatch." + + def test_TemporalEmbedding(self): + d_model = 64 + embed_type = "fixed" + freq = "h" + embedding = TemporalEmbedding(d_model, embed_type, freq) + batch_size = 32 + seq_len = 100 + + month = torch.randint(0, 13, (batch_size, seq_len, 1)) # 0-12 for months + day = torch.randint(1, 32, (batch_size, seq_len, 1)) # 1-31 for days + weekday = torch.randint(0, 7, (batch_size, seq_len, 1)) # 0-6 for weekdays + hour = torch.randint(0, 24, (batch_size, seq_len, 1)) # 0-23 for hours + minute = torch.randint( + 0, 4, (batch_size, seq_len, 1) + ) # 0-3 for minutes (assuming 15-minute intervals) + + x = torch.cat([month, day, weekday, hour, minute], dim=2) + + output = embedding(x) + assert output.shape == ( + batch_size, + seq_len, + d_model, + ), "TemporalEmbedding output shape mismatch." + + def test_DataEmbedding_no_x_mark(self): + c_in = 10 + d_model = 64 + embed_type = "fixed" + freq = "h" + dropout = 0.1 + embedding = DataEmbedding(c_in, d_model, embed_type, freq, dropout) + batch_size = 32 + seq_len = 100 + x = torch.randn(batch_size, seq_len, c_in) + output = embedding(x, None) + assert output.shape == ( + batch_size, + seq_len, + d_model, + ), "DataEmbedding output shape mismatch when x_mark is None." + + def test_DataEmbedding_wo_pos_no_x_mark(self): + c_in = 10 + d_model = 64 + embed_type = "fixed" + freq = "h" + dropout = 0.1 + embedding = DataEmbedding_wo_pos(c_in, d_model, embed_type, freq, dropout) + batch_size = 32 + seq_len = 100 + x = torch.randn(batch_size, seq_len, c_in) + output = embedding(x, None) + assert output.shape == ( + batch_size, + seq_len, + d_model, + ), "DataEmbedding_wo_pos output shape mismatch when x_mark is None." + + def test_DataEmbedding_with_x_mark(self): + c_in = 10 + d_model = 64 + embed_type = "fixed" + freq = "h" + dropout = 0.1 + embedding = DataEmbedding(c_in, d_model, embed_type, freq, dropout) + batch_size = 32 + seq_len = 100 + x = torch.randn(batch_size, seq_len, c_in) + + # Create x_mark with appropriate integer indices + month = torch.randint(0, 12, (batch_size, seq_len, 1)) + day = torch.randint(1, 32, (batch_size, seq_len, 1)) + weekday = torch.randint(0, 7, (batch_size, seq_len, 1)) + hour = torch.randint(0, 24, (batch_size, seq_len, 1)) + minute = torch.randint(0, 60, (batch_size, seq_len, 1)) + x_mark = torch.cat([month, day, weekday, hour, minute], dim=2) + + output = embedding(x, x_mark) + + assert output.shape == ( + batch_size, + seq_len, + d_model, + ), "DataEmbedding output shape mismatch." + + def test_DataEmbedding_inverted(self): + c_in = 10 + d_model = 64 + embed_type = "fixed" + freq = "h" + dropout = 0.1 + embedding = DataEmbedding_inverted(c_in, d_model, embed_type, freq, dropout) + batch_size = 32 + seq_len = 100 + + # Change the input shape to (batch_size, c_in, seq_len) + x = torch.randn(batch_size, c_in, seq_len) + + # Change x_mark shape to (batch_size, seq_len, 5) if x_mark is used + # or set it to None if it's not used in this embedding + x_mark = None # or torch.randn(batch_size, seq_len, 5) if it's used + + output = embedding(x, x_mark) + + # The expected output shape should be (batch_size, d_model, seq_len) + assert output.shape == ( + batch_size, + seq_len, + d_model, + ), "DataEmbedding_inverted output shape mismatch." + + def test_PatchEmbedding(self): + d_model = 64 + patch_len = 16 + stride = 8 + padding = 8 + dropout = 0.1 + embedding = PatchEmbedding(d_model, patch_len, stride, padding, dropout) + batch_size = 32 + n_vars = 10 + seq_len = 100 + x = torch.randn(batch_size, n_vars, seq_len) + output, n_vars_output = embedding(x) + num_patches = ((seq_len + padding) - patch_len) // stride + 1 + expected_shape = (batch_size * n_vars, num_patches, d_model) + assert output.shape == expected_shape, "PatchEmbedding output shape mismatch." + assert n_vars_output == n_vars, "PatchEmbedding n_vars output mismatch." + + def test_TimeFeatureEmbedding_invalid_input(self): + d_model = 64 + embed_type = "timeF" + freq = "h" + embedding = TimeFeatureEmbedding(d_model, embed_type, freq) + batch_size = 32 + seq_len = 100 + x = torch.randn(batch_size, seq_len, 10) # Incorrect feature size + with pytest.raises(RuntimeError): + embedding(x) diff --git a/darts/tests/models/forecasting/test_global_forecasting_models.py b/darts/tests/models/forecasting/test_global_forecasting_models.py index f8eea72615..c0772ddedd 100644 --- a/darts/tests/models/forecasting/test_global_forecasting_models.py +++ b/darts/tests/models/forecasting/test_global_forecasting_models.py @@ -33,6 +33,7 @@ TCNModel, TFTModel, TiDEModel, + TimesNetModel, TransformerModel, TSMixerModel, ) @@ -159,6 +160,14 @@ }, 60.0, ), + ( + TimesNetModel, + { + "n_epochs": 10, + "pl_trainer_kwargs": tfm_kwargs["pl_trainer_kwargs"], + }, + 60.0, + ), ( GlobalNaiveAggregate, { diff --git a/darts/tests/models/forecasting/test_historical_forecasts.py b/darts/tests/models/forecasting/test_historical_forecasts.py index 967e5e5e7c..7895f9a8ec 100644 --- a/darts/tests/models/forecasting/test_historical_forecasts.py +++ b/darts/tests/models/forecasting/test_historical_forecasts.py @@ -39,6 +39,7 @@ TCNModel, TFTModel, TiDEModel, + TimesNetModel, TransformerModel, TSMixerModel, ) @@ -185,6 +186,22 @@ (IN_LEN, OUT_LEN), "PastCovariates", ), + ( + TimesNetModel, + { + "input_chunk_length": IN_LEN, + "output_chunk_length": OUT_LEN, + "hidden_size": 4, + "num_layers": 1, + "num_kernels": 2, + "top_k": 1, + "batch_size": 32, + "n_epochs": NB_EPOCH, + **tfm_kwargs, + }, + (IN_LEN, OUT_LEN), + "PastCovariates", + ), ( NBEATSModel, { diff --git a/darts/tests/models/forecasting/test_probabilistic_models.py b/darts/tests/models/forecasting/test_probabilistic_models.py index 141fd43dcd..4953b39fb4 100644 --- a/darts/tests/models/forecasting/test_probabilistic_models.py +++ b/darts/tests/models/forecasting/test_probabilistic_models.py @@ -34,6 +34,7 @@ TCNModel, TFTModel, TiDEModel, + TimesNetModel, TransformerModel, TSMixerModel, ) @@ -165,6 +166,19 @@ 0.03, 0.04, ), + ( + TimesNetModel, + { + "input_chunk_length": 10, + "output_chunk_length": 5, + "n_epochs": 20, + "random_state": 0, + "likelihood": GaussianLikelihood(), + **tfm_kwargs, + }, + 0.03, + 0.04, + ), ( NBEATSModel, { diff --git a/darts/tests/models/forecasting/test_times_net_model.py b/darts/tests/models/forecasting/test_times_net_model.py new file mode 100644 index 0000000000..7d6dbb538d --- /dev/null +++ b/darts/tests/models/forecasting/test_times_net_model.py @@ -0,0 +1,174 @@ +import numpy as np +import pandas as pd +import pytest +import torch + +from darts import TimeSeries +from darts.models.forecasting.times_net_model import FFT_for_Period, TimesNetModel +from darts.tests.conftest import TORCH_AVAILABLE, tfm_kwargs +from darts.utils import timeseries_generation as tg + +if not TORCH_AVAILABLE: + pytest.skip( + f"Torch not available. {__name__} tests will be skipped.", + allow_module_level=True, + ) + + +class TestTimesNetModel: + times = pd.date_range("20130101", "20130410") + pd_series = pd.Series(range(100), index=times) + series: TimeSeries = TimeSeries.from_series(pd_series) + series_multivariate = series.stack(series * 2) + + def test_fit_and_predict(self): + model = TimesNetModel( + input_chunk_length=12, + output_chunk_length=12, + hidden_size=16, + num_layers=1, + num_kernels=2, + top_k=3, + n_epochs=2, + random_state=42, + **tfm_kwargs, + ) + model.fit(self.series) + pred = model.predict(n=2) + assert len(pred) == 2 + assert isinstance(pred, TimeSeries) + + def test_multivariate(self): + model = TimesNetModel( + input_chunk_length=12, + output_chunk_length=12, + n_epochs=2, + hidden_size=16, + num_layers=1, + num_kernels=2, + top_k=3, + **tfm_kwargs, + ) + model.fit(self.series_multivariate) + pred = model.predict(n=3) + assert pred.width == 2 + assert len(pred) == 3 + + def test_past_covariates(self): + target = tg.sine_timeseries(length=100) + covariates = tg.sine_timeseries(length=100, value_frequency=0.1) + + model = TimesNetModel( + input_chunk_length=12, + output_chunk_length=12, + n_epochs=2, + hidden_size=16, + num_layers=1, + num_kernels=2, + top_k=3, + **tfm_kwargs, + ) + model.fit(target, past_covariates=covariates) + pred = model.predict(n=3, past_covariates=covariates) + assert len(pred) == 3 + + def test_save_load(self, tmpdir_module): + model = TimesNetModel( + input_chunk_length=12, + output_chunk_length=12, + n_epochs=2, + model_name="unittest-model-TimesNet", + work_dir=tmpdir_module, + save_checkpoints=True, + force_reset=True, + hidden_size=16, + num_layers=1, + num_kernels=2, + top_k=3, + **tfm_kwargs, + ) + model.fit(self.series) + model_loaded = model.load_from_checkpoint( + model_name="unittest-model-TimesNet", + work_dir=tmpdir_module, + best=False, + map_location="cpu", + ) + pred1 = model.predict(n=1) + pred2 = model_loaded.predict(n=1) + + # Two models with the same parameters should deterministically yield the same output + np.testing.assert_array_equal(pred1.values(), pred2.values()) + + def test_prediction_with_custom_encoders(self): + target = tg.sine_timeseries(length=100, freq="H") + model = TimesNetModel( + input_chunk_length=12, + output_chunk_length=12, + add_encoders={ + "cyclic": {"future": ["hour"]}, + "datetime_attribute": {"future": ["dayofweek"]}, + }, + n_epochs=2, + hidden_size=16, + num_layers=1, + num_kernels=2, + top_k=3, + **tfm_kwargs, + ) + model.fit(target) + pred = model.predict(n=12) + assert len(pred) == 12 + + +class TestT_FFT_for_Period: + sample_input = ( + torch.sin(torch.linspace(0, 4 * torch.pi, 100)).unsqueeze(0).unsqueeze(-1) + ) + + def test_FFT_for_Period_output_shape(self): + period, amplitudes = FFT_for_Period(self.sample_input) + + assert isinstance(period, torch.Tensor) + assert isinstance(amplitudes, torch.Tensor) + assert period.shape == (2,) # Default k=2 + assert amplitudes.shape == (1, 2) # (B, k) + + def test_FFT_for_Period_custom_k(self): + k = 3 + period, amplitudes = FFT_for_Period(self.sample_input, k=k) + + assert period.shape == (k,) + assert amplitudes.shape == (1, k) + + def test_FFT_for_Period_period_values(self): + period, _ = FFT_for_Period(self.sample_input) + + # The main period should be close to 50 (half of the input length) + assert torch.isclose(period[0], torch.tensor(50), rtol=0.1) + + def test_FFT_for_Period_amplitude_values(self): + _, amplitudes = FFT_for_Period(self.sample_input) + + # Amplitudes should be non-negative + assert torch.all(amplitudes >= 0) + + def test_FFT_for_Period_different_shapes(self): + # Test with different input shapes + x1 = torch.randn(2, 100, 3) # [B, T, C] = [2, 100, 3] + x2 = torch.randn(1, 200, 1) # [B, T, C] = [1, 200, 1] + + period1, amplitudes1 = FFT_for_Period(x1) + period2, amplitudes2 = FFT_for_Period(x2) + + assert period1.shape == (2,) + assert amplitudes1.shape == (2, 2) + assert period2.shape == (2,) + assert amplitudes2.shape == (1, 2) + + def test_FFT_for_Period_zero_frequency_removal(self): + x = torch.ones(1, 100, 1) # Constant input + _, amplitudes = FFT_for_Period(x) + + # The amplitude of the zero frequency should be zero + assert torch.isclose(amplitudes[0, 0], torch.tensor(0.0)) diff --git a/darts/tests/models/forecasting/test_torch_forecasting_model.py b/darts/tests/models/forecasting/test_torch_forecasting_model.py index 81941317a4..f21992515f 100644 --- a/darts/tests/models/forecasting/test_torch_forecasting_model.py +++ b/darts/tests/models/forecasting/test_torch_forecasting_model.py @@ -45,6 +45,7 @@ TCNModel, TFTModel, TiDEModel, + TimesNetModel, TransformerModel, TSMixerModel, ) @@ -102,6 +103,7 @@ (TFTModel, {"add_relative_index": 2, **kwargs, **tft_light_kwargs}), (TiDEModel, kwargs), (TransformerModel, dict(kwargs, **trafo_light_kwargs)), + (TimesNetModel, kwargs), (TSMixerModel, kwargs), (GlobalNaiveSeasonal, kwargs), (GlobalNaiveAggregate, kwargs),