diff --git a/README.md b/README.md index f0431e7..fa12ebb 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ ## Acknowledgement During the development, the following repositories were referred to: -* [Kaldi](https://github.com/kaldi-asr/kaldi), for most utility scripts in `utils/`. +* [Kaldi](https://github.com/kaldi-asr/kaldi) and [UniCATS-CTX-vec2wav](https://github.com/cantabile-kwok/UniCATS-CTX-vec2wav) for most utility scripts in `utils/`. * [GradTTS](https://github.com/huawei-noah/Speech-Backbones/tree/main/Grad-TTS), where most of the model architecture and training pipelines are adopted. * [VITS](https://github.com/jaywalnut310/vits), whose distributed bucket sampler is used. * [CFM](https://github.com/atong01/conditional-flow-matching), for the ODE samplers. diff --git a/path.sh b/path.sh index 5632c2d..83e583a 100644 --- a/path.sh +++ b/path.sh @@ -1,2 +1,3 @@ -conda activate py39 -export PATH=$PWD/tools:$PATH \ No newline at end of file +conda activate vflow +export PATH=$PWD/tools:$PATH +chmod +x tools/* diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..26969bd --- /dev/null +++ b/requirements.txt @@ -0,0 +1,27 @@ +audioread==3.0.0 +Cython==0.29.28 +decorator==5.1.1 +h5py==3.7.0 +joblib==1.3.2 +kaldiio==2.18.0 +matplotlib==3.4.3 +numba==0.56.4 +numpy==1.21.6 +packaging==21.3 +pooch==1.6.0 +POT==0.9.0 +resampy==0.4.0 +setuptools==52.0.0 +soundfile==0.12.1 +soxr==0.3.5 +torch==1.11.0 +tqdm==4.62.2 +tensorboard==2.14.1 +Pillow==9.5.0 +pyyaml==6.0.1 +einops==0.7.0 +scikit-learn==1.3.1 +attrs==22.1.0 +torchsde>=0.2.5 +torchcde>=0.2.3 +pytorch-lightning>=0.8.4 diff --git a/tools/espnet_transform/perturb.py b/tools/espnet_transform/perturb.py index 7842b35..a657273 100644 --- a/tools/espnet_transform/perturb.py +++ b/tools/espnet_transform/perturb.py @@ -1,4 +1,4 @@ -import librosa +import custom_librosa as librosa import numpy import scipy import soundfile diff --git a/tools/espnet_transform/spec_augment.py b/tools/espnet_transform/spec_augment.py index 4ea96ea..8878e73 100644 --- a/tools/espnet_transform/spec_augment.py +++ b/tools/espnet_transform/spec_augment.py @@ -38,7 +38,7 @@ def time_warp(x, max_time_warp=80, inplace=False, mode="PIL"): elif mode == "sparse_image_warp": import torch - from espnet.utils import spec_augment + from espnet_utils import spec_augment # TODO(karita): make this differentiable again return spec_augment.time_warp(torch.from_numpy(x), window).numpy() diff --git a/tools/espnet_transform/transformation.py b/tools/espnet_transform/transformation.py index b0ccac7..2a7940c 100644 --- a/tools/espnet_transform/transformation.py +++ b/tools/espnet_transform/transformation.py @@ -23,26 +23,26 @@ # TODO(karita): inherit TransformInterface # TODO(karita): register cmd arguments in asr_train.py import_alias = dict( - identity='espnet.transform.transform_interface:Identity', - time_warp='espnet.transform.spec_augment:TimeWarp', - time_mask='espnet.transform.spec_augment:TimeMask', - freq_mask='espnet.transform.spec_augment:FreqMask', - spec_augment='espnet.transform.spec_augment:SpecAugment', - speed_perturbation='espnet.transform.perturb:SpeedPerturbation', - volume_perturbation='espnet.transform.perturb:VolumePerturbation', - noise_injection='espnet.transform.perturb:NoiseInjection', - bandpass_perturbation='espnet.transform.perturb:BandpassPerturbation', - rir_convolve='espnet.transform.perturb:RIRConvolve', - delta='espnet.transform.add_deltas:AddDeltas', - cmvn='espnet.transform.cmvn:CMVN', - utterance_cmvn='espnet.transform.cmvn:UtteranceCMVN', - fbank='espnet.transform.spectrogram:LogMelSpectrogram', - spectrogram='espnet.transform.spectrogram:Spectrogram', - stft='espnet.transform.spectrogram:Stft', - istft='espnet.transform.spectrogram:IStft', - stft2fbank='espnet.transform.spectrogram:Stft2LogMelSpectrogram', - wpe='espnet.transform.wpe:WPE', - channel_selector='espnet.transform.channel_selector:ChannelSelector') + identity='espnet_transform.transform_interface:Identity', + time_warp='espnet_transform.spec_augment:TimeWarp', + time_mask='espnet_transform.spec_augment:TimeMask', + freq_mask='espnet_transform.spec_augment:FreqMask', + spec_augment='espnet_transform.spec_augment:SpecAugment', + speed_perturbation='espnet_transform.perturb:SpeedPerturbation', + volume_perturbation='espnet_transform.perturb:VolumePerturbation', + noise_injection='espnet_transform.perturb:NoiseInjection', + bandpass_perturbation='espnet_transform.perturb:BandpassPerturbation', + rir_convolve='espnet_transform.perturb:RIRConvolve', + delta='espnet_transform.add_deltas:AddDeltas', + cmvn='espnet_transform.cmvn:CMVN', + utterance_cmvn='espnet_transform.cmvn:UtteranceCMVN', + fbank='espnet_transform.spectrogram:LogMelSpectrogram', + spectrogram='espnet_transform.spectrogram:Spectrogram', + stft='espnet_transform.spectrogram:Stft', + istft='espnet_transform.spectrogram:IStft', + stft2fbank='espnet_transform.spectrogram:Stft2LogMelSpectrogram', + wpe='espnet_transform.wpe:WPE', + channel_selector='espnet_transform.channel_selector:ChannelSelector') class Transformation(object): diff --git a/tools/espnet_utils/dynamic_import.py b/tools/espnet_utils/dynamic_import.py index 8eeccd4..77a7434 100644 --- a/tools/espnet_utils/dynamic_import.py +++ b/tools/espnet_utils/dynamic_import.py @@ -5,14 +5,14 @@ def dynamic_import(import_path, alias=dict()): """dynamic import module and class :param str import_path: syntax 'module_name:class_name' - e.g., 'espnet.transform.add_deltas:AddDeltas' + e.g., 'espnet_transform.add_deltas:AddDeltas' :param dict alias: shortcut for registered class :return: imported class """ if import_path not in alias and ':' not in import_path: raise ValueError( 'import_path should be one of {} or ' - 'include ":", e.g. "espnet.transform.add_deltas:AddDeltas" : ' + 'include ":", e.g. "espnet_transform.add_deltas:AddDeltas" : ' '{}'.format(set(alias), import_path)) if ':' not in import_path: import_path = alias[import_path] diff --git a/tools/feat-to-len.py b/tools/feat-to-len.py index 01c3b38..9114ebe 100644 --- a/tools/feat-to-len.py +++ b/tools/feat-to-len.py @@ -42,7 +42,7 @@ def main(): logging.info(get_commandline_args()) if args.preprocess_conf is not None: - from espnet.transform.transformation import Transformation + from espnet_transform.transformation import Transformation preprocessing = Transformation(args.preprocess_conf) logging.info('Apply preprocessing: {}'.format(preprocessing)) else: diff --git a/tools/feat-to-shape.py b/tools/feat-to-shape.py index e6034a4..ac9224d 100644 --- a/tools/feat-to-shape.py +++ b/tools/feat-to-shape.py @@ -42,7 +42,7 @@ def main(): logging.info(get_commandline_args()) if args.preprocess_conf is not None: - from espnet.transform.transformation import Transformation + from espnet_transform.transformation import Transformation preprocessing = Transformation(args.preprocess_conf) logging.info('Apply preprocessing: {}'.format(preprocessing)) else: diff --git a/torchdyn/__init__.py b/torchdyn/__init__.py new file mode 100644 index 0000000..47bda20 --- /dev/null +++ b/torchdyn/__init__.py @@ -0,0 +1,19 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__version__ = '1.0.6' +__author__ = 'Michael Poli, Stefano Massaroli et al.' + +from torch import Tensor +from typing import Tuple + +TTuple = Tuple[Tensor, Tensor] diff --git a/torchdyn/core/__init__.py b/torchdyn/core/__init__.py new file mode 100644 index 0000000..0ef1c94 --- /dev/null +++ b/torchdyn/core/__init__.py @@ -0,0 +1,21 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from torchdyn.core.defunc import DEFunc +from torchdyn.core.neuralde import NeuralODE, NeuralSDE, MultipleShootingLayer +from torchdyn.core.problems import ODEProblem, SDEProblem, MultipleShootingProblem + +# backward-compatibility (pre v0.2.0) +NeuralDE = NeuralODE + +__all__ = ['DEFunc', 'NeuralODE', 'NeuralDE', 'NeuralSDE', 'ODEProblem', 'SDEProblem', + 'MultipleShootingProblem', 'MultipleShootingLayer'] \ No newline at end of file diff --git a/torchdyn/core/defunc.py b/torchdyn/core/defunc.py new file mode 100644 index 0000000..693a9b5 --- /dev/null +++ b/torchdyn/core/defunc.py @@ -0,0 +1,117 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Dict +import torch +from torch import Tensor, cat +import torch.nn as nn + + +class DEFuncBase(nn.Module): + def __init__(self, vector_field:Callable, has_time_arg:bool=True): + """Basic wrapper to ensure call signature compatibility between generic torch Modules and vector fields. + Args: + vector_field (Callable): callable defining the dynamics / vector field / `dxdt` / forcing function + has_time_arg (bool, optional): Internal arg. to indicate whether the callable has `t` in its `__call__' + or `forward` method. Defaults to True. + """ + super().__init__() + self.nfe, self.vf, self.has_time_arg = 0., vector_field, has_time_arg + + def forward(self, t:Tensor, x:Tensor, args:Dict={}) -> Tensor: + self.nfe += 1 + if self.has_time_arg: return self.vf(t, x, args=args) + else: return self.vf(x) + + +class DEFunc(nn.Module): + def __init__(self, vector_field:Callable, order:int=1): + """Special vector field wrapper for Neural ODEs. + + Handles auxiliary tasks: time ("depth") concatenation, higher-order dynamics and forward propagated integral losses. + + Args: + vector_field (Callable): callable defining the dynamics / vector field / `dxdt` / forcing function + order (int, optional): order of the differential equation. Defaults to 1. + + Notes: + Currently handles the following: + (1) assigns time tensor to each submodule requiring it (e.g. `GalLinear`). + (2) in case of integral losses + reverse-mode differentiation, propagates the loss in the first dimension of `x` + and automatically splits the Tensor into `x[:, 0]` and `x[:, 1:]` for vector field computation + (3) in case of higher-order dynamics, adjusts the vector field forward to recursively compute various orders. + """ + super().__init__() + self.vf, self.nfe, = vector_field, 0. + self.order, self.integral_loss, self.sensitivity = order, None, None + # identify whether vector field already has time arg + + def forward(self, t:Tensor, x:Tensor, args:Dict={}) -> Tensor: + self.nfe += 1 + # set `t` depth-variable to DepthCat modules + for _, module in self.vf.named_modules(): + if hasattr(module, 't'): + module.t = t + + # if-else to handle autograd training with integral loss propagated in x[:, 0] + if (self.integral_loss is not None) and self.sensitivity == 'autograd': + x_dyn = x[:, 1:] + dlds = self.integral_loss(t, x_dyn) + if len(dlds.shape) == 1: dlds = dlds[:, None] + if self.order > 1: x_dyn = self.horder_forward(t, x_dyn, args) + else: x_dyn = self.vf(t, x_dyn) + return cat([dlds, x_dyn], 1).to(x_dyn) + + # regular forward + else: + if self.order > 1: x = self.higher_order_forward(t, x) + else: x = self.vf(t, x, args=args) + return x + + def higher_order_forward(self, t:Tensor, x:Tensor, args:Dict={}) -> Tensor: + x_new = [] + size_order = x.size(1) // self.order + for i in range(1, self.order): + x_new.append(x[:, size_order*i : size_order*(i+1)]) + x_new.append(self.vf(t, x)) + return cat(x_new, dim=1).to(x) + + +class SDEFunc(nn.Module): + def __init__(self, f:Callable, g:Callable, order:int=1): + """"Special vector field wrapper for Neural SDEs. + + Args: + f (Callable): callable defining the drift + g (Callable): callable defining the diffusion term + order (int, optional): order of the differential equation. Defaults to 1. + """ + super().__init__() + self.order, self.intloss, self.sensitivity = order, None, None + self.f_func, self.g_func = f, g + self.nfe = 0 + + def forward(self, t:Tensor, x:Tensor, args:Dict={}) -> Tensor: + pass + + def f(self, t:Tensor, x:Tensor, args:Dict={}) -> Tensor: + self.nfe += 1 + for _, module in self.f_func.named_modules(): + if hasattr(module, 't'): + module.t = t + return self.f_func(x, args) + + def g(self, t:Tensor, x:Tensor, args:Dict={}) -> Tensor: + for _, module in self.g_func.named_modules(): + if hasattr(module, 't'): + module.t = t + return self.g_func(x, args) diff --git a/torchdyn/core/neuralde.py b/torchdyn/core/neuralde.py new file mode 100644 index 0000000..1bb5c28 --- /dev/null +++ b/torchdyn/core/neuralde.py @@ -0,0 +1,219 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Union, Iterable, Generator, Dict + +from torchdyn.core.problems import MultipleShootingProblem, ODEProblem, SDEProblem +from torchdyn.numerics import odeint +from torchdyn.core.defunc import SDEFunc +from torchdyn.core.utils import standardize_vf_call_signature + +import pytorch_lightning as pl +import torch +from torch import Tensor +import torch.nn as nn +import torchsde + +import warnings + + +class NeuralODE(ODEProblem, pl.LightningModule): + def __init__(self, vector_field:Union[Callable, nn.Module], solver:Union[str, nn.Module]='tsit5', order:int=1, + atol:float=1e-3, rtol:float=1e-3, sensitivity='autograd', solver_adjoint:Union[str, nn.Module, None] = None, + atol_adjoint:float=1e-4, rtol_adjoint:float=1e-4, interpolator:Union[str, Callable, None]=None, \ + integral_loss:Union[Callable, None]=None, seminorm:bool=False, return_t_eval:bool=True, optimizable_params:Union[Iterable, Generator]=()): + """Generic Neural Ordinary Differential Equation. + + Args: + vector_field ([Callable]): the vector field, called with `vector_field(t, x)` for `vector_field(x)`. + In the second case, the Callable is automatically wrapped for consistency + solver (Union[str, nn.Module]): + order (int, optional): Order of the ODE. Defaults to 1. + atol (float, optional): Absolute tolerance of the solver. Defaults to 1e-4. + rtol (float, optional): Relative tolerance of the solver. Defaults to 1e-4. + sensitivity (str, optional): Sensitivity method ['autograd', 'adjoint', 'interpolated_adjoint']. Defaults to 'autograd'. + solver_adjoint (Union[str, nn.Module, None], optional): ODE solver for the adjoint. Defaults to None. + atol_adjoint (float, optional): Defaults to 1e-6. + rtol_adjoint (float, optional): Defaults to 1e-6. + integral_loss (Union[Callable, None], optional): Defaults to None. + seminorm (bool, optional): Whether to use seminorms for adaptive stepping in backsolve adjoints. Defaults to False. + return_t_eval (bool): Whether to return (t_eval, sol) or only sol. Useful for chaining NeuralODEs in `nn.Sequential`. + optimizable_parameters (Union[Iterable, Generator]): parameters to calculate sensitivies for. Defaults to (). + Notes: + In `torchdyn`-style, forward calls to a Neural ODE return both a tensor `t_eval` of time points at which the solution is evaluated + as well as the solution itself. This behavior can be controlled by setting `return_t_eval` to False. Calling `trajectory` also returns + the solution only. + + The Neural ODE class automates certain delicate steps that must be done depending on the solver and model used. + The `prep_odeint` method carries out such steps. Neural ODEs wrap `ODEProblem`. + """ + super().__init__(vector_field=standardize_vf_call_signature(vector_field, order, defunc_wrap=True), order=order, sensitivity=sensitivity, + solver=solver, atol=atol, rtol=rtol, solver_adjoint=solver_adjoint, atol_adjoint=atol_adjoint, rtol_adjoint=rtol_adjoint, + seminorm=seminorm, interpolator=interpolator, integral_loss=integral_loss, optimizable_params=optimizable_params) + self._control, self.controlled, self.t_span = None, False, None # data-control conditioning + self.return_t_eval = return_t_eval + if integral_loss is not None: self.vf.integral_loss = integral_loss + self.vf.sensitivity = sensitivity + + def _prep_integration(self, x:Tensor, t_span:Tensor) -> Tensor: + "Performs generic checks before integration. Assigns data control inputs and augments state for CNFs" + + # assign a basic value to `t_span` for `forward` calls that do no explicitly pass an integration interval + if t_span is None and self.t_span is None: t_span = torch.linspace(0, 1, 2) + elif t_span is None: t_span = self.t_span + + # loss dimension detection routine; for CNF div propagation and integral losses w/ autograd + excess_dims = 0 + if (not self.integral_loss is None) and self.sensitivity == 'autograd': + excess_dims += 1 + + # handle aux. operations required for some jacobian trace CNF estimators e.g Hutchinson's + # as well as datasets-control set to DataControl module + for _, module in self.vf.named_modules(): + if hasattr(module, 'trace_estimator'): + if module.noise_dist is not None: module.noise = module.noise_dist.sample((x.shape[0],)) + excess_dims += 1 + + # data-control set routine. Is performed once at the beginning of odeint since the control is fixed to IC + if hasattr(module, '_control'): + self.controlled = True + module._control = x[:, excess_dims:].detach() + return x, t_span + + def forward(self, x:Union[Tensor, Dict], t_span:Tensor=None, save_at:Iterable=(), args={}): + x, t_span = self._prep_integration(x, t_span) + t_eval, sol = super().forward(x, t_span, save_at, args) + if self.return_t_eval: return t_eval, sol + else: return sol + + def trajectory(self, x:torch.Tensor, t_span:Tensor): + x, t_span = self._prep_integration(x, t_span) + _, sol = odeint(self.vf, x, t_span, solver=self.solver, atol=self.atol, rtol=self.rtol) + return sol + + def __repr__(self): + npar = sum([p.numel() for p in self.vf.parameters()]) + return f"Neural ODE:\n\t- order: {self.order}\ + \n\t- solver: {self.solver}\n\t- adjoint solver: {self.solver_adjoint}\ + \n\t- tolerances: relative {self.rtol} absolute {self.atol}\ + \n\t- adjoint tolerances: relative {self.rtol_adjoint} absolute {self.atol_adjoint}\ + \n\t- num_parameters: {npar}\ + \n\t- NFE: {self.vf.nfe}" + + +class NeuralSDE(SDEProblem, pl.LightningModule): + def __init__(self, drift_func, diffusion_func, noise_type ='diagonal', sde_type = 'ito', order=1, + sensitivity='autograd', s_span=torch.linspace(0, 1, 2), solver='srk', + atol=1e-4, rtol=1e-4, ds = 1e-3, intloss=None): + """Generic Neural Stochastic Differential Equation. Follows the same design of the `NeuralODE` class. + + Args: + drift_func ([type]): drift function + diffusion_func ([type]): diffusion function + noise_type (str, optional): Defaults to 'diagonal'. + sde_type (str, optional): Defaults to 'ito'. + order (int, optional): Defaults to 1. + sensitivity (str, optional): Defaults to 'autograd'. + s_span ([type], optional): Defaults to torch.linspace(0, 1, 2). + solver (str, optional): Defaults to 'srk'. + atol ([type], optional): Defaults to 1e-4. + rtol ([type], optional): Defaults to 1e-4. + ds ([type], optional): Defaults to 1e-3. + intloss ([type], optional): Defaults to None. + + Raises: + NotImplementedError: higher-order Neural SDEs are not yet implemented, raised by setting `order` to >1. + + Notes: + The current implementation is rougher around the edges compared to `NeuralODE`, and is not guaranteed to have the same features. + """ + super().__init__(func=SDEFunc(f=drift_func, g=diffusion_func, order=order), order=order, sensitivity=sensitivity, s_span=s_span, solver=solver, + atol=atol, rtol=rtol) + if order != 1: raise NotImplementedError + self.defunc.noise_type, self.defunc.sde_type = noise_type, sde_type + self.adaptive = False + self.intloss = intloss + self._control, self.controlled = None, False # datasets-control + self.ds = ds + + def _prep_sdeint(self, x:torch.Tensor): + self.s_span = self.s_span.to(x) + # datasets-control set routine. Is performed once at the beginning of odeint since the control is fixed to IC + excess_dims = 0 + for _, module in self.defunc.named_modules(): + if hasattr(module, '_control'): + self.controlled = True + module._control = x[:, excess_dims:].detach() + + return x + + def forward(self, x:torch.Tensor): + x = self._prep_sdeint(x) + switcher = { + 'autograd': self._autograd, + 'adjoint': self._adjoint, + } + sdeint = switcher.get(self.sensitivity) + out = sdeint(x) + return out + + def trajectory(self, x:torch.Tensor, s_span:torch.Tensor): + x = self._prep_sdeint(x) + sol = torchsde.sdeint(self.defunc, x, s_span, rtol=self.rtol, atol=self.atol, + method=self.solver, dt=self.ds) + return sol + + def backward_trajectory(self, x:torch.Tensor, s_span:torch.Tensor): + raise NotImplementedError + + def _autograd(self, x): + self.defunc.intloss, self.defunc.sensitivity = self.intloss, self.sensitivity + return torchsde.sdeint(self.defunc, x, self.s_span, rtol=self.rtol, atol=self.atol, + adaptive=self.adaptive, method=self.solver, dt=self.ds)[-1] + + def _adjoint(self, x): + out = torchsde.sdeint_adjoint(self.defunc, x, self.s_span, rtol=self.rtol, atol=self.atol, + adaptive=self.adaptive, method=self.solver, dt=self.ds)[-1] + return out + + +class MultipleShootingLayer(MultipleShootingProblem, pl.LightningModule): + def __init__(self, vector_field:Callable, solver:str, sensitivity:str='autograd', + maxiter:int=4, fine_steps:int=4, solver_adjoint:Union[str, nn.Module, None] = None, atol_adjoint:float=1e-6, + rtol_adjoint:float=1e-6, seminorm:bool=False, integral_loss:Union[Callable, None]=None): + """Multiple Shooting Layer as defined in https://arxiv.org/abs/2106.03885. + + Uses parallel-in-time ODE solvers to solve an ODE parametrized by neural network `vector_field`. + + Args: + vector_field ([Callable]): the vector field, called with `vector_field(t, x)` for `vector_field(x)`. + In the second case, the Callable is automatically wrapped for consistency + solver (Union[str, nn.Module]): parallel-in-time solver, ['zero', 'direct'] + sensitivity (str, optional): Sensitivity method ['autograd', 'adjoint', 'interpolated_adjoint']. Defaults to 'autograd'. + maxiter (int): number of iterations of the root finding routine defined to parallel solve the ODE. + fine_steps (int): number of fine-solver steps to perform in each subinterval of the parallel solution. + solver_adjoint (Union[str, nn.Module, None], optional): Standard sequential ODE solver for the adjoint system. + atol_adjoint (float, optional): Defaults to 1e-6. + rtol_adjoint (float, optional): Defaults to 1e-6. + integral_loss (Union[Callable, None], optional): Currently not implemented + seminorm (bool, optional): Whether to use seminorms for adaptive stepping in backsolve adjoints. Defaults to False. + Notes: + The number of shooting parameters (first dimension in `B0`) is implicitly defined by passing `t_span` during forward calls. + For example, a `t_span=torch.linspace(0, 1, 10)` will define 9 intervals and 10 shooting parameters. + + For the moment only a thin wrapper around `MultipleShootingProblem`. At this level will be convenience routines for special + initializations of shooting parameters `B0`, as well as usual convenience checks for integral losses. + """ + super().__init__(vector_field, solver, sensitivity, maxiter, fine_steps, solver_adjoint, atol_adjoint, + rtol_adjoint, seminorm, integral_loss) + + diff --git a/torchdyn/core/problems.py b/torchdyn/core/problems.py new file mode 100644 index 0000000..bac382c --- /dev/null +++ b/torchdyn/core/problems.py @@ -0,0 +1,142 @@ +import torch +from torch import Tensor +import torch.nn as nn +from typing import Callable, Generator, Iterable, Union + +from torchdyn.numerics.sensitivity import _gather_odefunc_adjoint, _gather_odefunc_interp_adjoint +from torchdyn.numerics.odeint import odeint, odeint_mshooting +from torchdyn.numerics.solvers.ode import str_to_solver, str_to_ms_solver +from torchdyn.core.utils import standardize_vf_call_signature + + +class ODEProblem(nn.Module): + def __init__(self, vector_field:Union[Callable, nn.Module], solver:Union[str, nn.Module], interpolator:Union[str, Callable, None]=None, order:int=1, + atol:float=1e-4, rtol:float=1e-4, sensitivity:str='autograd', solver_adjoint:Union[str, nn.Module, None] = None, atol_adjoint:float=1e-6, + rtol_adjoint:float=1e-6, seminorm:bool=False, integral_loss:Union[Callable, None]=None, optimizable_params:Union[Iterable, Generator]=()): + """An ODE Problem coupling a given vector field with solver and sensitivity algorithm to compute gradients w.r.t different quantities. + + Args: + vector_field ([Callable]): the vector field, called with `vector_field(t, x)` for `vector_field(x)`. + In the second case, the Callable is automatically wrapped for consistency + solver (Union[str, nn.Module]): + order (int, optional): Order of the ODE. Defaults to 1. + atol (float, optional): Absolute tolerance of the solver. Defaults to 1e-4. + rtol (float, optional): Relative tolerance of the solver. Defaults to 1e-4. + sensitivity (str, optional): Sensitivity method ['autograd', 'adjoint', 'interpolated_adjoint']. Defaults to 'autograd'. + solver_adjoint (Union[str, nn.Module, None], optional): ODE solver for the adjoint. Defaults to None. + atol_adjoint (float, optional): Defaults to 1e-6. + rtol_adjoint (float, optional): Defaults to 1e-6. + seminorm (bool, optional): Indicates whether the a seminorm should be used for error estimation during adjoint backsolves. Defaults to False. + integral_loss (Union[Callable, None]): Integral loss to optimize for. Defaults to None. + optimizable_parameters (Union[Iterable, Generator]): parameters to calculate sensitivies for. Defaults to (). + Notes: + Integral losses can be passed as generic function or `nn.Modules`. + """ + super().__init__() + # instantiate solver at initialization + if type(solver) == str: solver = str_to_solver(solver) + if solver_adjoint is None: + solver_adjoint = solver + else: solver_adjoint = str_to_solver(solver_adjoint) + + self.solver, self.interpolator, self.atol, self.rtol = solver, interpolator, atol, rtol + self.solver_adjoint, self.atol_adjoint, self.rtol_adjoint = solver_adjoint, atol_adjoint, rtol_adjoint + self.sensitivity, self.integral_loss = sensitivity, integral_loss + + # wrap vector field if `t, x` is not the call signature + vector_field = standardize_vf_call_signature(vector_field) + + self.vf, self.order, self.sensalg = vector_field, order, sensitivity + optimizable_params = tuple(optimizable_params) + + if len(tuple(self.vf.parameters())) > 0: + self.vf_params = torch.cat([p.contiguous().flatten() for p in self.vf.parameters()]) + + elif len(optimizable_params) > 0: + # use `optimizable_parameters` if f itself does not have a .parameters() iterable + # TODO: advanced logic to retain naming in case `state_dicts()` are passed + for k, p in enumerate(optimizable_params): self.vf.register_parameter(f'optimizable_parameter_{k}', p) + self.vf_params = torch.cat([p.contiguous().flatten() for p in optimizable_params]) + + else: + print("Your vector field does not have `nn.Parameters` to optimize.") + dummy_parameter = nn.Parameter(torch.zeros(1)) + self.vf.register_parameter('dummy_parameter', dummy_parameter) + self.vf_params = torch.cat([p.contiguous().flatten() for p in self.vf.parameters()]) + + def _autograd_func(self): + "create autograd functions for backward pass" + self.vf_params = torch.cat([p.contiguous().flatten() for p in self.vf.parameters()]) + if self.sensalg == 'adjoint': # alias .apply as direct call to preserve consistency of call signature + return _gather_odefunc_adjoint(self.vf, self.vf_params, self.solver, self.atol, self.rtol, self.interpolator, + self.solver_adjoint, self.atol_adjoint, self.rtol_adjoint, self.integral_loss, + problem_type='standard').apply + elif self.sensalg == 'interpolated_adjoint': + return _gather_odefunc_interp_adjoint(self.vf, self.vf_params, self.solver, self.atol, self.rtol, self.interpolator, + self.solver_adjoint, self.atol_adjoint, self.rtol_adjoint, self.integral_loss, + problem_type='standard').apply + + def odeint(self, x:Tensor, t_span:Tensor, save_at:Tensor=(), args={}): + "Returns Tuple(`t_eval`, `solution`)" + if self.sensalg == 'autograd': + return odeint(self.vf, x, t_span, self.solver, self.atol, self.rtol, interpolator=self.interpolator, + save_at=save_at, args=args) + else: + return self._autograd_func()(self.vf_params, x, t_span, save_at, args) + + def forward(self, x:Tensor, t_span:Tensor, save_at:Tensor=(), args={}): + "For safety redirects to intended method `odeint`" + return self.odeint(x, t_span, save_at, args) + + +class MultipleShootingProblem(ODEProblem): + def __init__(self, vector_field:Callable, solver:str, sensitivity:str='autograd', + maxiter:int=4, fine_steps:int=4, solver_adjoint:Union[str, nn.Module, None] = None, atol_adjoint:float=1e-6, + rtol_adjoint:float=1e-6, seminorm:bool=False, integral_loss:Union[Callable, None]=None): + """An ODE problem solved with parallel-in-time methods. + Args: + vector_field (Callable): the vector field, called with `vector_field(t, x)` for `vector_field(x)`. + In the second case, the Callable is automatically wrapped for consistency + solver (str): parallel-in-time solver. + sensitivity (str, optional): . Defaults to 'autograd'. + solver_adjoint (Union[str, nn.Module, None], optional): . Defaults to None. + atol_adjoint (float, optional): . Defaults to 1e-6. + rtol_adjoint (float, optional): . Defaults to 1e-6. + seminorm (bool, optional): . Defaults to False. + integral_loss (Union[Callable, None], optional): . Defaults to None. + """ + super().__init__(vector_field=vector_field, solver=None, interpolator=None, order=1, + sensitivity=sensitivity, solver_adjoint=solver_adjoint, atol_adjoint=atol_adjoint, + rtol_adjoint=rtol_adjoint, seminorm=seminorm, integral_loss=integral_loss) + self.parallel_solver = solver + self.fine_steps, self.maxiter = fine_steps, maxiter + + def _autograd_func(self): + "create autograd functions for backward pass" + self.vf_params = torch.cat([p.contiguous().flatten() for p in self.vf.parameters()]) + if self.sensalg == 'adjoint': # alias .apply as direct call to preserve consistency of call signature + return _gather_odefunc_adjoint(self.vf, self.vf_params, self.solver, 0, 0, None, + self.solver_adjoint, self.atol_adjoint, self.rtol_adjoint, self.integral_loss, + 'multiple_shooting', self.fine_steps, self.maxiter).apply + elif self.sensalg == 'interpolated_adjoint': + return _gather_odefunc_interp_adjoint(self.vf, self.vf_params, self.solver, 0, 0, None, + self.solver_adjoint, self.atol_adjoint, self.rtol_adjoint, self.integral_loss, + 'multiple_shooting', self.fine_steps, self.maxiter).apply + + def odeint(self, x:Tensor, t_span:Tensor, B0:Tensor=None): + "Returns Tuple(`t_eval`, `solution`)" + if self.sensalg == 'autograd': + return odeint_mshooting(self.vf, x, t_span, self.parallel_solver, B0, self.fine_steps, self.maxiter) + else: + return self._autograd_func()(self.vf_params, x, t_span, B0) + + def forward(self, x:Tensor, t_span:Tensor, B0:Tensor=None): + "For safety redirects to intended method `odeint`" + return self.odeint(x, t_span, B0) + + +class SDEProblem(nn.Module): + def __init__(self): + "Extension of `ODEProblem` to SDE" + super().__init__() + raise NotImplementedError("Hopefully soon...") \ No newline at end of file diff --git a/torchdyn/core/utils.py b/torchdyn/core/utils.py new file mode 100644 index 0000000..460f3f3 --- /dev/null +++ b/torchdyn/core/utils.py @@ -0,0 +1,36 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from inspect import getfullargspec + +import torch +from torchdyn.core.defunc import DEFuncBase, DEFunc +import torch.nn as nn + +def standardize_vf_call_signature(vector_field, order=1, defunc_wrap=False): + "Ensures Callables or nn.Modules passed to `ODEProblems` and `NeuralODE` have consistent `__call__` signature (t, x)" + + if issubclass(type(vector_field), nn.Module): + if 't' not in getfullargspec(vector_field.forward).args: + print("Your vector field callable (nn.Module) should have both time `t` and state `x` as arguments, " + "we've wrapped it for you.") + vector_field = DEFuncBase(vector_field, has_time_arg=False) + else: + # argspec for lambda functions needs to be done on the function itself + if 't' not in getfullargspec(vector_field).args: + print("Your vector field callable (lambda) should have both time `t` and state `x` as arguments, " + "we've wrapped it for you.") + vector_field = DEFuncBase(vector_field, has_time_arg=False) + else: vector_field = DEFuncBase(vector_field, has_time_arg=True) + if defunc_wrap: return DEFunc(vector_field, order) + else: return vector_field + diff --git a/torchdyn/datasets/__init__.py b/torchdyn/datasets/__init__.py new file mode 100644 index 0000000..7526156 --- /dev/null +++ b/torchdyn/datasets/__init__.py @@ -0,0 +1,13 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .static_datasets import * diff --git a/torchdyn/datasets/static_datasets.py b/torchdyn/datasets/static_datasets.py new file mode 100644 index 0000000..ec67452 --- /dev/null +++ b/torchdyn/datasets/static_datasets.py @@ -0,0 +1,250 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Commonly used static datasets. Several can be used in both `density estimation` as well as classification +""" + +import math + +import numpy as np +import torch +from torch import sqrt, pow, cat, zeros, Tensor +from scipy.integrate import solve_ivp +from torchdyn import TTuple, Tuple +from sklearn.neighbors import KernelDensity +from torch.distributions import Normal + + +def randnsphere(dim:int, radius:float) -> Tensor: + """Uniform sampling on a sphere of `dim` and `radius` + + :param dim: dimension of the sphere + :type dim: int + :param radius: radius of the sphere + :type radius: float + """ + v = torch.randn(dim) + inv_len = radius / sqrt(pow(v, 2).sum()) + return v * inv_len + + +def generate_concentric_spheres(n_samples:int=100, noise:float=1e-4, dim:int=3, + inner_radius:float=0.5, outer_radius:int=1) -> TTuple: + """Creates a *concentric spheres* dataset of `n_samples` datasets points. + + :param n_samples: number of datasets points in the generated dataset + :type n_samples: int + :param noise: standard deviation of noise magnitude added to each datasets point + :type noise: float + :param dim: dimension of the spheres + :type dim: float + :param inner_radius: radius of the inner sphere + :type inner_radius: float + :param outer_radius: radius of the outer sphere + :type outer_radius: float + """ + X, y = zeros((n_samples, dim)), torch.zeros(n_samples) + y[:n_samples // 2] = 1 + samples = [] + for i in range(n_samples // 2): + samples.append(randnsphere(dim, inner_radius)[None, :]) + X[:n_samples // 2] = cat(samples) + X[:n_samples // 2] += zeros((n_samples // 2, dim)).normal_(0, std=noise) + samples = [] + for i in range(n_samples // 2): + samples.append(randnsphere(dim, outer_radius)[None, :]) + X[n_samples // 2:] = cat(samples) + X[n_samples // 2:] += zeros((n_samples // 2, dim)).normal_(0, std=noise) + return X, y + + +def generate_moons(n_samples:int=100, noise:float=1e-4, **kwargs) -> TTuple: + """Creates a *moons* dataset of `n_samples` datasets points. + + :param n_samples: number of datasets points in the generated dataset + :type n_samples: int + :param noise: standard deviation of noise magnitude added to each datasets point + :type noise: float + """ + n_samples_out = n_samples // 2 + n_samples_in = n_samples - n_samples_out + outer_circ_x = np.cos(np.linspace(0, np.pi, n_samples_out)) + outer_circ_y = np.sin(np.linspace(0, np.pi, n_samples_out)) + inner_circ_x = 1 - np.cos(np.linspace(0, np.pi, n_samples_in)) + inner_circ_y = 1 - np.sin(np.linspace(0, np.pi, n_samples_in)) - .5 + + X = np.vstack([np.append(outer_circ_x, inner_circ_x), + np.append(outer_circ_y, inner_circ_y)]).T + y = np.hstack([np.zeros(n_samples_out, dtype=np.intp), + np.ones(n_samples_in, dtype=np.intp)]) + + if noise is not None: + X += np.random.rand(n_samples, 1) * noise + + X, y = Tensor(X), Tensor(y).long() + return X, y + + +def generate_spirals(n_samples=100, noise=1e-4, **kwargs) -> TTuple: + """Creates a *spirals* dataset of `n_samples` datasets points. + + :param n_samples: number of datasets points in the generated dataset + :type n_samples: int + :param noise: standard deviation of noise magnitude added to each datasets point + :type noise: float + """ + n = np.sqrt(np.random.rand(n_samples, 1)) * 780 * (2 * np.pi) / 360 + d1x = -np.cos(n) * n + np.random.rand(n_samples, 1) * noise + d1y = np.sin(n) * n + np.random.rand(n_samples, 1) * noise + X, y = (np.vstack((np.hstack((d1x, d1y)), np.hstack((-d1x, -d1y)))), + np.hstack((np.zeros(n_samples), np.ones(n_samples)))) + X, y = torch.Tensor(X), torch.Tensor(y).long() + return X, y + + +def generate_gaussians(n_samples=100, n_gaussians=7, dim=2, + radius=0.5, std_gaussians=0.1, noise=1e-3) -> TTuple: + """Creates `dim`-dimensional `n_gaussians` on a ring of radius `radius`. + + :param n_samples: number of datasets points in the generated dataset + :type n_samples: int + :param n_gaussians: number of gaussians distributions placed on the circle of radius `radius` + :type n_gaussians: int + :param dim: dimension of the dataset. The distributions are placed on the hyperplane (x1, x2, 0, 0..) if dim > 2 + :type dim: int + :param radius: radius of the circle on which the distributions lie + :type radius: int + :param std_gaussians: standard deviation of the gaussians. + :type std_gaussians: int + :param noise: standard deviation of noise magnitude added to each datasets point + :type noise: float + """ + X = torch.zeros(n_samples * n_gaussians, dim) ; y = torch.zeros(n_samples * n_gaussians).long() + angle = torch.zeros(1) + if dim > 2: loc = torch.cat([radius*torch.cos(angle), radius*torch.sin(angle), torch.zeros(dim-2)]) + else: loc = torch.cat([radius*torch.cos(angle), radius*torch.sin(angle)]) + dist = Normal(loc, scale=std_gaussians) + + for i in range(n_gaussians): + angle += 2*math.pi / n_gaussians + if dim > 2: dist.loc = torch.Tensor([radius*torch.cos(angle), torch.sin(angle), radius*torch.zeros(dim-2)]) + else: dist.loc = torch.Tensor([radius*torch.cos(angle), radius*torch.sin(angle)]) + X[i*n_samples:(i+1)*n_samples] = dist.sample(sample_shape=(n_samples,)) + torch.randn(dim)*noise + y[i*n_samples:(i+1)*n_samples] = i + return X, y + + +def generate_gaussians_spiral(n_samples=100, n_gaussians=7, n_gaussians_per_loop=4, dim=2, + radius_start=1, radius_end=0.2, std_gaussians_start=0.3, + std_gaussians_end=0.1, noise=1e-3) -> TTuple: + """Creates `dim`-dimensional `n_gaussians` on a spiral. + + :param n_samples: number of datasets points in the generated dataset + :type n_samples: int + :param n_gaussians: number of total gaussians distributions placed on the spirals + :type n_gaussians: int + :param n_gaussians_per_loop: number of gaussians distributions per loop of the spiral + :type n_gaussians_per_loop: int + :param dim: dimension of the dataset. The distributions are placed on the hyperplane (x1, x2, 0, 0..) if dim > 2 + :type dim: int + :param radius_start: starting radius of the spiral + :type radius_start: int + :param radius_end: end radius of the spiral + :type radius_end: int + :param std_gaussians_start: standard deviation of the gaussians at the start of the spiral. Linear interpolation (start, end, num_gaussians) + :type std_gaussians_start: int + :param std_gaussians_end: standard deviation of the gaussians at the end of the spiral + :type std_gaussians_end: int + :param noise: standard deviation of noise magnitude added to each datasets point + :type noise: float + """ + X = torch.zeros(n_samples * n_gaussians, dim) ; y = torch.zeros(n_samples * n_gaussians).long() + angle = torch.zeros(1) + radiuses = torch.linspace(radius_start, radius_end, n_gaussians) + std_devs = torch.linspace(std_gaussians_start, std_gaussians_end, n_gaussians) + + if dim > 2: loc = torch.cat([radiuses[0]*torch.cos(angle), radiuses[0]*torch.sin(angle), torch.zeros(dim-2)]) + else: loc = torch.cat([radiuses[0]*torch.cos(angle), radiuses[0]*torch.sin(angle)]) + dist = Normal(loc, scale=std_devs[0]) + + for i in range(n_gaussians): + angle += 2*math.pi / n_gaussians_per_loop + if dim > 2: dist.loc = torch.Tensor([radiuses[i]*torch.cos(angle), torch.sin(angle), radiuses[i]*torch.zeros(dim-2)]) + else: dist.loc = torch.Tensor([radiuses[i]*torch.cos(angle), radiuses[i]*torch.sin(angle)]) + dist.scale = std_devs[i] + + X[i*n_samples:(i+1)*n_samples] = dist.sample(sample_shape=(n_samples,)) + torch.randn(dim)*noise + y[i*n_samples:(i+1)*n_samples] = i + return X, y + + +def generate_diffeqml(n_samples=100, noise=1e-3) -> Tuple[Tensor, None]: + """Samples `n_samples` 2-dim points from the DiffEqML logo. + + :param n_samples: number of datasets points in the generated dataset + :type n_samples: int + :param noise: standard deviation of noise magnitude added to each datasets point + :type noise: float + """ + mu = 1 + X0 = [[0,2],[-1.6, -1.2],[1.6, -1.2],] + ti, tf = 0., 3.2 + t = np.linspace(ti,tf,500) + # define the ODE model + def odefunc(t,x): + dxdt = -x[1] + mu*x[0]*(1- x[0]**2 - x[1]**2) + dydt = x[0] + mu*x[1]*(1- x[0]**2 - x[1]**2) + return np.array([dxdt,dydt]).T + # integrate ODE + X = [] + for x0 in X0: + sol = solve_ivp(odefunc, [ti, tf], x0, method='LSODA', t_eval=t) + X.append(torch.tensor(sol.y.T).float()) + + theta = torch.linspace(0,2*np.pi, 1000) + X.append(torch.cat([2*torch.cos(theta)[:,None], 2*torch.sin(theta)[:,None]],1)) + X = torch.cat(X) + k = KernelDensity(kernel='gaussian',bandwidth=.01) + k.fit(X) + + X = torch.tensor(k.sample(n_samples) + noise*np.random.randn(n_samples, 2)).float() + return X, None + + +class ToyDataset: + """Handles the generation of classification toy datasets""" + def generate(self, n_samples:int, dataset_type:str, **kwargs) -> TTuple: + """Handles the generation of classification toy datasets + :param n_samples: number of datasets points in the generated dataset + :type n_samples: int + :param dataset_type: {'moons', 'spirals', 'spheres', 'gaussians', 'gaussians_spiral', diffeqml'} + :type dataset_type: str + :param dim: if 'spheres': dimension of the spheres + :type dim: float + :param inner_radius: if 'spheres': radius of the inner sphere + :type inner_radius: float + :param outer_radius: if 'spheres': radius of the outer sphere + :type outer_radius: float + """ + if dataset_type == 'moons': + return generate_moons(n_samples=n_samples, **kwargs) + elif dataset_type == 'spirals': + return generate_spirals(n_samples=n_samples, **kwargs) + elif dataset_type == 'spheres': + return generate_concentric_spheres(n_samples=n_samples, **kwargs) + elif dataset_type == 'gaussians': + return generate_gaussians(n_samples=n_samples, **kwargs) + elif dataset_type == 'gaussians_spiral': + return generate_gaussians_spiral(n_samples=n_samples, **kwargs) + elif dataset_type == 'diffeqml': + return generate_diffeqml(n_samples=n_samples, **kwargs) diff --git a/torchdyn/models/README.md b/torchdyn/models/README.md new file mode 100644 index 0000000..43774ee --- /dev/null +++ b/torchdyn/models/README.md @@ -0,0 +1,22 @@ +### Goals of `torchdyn` +Our aim with `torchdyn` aims is to provide a unified, flexible API to aid in the implementation of recent advances in continuous and implicit learning. Some models already implemented, either here under `torchdyn.models` or in the tutorials, are: + +* Neural Ordinary Differential Equations (Neural ODE) [[1](https://arxiv.org/abs/1806.07366)] +* Galerkin Neural ODE [[2](https://arxiv.org/abs/2002.08071)] +* Neural Stochastic Differential Equations (Neural SDE) [[3](https://arxiv.org/abs/1905.09883),[4](https://arxiv.org/abs/1906.02355)] +* Graph Neural ODEs [[5](https://arxiv.org/abs/1911.07532)] +* Hamiltonian Neural Networks [[6](https://arxiv.org/abs/1906.01563)] + +Recurrent or "hybrid" versions for sequences +* ODE-RNN [[7](https://arxiv.org/abs/1907.03907)] + +Neural numerical methods +* Hypersolvers [[12](https://arxiv.org/pdf/2007.09601.pdf)] + +Augmentation strategies to relieve neural differential equations of their expressivity limitations and reduce the computational burden of the numerical solver +* ANODE (0-augmentation) [[8](https://arxiv.org/abs/1904.01681)] +* Input-layer augmentation [[9](https://arxiv.org/abs/2002.08071)] +* Higher-order augmentation [[10](https://arxiv.org/abs/2002.08071)] + +Various sensitivity algorithms / variants +* Integral loss adjoint [[11](https://arxiv.org/abs/2003.08063)] \ No newline at end of file diff --git a/torchdyn/models/__init__.py b/torchdyn/models/__init__.py new file mode 100644 index 0000000..b819e74 --- /dev/null +++ b/torchdyn/models/__init__.py @@ -0,0 +1,15 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .energy import * +from torchdyn.core.neuralde import * +from .cnf import * \ No newline at end of file diff --git a/torchdyn/models/cnf.py b/torchdyn/models/cnf.py new file mode 100644 index 0000000..017855c --- /dev/null +++ b/torchdyn/models/cnf.py @@ -0,0 +1,66 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +from typing import Union, Callable +from torch.autograd import grad + + +def autograd_trace(x_out, x_in, **kwargs): + """Standard brute-force means of obtaining trace of the Jacobian, O(d) calls to autograd""" + trJ = 0. + for i in range(x_in.shape[1]): + trJ += grad(x_out[:, i].sum(), x_in, allow_unused=False, create_graph=True)[0][:, i] + return trJ + +def hutch_trace(x_out, x_in, noise=None, **kwargs): + """Hutchinson's trace Jacobian estimator, O(1) call to autograd""" + jvp = grad(x_out, x_in, noise, create_graph=True)[0] + trJ = torch.einsum('bi,bi->b', jvp, noise) + + return trJ + +REQUIRES_NOISE = [hutch_trace] + +class CNF(nn.Module): + def __init__(self, net:nn.Module, trace_estimator:Union[Callable, None]=None, noise_dist=None, order=1): + """Continuous Normalizing Flow + + :param net: function parametrizing the datasets vector field. + :type net: nn.Module + :param trace_estimator: specifies the strategy to otbain Jacobian traces. Options: (autograd_trace, hutch_trace) + :type trace_estimator: Callable + :param noise_dist: distribution of noise vectors sampled for stochastic trace estimators. Needs to have a `.sample` method. + :type noise_dist: torch.distributions.Distribution + :param order: specifies parameters of the Neural DE. + :type order: int + """ + super().__init__() + self.net, self.order = net, order # order at the CNF level will be merged with DEFunc + self.trace_estimator = trace_estimator if trace_estimator is not None else autograd_trace; + self.noise_dist, self.noise = noise_dist, None + if self.trace_estimator in REQUIRES_NOISE: + assert self.noise_dist is not None, 'This type of trace estimator requires specification of a noise distribution' + + def forward(self, x): + with torch.set_grad_enabled(True): + # first dimension is reserved to divergence propagation + x_in = x[:,1:].requires_grad_(True) + + # the neural network will handle the datasets-dynamics here + if self.order > 1: x_out = self.higher_order(x_in) + else: x_out = self.net(x_in) + + trJ = self.trace_estimator(x_out, x_in, noise=self.noise) + return torch.cat([-trJ[:, None], x_out], 1) + 0*x # `+ 0*x` has the only purpose of connecting x[:, 0] to autograd graph + diff --git a/torchdyn/models/energy.py b/torchdyn/models/energy.py new file mode 100644 index 0000000..0cedd20 --- /dev/null +++ b/torchdyn/models/energy.py @@ -0,0 +1,112 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial + +import torch +from torch import Tensor +import torch.nn as nn +from torch.autograd import grad +from torch.autograd.functional import hessian, jacobian + + +class ConservativeLinearSNF(nn.Module): + def __init__(self, energy, J): + """Stable Neural Flows: https://arxiv.org/abs/2003.08063 + A generalization of Hamiltonian Neural Networks and other energy-based parametrization of Neural ODEs + Conservative version with energy preservation. Input assumed to be of dimensions `batch, dim` + + Args: + energy: function parametrizing the energy. + J: network parametrizing the skew-symmetric interconnection matrix + """ + super().__init__() + self.energy = energy + self.J = J + + def forward(self, x: Tensor): + with torch.set_grad_enabled(True): + self.n = x.shape[1] // 2 + x = x.requires_grad_(True) + dHdx = torch.autograd.grad(self.H(x).sum(), x, create_graph=True)[0] + dHdx = torch.einsum('ijk, ij -> ik', self._skew(x), dHdx) + return dHdx + + def _generate_skew(self, x): + M = self.J(x).reshape(-1, *x.shape[1:]) + return (M - M.transpose(0, 2, 1)) / 2 + +class GNF(nn.Module): + def __init__(self, energy:nn.Module): + """Gradient Neural Flows version of SNFs: https://arxiv.org/abs/2003.08063 + Args: + energy (nn.Module): function parametrizing the energy. + """ + super().__init__() + self.energy = energy + + def forward(self, x): + with torch.set_grad_enabled(True): + x = x.requires_grad_(True) + eps = self.energy(x).sum() + out = -torch.autograd.grad(eps, x, allow_unused=False, create_graph=True)[0] + return out + + +class HNN(nn.Module): + def __init__(self, net:nn.Module): + """Hamiltonian Neural ODE + + Args: + net (nn.Module): function parametrizing the vector field. + """ + super().__init__() + self.net = net + + def forward(self, x): + with torch.set_grad_enabled(True): + n = x.shape[1] // 2 + x = x.requires_grad_(True) + gradH = grad(self.net(x).sum(), x, create_graph=True)[0] + return torch.cat([gradH[:, n:], -gradH[:, :n]], 1).to(x) + + +class LNN(nn.Module): + def __init__(self, net): + """Lagrangian Neural Network. + + Args: + net (nn.Module) + Notes: + LNNs are currently quite slow. Improvements will be made whenever `functorch` is either merged upstream or included + as a dependency. + """ + super().__init__() + self.net = net + + def forward(self, x): + self.n = n = x.shape[1]//2 + bs = x.shape[0] + x = x.requires_grad_(True) + qqd_batch = tuple(x[i, :] for i in range(bs)) + jac = tuple(map(partial(jacobian, self._lagrangian, create_graph=True), qqd_batch)) + hess = tuple(map(partial(hessian, self._lagrangian, create_graph=True), qqd_batch)) + qdd_batch = tuple(map(self._qdd, zip(jac, hess, qqd_batch))) + qd, qdd = x[:, n:], torch.cat([qdd[None] for qdd in qdd_batch]) + return torch.cat([qd, qdd], 1) + + def _lagrangian(self, qqd): + return self.net(qqd).sum() + + def _qdd(self, inp): + n = self.n ; jac, hess, qqd = inp + return hess[n:, n:].pinverse()@(jac[:n] - hess[n:, :n]@qqd[n:]) diff --git a/torchdyn/models/hybrid.py b/torchdyn/models/hybrid.py new file mode 100644 index 0000000..b95958a --- /dev/null +++ b/torchdyn/models/hybrid.py @@ -0,0 +1,126 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"Experimental API for hybrid Neural DEs and continuous models applied to sequences -> [ODE-RNN, Neural CDE]" + +import math +import torch +import torch.nn as nn +from torch.distributions import Normal, kl_divergence +import pytorch_lightning as pl +import torchsde + +from torchdyn.models import LSDEFunc + + +class HybridNeuralDE(nn.Module): + def __init__(self, flow, jump, out, last_output=True, reverse=False): + """ODE-RNN / LSTM / GRU""" + super().__init__() + self.flow, self.jump, self.out = flow, jump, out + self.reverse, self.last_output = reverse, last_output + + # determine type of `jump` func + # jump can be of two types: + # either take hidden and element of sequence (e.g RNNCell) + # or h, x_t and c (LSTMCell). Custom implementation assumes call + # signature of type (x_t, h) and .hidden_size property + if type(jump) == nn.modules.rnn.LSTMCell: + self.jump_func = self._jump_latent_cell + else: + self.jump_func = self._jump_latent + + def forward(self, x): + h = c = self._init_latent(x) + Y = torch.zeros(x.shape[0], *h.shape).to(x) + if self.reverse: x_t = x_t.flip(0) + for t, x_t in enumerate(x): + h, c = self.jump_func(x_t, h, c) + h = self.flow(h) + Y[t] = h + Y = self.out(Y) + return Y[-1] if self.last_output else Y + + def _init_latent(self, x): + x = x[0] + return torch.zeros(x.shape[0], self.jump.hidden_size).to(x.device) + + def _jump_latent(self, *args): + x_t, h, c = args[:3] + return self.jump(x_t, h), c + + def _jump_latent_cell(self, *args): + x_t, h, c = args[:3] + return self.jump(x_t, (h, c)) + + +class LatentNeuralSDE(NeuralSDE, pl.LightningModule): # pragma: no cover + def __init__(self, post_drift, diffusion, prior_drift, sigma, theta, mu, options, + noise_type, order, sensitivity, s_span, solver, atol, rtol, intloss): + """Latent Neural SDEs.""" + + super().__init__(drift_func=post_drift, diffusion_func=diffusion, noise_type=noise_type, + order=order, sensitivity=sensitivity, s_span=s_span, solver=solver, + atol=atol, rtol=rtol, intloss=intloss) + + self.defunc = LSDEFunc(f=post_drift, g=diffusion, h=prior_drift) + self.defunc.noise_type, self.defunc.sde_type = noise_type, 'ito' + self.options = options + + # p(y0). + logvar = math.log(sigma ** 2. / (2. * theta)) + self.py0_mean = nn.Parameter(torch.tensor([[mu]]), requires_grad=False) + self.py0_logvar = nn.Parameter(torch.tensor([[logvar]]), requires_grad=False) + + # q(y0). + self.qy0_mean = nn.Parameter(torch.tensor([[mu]]), requires_grad=True) + self.qy0_logvar = nn.Parameter(torch.tensor([[logvar]]), requires_grad=True) + + def forward(self, eps: torch.Tensor, s_span=None): + eps = eps.to(self.qy0_std) + x0 = self.qy0_mean + eps * self.qy0_std + + qy0 = Normal(loc=self.qy0_mean, scale=self.qy0_std) + py0 = Normal(loc=self.py0_mean, scale=self.py0_std) + logqp0 = kl_divergence(qy0, py0).sum(1).mean(0) # KL(time=0). + + if s_span is not None: + s_span_ext = s_span + else: + s_span_ext = self.s_span.cpu() + + zs, logqp = torchsde.sdeint(sde=self.defunc, x0=x0, s_span=s_span_ext, + rtol=self.rtol, atol=self.atol, logqp=True, options=self.options, + adaptive=self.adaptive, method=self.solver) + + logqp = logqp.sum(0).mean(0) + log_ratio = logqp0 + logqp # KL(time=0) + KL(path). + + return zs, log_ratio + + def sample_p(self, vis_span, n_sim, eps=None, bm=None, dt=0.01): + eps = torch.randn(n_sim, 1).to(self.py0_mean).to(self.device) if eps is None else eps + y0 = self.py0_mean + eps.to(self.device) * self.py0_std + return torchsde.sdeint(self.defunc, y0, vis_span, bm=bm, method='srk', dt=dt, names={'drift': 'h'}) + + def sample_q(self, vis_span, n_sim, eps=None, bm=None, dt=0.01): + eps = torch.randn(n_sim, 1).to(self.qy0_mean) if eps is None else eps + y0 = self.qy0_mean + eps.to(self.device) * self.qy0_std + return torchsde.sdeint(self.defunc, y0, vis_span, bm=bm, method='srk', dt=dt) + + @property + def py0_std(self): + return torch.exp(.5 * self.py0_logvar) + + @property + def qy0_std(self): + return torch.exp(.5 * self.qy0_logvar) diff --git a/torchdyn/nn/__init__.py b/torchdyn/nn/__init__.py new file mode 100644 index 0000000..7800a9e --- /dev/null +++ b/torchdyn/nn/__init__.py @@ -0,0 +1,19 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from torchdyn.nn.galerkin import GalLayer, GalLinear, GalConv2d, Fourier, Polynomial, Chebychev, VanillaRBF, MultiquadRBF, GaussianRBF +from torchdyn.nn.node_layers import Augmenter, DepthCat, DataControl + + +__all__ = ['Augmenter', 'DepthCat', 'DataControl', + 'GalLinear', 'GalConv2d', 'VanillaRBF', 'MultiquadRBF', 'GaussianRBF', + 'Fourier', 'Polynomial', 'Chebychev'] \ No newline at end of file diff --git a/torchdyn/nn/galerkin.py b/torchdyn/nn/galerkin.py new file mode 100644 index 0000000..217d12b --- /dev/null +++ b/torchdyn/nn/galerkin.py @@ -0,0 +1,273 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import math + + +class GaussianRBF(nn.Module): + """Eigenbasis expansion using gaussian radial basis functions. $phi(r) = e^{-(\eps r)^2}$ with $r := || x - x0 ||_2$" + :param deg: degree of the eigenbasis expansion + :type deg: int + :param adaptive: whether to adjust `centers` and `eps_scales` during training. + :type adaptive: bool + :param eps_scales: scaling in the rbf formula ($\eps$) + :type eps_scales: int + :param centers: centers of the radial basis functions (one per degree). Same center across all degrees. x0 in the radius formulas + :type centers: int + """ + + def __init__(self, deg, adaptive=False, eps_scales=2, centers=0): + super().__init__() + self.deg, self.n_eig = deg, 1 + if adaptive: + self.centers = torch.nn.Parameter(centers * torch.ones(deg + 1)) + self.eps_scales = torch.nn.Parameter(eps_scales * torch.ones((deg + 1))) + else: + self.centers = 0 + self.eps_scales = 2 + + def forward(self, n_range, s): + n_range_scaled = (n_range - self.centers) / self.eps_scales + r = torch.norm(s - self.centers, p=2) + basis = [math.e ** (-(r * n_range_scaled) ** 2)] + return basis + + +class VanillaRBF(nn.Module): + """Eigenbasis expansion using vanilla radial basis functions." + :param deg: degree of the eigenbasis expansion + :type deg: int + :param adaptive: whether to adjust `centers` and `eps_scales` during training. + :type adaptive: bool + :param eps_scales: scaling in the rbf formula ($\eps$) + :type eps_scales: int + :param centers: centers of the radial basis functions (one per degree). Same center across all degrees. x0 in the radius formulas + :type centers: int + """ + + def __init__(self, deg, adaptive=False, eps_scales=2, centers=0): + super().__init__() + self.deg, self.n_eig = deg, 1 + if adaptive: + self.centers = torch.nn.Parameter(centers * torch.ones(deg + 1)) + self.eps_scales = torch.nn.Parameter(eps_scales * torch.ones((deg + 1))) + else: + self.centers = 0 + self.eps_scales = 2 + + def forward(self, n_range, s): + n_range_scaled = n_range / self.eps_scales + r = torch.norm(s - self.centers, p=2) + basis = [r * n_range_scaled] + return basis + + +class MultiquadRBF(nn.Module): + """Eigenbasis expansion using multiquadratic radial basis functions." + :param deg: degree of the eigenbasis expansion + :type deg: int + :param adaptive: whether to adjust `centers` and `eps_scales` during training. + :type adaptive: bool + :param eps_scales: scaling in the rbf formula ($\eps$) + :type eps_scales: int + :param centers: centers of the radial basis functions (one per degree). Same center across all degrees. x0 in the radius formulas + :type centers: int + """ + + def __init__(self, deg, adaptive=False, eps_scales=2, centers=0): + super().__init__() + self.deg, self.n_eig = deg, 1 + if adaptive: + self.centers = torch.nn.Parameter(centers * torch.ones(deg + 1)) + self.eps_scales = torch.nn.Parameter(eps_scales * torch.ones((deg + 1))) + else: + self.centers = 0 + self.eps_scales = 2 + + def forward(self, n_range, s): + n_range_scaled = n_range / self.eps_scales + r = torch.norm(s - self.centers, p=2) + basis = [1 + torch.sqrt(1 + (r * n_range_scaled) ** 2)] + return basis + + +class Fourier(nn.Module): + """Eigenbasis expansion using fourier functions." + :param deg: degree of the eigenbasis expansion + :type deg: int + :param adaptive: does nothing (for now) + :type adaptive: bool + """ + + def __init__(self, deg, adaptive=False): + super().__init__() + self.deg, self.n_eig = deg, 2 + + def forward(self, n_range, s): + s_n_range = s * n_range + basis = [torch.cos(s_n_range), torch.sin(s_n_range)] + return basis + + +class Polynomial(nn.Module): + """Eigenbasis expansion using polynomials." + :param deg: degree of the eigenbasis expansion + :type deg: int + :param adaptive: does nothing (for now) + :type adaptive: bool + """ + + def __init__(self, deg, adaptive=False): + super().__init__() + self.deg, self.n_eig = deg, 1 + + def forward(self, n_range, s): + basis = [s ** n_range] + return basis + + +class Chebychev(nn.Module): + """Eigenbasis expansion using chebychev polynomials." + :param deg: degree of the eigenbasis expansion + :type deg: int + :param adaptive: does nothing (for now) + :type adaptive: bool + """ + + def __init__(self, deg, adaptive=False): + super().__init__() + self.deg, self.n_eig = deg, 1 + + def forward(self, n_range, s): + max_order = n_range[-1].int().item() + basis = [1] + # Based on numpy's Cheb code + if max_order > 0: + s2 = 2 * s + basis += [s.item()] + for i in range(2, max_order): + basis += [basis[-1] * s2 - basis[-2]] + return [torch.tensor(basis).to(n_range)] + + +class GalLayer(nn.Module): + """Galerkin layer template. Introduced in https://arxiv.org/abs/2002.08071""" + def __init__(self, bias=True, expfunc=Fourier(5), dilation=True, shift=True): + super().__init__() + self.dilation = torch.ones(1) if not dilation else nn.Parameter(data=torch.ones(1), requires_grad=True) + self.shift = torch.zeros(1) if not shift else nn.Parameter(data=torch.zeros(1), requires_grad=True) + self.expfunc = expfunc + self.n_eig = n_eig = self.expfunc.n_eig + self.deg = deg = self.expfunc.deg + + def reset_parameters(self): + torch.nn.init.zeros_(self.coeffs) + + def calculate_weights(self, t): + "Expands `t` following the chosen eigenbasis" + n_range = torch.linspace(0, self.deg, self.deg).to(self.coeffs.device) + basis = self.expfunc(n_range, t*self.dilation.to(self.coeffs.device) + self.shift.to(self.coeffs.device)) + B = [] + for i in range(self.n_eig): + Bin = torch.eye(self.deg).to(self.coeffs.device) + Bin[range(self.deg), range(self.deg)] = basis[i] + B.append(Bin) + B = torch.cat(B, 1).to(self.coeffs.device) + coeffs = torch.cat([self.coeffs[:,:,i] for i in range(self.n_eig)],1).transpose(0,1).to(self.coeffs.device) + X = torch.matmul(B, coeffs) + return X.sum(0) + + +class GalLinear(GalLayer): + """Linear Galerkin layer for depth--variant neural differential equations. Introduced in https://arxiv.org/abs/2002.08071 + :param in_features: input dimensions + :type in_features: int + :param out_features: output dimensions + :type out_features: int + :param bias: include bias parameter vector in the layer computation + :type bias: bool + :param expfunc: {'Fourier', 'Polynomial', 'Chebychev', 'VanillaRBF', 'MultiquadRBF', 'GaussianRBF'}. Choice of eigenfunction expansion. + :type expfunc: str + :param dilation: whether to optimize for `dilation` parameter. Allows the GalLayer to dilate the eigenfunction period. + :type dilation: bool + :param shift: whether to optimize for `shift` parameter. Allows the GalLayer to shift the eigenfunction period. + :type shift: bool + """ + def __init__(self, in_features, out_features, bias=True, expfunc=Fourier(5), dilation=True, shift=True): + super().__init__(bias, expfunc, dilation, shift) + + self.in_features, self.out_features = in_features, out_features + self.weight = torch.Tensor(out_features, in_features) + if bias: + self.bias = torch.Tensor(out_features) + else: + self.register_parameter('bias', None) + self.coeffs = torch.nn.Parameter(torch.Tensor((in_features+1)*out_features, self.deg, self.n_eig)) + self.reset_parameters() + + def forward(self, input): + # For the moment, GalLayers rely on DepthCat to access the `t` variable. + t = input[-1,-1] + input = input[:,:-1] + w = self.calculate_weights(t) + self.weight = w[0:self.in_features*self.out_features].reshape(self.out_features, self.in_features) + self.bias = w[self.in_features*self.out_features:(self.in_features+1)*self.out_features].reshape(self.out_features) + return torch.nn.functional.linear(input, self.weight, self.bias) + + +class GalConv2d(GalLayer): + """2D convolutional Galerkin layer for depth--variant neural differential equations. Introduced in https://arxiv.org/abs/2002.08071 + :param in_channels: number of channels in the input image + :type in_channels: int + :param out_channels: number of channels produced by the convolution + :type out_channels: int + :param kernel_size: size of the convolving kernel + :type kernel_size: int + :param stride: stride of the convolution. Default: 1 + :type stride: int + :param padding: zero-padding added to both sides of the input. Default: 0 + :type padding: int + :param bias: include bias parameter vector in the layer computation + :type bias: bool + :param expfunc: {'Fourier', 'Polynomial', 'Chebychev', 'VanillaRBF', 'MultiquadRBF', 'GaussianRBF'}. Choice of eigenfunction expansion. + :type expfunc: str + :param dilation: whether to optimize for `dilation` parameter. Allows the GalLayer to dilate the eigenfunction period. + :type dilation: bool + :param shift: whether to optimize for `shift` parameter. Allows the GalLayer to shift the eigenfunction period. + :type shift: bool + """ + __constants__ = ['bias', 'in_channels', 'out_channels', 'kernel_size', 'stride', 'padding', 'deg'] + def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, bias=True, + expfunc=Fourier(5), dilation=True, shift=True): + super().__init__(bias, expfunc, dilation, shift) + + self.ic, self.oc, self.ks = in_channels, out_channels, kernel_size + self.pad, self.stride = padding, stride + + self.weight = torch.Tensor(out_channels, in_channels, kernel_size, kernel_size) + if bias: + self.bias = torch.Tensor(out_channels) + else: + self.register_parameter('bias', None) + self.coeffs = torch.nn.Parameter(torch.Tensor(((out_channels)*in_channels*(kernel_size**2)+out_channels), self.deg, 2)) + self.reset_parameters() + + def forward(self, input): + t = input[-1,-1,0,0] + input = input[:,:-1] + w = self.calculate_weights(t) + n = self.oc*self.ic*self.ks*self.ks + self.weight = w[0:n].reshape(self.oc, self.ic, self.ks, self.ks) + self.bias = w[n:].reshape(self.oc) + return torch.nn.functional.conv2d(input, self.weight, self.bias, stride=self.stride, padding=self.pad) diff --git a/torchdyn/nn/node_layers.py b/torchdyn/nn/node_layers.py new file mode 100644 index 0000000..fbf479b --- /dev/null +++ b/torchdyn/nn/node_layers.py @@ -0,0 +1,82 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn + +class Augmenter(nn.Module): + """Augmentation class. Can handle several types of augmentation strategies for Neural DEs. + + :param augment_dims: number of augmented dimensions to initialize + :type augment_dims: int + :param augment_idx: index of dimension to augment + :type augment_idx: int + :param augment_func: nn.Module applied to the input datasets of dimension `d` to determine the augmented initial condition of dimension `d + a`. + `a` is defined implicitly in `augment_func` e.g. augment_func=nn.Linear(2, 5) augments a 2 dimensional input with 3 additional dimensions. + :type augment_func: nn.Module + :param order: whether to augment before datasets [augmentation, x] or after [x, augmentation] along dimension `augment_idx`. Options: ('first', 'last') + :type order: str + """ + def __init__(self, augment_idx:int=1, augment_dims:int=5, augment_func=None, order='first'): + super().__init__() + self.augment_dims, self.augment_idx, self.augment_func = augment_dims, augment_idx, augment_func + self.order = order + + def forward(self, x: torch.Tensor): + if not self.augment_func: + new_dims = list(x.shape) + new_dims[self.augment_idx] = self.augment_dims + + # if-else check for augmentation order + if self.order == 'first': + x = torch.cat([torch.zeros(new_dims).to(x), x], + self.augment_idx) + else: + x = torch.cat([x, torch.zeros(new_dims).to(x)], + self.augment_idx) + else: + # if-else check for augmentation order + if self.order == 'first': + x = torch.cat([self.augment_func(x).to(x), x], + self.augment_idx) + else: + x = torch.cat([x, self.augment_func(x).to(x)], + self.augment_idx) + return x + + +class DepthCat(nn.Module): + """Depth variable `t` concatenation module. Allows for easy concatenation of `t` each call of the numerical solver, at specified nn of the DEFunc. + + :param idx_cat: index of the datasets dimension to concatenate `t` to. + :type idx_cat: int + """ + def __init__(self, idx_cat=1): + super().__init__() + self.idx_cat, self.t = idx_cat, None + + def forward(self, x): + t_shape = list(x.shape) + t_shape[self.idx_cat] = 1 + t = self.t * torch.ones(t_shape).to(x) + return torch.cat([x, t], self.idx_cat).to(x) + + +class DataControl(nn.Module): + """Data-control module. Allows for datasets-control inputs at arbitrary points of the DEFunc + """ + def __init__(self): + super().__init__() + self._control = None + + def forward(self, x): + return torch.cat([x, self._control], 1).to(x) \ No newline at end of file diff --git a/torchdyn/numerics/__init__.py b/torchdyn/numerics/__init__.py new file mode 100644 index 0000000..492edfc --- /dev/null +++ b/torchdyn/numerics/__init__.py @@ -0,0 +1,20 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from torchdyn.numerics.solvers.ode import Euler, RungeKutta4, Tsitouras45, DormandPrince45, AsynchronousLeapfrog, MSZero, MSBackward +from torchdyn.numerics.solvers.hyper import HyperEuler +from torchdyn.numerics.odeint import odeint, odeint_symplectic, odeint_mshooting, odeint_hybrid +from torchdyn.numerics.systems import VanDerPol, Lorenz + +__all__ = ['odeint', 'odeint_symplectic', 'Euler', 'RungeKutta4', 'DormandPrince45', 'Tsitouras45', + 'AsynchronousLeapfrog', 'HyperEuler', 'MSZero', 'MSBackward', 'Lorenz', 'VanDerPol'] + \ No newline at end of file diff --git a/torchdyn/numerics/interpolators.py b/torchdyn/numerics/interpolators.py new file mode 100644 index 0000000..ea3c7b4 --- /dev/null +++ b/torchdyn/numerics/interpolators.py @@ -0,0 +1,73 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Contains several Interpolator classes""" + +import torch +from torchdyn.numerics.solvers._constants import construct_4th +class Interpolator: + def __init__(self, order): + self.order = order + + def sync_device_dtype(self, x, t_span): + "Ensures `x`, `t_span`, `tableau` and other interpolator tensors are on the same device with compatible dtypes" + if self.bmid is not None: self.bmid = self.bmid.to(x) + return x, t_span + + def fit(self, f0, f1, x0, x1, t, dt, **kwargs): + pass + + def evaluate(self, coefs, t0, t1, t): + "Evaluates a generic interpolant given coefs between [t0, t1]." + theta = (t - t0) / (t1 - t0) + result = coefs[0] + theta * coefs[1] + theta_power = theta + for coef in coefs[2:]: + theta_power = theta_power * theta + result += theta_power * coef + return result + + +class Linear(Interpolator): + def __init__(self): + raise NotImplementedError + + +class ThirdHermite(Interpolator): + def __init__(self): + super().__init__(order=3) + raise NotImplementedError + + +class FourthOrder(Interpolator): + def __init__(self, dtype): + """4th order interpolation scheme.""" + super().__init__(order=4) + self.bmid = construct_4th(dtype) + + def fit(self, dt, f0, f1, x0, x1, x_mid, **kwargs): + c1 = 2 * dt * (f1 - f0) - 8 * (x1 + x0) + 16 * x_mid + c2 = dt * (5 * f0 - 3 * f1) + 18 * x0 + 14 * x1 - 32 * x_mid + c3 = dt * (f1 - 4 * f0) - 11 * x0 - 5 * x1 + 16 * x_mid + c4 = dt * f0 + c5 = x0 + return [c5, c4, c3, c2, c1] + + + +INTERP_DICT = {'4th': FourthOrder} + + +def str_to_interp(solver_name, dtype=torch.float32): + "Transforms string specifying desired interpolation scheme into an instance of the Interpolator class." + interpolator = INTERP_DICT[solver_name] + return interpolator(dtype) \ No newline at end of file diff --git a/torchdyn/numerics/odeint.py b/torchdyn/numerics/odeint.py new file mode 100644 index 0000000..7f10158 --- /dev/null +++ b/torchdyn/numerics/odeint.py @@ -0,0 +1,507 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" + Functional API of ODE integration routines, with specialized functions for different options + `odeint` and `odeint_mshooting` prepare and redirect to more specialized routines, detected automatically. +""" +from typing import List, Tuple, Union, Callable, Dict, Iterable +from warnings import warn + +import torch +from torch import Tensor +import torch.nn as nn + +from torchdyn.numerics.solvers.ode import AsynchronousLeapfrog, Tsitouras45, str_to_solver, str_to_ms_solver +from torchdyn.numerics.interpolators import str_to_interp +from torchdyn.numerics.utils import hairer_norm, init_step, adapt_step, EventState + + +def odeint(f:Callable, x:Tensor, t_span:Union[List, Tensor], solver:Union[str, nn.Module], atol:float=1e-3, rtol:float=1e-3, + t_stops:Union[List, Tensor, None]=None, verbose:bool=False, interpolator:Union[str, Callable, None]=None, return_all_eval:bool=False, + save_at:Union[Iterable, Tensor]=(), args:Dict={}, seminorm:Tuple[bool, Union[int, None]]=(False, None)) -> Tuple[Tensor, Tensor]: + """Solve an initial value problem (IVP) determined by function `f` and initial condition `x`. + + Functional `odeint` API of the `torchdyn` package. + + Args: + f (Callable): + x (Tensor): + t_span (Union[List, Tensor]): + solver (Union[str, nn.Module]): + atol (float, optional): Defaults to 1e-3. + rtol (float, optional): Defaults to 1e-3. + t_stops (Union[List, Tensor, None], optional): Defaults to None. + verbose (bool, optional): Defaults to False. + interpolator (bool, optional): Defaults to False. + return_all_eval (bool, optional): Defaults to False. + save_at (Union[List, Tensor], optional): Defaults to t_span + args (Dict): Arbitrary parameters used in step + seminorm (Tuple[bool, Union[int, None]], optional): Whether to use seminorms in local error computation. + + Returns: + Tuple[Tensor, Tensor]: returns a Tuple (t_eval, solution). + """ + if t_span[1] < t_span[0]: # time is reversed + if verbose: warn("You are integrating on a reversed time domain, adjusting the vector field automatically") + f_ = lambda t, x: -f(-t, x) + t_span = -t_span + else: f_ = f + + if type(t_span) == list: t_span = torch.cat(t_span) + # instantiate the solver in case the user has specified preference via a `str` and ensure compatibility of device ~ dtype + if type(solver) == str: + solver = str_to_solver(solver, x.dtype) + x, t_span = solver.sync_device_dtype(x, t_span) + stepping_class = solver.stepping_class + + # instantiate the interpolator similar to the solver steps above + if isinstance(solver, Tsitouras45): + if verbose: warn("Running interpolation not yet implemented for `tsit5`") + interpolator = None + + if type(interpolator) == str: + interpolator = str_to_interp(interpolator, x.dtype) + x, t_span = interpolator.sync_device_dtype(x, t_span) + + # access parallel integration routines with different t_spans for each sample in `x`. + if len(t_span.shape) > 1: + raise NotImplementedError("Parallel routines not implemented yet, check experimental versions of `torchdyn`") + # odeint routine with a single t_span for all samples + elif len(t_span.shape) == 1: + if stepping_class == 'fixed': + if atol != odeint.__defaults__[0] or rtol != odeint.__defaults__[1]: + warn("Setting tolerances has no effect on fixed-step methods") + # instantiate save_at tensor + return _fixed_odeint(f_, x, t_span, solver, save_at=save_at, args=args) + elif stepping_class == 'adaptive': + t = t_span[0] + k1 = f_(t, x) + dt = init_step(f, k1, x, t, solver.order, atol, rtol) + if len(save_at) > 0: warn("Setting save_at has no effect on adaptive-step methods") + return _adaptive_odeint(f_, k1, x, dt, t_span, solver, atol, rtol, args, interpolator, return_all_eval, seminorm) + + +# TODO (qol) state augmentation for symplectic methods +def odeint_symplectic(f:Callable, x:Tensor, t_span:Union[List, Tensor], solver:Union[str, nn.Module], atol:float=1e-3, rtol:float=1e-3, + verbose:bool=False, return_all_eval:bool=False, save_at:Union[List, Tensor]=()): + """Solve an initial value problem (IVP) determined by function `f` and initial condition `x` using symplectic methods. + + Designed to be a subroutine of `odeint` (i.e. will eventually automatically be dispatched to here, much like `_adaptive_odeint`) + + Args: + f (Callable): + x (Tensor): + t_span (Union[List, Tensor]): + solver (Union[str, nn.Module]): + atol (float, optional): Defaults to 1e-3. + rtol (float, optional): Defaults to 1e-3. + verbose (bool, optional): Defaults to False. + return_all_eval (bool, optional): Defaults to False. + save_at (Union[List, Tensor], optional): Defaults to t_span + """ + if t_span[1] < t_span[0]: # time is reversed + if verbose: warn("You are integrating on a reversed time domain, adjusting the vector field automatically") + f_ = lambda t, x: -f(-t, x) + t_span = -t_span + else: f_ = f + if type(t_span) == list: t_span = torch.cat(t_span) + + # instantiate the solver in case the user has specified preference via a `str` and ensure compatibility of device ~ dtype + if type(solver) == str: + solver = str_to_solver(solver, x.dtype) + x, t_span = solver.sync_device_dtype(x, t_span) + stepping_class = solver.stepping_class + + # additional bookkeeping for symplectic solvers + if not hasattr(f, 'order'): + raise RuntimeError('The system order should be specified as an attribute `order` of `vector_field`') + if isinstance(solver, AsynchronousLeapfrog) and f.order == 2: + raise RuntimeError('ALF solver should be given a vector field specified as a first-order symplectic system: v = f(t, x)') + solver.x_shape = x.shape[-1] // 2 + + # access parallel integration routines with different t_spans for each sample in `x`. + if len(t_span.shape) > 1: + raise NotImplementedError("Parallel routines not implemented yet, check experimental versions of `torchdyn`") + # odeint routine with a single t_span for all samples + elif len(t_span.shape) == 1: + if stepping_class == 'fixed': + if atol != odeint_symplectic.__defaults__[0] or rtol != odeint_symplectic.__defaults__[1]: + warn("Setting tolerances has no effect on fixed-step methods") + return _fixed_odeint(f_, x, t_span, solver, save_at=save_at) + elif stepping_class == 'adaptive': + t = t_span[0] + if f.order == 1: + pos = x[..., : solver.x_shape] + k1 = f(t, pos) + dt = init_step(f, k1, pos, t, solver.order, atol, rtol) + else: + k1 = f(t, x) + dt = init_step(f, k1, x, t, solver.order, atol, rtol) + return _adaptive_odeint(f_, k1, x, dt, t_span, solver, atol, rtol, return_all_eval) + + +def odeint_mshooting(f:Callable, x:Tensor, t_span:Tensor, solver:Union[str, nn.Module], B0=None, fine_steps=2, maxiter=4): + """Solve an initial value problem (IVP) determined by function `f` and initial condition `x` using parallel-in-time solvers. + + Args: + f (Callable): vector field + x (Tensor): batch of initial conditions + t_span (Tensor): integration interval + solver (Union[str, nn.Module]): parallel-in-time solver. + B0 ([type], optional): Initialized shooting parameters. If left to None, will compute automatically + using the coarse method of solver. Defaults to None. + fine_steps (int, optional): Defaults to 2. + maxiter (int, optional): Defaults to 4. + + Notes: + TODO: At the moment assumes the ODE to NOT be time-varying. An extension is possible by adaptive the step + function of a parallel-in-time solvers. + """ + if type(solver) == str: + solver = str_to_ms_solver(solver) + x, t_span = solver.sync_device_dtype(x, t_span) + # first-guess B0 of shooting parameters + if B0 is None: + _, B0 = odeint(f, x, t_span, solver.coarse_method) + # determine which odeint to apply to MS solver. This is where time-variance can be introduced + odeint_func = _fixed_odeint + B = solver.root_solve(odeint_func, f, x, t_span, B0, fine_steps, maxiter) + return t_span, B + + + +def odeint_hybrid(f, x, t_span, j_span, solver, callbacks, atol=1e-3, rtol=1e-3, event_tol=1e-4, priority='jump', + seminorm:Tuple[bool, Union[int, None]]=(False, None)): + """Solve an initial value problem (IVP) determined by function `f` and initial condition `x`, with jump events defined + by a callbacks. + + Args: + f ([type]): + x ([type]): + t_span ([type]): + j_span ([type]): + solver ([type]): + callbacks ([type]): + t_eval (list, optional): Defaults to []. + atol ([type], optional): Defaults to 1e-3. + rtol ([type], optional): Defaults to 1e-3. + event_tol ([type], optional): Defaults to 1e-4. + priority (str, optional): Defaults to 'jump'. + """ + # instantiate the solver in case the user has specified preference via a `str` and ensure compatibility of device ~ dtype + if type(solver) == str: solver = str_to_solver(solver, x.dtype) + x, t_span = solver.sync_device_dtype(x, t_span) + x_shape = x.shape + ckpt_counter, ckpt_flag, jnum = 0, False, 0 + t_eval, t, T = t_span[1:], t_span[:1], t_span[-1] + + ###### initial jumps ########### + event_states = EventState([False for _ in range(len(callbacks))]) + + if priority == 'jump': + new_event_states = EventState([cb.check_event(t, x) for cb in callbacks]) + triggered_events = event_states != new_event_states + # check if any event flag changed from `False` to `True` within last step + triggered_events = sum([(a_ != b_) & (b_ == False) + for a_, b_ in zip(new_event_states.evid, event_states.evid)]) + if triggered_events > 0: + i = min([i for i, idx in enumerate(new_event_states.evid) if idx == True]) + x = callbacks[i].jump_map(t, x) + jnum = jnum + 1 + + ################## initial step size setting ################ + k1 = f(t, x) + dt = init_step(f, k1, x, t, solver.order, atol, rtol) + + #### init solution & time vector #### + eval_times, sol = [t], [x] + + while t < T and jnum < j_span: + + ############### checkpointing ############################### + if t + dt > t_span[-1]: + dt = t_span[-1] - t + if t_eval is not None: + if (ckpt_counter < len(t_eval)) and (t + dt > t_eval[ckpt_counter]): + dt_old, ckpt_flag = dt, True + dt = t_eval[ckpt_counter] - t + ckpt_counter += 1 + + ################ step + f_new, x_new, x_err, _ = solver.step(f, x, t, dt, k1=k1) + + ################ callback and events ######################## + new_event_states = EventState([cb.check_event(t + dt, x_new) for cb in callbacks]) + triggered_events = sum([(a_ != b_) & (b_ == False) + for a_, b_ in zip(new_event_states.evid, event_states.evid)]) + + + # if event, close in on switching state in [t, t + Δt] via bisection + if triggered_events > 0: + + dt_pre, t_inner, dt_inner, x_inner, niters = dt, t, dt, x, 0 + max_iters = 100 # TODO (numerics): compute tol as function of tolerances + + while niters < max_iters and event_tol < dt_inner: + with torch.no_grad(): + dt_inner = dt_inner / 2 + f_new, x_, x_err, _ = solver.step(f, x_inner, t_inner, dt_inner, k1=k1) + + new_event_states = EventState([cb.check_event(t_inner + dt_inner, x_) + for cb in callbacks]) + triggered_events = sum([(a_ != b_) & (b_ == False) + for a_, b_ in zip(new_event_states.evid, event_states.evid)]) + niters = niters + 1 + + if triggered_events == 0: # if no event, advance start point of bisection search + x_inner = x_ + t_inner = t_inner + dt_inner + dt_inner = dt + k1 = f_new + # TODO (qol): optional save + #sol.append(x_inner.reshape(x_shape)) + #eval_times.append(t_inner.reshape(t.shape)) + x = x_inner + t = t_inner + i = min([i for i, x in enumerate(new_event_states.evid) if x == True]) + + # save state and time BEFORE jump + sol.append(x.reshape(x_shape)) + eval_times.append(t.reshape(t.shape)) + + # apply jump func. + x = callbacks[i].jump_map(t, x) + + # save state and time AFTER jump + sol.append(x.reshape(x_shape)) + eval_times.append(t.reshape(t.shape)) + + # reset k1 + k1 = None + dt = dt_pre + + else: + ################# compute error ############################# + if seminorm[0] == True: + state_dim = seminorm[1] + error = x_err[:state_dim] + error_scaled = error / (atol + rtol * torch.max(x[:state_dim].abs(), x_new[:state_dim].abs())) + else: + error = x_err + error_scaled = error / (atol + rtol * torch.max(x.abs(), x_new.abs())) + + error_ratio = hairer_norm(error_scaled) + accept_step = error_ratio <= 1 + + if accept_step: + t = t + dt + x = x_new + sol.append(x.reshape(x_shape)) + eval_times.append(t.reshape(t.shape)) + k1 = f_new + + if ckpt_flag: + dt = dt_old - dt + ckpt_flag = False + ################ stepsize control ########################### + dt = adapt_step(dt, error_ratio, + solver.safety, + solver.min_factor, + solver.max_factor, + solver.order) + + return torch.cat(eval_times), torch.stack(sol) + + +def _adaptive_odeint(f, k1, x, dt, t_span, solver, atol=1e-4, rtol=1e-4, args=None, interpolator=None, return_all_eval=False, seminorm=(False, None)): + """Adaptive ODE solve routine, called by `odeint`. + + Args: + f ([type]): + k1 ([type]): + x ([type]): + dt ([type]): + t_span ([type]): + solver ([type]): + atol ([type], optional): Defaults to 1e-4. + rtol ([type], optional): Defaults to 1e-4. + args (Dict): + use_interp (bool, optional): + return_all_eval (bool, optional): Defaults to False. + + + Notes: + (1) We check if the user wants all evaluated solution points, not only those + corresponding to times in `t_span`. This is automatically set to `True` when `odeint` + is called for interpolated adjoints + """ + t_eval, t, T = t_span[1:], t_span[:1], t_span[-1] + ckpt_counter, ckpt_flag = 0, False + eval_times, sol = [t], [x] + while t < T: + if t + dt > T: + dt = T - t + ############### checkpointing ############################### + if t_eval is not None: + # satisfy checkpointing by using interpolation scheme or resetting `dt` + if (ckpt_counter < len(t_eval)) and (t + dt > t_eval[ckpt_counter]): + if interpolator == None: + # save old dt, raise "checkpoint" flag and repeat step + dt_old, ckpt_flag = dt, True + dt = t_eval[ckpt_counter] - t + + f_new, x_new, x_err, stages = solver.step(f, x, t, dt, k1=k1, args=args) + ################# compute error ############################# + if seminorm[0] == True: + state_dim = seminorm[1] + error = x_err[:state_dim] + error_scaled = error / (atol + rtol * torch.max(x[:state_dim].abs(), x_new[:state_dim].abs())) + else: + error = x_err + error_scaled = error / (atol + rtol * torch.max(x.abs(), x_new.abs())) + error_ratio = hairer_norm(error_scaled) + accept_step = error_ratio <= 1 + + if accept_step: + ############### checkpointing via interpolation ############################### + if t_eval is not None and interpolator is not None: + coefs = None + while (ckpt_counter < len(t_eval)) and (t + dt > t_eval[ckpt_counter]): + t0, t1 = t, t + dt + x_mid = x + dt * sum([interpolator.bmid[i] * stages[i] for i in range(len(stages))]) + f0, f1, x0, x1 = k1, f_new, x, x_new + if coefs == None: coefs = interpolator.fit(dt, f0, f1, x0, x1, x_mid) + x_in = interpolator.evaluate(coefs, t0, t1, t_eval[ckpt_counter]) + sol.append(x_in) + eval_times.append(t_eval[ckpt_counter][None]) + ckpt_counter += 1 + + if t + dt == t_eval[ckpt_counter] or return_all_eval: # note (1) + sol.append(x_new) + eval_times.append(t + dt) + # we only increment the ckpt counter if the solution points corresponds to a time point in `t_span` + if t + dt == t_eval[ckpt_counter]: ckpt_counter += 1 + t, x = t + dt, x_new + k1 = f_new + + ################ stepsize control ########################### + # reset "dt" in case of checkpoint without interp + if ckpt_flag: + dt = dt_old - dt + ckpt_flag = False + + dt = adapt_step(dt, error_ratio, + solver.safety, + solver.min_factor, + solver.max_factor, + solver.order) + return torch.cat(eval_times), torch.stack(sol) + + +def _fixed_odeint(f, x, t_span, solver, save_at=(), args={}): + """Solves IVPs with same `t_span`, using fixed-step methods""" + if len(save_at) == 0: save_at = t_span + if not isinstance(save_at, torch.Tensor): + save_at = torch.tensor(save_at) + + assert all(torch.isclose(t, save_at).sum() == 1 for t in save_at),\ + "each element of save_at [torch.Tensor] must be contained in t_span [torch.Tensor] once and only once" + + t, T, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] + + sol = [] + if torch.isclose(t, save_at).sum(): + sol = [x] + + steps = 1 + while steps <= len(t_span) - 1: + _, x, _ = solver.step(f, x, t, dt, k1=None, args=args) + t = t + dt + + if torch.isclose(t, save_at).sum(): + sol.append(x) + if steps < len(t_span) - 1: dt = t_span[steps+1] - t + steps += 1 + + if isinstance(sol[0], dict): + final_out = {k: [v] for k, v in sol[0].items()} + _ = [final_out[k].append(x[k]) for k in x.keys() for x in sol[1:]] + final_out = {k: torch.stack(v) for k, v in final_out.items()} + elif isinstance(sol[0], torch.Tensor): + final_out = torch.stack(sol) + else: + raise NotImplementedError(f"{type(x)} is not supported as the state variable") + + return save_at, final_out + + +def _shifted_fixed_odeint(f, x, t_span): + """Solves ``n_segments'' jagged IVPs in parallel with fixed-step methods. All subproblems + have equal step sizes and number of solution points + + Notes: + Assumes `dt` fixed. TODO: update in each loop evaluation.""" + t, T = t_span[..., 0], t_span[..., -1] + dt = t_span[..., 1] - t + sol, k1 = [], f(t, x) + + not_converged = ~((t - T).abs() <= 1e-6).bool() + while not_converged.any(): + x[:, ~not_converged] = torch.zeros_like(x[:, ~not_converged]) + k1, _, x = solver.step(f, x, t, dt[..., None], k1=k1) # dt will be broadcasted on dim1 + sol.append(x) + t = t + dt + not_converged = ~((t - T).abs() <= 1e-6).bool() + # stacking is only possible since the number of steps in each of the ``n_segments'' + # is assumed to be the same. Otherwise require jagged tensors or a [] + return torch.stack(sol) + + + +def _jagged_fixed_odeint(f, x, + t_span: List, solver): + """ + Solves ``n_segments'' jagged IVPs in parallel with fixed-step methods. Each sub-IVP can vary in number + of solution steps and step sizes + + Returns: + A list of `len(t_span)' containing solutions of each IVP computed in parallel. + """ + t, T = [t_sub[0] for t_sub in t_span], [t_sub[-1] for t_sub in t_span] + t, T = torch.stack(t), torch.stack(T) + + dt = torch.stack([t_[1] - t0 for t_, t0 in zip(t_span, t)]) + sol = [[x_] for x_ in x] + not_converged = ~((t - T).abs() <= 1e-6).bool() + steps = 0 + while not_converged.any(): + _, _, x = solver.step(f, x, t, dt[..., None, None]) # dt will be to x dims + + for n, sol_ in enumerate(sol): + sol_.append(x[n]) + t = t + dt + not_converged = ~((t - T).abs() <= 1e-6).bool() + + steps += 1 + dt = [] + for t_, tcur in zip(t_span, t): + if steps > len(t_) - 1: + dt.append(torch.zeros_like(tcur)) # subproblem already solved + else: + dt.append(t_[steps] - tcur) + + dt = torch.stack(dt) + # prune solutions to remove noop steps + sol = [sol_[:len(t_)] for sol_, t_ in zip(sol, t_span)] + return [torch.stack(sol_, 0) for sol_ in sol] + diff --git a/torchdyn/numerics/sensitivity.py b/torchdyn/numerics/sensitivity.py new file mode 100644 index 0000000..503b21e --- /dev/null +++ b/torchdyn/numerics/sensitivity.py @@ -0,0 +1,161 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from inspect import getfullargspec +import torch +from torch.autograd import Function, grad +from torchcde import CubicSpline, natural_cubic_coeffs +from torchdyn.numerics.odeint import odeint, odeint_mshooting + + +def generic_odeint(problem_type, vf, x, t_span, solver, atol, rtol, interpolator, B0=None, + return_all_eval=False, maxiter=4, fine_steps=4, save_at=()): + "Dispatches to appropriate `odeint` function depending on `Problem` class (ODEProblem, MultipleShootingProblem)" + if problem_type == 'standard': + return odeint(vf, x, t_span, solver, atol=atol, rtol=rtol, interpolator=interpolator, return_all_eval=return_all_eval, + save_at=save_at) + elif problem_type == 'multiple_shooting': + return odeint_mshooting(vf, x, t_span, solver, B0=B0, fine_steps=fine_steps, maxiter=maxiter) + + +# TODO: optimize and make conditional gradient computations w.r.t end times +# TODO: link `seminorm` arg from `ODEProblem` +def _gather_odefunc_adjoint(vf, vf_params, solver, atol, rtol, interpolator, solver_adjoint, + atol_adjoint, rtol_adjoint, integral_loss, problem_type, maxiter=4, fine_steps=4): + "Prepares definition of autograd.Function for adjoint sensitivity analysis of the above `ODEProblem`" + class _ODEProblemFunc(Function): + @staticmethod + def forward(ctx, vf_params, x, t_span, B=None, save_at=()): + t_sol, sol = generic_odeint(problem_type, vf, x, t_span, solver, atol, rtol, interpolator, B, + False, maxiter, fine_steps, save_at) + ctx.save_for_backward(sol, t_sol) + return t_sol, sol + + @staticmethod + def backward(ctx, *grad_output): + sol, t_sol = ctx.saved_tensors + vf_params = torch.cat([p.contiguous().flatten() for p in vf.parameters()]) + # initialize flattened adjoint state + xT, λT, μT = sol[-1], grad_output[-1][-1], torch.zeros_like(vf_params) + xT_nel, λT_nel, μT_nel = xT.numel(), λT.numel(), μT.numel() + xT_shape, λT_shape, μT_shape = xT.shape, λT.shape, μT.shape + + λT_flat = λT.flatten() + λtT = λT_flat @ vf(t_sol[-1], xT).flatten() + # concatenate all states of adjoint system + A = torch.cat([xT.flatten(), λT_flat, μT.flatten(), λtT[None]]) + + def adjoint_dynamics(t, A): + if len(t.shape) > 0: t = t[0] + x, λ, μ = A[:xT_nel], A[xT_nel:xT_nel+λT_nel], A[-μT_nel-1:-1] + x, λ, μ = x.reshape(xT.shape), λ.reshape(λT.shape), μ.reshape(μT.shape) + with torch.set_grad_enabled(True): + x, t = x.requires_grad_(True), t.requires_grad_(True) + dx = vf(t, x) + dλ, dt, *dμ = tuple(grad(dx, (x, t) + tuple(vf.parameters()), -λ, + allow_unused=True, retain_graph=False)) + + if integral_loss: + dg = torch.autograd.grad(integral_loss(t, x).sum(), x, allow_unused=True, retain_graph=True)[0] + dλ = dλ - dg + + dμ = torch.cat([el.flatten() if el is not None else torch.zeros(1) + for el in dμ], dim=-1) + if dt == None: dt = torch.zeros(1).to(t) + if len(t.shape) == 0: dt = dt.unsqueeze(0) + return torch.cat([dx.flatten(), dλ.flatten(), dμ.flatten(), dt.flatten()]) + + # solve the adjoint equation + n_elements = (xT_nel, λT_nel, μT_nel) + dLdt = torch.zeros(len(t_sol)).to(xT) + dLdt[-1] = λtT + for i in range(len(t_sol) - 1, 0, -1): + t_adj_sol, A = odeint(adjoint_dynamics, A, t_sol[i - 1:i + 1].flip(0), + solver_adjoint, atol=atol_adjoint, rtol=rtol_adjoint, + seminorm=(True, xT_nel+λT_nel)) + # prepare adjoint state for next interval + #TODO: reuse vf_eval for dLdt calculations + xt = A[-1, :xT_nel].reshape(xT_shape) + dLdt_ = A[-1, xT_nel:xT_nel + λT_nel]@vf(t_sol[i], xt).flatten() + A[-1, -1:] -= grad_output[0][i - 1] + dLdt[i-1] = dLdt_ + + A = torch.cat([A[-1, :xT_nel], A[-1, xT_nel:xT_nel + λT_nel], A[-1, -μT_nel-1:-1], A[-1, -1:]]) + A[xT_nel:xT_nel + λT_nel] += grad_output[-1][i - 1].flatten() + + λ, μ = A[xT_nel:xT_nel + λT_nel], A[-μT_nel-1:-1] + λ, μ = λ.reshape(λT.shape), μ.reshape(μT.shape) + λ_tspan = torch.stack([dLdt[0], dLdt[-1]]) + return (μ, λ, λ_tspan, None, None, None) + + return _ODEProblemFunc + + +#TODO: introduce `t_span` grad as above +def _gather_odefunc_interp_adjoint(vf, vf_params, solver, atol, rtol, interpolator, solver_adjoint, + atol_adjoint, rtol_adjoint, integral_loss, problem_type, maxiter=4, fine_steps=4): + "Prepares definition of autograd.Function for interpolated adjoint sensitivity analysis of the above `ODEProblem`" + class _ODEProblemFunc(Function): + @staticmethod + def forward(ctx, vf_params, x, t_span, B=None, save_at=()): + t_sol, sol = generic_odeint(problem_type, vf, x, t_span, solver, atol, rtol, interpolator, B, + True, maxiter, fine_steps, save_at) + ctx.save_for_backward(sol, t_span, t_sol) + return t_sol, sol + + @staticmethod + def backward(ctx, *grad_output): + sol, t_span, t_sol = ctx.saved_tensors + vf_params = torch.cat([p.contiguous().flatten() for p in vf.parameters()]) + + # initialize adjoint state + xT, λT, μT = sol[-1], grad_output[-1][-1], torch.zeros_like(vf_params) + λT_nel, μT_nel = λT.numel(), μT.numel() + xT_shape, λT_shape, μT_shape = xT.shape, λT.shape, μT.shape + A = torch.cat([λT.flatten(), μT.flatten()]) + + spline_coeffs = natural_cubic_coeffs(x=sol.permute(1, 0, 2).detach(), t=t_sol) + x_spline = CubicSpline(coeffs=spline_coeffs, t=t_sol) + + # define adjoint dynamics + def adjoint_dynamics(t, A): + if len(t.shape) > 0: t = t[0] + x = x_spline.evaluate(t).requires_grad_(True) + t = t.requires_grad_(True) + λ, μ = A[:λT_nel], A[-μT_nel:] + λ, μ = λ.reshape(λT.shape), μ.reshape(μT.shape) + with torch.set_grad_enabled(True): + dx = vf(t, x) + dλ, dt, *dμ = tuple(grad(dx, (x, t) + tuple(vf.parameters()), -λ, + allow_unused=True, retain_graph=False)) + + if integral_loss: + dg = torch.autograd.grad(integral_loss(t, x).sum(), x, allow_unused=True, retain_graph=True)[0] + dλ = dλ - dg + + dμ = torch.cat([el.flatten() if el is not None else torch.zeros(1) + for el in dμ], dim=-1) + return torch.cat([dλ.flatten(), dμ.flatten()]) + + # solve the adjoint equation + n_elements = (λT_nel, μT_nel) + for i in range(len(t_span) - 1, 0, -1): + t_adj_sol, A = odeint(adjoint_dynamics, A, t_span[i - 1:i + 1].flip(0), solver, atol=atol, rtol=rtol) + # prepare adjoint state for next interval + A = torch.cat([A[-1, :λT_nel], A[-1, -μT_nel:]]) + A[:λT_nel] += grad_output[-1][i - 1].flatten() + + λ, μ = A[:λT_nel], A[-μT_nel:] + λ, μ = λ.reshape(λT.shape), μ.reshape(μT.shape) + return (μ, λ, None, None, None) + + return _ODEProblemFunc \ No newline at end of file diff --git a/torchdyn/numerics/solvers/_constants.py b/torchdyn/numerics/solvers/_constants.py new file mode 100644 index 0000000..e7d4fff --- /dev/null +++ b/torchdyn/numerics/solvers/_constants.py @@ -0,0 +1,128 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"This file contains helper functions to construct Tableaus of explicit solvers RungeKutta4, DormandPrince45, Tsit45" +import torch +from collections import namedtuple + +ExplicitRKTableau = namedtuple('ExplicitRKTableau', 'c, A, b_sol, b_err') + + +def construct_rk4(dtype): + c = torch.tensor([0., 1 / 2, 1 / 2, 1], dtype=dtype) + a = [ + torch.tensor([1 / 2], dtype=dtype), + torch.tensor([0., 1 / 2], dtype=dtype), + torch.tensor([0., 0., 1], dtype=dtype)] + bsol = torch.tensor([1 / 6, 1 / 3, 1 / 3, 1 / 6], dtype=dtype) + berr = torch.tensor([0.]) # for improved compatibility with utilities of other solvers, not technically true + return (c, a, bsol, berr) + + +def construct_dopri5(dtype): + c = torch.tensor([1 / 5, 3 / 10, 4 / 5, 8 / 9, 1., 1.], dtype=dtype) + a = [ + torch.tensor([1 / 5], dtype=dtype), + torch.tensor([3 / 40, 9 / 40], dtype=dtype), + torch.tensor([44 / 45, -56 / 15, 32 / 9], dtype=dtype), + torch.tensor([19372 / 6561, -25360 / 2187, 64448 / 6561, -212 / 729], dtype=dtype), + torch.tensor([9017 / 3168, -355 / 33, 46732 / 5247, 49 / 176, -5103 / 18656], dtype=dtype), + torch.tensor([35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84], dtype=dtype), + ] + bsol = torch.tensor([35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84, 0], dtype=dtype) + berr = torch.tensor([1951 / 21600, 0, 22642 / 50085, 451 / 720, -12231 / 42400, 649 / 6300, 1 / 60.], dtype=dtype) + + dmid = torch.tensor([-1.1270175653862835, 0., 2.675424484351598, -5.685526961588504, 3.5219323679207912, + -1.7672812570757455, 2.382468931778144]) + return (c, a, bsol, bsol - berr) + + +def construct_tsit5(dtype): + + c = torch.tensor([ + 161 / 1000, + 327 / 1000, + 9 / 10, + .9800255409045096857298102862870245954942137979563024768854764293221195950761080302604, + 1., + 1. + ], dtype=dtype) + a = [ + torch.tensor([ + 161 / 1000 + ], dtype=dtype), + torch.tensor([ + -.8480655492356988544426874250230774675121177393430391537369234245294192976164141156943e-2, + .3354806554923569885444268742502307746751211773934303915373692342452941929761641411569 + ], dtype=dtype), + torch.tensor([ + 2.897153057105493432130432594192938764924887287701866490314866693455023795137503079289, + -6.359448489975074843148159912383825625952700647415626703305928850207288721235210244366, + 4.362295432869581411017727318190886861027813359713760212991062156752264926097707165077 + ], dtype=dtype), + torch.tensor([ + 5.325864828439256604428877920840511317836476253097040101202360397727981648835607691791, + -11.74888356406282787774717033978577296188744178259862899288666928009020615663593781589, + 7.495539342889836208304604784564358155658679161518186721010132816213648793440552049753, + -.9249506636175524925650207933207191611349983406029535244034750452930469056411389539635e-1 + ], dtype=dtype), + torch.tensor([ + 5.861455442946420028659251486982647890394337666164814434818157239052507339770711679748, + -12.92096931784710929170611868178335939541780751955743459166312250439928519268343184452, + 8.159367898576158643180400794539253485181918321135053305748355423955009222648673734986, + -.7158497328140099722453054252582973869127213147363544882721139659546372402303777878835e-1, + -.2826905039406838290900305721271224146717633626879770007617876201276764571291579142206e-1 + ], dtype=dtype), + torch.tensor([ + .9646076681806522951816731316512876333711995238157997181903319145764851595234062815396e-1, + 1 / 100, + .4798896504144995747752495322905965199130404621990332488332634944254542060153074523509, + 1.379008574103741893192274821856872770756462643091360525934940067397245698027561293331, + -3.290069515436080679901047585711363850115683290894936158531296799594813811049925401677, + 2.324710524099773982415355918398765796109060233222962411944060046314465391054716027841 + ], dtype=dtype), + ] + bsol = torch.tensor([ + .9646076681806522951816731316512876333711995238157997181903319145764851595234062815396e-1, + 1 / 100, + .4798896504144995747752495322905965199130404621990332488332634944254542060153074523509, + 1.379008574103741893192274821856872770756462643091360525934940067397245698027561293331, + -3.290069515436080679901047585711363850115683290894936158531296799594813811049925401677, + 2.324710524099773982415355918398765796109060233222962411944060046314465391054716027841, + 0. + ], dtype=dtype) + berr = torch.tensor([ + .9468075576583945807478876255758922856117527357724631226139574065785592789071067303271e-1, + .9183565540343253096776363936645313759813746240984095238905939532922955247253608687270e-2, + .4877705284247615707855642599631228241516691959761363774365216240304071651579571959813, + 1.234297566930478985655109673884237654035539930748192848315425833500484878378061439761, + -2.707712349983525454881109975059321670689605166938197378763992255714444407154902012702, + 1.866628418170587035753719399566211498666255505244122593996591602841258328965767580089, + 1 / 66., + ], dtype=dtype) + return (c, a, bsol, bsol - berr) + + +######################## +# Interpolator coeffs +######################## + +"Once we have enough combinations implemented, these will go in each solver's tableau constructor and will be accessed by `Interpolators` through the solver." + +def construct_4th(dtype): + "4th order interpolator for `dopri5`" + bmid = torch.tensor([ + 0.10013431883002395, 0, 0.3918321794184259, -0.02982460176594817, + 0.05893268337240795, -0.04497888809104361, 0.023904308236133973 + ], dtype=dtype) + return bmid + \ No newline at end of file diff --git a/torchdyn/numerics/solvers/hyper.py b/torchdyn/numerics/solvers/hyper.py new file mode 100644 index 0000000..69340f5 --- /dev/null +++ b/torchdyn/numerics/solvers/hyper.py @@ -0,0 +1,47 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +from torchdyn.numerics.solvers.ode import Euler, Midpoint, RungeKutta4 + +class HyperEuler(Euler): + def __init__(self, hypernet, dtype=torch.float32): + super().__init__(dtype) + self.hypernet = hypernet + self.stepping_class = 'fixed' + self.op1 = self.order + 1 + + def step(self, f, x, t, dt, k1=None, args=None): + _, x_sol, _ = super().step(f, x, t, dt, k1) + return None, x_sol + dt**(self.op1) * self.hypernet(t, x), None + +class HyperMidpoint(Midpoint): + def __init__(self, hypernet, dtype=torch.float32): + super().__init__(dtype) + self.hypernet = hypernet + self.stepping_class = 'fixed' + self.op1 = self.order + 1 + + def step(self, f, x, t, dt, k1=None, args=None): + _, x_sol, _ = super().step(f, x, t, dt, k1) + return None, x_sol + dt**(self.op1) * self.hypernet(t, x), None + +class HyperRungeKutta4(RungeKutta4): + def __init__(self, hypernet, dtype=torch.float32): + super().__init__(dtype) + self.hypernet = hypernet + self.op1 = self.order + 1 + + def step(self, f, x, t, dt, k1=None, args=None): + _, x_sol, _ = super().step(f, x, t, dt, k1) + return None, x_sol + dt**(self.op1) * self.hypernet(t, x), None diff --git a/torchdyn/numerics/solvers/ode.py b/torchdyn/numerics/solvers/ode.py new file mode 100644 index 0000000..7406fb9 --- /dev/null +++ b/torchdyn/numerics/solvers/ode.py @@ -0,0 +1,345 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" + Contains ODE solvers, both sequential as well as time-parallel multiple shooting methods necessary for multiple-shoting layers [1]. + The stateful design allows users to modify or tweak each Tableau during training, ensuring compatibility with hybrid methods such as Hypersolvers [2] + [1]: Massaroli S., Poli M. et al "Differentiable Multiple Shooting Layers." + [2]: Poli M., Massaroli S. et al "Hypersolvers: Toward fast continuous-depth models." NeurIPS 2020 +""" + +from typing import Tuple +import torch +import torch.nn as nn +from torchdyn.numerics.solvers.templates import DiffEqSolver, MultipleShootingDiffeqSolver +from torchdyn.numerics.solvers._constants import construct_rk4, construct_dopri5, construct_tsit5 + + +class SolverTemplate(nn.Module): + def __init__(self, order, min_factor:float=0.2, max_factor:float=10, safety:float=0.9): + super().__init__() + self.order = order + self.min_factor = torch.tensor([min_factor]) + self.max_factor = torch.tensor([max_factor]) + self.safety = torch.tensor([safety]) + self.tableau = None + + def sync_device_dtype(self, x, t_span): + "Ensures `x`, `t_span`, `tableau` and other solver tensors are on the same device with compatible dtypes" + + if isinstance(x, dict): + proto_arr = x[list(x.keys())[0]] + elif isinstance(x, torch.Tensor): + proto_arr = x + else: + raise NotImplementedError(f"{type(x)} is not supported as the state variable") + + device = proto_arr.device + + if self.tableau is not None: + c, a, bsol, berr = self.tableau + self.tableau = c.to(proto_arr), [a.to(proto_arr) for a in a], bsol.to(proto_arr), berr.to(proto_arr) + t_span = t_span.to(device) + self.safety = self.safety.to(device) + self.min_factor = self.min_factor.to(device) + self.max_factor = self.max_factor.to(device) + return x, t_span + + def step(self, f, x, t, dt, k1=None, args=None): + pass + + +class Euler(SolverTemplate): + def __init__(self, dtype=torch.float32): + """Explicit Euler ODE stepper, order 1""" + super().__init__(order=1) + self.dtype = dtype + self.stepping_class = 'fixed' + + def step(self, f, x, t, dt, k1=None, args=None): + if k1 == None: k1 = f(t, x) + x_sol = x + dt * k1 + return None, x_sol, None + + + +class Midpoint(DiffEqSolver): + def __init__(self, dtype=torch.float32): + """Explicit Midpoint ODE stepper, order 2""" + super().__init__(order=2) + self.dtype = dtype + self.stepping_class = 'fixed' + + def step(self, f, x, t, dt, k1=None, args=None): + if k1 == None: k1 = f(t, x) + x_mid = x + 0.5 * dt * k1 + x_sol = x + dt * f(t + 0.5 * dt, x_mid) + return None, x_sol, None + + +class RungeKutta4(DiffEqSolver): + def __init__(self, dtype=torch.float32): + """Explicit Midpoint ODE stepper, order 4""" + super().__init__(order=4) + self.dtype = dtype + self.stepping_class = 'fixed' + self.tableau = construct_rk4(self.dtype) + + def step(self, f, x, t, dt, k1=None, args=None): + c, a, bsol, _ = self.tableau + if k1 == None: k1 = f(t, x) + k2 = f(t + c[0] * dt, x + dt * (a[0] * k1)) + k3 = f(t + c[1] * dt, x + dt * (a[1][0] * k1 + a[1][1] * k2)) + k4 = f(t + c[2] * dt, x + dt * (a[2][0] * k1 + a[2][1] * k2 + a[2][2] * k3)) + x_sol = x + dt * (bsol[0] * k1 + bsol[1] * k2 + bsol[2] * k3 + bsol[3] * k4) + return None, x_sol, None + + +class AsynchronousLeapfrog(DiffEqSolver): + def __init__(self, channel_index:int=-1, stepping_class:str='fixed', dtype=torch.float32): + """Explicit Leapfrog symplectic ODE stepper. + Can return local error estimates if adaptive stepping is required""" + super().__init__(order=2) + self.dtype = dtype + self.channel_index = channel_index + self.stepping_class = stepping_class + self.const = 1 + self.tableau = construct_rk4(self.dtype) + # an additional overhead, necessary to preserve a certain degree of sanity + # in the implementation and to avoid API bloating. + self.x_shape = None + + + def step(self, f, xv, t, dt, k1=None, args=None): + half_state_dim = xv.shape[-1] // 2 + x, v = xv[..., :half_state_dim], xv[..., half_state_dim:] + if k1 == None: k1 = f(t, x) + x1 = x + 0.5 * dt * v + vt1 = f(t + 0.5 * dt, x1) + v1 = 2 * self.const * (vt1 - v) + v + x2 = x1 + 0.5 * dt * v1 + x_sol = torch.cat([x2, v1], -1) + if self.stepping_class == 'adaptive': + xv_err = torch.cat([torch.zeros_like(x), v], -1) + else: + xv_err = None + return None, x_sol, xv_err + + +class DormandPrince45(DiffEqSolver): + def __init__(self, dtype=torch.float32): + super().__init__(order=5) + self.dtype = dtype + self.stepping_class = 'adaptive' + self.tableau = construct_dopri5(self.dtype) + + def step(self, f, x, t, dt, k1=None, args=None) -> Tuple: + c, a, bsol, berr = self.tableau + if k1 == None: k1 = f(t, x) + k2 = f(t + c[0] * dt, x + dt * a[0] * k1) + k3 = f(t + c[1] * dt, x + dt * (a[1][0] * k1 + a[1][1] * k2)) + k4 = f(t + c[2] * dt, x + dt * a[2][0] * k1 + dt * a[2][1] * k2 + dt * a[2][2] * k3) + k5 = f(t + c[3] * dt, x + dt * a[3][0] * k1 + dt * a[3][1] * k2 + dt * a[3][2] * k3 + dt * a[3][3] * k4) + k6 = f(t + c[4] * dt, x + dt * a[4][0] * k1 + dt * a[4][1] * k2 + dt * a[4][2] * k3 + dt * a[4][3] * k4 + dt * a[4][4] * k5) + k7 = f(t + c[5] * dt, x + dt * a[5][0] * k1 + dt * a[5][1] * k2 + dt * a[5][2] * k3 + dt * a[5][3] * k4 + dt * a[5][4] * k5 + dt * a[5][5] * k6) + x_sol = x + dt * (bsol[0] * k1 + bsol[1] * k2 + bsol[2] * k3 + bsol[3] * k4 + bsol[4] * k5 + bsol[5] * k6) + err = dt * (berr[0] * k1 + berr[1] * k2 + berr[2] * k3 + berr[3] * k4 + berr[4] * k5 + berr[5] * k6 + berr[6] * k7) + return k7, x_sol, err, (k1, k2, k3, k4, k5, k6, k7) + + + +class Tsitouras45(DiffEqSolver): + def __init__(self, dtype=torch.float32): + super().__init__(order=5) + self.dtype = dtype + self.stepping_class = 'adaptive' + self.tableau = construct_tsit5(self.dtype) + + def step(self, f, x, t, dt, k1=None, args=None) -> Tuple: + c, a, bsol, berr = self.tableau + if k1 == None: k1 = f(t, x) + k2 = f(t + c[0] * dt, x + dt * a[0][0] * k1) + k3 = f(t + c[1] * dt, x + dt * (a[1][0] * k1 + a[1][1] * k2)) + k4 = f(t + c[2] * dt, x + dt * a[2][0] * k1 + dt * a[2][1] * k2 + dt * a[2][2] * k3) + k5 = f(t + c[3] * dt, x + dt * a[3][0] * k1 + dt * a[3][1] * k2 + dt * a[3][2] * k3 + dt * a[3][3] * k4) + k6 = f(t + c[4] * dt, x + dt * a[4][0] * k1 + dt * a[4][1] * k2 + dt * a[4][2] * k3 + dt * a[4][3] * k4 + dt * a[4][4] * k5) + k7 = f(t + c[5] * dt, x + dt * a[5][0] * k1 + dt * a[5][1] * k2 + dt * a[5][2] * k3 + dt * a[5][3] * k4 + dt * a[5][4] * k5 + dt * a[5][5] * k6) + x_sol = x + dt * (bsol[0] * k1 + bsol[1] * k2 + bsol[2] * k3 + bsol[3] * k4 + bsol[4] * k5 + bsol[5] * k6) + err = dt * (berr[0] * k1 + berr[1] * k2 + berr[2] * k3 + berr[3] * k4 + berr[4] * k5 + berr[5] * k6 + berr[6] * k7) + return k7, x_sol, err, (k1, k2, k3, k4, k5, k6, k7) + + +class ImplicitEuler(DiffEqSolver): + def __init__(self, dtype=torch.float32): + super().__init__(order=1) + self.dtype = dtype + self.stepping_class = 'fixed' + self.opt = torch.optim.LBFGS + self.max_iters = 200 + + @staticmethod + def _residual(f, x, t, dt, x_sol): + f_sol = f(t, x_sol) + return torch.sum((x_sol - x - dt*f_sol)**2) + + def step(self, f, x, t, dt, k1=None, args=None): + x_sol = x.clone() + x_sol = nn.Parameter(data=x_sol) + opt = self.opt([x_sol], lr=1, max_iter=self.max_iters, max_eval=10*self.max_iters, + tolerance_grad=1.e-12, tolerance_change=1.e-12, history_size=100, line_search_fn='strong_wolfe') + def closure(): + opt.zero_grad() + residual = ImplicitEuler._residual(f, x, t, dt, x_sol) + x_sol.grad, = torch.autograd.grad(residual, x_sol, only_inputs=True, allow_unused=False) + return residual + opt.step(closure) + return None, x_sol, None + + + + + +class MSForward(MultipleShootingDiffeqSolver): + """Multiple shooting solver using forward sensitivity analysis on the matching conditions of shooting parameters""" + def __init__(self, coarse_method='euler', fine_method='rk4'): + super().__init__(coarse_method, fine_method) + + def root_solve(self, f, x, t_span, B): + raise NotImplementedError("Waiting for `functorch` to be merged in the stable version of Pytorch" + "we need their vjp for efficient implementation of forward sensitivity" + "Refer to DiffEqML/diffeqml-research/multiple-shooting-layers for a manual implementation") + + +class MSZero(MultipleShootingDiffeqSolver): + def __init__(self, coarse_method='euler', fine_method='rk4'): + """Multiple shooting solver using Parareal updates (zero-order approximation of the Jacobian) + + Args: + coarse_method (str, optional): . Defaults to 'euler'. + fine_method (str, optional): . Defaults to 'rk4'. + """ + super().__init__(coarse_method, fine_method) + + # TODO (qol): extend to time-variant ODEs by using shifted_odeint + def root_solve(self, odeint_func, f, x, t_span, B, fine_steps, maxiter): + dt, n_subinterv = t_span[1] - t_span[0], len(t_span) + sub_t_span = torch.linspace(0, dt, fine_steps).to(x) + i = 0 + while i <= maxiter: + i += 1 + B_coarse = odeint_func(f, B[i-1:], sub_t_span, solver=self.coarse_method)[1][-1] + B_fine = odeint_func(f, B[i-1:], sub_t_span, solver=self.fine_method)[1][-1] + B_out = torch.zeros_like(B) + B_out[:i] = B[:i] + B_in = B[i-1] + for m in range(i, n_subinterv): + B_in = odeint_func(f, B_in, sub_t_span, solver=self.coarse_method)[1][-1] + B_in = B_in - B_coarse[m-i] + B_fine[m-i] + B_out[m] = B_in + B = B_out + return B + + +class MSBackward(MultipleShootingDiffeqSolver): + def __init__(self, coarse_method='euler', fine_method='rk4'): + """Multiple shooting solver using discrete adjoints for the Jacobian + + Args: + coarse_method (str, optional): . Defaults to 'euler'. + fine_method (str, optional): . Defaults to 'rk4'. + """ + super().__init__(coarse_method, fine_method) + + def root_solve(self, odeint_func, f, x, t_span, B, fine_steps, maxiter): + dt, n_subinterv = t_span[1] - t_span[0], len(t_span) + sub_t_span = torch.linspace(0, dt, fine_steps).to(x) + i = 0 + B = B.requires_grad_(True) + while i <= maxiter: + i += 1 + B_fine = odeint_func(f, B[i-1:], sub_t_span, solver=self.fine_method)[1][-1] + B_out = torch.zeros_like(B) + B_out[:i] = B[:i] + B_in = B[i-1] + for m in range(i, n_subinterv): + # instead of jvps here the full jacobian can be computed and the vector products + # which involve `B_in` can be performed. Trading memory ++ for speed ++ + J_blk = torch.autograd.grad(B_fine[m-1], B, B_in - B[m-1], retain_graph=True)[0][m-1] + B_in = B_fine[m-1] + J_blk + B_out[m] = B_in + del B # manually free graph + B = B_out + return B + + +class ParallelImplicitEuler(MultipleShootingDiffeqSolver): + def __init__(self, coarse_method='euler', fine_method='euler'): + """Parallel Implicit Euler Method""" + super().__init__(coarse_method, fine_method) + self.solver = torch.optim.LBFGS + self.max_iters = 200 + + def sync_device_dtype(self, x, t_span): + return x, t_span + + @staticmethod + def _residual(f, x, B, t_span): + dt = t_span[1:] - t_span[:-1] + F = f(0., B[1:]) + residual = torch.sum((B[2:] - B[1:-1] - dt[1:, None, None] * F[1:]) ** 2) + residual += torch.sum((B[1] - x - dt[0] * F[0]) ** 2) + return residual + + # TODO (qol): extend to time-variant ODEs by model parallelization + def root_solve(self, odeint_func, f, x, t_span, B, fine_steps, maxiter): + B = B.clone() + B = nn.Parameter(data=B) + solver = self.solver([B], lr=1, max_iter=self.max_iters, max_eval=10 * self.max_iters, + tolerance_grad=1.e-12, tolerance_change=1.e-12, history_size=100, + line_search_fn='strong_wolfe') + + def closure(): + solver.zero_grad() + residual = ParallelImplicitEuler._residual(f, x, B, t_span) + B.grad, = torch.autograd.grad(residual, B, only_inputs=True, allow_unused=False) + return residual + + solver.step(closure) + return B + + +SOLVER_DICT = {'euler': Euler, 'midpoint': Midpoint, + 'rk4': RungeKutta4, 'rk-4': RungeKutta4, 'RungeKutta4': RungeKutta4, + 'dopri5': DormandPrince45, 'DormandPrince45': DormandPrince45, 'DormandPrince5': DormandPrince45, + 'tsit5': Tsitouras45, 'Tsitouras45': Tsitouras45, 'Tsitouras5': Tsitouras45, + 'ieuler': ImplicitEuler, 'implicit_euler': ImplicitEuler, + 'alf': AsynchronousLeapfrog, 'AsynchronousLeapfrog': AsynchronousLeapfrog} + + +MS_SOLVER_DICT = {'mszero': MSZero, 'zero': MSZero, 'parareal': MSZero, + 'msbackward': MSBackward, 'backward': MSBackward, 'discrete-adjoint': MSBackward, + 'ieuler': ParallelImplicitEuler, 'parallel-implicit-euler': ParallelImplicitEuler} + + +def str_to_solver(solver_name, dtype=torch.float32): + "Transforms string specifying desired solver into an instance of the Solver class." + solver = SOLVER_DICT[solver_name] + return solver(dtype) + + +def str_to_ms_solver(solver_name, dtype=torch.float32): + "Returns MSSolver class corresponding to a given string." + solver = MS_SOLVER_DICT[solver_name] + return solver() + + + diff --git a/torchdyn/numerics/solvers/root.py b/torchdyn/numerics/solvers/root.py new file mode 100644 index 0000000..d17b6d9 --- /dev/null +++ b/torchdyn/numerics/solvers/root.py @@ -0,0 +1,300 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Root finding solvers, line search utilities and root find API""" +import torch +from torch import einsum +from torch import norm +import numpy as np +from ..utils import RootLogger +from torch.autograd.functional import jacobian + + +class Broyden: + """ + Template class for Broyden-type low-rank methods + """ + type = 'Quasi-Newton' + + def __init__(self): + pass + + def step(self, g, z0, J_inv, geval_old, alpha=1, **kwargs): + dz = einsum('...o,...io->...i', geval_old, J_inv) + z = z0 - alpha * dz + geval = g(z) + Δz, Δg = z - z0, geval - geval_old + J_inv = self.update_jacobian(Δg, Δz, J_inv) + return z, dz, geval, J_inv + + def update_jacobian(Δg, Δz, J_inv, **kwargs): + raise NotImplementedError("") + +class BroydenFull(Broyden): + """ + """ + type = 'Quasi-Newton' + + def __init__(self): + super().__init__() + self.ε = 1e-6 + + def update_jacobian(self, Δg, Δz, J_inv, **kwargs): + num = Δz - einsum('...io, ...o -> ...i', J_inv, Δg) + den = einsum('...i, ...io, ...o -> ...', Δz, J_inv, Δg) + self.ε + prod = einsum('...i, ...io -> ...o', Δz, J_inv) + ΔJ_inv = einsum('...i, ...o -> ...io', num / den[..., None], prod) + J_inv = J_inv + ΔJ_inv + return J_inv + +class BroydenBad(Broyden): + """Faster approximate Broyden method + + References + ---------- + [Bro1965] Broyden, C G. *A class of methods for solving nonlinear + simultaneous equations*. Mathematics of computation, 33 (1965), + pp 577--593. + + [Kva1991] Kvaalen, E. *A faster Broyden method*. BIT Numerical + Mathematics 31 (1991), pp 369--372. + """ + type = 'Quasi-Newton' + + def __init__(self): + super().__init__() + self.Δg_tol = 1e-6 + + def update_jacobian(self, Δg, Δz, J_inv, **kwargs): + num = Δz - torch.einsum('...io, ...o -> ...i', J_inv, Δg) + den = torch.sum(Δg**2, dim=1, keepdim=True) + self.Δg_tol + ΔJ_inv = torch.einsum('...i, o... -> ...io', num, Δg.T) / den[..., None] + J_inv = J_inv + ΔJ_inv + return J_inv + +class Newton(): + "Standard Newton iteration" + type = 'Quasi-Newton' + + @staticmethod + def step(g, geval_old, z0, J_inv, alpha=1, **kwargs): + raise NotImplementedError + +class Chord(): + "Standard newton iteration with precomputed J_inv" + type = 'Quasi-Newton' + + @staticmethod + def step(g, geval_old, z0, J_inv, alpha=1, **kwargs): + raise NotImplementedError + + +############################## +### LINE SEARCH ALGORITHMS ### +############################## +def _safe_norm(v): + if not torch.isfinite(v).all(): + return np.inf + return torch.norm(v) + + +class LineSearcher(object): + def __init__(self, g, g0, dz, z0): + self.g, self.g0, self.dz, self.z0 = g, g0, dz, z0 + self.phi0 = _safe_norm(g0)**2 + + def search(self, alpha0=1): + raise NotImplementedError("") + + def phi(self, alpha): + "Objective function for line search min_alpha phi(alpha)" + return _safe_norm(self.g(self.z0 + alpha * self.dz))**2 + + +class NaiveSearch(LineSearcher): + def __init__(self, g, g0, dz, z0): + super().__init__(g, g0, dz, z0) + + def search(self, alpha0=1, min_alpha=1e-6, mult_factor=0.1): + alpha = alpha0 + phi_a0 = self.phi(alpha) + while phi_a0 > self.phi0 and alpha > min_alpha: + alpha = mult_factor*alpha + phi_a0 = self.phi(alpha) + return alpha, phi_a0 + + +class LineSearchArmijo(LineSearcher): + def __init__(self, g, g0, dz, z0): + super().__init__(g, g0, dz, z0) + + def search(self, alpha0=1, c1=1e-4): + """Minimize over alpha, the function ``phi(alpha)``. + Uses the interpolation algorithm (Armijo backtracking) as suggested by + Wright and Nocedal in 'Numerical Optimization', 1999, pp. 56-57 + alpha > 0 is assumed to be a descent direction. + Returns + ------- + alpha + phi1 + """ + phi_a0 = self.phi(alpha0) + derphi0 = -phi_a0 #TODO: check if this is correct <- derphi + + # if the objective function is <, return + if phi_a0 < self.phi0 + c1*alpha0*derphi0: return alpha0, phi_a0 + + # Otherwise, compute the minimizer of a quadratic interpolant + alpha1 = -(derphi0) * alpha0**2 / 2.0 / (phi_a0 - self.phi0 - derphi0 * alpha0) + phi_a1 = self.phi(alpha1) + + # Loop with cubic interpolation until we find an alpha which + # satisfies the first Wolfe condition (since we are backtracking, we will + # assume that the value of alpha is not too small and satisfies the second + # condition. + while alpha1 > 1e-2: # we are assuming alpha>0 is a descent direction + factor = alpha0**2 * alpha1**2 * (alpha1-alpha0) + a = alpha0**2 * (phi_a1 - self.phi0 - derphi0*alpha1) - \ + alpha1**2 * (phi_a0 - self.phi0 - derphi0*alpha0) + a = a / factor + b = -alpha0**3 * (phi_a1 - self.phi0 - derphi0*alpha1) + \ + alpha1**3 * (phi_a0 - self.phi0 - derphi0*alpha0) + b = b / factor + + alpha2 = (-b + torch.sqrt(torch.abs(b**2 - 3 * a * derphi0))) / (3.0*a) + phi_a2 = self.phi(alpha2) + + if (phi_a2 <= self.phi0 + c1*alpha2*derphi0): + return alpha2, phi_a2 + + if (alpha1 - alpha2) > alpha1 / 2.0 or (1 - alpha2/alpha1) < 0.96: + alpha2 = alpha1 / 2.0 + + alpha0 = alpha1 + alpha1 = alpha2 + phi_a0 = phi_a1 + phi_a1 = phi_a2 + + # Failed to find a suitable step length + return 1e-2, phi_a0 + + +class LineSearchGriewank(LineSearcher): + def __init__(self): + super().__init__() + + def search(self): + raise NotImplementedError + + +class LineSearchWolfe1(LineSearcher): + def __init__(self): + super().__init__() + + def search(self): + raise NotImplementedError + + +class LineSearchWolfe2(LineSearcher): + def __init__(self): + super().__init__() + + def search(self): + raise NotImplementedError + +############################## +### TERMINATION CONDITIONS ### +############################## + +class TerminationCondition(object): + """ + Termination condition for an iteration. It is terminated if + - (|F| < f_rtol*|F_0| && |F| < f_tol) && (|dx| < x_rtol*|x| && |dx| < x_tol) + """ + + def __init__(self, f_tol=1e-2, f_rtol=1e-1, x_tol=1e-2, x_rtol=1, iter=None): + + self.x_tol, self.x_rtol = x_tol, x_rtol + self.f_tol, self.f_rtol = f_tol, f_rtol + self.norm = norm + self.iter = iter + self.f0_norm = None + self.iteration = 0 + + def check(self, geval, z, dz): + self.iteration += 1 + g_norm, z_norm, dz_norm = norm(geval, p=2, dim=1), norm(z, p=2, dim=1), norm(dz, p=2, dim=1) + + if self.f0_norm is None: self.f0_norm = g_norm + + cond1 = (g_norm <= self.f_tol).all() and (g_norm / self.f_rtol <= self.f0_norm).all() + cond2 = (dz_norm <= self.x_tol).all() and (dz_norm / self.x_rtol <= z_norm).all() + if cond1 and cond2: return 2 + + return 0 + +################# +### ROOT FIND ### +################# + +SEARCH_METHODS = {'naive': NaiveSearch, 'armijo': LineSearchArmijo, 'none': None} +ROOT_SOLVER_DICT = {'broyden_fast': BroydenBad(), 'broyden': BroydenFull(), 'newton': Newton, 'chord': Chord} + +RETURN_CODES = {1: 'convergence', + 2: 'total condition'} + +def batch_jacobian(func, x): + "Batch Jacobian for 2D inputs of dimensions `bs, dims`" + return torch.stack([jacobian(func, el) for el in x], 0) + +def root_find(g, z, alpha=0.1, f_tol=1e-2, f_rtol=1e-1, x_tol=1e-2, x_rtol=1, maxiters=100, + method='broyden', search_method='naive', verbose=True): + assert method in ROOT_SOLVER_DICT, f'{method} not supported' + assert search_method in SEARCH_METHODS, f'{search_method} not supported' + + tc = TerminationCondition(f_tol=f_tol, f_rtol=f_rtol, x_tol=x_tol, x_rtol=x_rtol) + logger = RootLogger() + solver = ROOT_SOLVER_DICT[method] + + # first evaluation of g(z) + geval = g(z) + + # initialize inverse jacobian J^-1g(z) + J_inv = batch_jacobian(g, z).pinverse() + + iteration = 0 + while iteration <= maxiters: + iteration += 1 + + # solver step + z, dz, geval, J_inv = solver.step(g=g, z0=z, J_inv=J_inv, geval_old=geval, alpha=alpha) + + # line search subroutines + if SEARCH_METHODS[search_method] is not None: + line_searcher = SEARCH_METHODS[search_method](g, geval, dz, z) + alpha, phi = line_searcher.search() + + # logging + if verbose and logger: + logger.log({'geval': geval, + 'z': z, + 'dz': dz, + 'iteration': iteration}) + + # full termination check + code = tc.check(geval, z, dz) + if code > 0: + if verbose and logger: + logger.log({'termination_condition': RETURN_CODES[code]}) + break + + return z, logger \ No newline at end of file diff --git a/torchdyn/numerics/solvers/sde.py b/torchdyn/numerics/solvers/sde.py new file mode 100644 index 0000000..6d594fd --- /dev/null +++ b/torchdyn/numerics/solvers/sde.py @@ -0,0 +1,5 @@ +from torchdyn.numerics.solvers.templates import DiffEqSolver, MultipleShootingDiffeqSolver + +class EulerMaruyama(DiffEqSolver): + def __init__(self): + raise NotImplementedError \ No newline at end of file diff --git a/torchdyn/numerics/solvers/templates.py b/torchdyn/numerics/solvers/templates.py new file mode 100644 index 0000000..f7e5e36 --- /dev/null +++ b/torchdyn/numerics/solvers/templates.py @@ -0,0 +1,82 @@ +import torch +import torch.nn as nn + +# TODO: work around circular imports +# multiple shooting solvers are "composite": they use +# several "base" solvers and this are further down in the dependency chain +# likely solution: place composite solver templates in a different file + +# from torchdyn.numerics.solvers.ode import str_to_solver, str_to_ms_solver + + +class DiffEqSolver(nn.Module): + def __init__( + self, + order, + stepping_class:str="fixed", + min_factor:float=0.2, + max_factor:float=10, + safety:float=0.9 + ): + + super(DiffEqSolver, self).__init__() + self.order = order + self.min_factor = torch.tensor([min_factor]) + self.max_factor = torch.tensor([max_factor]) + self.safety = torch.tensor([safety]) + self.tableau = None + self.stepping_class = stepping_class + + def sync_device_dtype(self, x, t_span): + "Ensures `x`, `t_span`, `tableau` and other solver tensors are on the same device with compatible dtypes" + device = x.device + if self.tableau is not None: + c, a, bsol, berr = self.tableau + self.tableau = c.to(x), [a.to(x) for a in a], bsol.to(x), berr.to(x) + t_span = t_span.to(device) + self.safety = self.safety.to(device) + self.min_factor = self.min_factor.to(device) + self.max_factor = self.max_factor.to(device) + return x, t_span + + def step(self, f, x, t, dt, k1=None, args=None): + raise NotImplementedError("Stepping rule not implemented for the solver") + + +class BaseExplicit(DiffEqSolver): + def __init__(self, *args, **kwargs): + """Base template for an explicit differential equation solver + """ + super(BaseExplicit, DiffEqSolver).__init__(*args, **kwargs) + assert self.stepping_class in ["fixed", "adaptive"] + + + +class BaseImplicit(DiffEqSolver): + def __init__(self, *args, **kwargs): + """Base template for an implicit differential equation solver + """ + super(BaseImplicit, DiffEqSolver).__init__(*args, **kwargs) + assert self.stepping_class in ["fixed", "adaptive"] + + @staticmethod + def _residual(f, x, t, dt, x_sol): + raise NotImplementedError + + +class MultipleShootingDiffeqSolver(nn.Module): + def __init__(self, coarse_method, fine_method): + from torchdyn.numerics.solvers.ode import str_to_solver + + super(MultipleShootingDiffeqSolver, self).__init__() + if type(coarse_method) == str: self.coarse_method = str_to_solver(coarse_method) + if type(fine_method) == str: self.fine_method = str_to_solver(fine_method) + + def sync_device_dtype(self, x, t_span): + "Ensures `x`, `t_span`, `tableau` and other solver tensors are on the same device with compatible dtypes" + x, t_span = self.coarse_method.sync_device_dtype(x, t_span) + x, t_span = self.fine_method.sync_device_dtype(x, t_span) + return x, t_span + + def root_solve(self, odeint_func, f, x, t_span, B, fine_steps, maxiter): + raise NotImplementedError \ No newline at end of file diff --git a/torchdyn/numerics/systems.py b/torchdyn/numerics/systems.py new file mode 100644 index 0000000..14e2505 --- /dev/null +++ b/torchdyn/numerics/systems.py @@ -0,0 +1,245 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.distributions +from torch.distributions import Uniform + + +class Lorenz(nn.Module): + def __init__(self): + super().__init__() + self.p = nn.Linear(1, 1) + + def forward(self, t, x, **kwargs): + x1, x2, x3 = x[..., :1], x[..., 1:2], x[..., 2:] + dx1 = 10 * (x2 - x1) + dx2 = x1 * (28 - x3) - x2 + dx3 = x1 * x2 - 8 / 3 * x3 + return torch.cat([dx1, dx2, dx3], -1) + + +class VanDerPol(nn.Module): + def __init__(self, alpha=10): + super().__init__() + self.alpha = alpha + self.nfe = 0 + + def forward(self, t, x, **kwargs): + self.nfe += 1 + x1, x2 = x[..., :1], x[..., 1:2] + return torch.cat([x2, self.alpha * (1 - x1 ** 2) * x2 - x1], -1) + + +class ODEProblem2(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, s, z): + return 0.5 * z + + +class ODEProblem3(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, s, z): + return -0.1 * z + + +class ODEProblem4(nn.Module): + "Rabinovich-Fabrikant" + + def __init__(self): + super().__init__() + + def forward(self, s, z): + x1, x2, x3 = z[..., :1], z[..., 1:2], z[..., -1:] + dx1 = x2 * (x3 - 1 + x1 ** 2) + 0.87 * x1 + dx2 = x1 * (3 * x3 + 1 - x1 ** 2) + 0.87 * x2 + dx3 = -2 * x3 * (1.1 + x1 * x2) + return torch.cat([dx1, dx2, dx3], -1) + + +class SineSystem(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, s, z): + s = s * torch.ones_like(z) + return torch.sin(s) + + +class LTISystem(nn.Module): + def __init__(self, dim=2, randomizable=True): + super().__init__() + self.dim = dim + self.randomizable = randomizable + self.l = nn.Linear(dim, dim) + + def forward(self, s, x): + return self.l(x) + + def randomize_parameters(self): + self.l = nn.Linear(self.dim, self.dim) + + +class FourierSystem(nn.Module): + def __init__(self, + dim=2, + A_dist=Uniform(-10, 10), + phi_dist=Uniform(-1, 1), + w_dist=Uniform(-20, 20), + randomizable=True + ): + + super().__init__() + self.n_harmonics = n_harmonics = torch.randint(2, 20, size=(1,)) + self.A_dist = A_dist; + self.A = A_dist.sample(torch.Size([dim, n_harmonics, 2])) + self.phi_dist = phi_dist; + self.phi = phi_dist.sample(torch.Size([dim, n_harmonics, 2])) + self.w_dist = w_dist; + self.w = w_dist.sample(torch.Size([dim, n_harmonics, 2])) + self.dim = dim + self.randomizable = randomizable + + def forward(self, s, x): + if len(s.shape) == 0: + return (self.A[:, :, 0] * torch.cos(self.w[:, :, 0] * s + self.phi[:, :, 0]) + + self.A[:, :, 1] * torch.cos(self.w[:, :, 1] * s + self.phi[:, :, 1])).sum(1)[None, :] + else: + sol = [] + for s_ in s: + sol += [(self.A[:, :, 0] * torch.cos(self.w[:, :, 0] * s_ + self.phi[:, :, 0]) + + self.A[:, :, 1] * torch.cos(self.w[:, :, 1] * s_ + self.phi[:, :, 1])).sum(1)[None, :]] + return torch.cat(sol) + + def randomize_parameters(self): + self.A = self.A_dist.sample(torch.Size([self.dim, self.n_harmonics, 2])) + self.phi = self.phi_dist.sample(torch.Size([self.dim, self.n_harmonics, 2])) + self.w = self.w_dist.sample(torch.Size([self.dim, self.n_harmonics, 2])) + + +class StiffFourierSystem(nn.Module): + def __init__(self, + dim=2, + A_dist=Uniform(-10, 10), + phi_dist=Uniform(-1, 1), + w_dist=Uniform(-20, 20), + randomizable=True + ): + + super().__init__() + self.n_harmonics = n_harmonics = torch.randint(20, 100, size=(1,)) + self.A_dist = A_dist; + self.A = A_dist.sample(torch.Size([dim, n_harmonics, 2])) + self.phi_dist = phi_dist; + self.phi = phi_dist.sample(torch.Size([dim, n_harmonics, 2])) + self.w_dist = w_dist; + self.w = w_dist.sample(torch.Size([dim, n_harmonics, 2])) + self.dim = dim + self.randomizable = randomizable + + def forward(self, s, x): + if len(s.shape) == 0: + return (self.A[:, :, 0] * torch.cos(self.w[:, :, 0] * s + self.phi[:, :, 0]) + + self.A[:, :, 1] * torch.cos(self.w[:, :, 1] * s + self.phi[:, :, 1])).sum(1)[None, :] + else: + sol = [] + for s_ in s: + sol += [(self.A[:, :, 0] * torch.cos(self.w[:, :, 0] * s_ + self.phi[:, :, 0]) + + self.A[:, :, 1] * torch.cos(self.w[:, :, 1] * s_ + self.phi[:, :, 1])).sum(1)[None, :]] + return torch.cat(sol) + + def randomize_parameters(self): + self.A = self.A_dist.sample(torch.Size([self.dim, self.n_harmonics, 2])) + self.phi = self.phi_dist.sample(torch.Size([self.dim, self.n_harmonics, 2])) + self.w = self.w_dist.sample(torch.Size([self.dim, self.n_harmonics, 2])) + + +class CoupledFourierSystem(nn.Module): + def __init__(self, + dim=2, + A_dist=Uniform(-10, 10), + phi_dist=Uniform(-1, 1), + w_dist=Uniform(-20, 20), + randomizable=True + ): + + super().__init__() + self.n_harmonics = n_harmonics = torch.randint(2, 20, size=(1,)) + self.A_dist = A_dist; + self.A = A_dist.sample(torch.Size([dim, n_harmonics, 2])) + self.phi_dist = phi_dist; + self.phi = phi_dist.sample(torch.Size([dim, n_harmonics, 2])) + self.w_dist = w_dist; + self.w = w_dist.sample(torch.Size([dim, n_harmonics, 2])) + self.dim = dim + self.randomizable = randomizable + self.mixing_l = nn.Linear(dim, dim) + + def forward(self, s, x): + if len(s.shape) == 0: + pre_sol = (self.A[:, :, 0] * torch.cos(self.w[:, :, 0] * s + self.phi[:, :, 0]) + + self.A[:, :, 1] * torch.cos(self.w[:, :, 1] * s + self.phi[:, :, 1])).sum(1)[None, :] + return self.mixing_l(pre_sol) + + else: + sol = [] + for s_ in s: + sol += [(self.A[:, :, 0] * torch.cos(self.w[:, :, 0] * s_ + self.phi[:, :, 0]) + + self.A[:, :, 1] * torch.cos(self.w[:, :, 1] * s_ + self.phi[:, :, 1])).sum(1)[None, None, :]] + return self.mixing_l(torch.cat(sol, 0))[:, 0, :] + + def randomize_parameters(self): + self.A = self.A_dist.sample(torch.Size([self.dim, self.n_harmonics, 2])) + self.phi = self.phi_dist.sample(torch.Size([self.dim, self.n_harmonics, 2])) + self.w = self.w_dist.sample(torch.Size([self.dim, self.n_harmonics, 2])) + self.mixing_l = nn.Linear(self.dim, self.dim) + + +class MatMulSystem(nn.Module): + def __init__(self, dim=2, activation=nn.Tanh(), layers=5, hdim=32, randomizable=True): + super().__init__() + self.dim = dim + self.net = nn.Sequential(nn.Sequential(nn.Linear(dim, hdim), nn.Tanh(), + *[nn.Sequential(nn.Linear(hdim, hdim), nn.Tanh()) for i in range(4)], + nn.Linear(hdim, dim))) + self.randomizable = randomizable + + def forward(self, s, x): + return self.net(x) + + def randomize_parameters(self): + for p in self.net.parameters(): + torch.nn.init.normal_(p, 0, 1) + + +class MatMulBoundedSystem(nn.Module): + def __init__(self, dim=2, activation=nn.Tanh(), layers=5, hdim=32, randomizable=True): + super().__init__() + self.dim = dim + self.net = nn.Sequential(nn.Sequential(nn.Linear(dim, hdim), nn.Tanh(), + *[nn.Sequential(nn.Linear(hdim, hdim), nn.Tanh()) for i in range(4)], + nn.Linear(hdim, dim), + nn.Tanh())) + self.randomizable = randomizable + + def forward(self, s, x): + return self.net(x) + + def randomize_parameters(self): + for p in self.net.parameters(): + torch.nn.init.normal_(p, 0, 1) + diff --git a/torchdyn/numerics/utils.py b/torchdyn/numerics/utils.py new file mode 100644 index 0000000..c8f7c3e --- /dev/null +++ b/torchdyn/numerics/utils.py @@ -0,0 +1,129 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" + Contains various utilities for `odeint` and numerical methods. Various norms, step size initialization, event callbacks for hybrid systems, vmapped matrix-Jacobian products and some + additional goodies. +""" +import attr +import torch +import torch.nn as nn +from torch.distributions import Exponential +from torchcde import CubicSpline, hermite_cubic_coefficients_with_backward_differences + + +def make_norm(state): + state_size = state.numel() + def norm_(aug_state): + y = aug_state[1:1 + state_size] + adj_y = aug_state[1 + state_size:1 + 2 * state_size] + return max(hairer_norm(y), hairer_norm(adj_y)) + return norm_ + + +def hairer_norm(tensor): + return tensor.abs().pow(2).mean().sqrt() + + +def init_step(f, f0, x0, t0, order, atol, rtol): + scale = atol + torch.abs(x0) * rtol + d0, d1 = hairer_norm(x0 / scale), hairer_norm(f0 / scale) + + if d0 < 1e-5 or d1 < 1e-5: + h0 = torch.tensor(1e-6, dtype=t0.dtype, device=t0.device) + else: + h0 = 0.01 * d0 / d1 + + x_new = x0 + h0 * f0 + f_new = f(t0 + h0, x_new) + d2 = hairer_norm((f_new - f0) / scale) / h0 + if d1 <= 1e-15 and d2 <= 1e-15: + h1 = torch.max(torch.tensor(1e-6, dtype=t0.dtype, device=t0.device), h0 * 1e-3) + else: + h1 = (0.01 / max(d1, d2)) ** (1. / float(order + 1)) + dt = torch.min(100 * h0, h1).to(t0) + return dt + + +@torch.no_grad() +def adapt_step(dt, error_ratio, safety, min_factor, max_factor, order): + if error_ratio == 0: return dt * max_factor + if error_ratio < 1: min_factor = torch.ones_like(dt) + exponent = torch.tensor(order, dtype=dt.dtype, device=dt.device).reciprocal() + factor = torch.min(max_factor, torch.max(safety / error_ratio ** exponent, min_factor)) + return dt * factor + + +def dense_output(sol, t_sol, t_eval, return_spline=False): + t_sol = t_sol.to(sol) + spline_coeff = hermite_cubic_coefficients_with_backward_differences(t_sol, sol.permute(1, 0, 2)) + sol_spline = CubicSpline(t_sol, spline_coeff) + sol_eval = torch.stack([sol_spline.evaluate(t) for t in t_eval]) + if return_spline: + return sol_eval, sol_spline + return sol_eval + + +class EventState: + def __init__(self, evid): + self.evid = evid + + def __ne__(self, other): + return sum([a_ != b_ for a_, b_ in zip(self.evid, other.evid)]) + + +@attr.s +class EventCallback(nn.Module): + "Basic callback for hybrid differential equations. Must define an event condition and a state-jump" + def __attrs_post_init__(self): + super().__init__() + + def check_event(self, t, x): + raise NotImplementedError + + def jump_map(self, t, x): + raise NotImplementedError + + +@attr.s +class StochasticEventCallback(nn.Module): + def __attrs_post_init__(self): + super().__init__() + self.expdist = Exponential(1) + + def initialize(self, x0): + self.s = self.expdist.sample(x0.shape[:1]) + + def check_event(self, t, x): + raise NotImplementedError + + def jump_map(self, t, x): + raise NotImplementedError + +class RootLogger(object): + def __init__(self): + self.data = {'geval': [], 'z': [], 'dz': [], 'iteration': [], 'alpha': [], 'phi': []} + + def log(self, logged_data): + self.data.update(**logged_data) + + def permanent_log(self, logged_data): + for key in self.data.keys(): + self.data.update({key: list(self.data[key] + logged_data[key])}) + + +class WrapFunc(nn.Module): + def __init__(self, f): + super().__init__() + self.f = f + def forward(self, t, x): return self.f(x) + diff --git a/torchdyn/utils.py b/torchdyn/utils.py new file mode 100644 index 0000000..2e44ee2 --- /dev/null +++ b/torchdyn/utils.py @@ -0,0 +1,199 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +General plotting utilities. These are used in tutorials and are designed for narrow uses. +""" +import matplotlib as mpl +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.nn as nn +from mpl_toolkits.mplot3d import Axes3D + + +def plot_2d_boundary(model, X, y, mesh, num_classes=2, figsize=(8,4), alpha=0.8): + "Plots decision boundary of a 2-dimensional task" + preds = torch.argmax(nn.Softmax(1)(model(mesh)), dim=1) + preds = preds.detach().cpu().reshape(mesh.size(0), mesh.size(1)) + plt.figure(figsize=figsize) + plt.contourf(torch.linspace(0, mesh.size(0), mesh.size(0)), torch.linspace(0, mesh.size(1), mesh.size(1)), + preds, cmap='winter', alpha=alpha, levels=10) + for i in range(num_classes): + plt.scatter(X[y==i,0], X[y==i,1], alpha=alpha) + + +def plot_2d_flows(trajectory, num_flows=2, figsize=(8,4), alpha=0.8): + "Plots datasets flows learned by a neural differential equation." + plt.figure(figsize=figsize) + plt.subplot(121) + plt.title('Dimension: 0') + for i in range(num_flows): + plt.plot(trajectory[:,i,0], color='red', alpha=alpha) + plt.subplot(122) + plt.title('Dimension: 1') + for i in range(num_flows): + plt.plot(trajectory[:,i,1], color='blue', alpha=alpha) + + +defaults_1D = {'n_grid':100, 'n_levels':30, 'x_span':[-1,1], + 'contour_alpha':0.7, 'cmap':'winter', + 'traj_color':'orange', 'traj_alpha':0.1, + 'device':'cuda:0'} + +def plot_traj_vf_1D(model, s_span, traj, device, x_span, n_grid, + n_levels=30, contour_alpha=0.7, cmap='winter', traj_color='orange', traj_alpha=0.1): + "Plots 1D datasets flows." + ss = torch.linspace(s_span[0], s_span[-1], n_grid) + xx = torch.linspace(x_span[0], x_span[-1], n_grid) + + S, X = torch.meshgrid(ss,xx) + + if model.controlled: + ax = st['ax'] + u_traj = traj[0,:,0].repeat(traj.shape[1],1) + e = torch.abs(st['y'].T - traj[:,:,0]) + color = plt.cm.coolwarm(e) + for i in range(traj.shape[1]): + tr = ax.scatter(s_span, u_traj[:,i],traj[:,i,0], + c=color[:,i],alpha=1, cmap=color[:,i],zdir='z') + norm = mpl.colors.Normalize(e.min(),e.max()) + plt.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap='coolwarm'), + label='Approximation Error', orientation='horizontal') + ax.set_xlabel(r"$s$ [depth]") + ax.set_ylabel(r"$u$") + ax.set_zlabel(r"$h(s)$") + # make the panes transparent + ax.xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0)) + ax.yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0)) + ax.zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0)) + # make the grid lines transparent + ax.xaxis._axinfo["grid"]['color'] = (1,1,1,0) + ax.yaxis._axinfo["grid"]['color'] = (1,1,1,0) + ax.zaxis._axinfo["grid"]['color'] = (1,1,1,0) + + + else: + U, V = torch.ones(n_grid, n_grid), torch.zeros(n_grid, n_grid) + for i in range(n_grid): + for j in range(n_grid): + V[i,j] = model.vf( + S[i,j].reshape(1,-1).to(device), + X[i,j].reshape(1,-1).to(device) + ).detach().cpu() + F = torch.sqrt(U**2 + V**2) + + plt.contourf(S,X,F,n_levels,cmap=cmap,alpha=contour_alpha) + plt.streamplot(S.T.numpy(),X.T.numpy(), + U.T.numpy(),V.T.numpy(), + color='black',linewidth=1) + if not traj==None: + plt.plot(s_span, traj[:,:,0], + color=traj_color,alpha=traj_alpha) + + plt.xlabel(r"$s$ [Depth]") + plt.ylabel(r"$h(s)$") + + return (S, X, U, V) + +def plot_2D_depth_trajectory(s_span, trajectory, yn, n_samples=128): + "Plots 2-dimensional trajectories of points." + color=['orange', 'blue'] + + fig = plt.figure(figsize=(8,2)) + ax0 = fig.add_subplot(121) + ax1 = fig.add_subplot(122) + for i in range(n_samples): + ax0.plot(s_span, trajectory[:,i,0], color=color[int(yn[i])], alpha=.1) + ax1.plot(s_span, trajectory[:,i,1], color=color[int(yn[i])], alpha=.1) + + ax0.set_xlabel(r"$s$ [Depth]") + ax0.set_ylabel(r"$h_0(s)$") + ax0.set_title("Dimension 0") + ax1.set_xlabel(r"$s$ [Depth]") + ax1.set_ylabel(r"$h_1(s)$") + ax1.set_title("Dimension 1") + + +def plot_2D_state_space(trajectory, yn, n_samples=128): + "Plots state-space trajectories." + color=['orange', 'blue'] + + fig = plt.figure(figsize=(3,3)) + ax = fig.add_subplot(111) + for i in range(n_samples): + ax.plot(trajectory[:,i,0], trajectory[:,i,1], color=color[int(yn[i])], alpha=.1); + + ax.set_xlabel(r"$h_0$") + ax.set_ylabel(r"$h_1$") + ax.set_title("Flows in the state-space") + + +def plot_2D_space_depth(s_span, trajectory, yn, n_lines): + "Plots 2D trajectories in a 3D space (2 dimensions of the system + time)." + colors = ['orange', 'blue'] + fig = plt.figure(figsize=(6,3)) + ax = Axes3D(fig, auto_add_to_figure=False) + fig.add_axes(ax) + for i in range(n_lines): + ax.plot(s_span, trajectory[:,i,0], trajectory[:,i,1], color=colors[yn[i].int()], alpha = .1) + ax.view_init(30, -110) + + ax.set_xlabel(r"$s$ [Depth]") + ax.set_ylabel(r"$h_0$") + ax.set_zlabel(r"$h_1$") + ax.set_title("Flows in the space-depth") + ax.xaxis._axinfo["grid"]['color'] = (1,1,1,0) + ax.yaxis._axinfo["grid"]['color'] = (1,1,1,0) + ax.zaxis._axinfo["grid"]['color'] = (1,1,1,0) + + +def plot_static_vector_field(model, trajectory, t=0., N=50, device='cuda'): + "Plots vector field and trajectories on it." + x = torch.linspace(trajectory[:,:,0].min(), trajectory[:,:,0].max(), N) + y = torch.linspace(trajectory[:,:,1].min(), trajectory[:,:,1].max(), N) + X, Y = torch.meshgrid(x,y) + U, V = torch.zeros(N,N), torch.zeros(N,N) + + for i in range(N): + for j in range(N): + p = torch.cat([X[i,j].reshape(1,1), Y[i,j].reshape(1,1)],1).to(device) + O = model.defunc(t,p).detach().cpu() + U[i,j], V[i,j] = O[0,0], O[0,1] + + fig = plt.figure(figsize=(3,3)) + ax = fig.add_subplot(111) + ax.contourf(X, Y, torch.sqrt(U**2 + V**2), cmap='RdYlBu') + ax.streamplot(X.T.numpy(),Y.T.numpy(),U.T.numpy(),V.T.numpy(), color='k') + + ax.set_xlim([x.min(),x.max()]) + ax.set_ylim([y.min(),y.max()]) + ax.set_xlabel(r"$h_0$") + ax.set_ylabel(r"$h_1$") + ax.set_title("Learned Vector Field") + + +def plot_3D_dataset(X, yn): + "Plots set of points in 3D." + colors = ['orange', 'blue'] + fig = plt.figure(figsize=(4,4)) + ax = Axes3D(fig) + for i in range(len(X)): + ax.scatter(X[:,0],X[:,1],X[:,2], color=colors[yn[i].int()], alpha = .1) + ax.set_xlabel(r"$h_0$") + ax.set_ylabel(r"$h_1$") + ax.set_zlabel(r"$h_2$") + ax.set_title("Data Points") + ax.xaxis._axinfo["grid"]['color'] = (1,1,1,0) + ax.yaxis._axinfo["grid"]['color'] = (1,1,1,0) + ax.zaxis._axinfo["grid"]['color'] = (1,1,1,0) + diff --git a/train.py b/train.py index 2f5093c..8047526 100644 --- a/train.py +++ b/train.py @@ -185,8 +185,8 @@ def run(rank, n_gpus, hps): fm_losses.append(fm_loss.item()) if batch_idx % 5 == 0: - msg = (f'Epoch: {epoch}, iteration: {iteration} | dur_loss: {dur_loss.item():.4f}, prior_loss: {prior_loss.item():.4f}, ' - f'flow_loss: {fm_loss.item():.4f}, mle loss: {l_mle.item():.4f}') + msg = (f'Epoch: {epoch}, iter: {iteration} | dur_loss: {dur_loss.item():.3f}, prior_loss: {prior_loss.item():.3f}, ' + f'flow_loss: {fm_loss.item():.3f}, mle loss: {l_mle.item():.3f}') # logger_text.info(msg) progress_bar.set_description(msg)