Skip to content

Commit

Permalink
fix: revert some changes
Browse files Browse the repository at this point in the history
  • Loading branch information
madtoinou committed Dec 17, 2024
1 parent c2c00f3 commit f376f09
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 83 deletions.
58 changes: 0 additions & 58 deletions darts/models/forecasting/rnn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,64 +587,6 @@ def _verify_train_dataset_type(self, train_dataset: TrainingDataset):
"RNNModel requires a shifted training dataset with shift=1.",
)

def to_onnx(self, path: Optional[str] = None, **kwargs):
"""Export model to ONNX format for optimized inference, wrapping around PyTorch Lightning's
:func:`torch.onnx.export` method ()`official documentation <https://lightning.ai/docs/pytorch/
stable/common/lightning_module.html#to-onnx>`_).
Note: onnx library (optionnal dependency) must be installed in order to call this method
Parameters
----------
path
Path under which to save the model at its current state. Please avoid path starting with "last-" or
"best-" to avoid collision with Pytorch-Ligthning checkpoints. If no path is specified, the model
is automatically saved under ``"{ModelClass}_{YYYY-mm-dd_HH_MM_SS}.pt"``.
**kwargs
Additional kwargs for PyTorch's :func:`torch.onnx.export` method, such as ``verbose`` prints a
description of the model being exported to stdout.
For more information, read the `official documentation <https://pytorch.org/docs/master/
onnx.html#torch.onnx.export>`_.
"""
raise_if_not(
self._fit_called, "`fit()` needs to be called before `to_onnx()`.", logger
)

if path is None:
# default path
path = self._default_save_path() + ".onnx"

# mimic preprocessing performed by RNNModule_._get_batch_prediction()
(
past_target,
past_covariates,
historic_future_covariates,
future_covariates,
future_past_covariates,
) = (
torch.Tensor(x).unsqueeze(0) if x is not None else None
for x in self.train_sample
)

if historic_future_covariates is not None:
# RNNs need as inputs (target[t] and covariates[t+1]) so here we shift the covariates
all_covariates = torch.cat(
[historic_future_covariates[:, 1:, :], future_covariates], dim=1
)
cov_past, _ = (
all_covariates[:, : past_target.shape[1], :],
all_covariates[:, past_target.shape[1] :, :],
)
input_past = torch.cat([past_target, cov_past], dim=2)
else:
input_past = past_target

input_sample = [
input_past.double(),
future_covariates.double() if future_covariates is not None else None,
]
self.model.to_onnx(file_path=path, input_sample=input_sample, **kwargs)

@property
def supports_multivariate(self) -> bool:
return True
Expand Down
25 changes: 0 additions & 25 deletions darts/models/forecasting/torch_forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2621,13 +2621,6 @@ def _verify_inference_dataset_type(self, inference_dataset: InferenceDataset):
def _verify_predict_sample(self, predict_sample: tuple):
_basic_compare_sample(self.train_sample, predict_sample)

def _verify_past_future_covariates(self, past_covariates, future_covariates):
raise_if_not(
future_covariates is None,
"Some future_covariates have been provided to a PastCovariates model. These models "
"support only past_covariates.",
)

@property
def _model_encoder_settings(
self,
Expand Down Expand Up @@ -2721,13 +2714,6 @@ def _verify_inference_dataset_type(self, inference_dataset: InferenceDataset):
def _verify_predict_sample(self, predict_sample: tuple):
_basic_compare_sample(self.train_sample, predict_sample)

def _verify_past_future_covariates(self, past_covariates, future_covariates):
raise_if_not(
past_covariates is None,
"Some past_covariates have been provided to a PastCovariates model. These models "
"support only future_covariates.",
)

@property
def _model_encoder_settings(
self,
Expand Down Expand Up @@ -2822,13 +2808,6 @@ def _verify_inference_dataset_type(self, inference_dataset: InferenceDataset):
def _verify_predict_sample(self, predict_sample: tuple):
_basic_compare_sample(self.train_sample, predict_sample)

def _verify_past_future_covariates(self, past_covariates, future_covariates):
raise_if_not(
past_covariates is None,
"Some past_covariates have been provided to a DualCovariates Torch model. These models "
"support only future_covariates.",
)

@property
def _model_encoder_settings(
self,
Expand Down Expand Up @@ -2923,10 +2902,6 @@ def _verify_inference_dataset_type(self, inference_dataset: InferenceDataset):
def _verify_predict_sample(self, predict_sample: tuple):
_mixed_compare_sample(self.train_sample, predict_sample)

def _verify_past_future_covariates(self, past_covariates, future_covariates):
# both covariates are supported; do nothing
pass

@property
def _model_encoder_settings(
self,
Expand Down

0 comments on commit f376f09

Please sign in to comment.