-
Notifications
You must be signed in to change notification settings - Fork 900
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat/time net model #2538
base: master
Are you sure you want to change the base?
Feat/time net model #2538
Conversation
d60d5e4
to
3edce3a
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #2538 +/- ##
==========================================
- Coverage 94.14% 94.11% -0.04%
==========================================
Files 139 141 +2
Lines 14884 15113 +229
==========================================
+ Hits 14013 14223 +210
- Misses 871 890 +19 ☔ View full report in Codecov by Sentry. |
@dennisbader This model is ready for review. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for this great PR @gdevos010 and sorry for the delay, we have been quite busy with others responsibilities.
Some comments to reduce the code redundancy in the implementation of this quite complex model.
Also, could you please also include an example notebook to compare it to models such as TiDEModel
and/or TSMixerModel
on a toy example, just to make sure that the default configuration yield acceptable results?
from darts.utils.torch import MonteCarloDropout | ||
|
||
|
||
class PositionalEmbedding(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the PositionalEncoding
class already implements this logic in darts/models/forecasting/transformer_model.py
, let's maybe move it here and keep the old name
w = torch.zeros(c_in, d_model).float() | ||
w.require_grad = False | ||
|
||
position = torch.arange(0, c_in).float().unsqueeze(1) | ||
div_term = ( | ||
torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model) | ||
).exp() | ||
|
||
w[:, 0::2] = torch.sin(position * div_term) | ||
w[:, 1::2] = torch.cos(position * div_term) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this code snippet can also be found in PositionalEncoding
, let's abstract it to reduce redundancy
return self.dropout(x) | ||
|
||
|
||
class DataEmbedding_inverted(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class DataEmbedding_inverted(nn.Module): | |
class DataEmbeddingInverted(nn.Module): |
return self.dropout(x) | ||
|
||
|
||
class DataEmbedding_wo_pos(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic is so similar to DataEmbedding
and DataEmbedding_inverted
that they should probably be combined into a single class and the logic difference should be implemented in the forward()
method instead (by adding a parameter/attribute type
that could take the values "normal", "inverted" or "wopos".
from darts.models.forecasting.torch_forecasting_model import PastCovariatesTorchModel | ||
|
||
|
||
class Inception_Block_V1(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just of the sake of homogeneity.
class Inception_Block_V1(nn.Module): | |
class InceptionBlock(nn.Module): |
@@ -480,7 +480,7 @@ def encode_year(idx): | |||
} | |||
.. | |||
random_state | |||
Control the randomness of the weights initialization. Check this | |||
Control the randomness of the weight's initialization. Check this |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's revert this change, since there are several weights, the apostrophe does not seem appropriate.
Or we could eventually rephrase it into "Control the randomness in the initialization of the weights" if you have the impression that the original sentence was not clear.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's rename the file "timesnet_model.py" so that it's homogeneous with the others
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's rename the file "test_timesnet_model.py" to keep the number of underscores to the minimum
@@ -619,13 +620,13 @@ Patch release | |||
- Added support for past and future covariates to `residuals()` function. [#1223](https://github.com/unit8co/darts/pull/1223) by [Eliane Maalouf](https://github.com/eliane-maalouf). | |||
- Added support for retraining model(s) every `n` iteration and on custom conditions in `historical_forecasts` method of `ForecastingModel`s. [#1139](https://github.com/unit8co/darts/pull/1139) by [Francesco Bruzzesi](https://github.com/fbruzzesi). | |||
- Added support for beta-NLL in `GaussianLikelihood`s, as proposed in [this paper](https://arxiv.org/abs/2203.09168). [#1162](https://github.com/unit8co/darts/pull/1162) by [Julien Herzen](https://github.com/hrzn). | |||
- New LayerNorm alternatives, RMSNorm and LayerNormNoBias [#1113](https://github.com/unit8co/darts/pull/1113) by [Greg DeVos](https://github.com/gdevos010). | |||
- New LayerNorm alternatives, RMSNorm and LayerNormNoBias [#1113](https://github.com/unit8co/darts/issues/1113) by [Greg DeVosNouri](https://github.com/gdevos010). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Incorrect link
Checklist before merging this PR:
TODO
Fixes #2537.
Summary
Adds the TimesNet mode based on this code
https://github.com/thuml/Time-Series-Library/blob/main/models/TimesNet.py
Other Information