Skip to content

Commit

Permalink
parameter groups and lr scheduling support
Browse files Browse the repository at this point in the history
  • Loading branch information
inikishev committed Jan 14, 2025
1 parent 827e3f1 commit 48c071b
Show file tree
Hide file tree
Showing 27 changed files with 751 additions and 142 deletions.
35 changes: 24 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,20 +1,25 @@
![example workflow](https://github.com/inikishev/torchzero/actions/workflows/tests.yml/badge.svg)

# torchzero

This is a work-in-progress optimizers library for pytorch with composable zeroth, first, second order and quasi newton methods, gradient approximation, line searches and a whole lot of other stuff.

Most optimizers are modular, meaning you can chain them like this:

```py
optimizer = torchzero.optim.Modular(model.parameters(), [*list of modules*])`
```

For example you might use `[ClipNorm(4), LR(1e-3), NesterovMomentum(0.9)]` for standard SGD with gradient clipping and nesterov momentum. Move `ClipNorm` to the end to clip the update instead of the gradients. If you don't have access to gradients, add a `RandomizedFDM()` at the beginning to approximate them via randomized finite differences. Add `Cautious()` to make the optimizer cautious.

Each new module takes previous module update and works on it. That way there is no need to reimplement stuff like laplacian smoothing for all optimizers, and it is easy to experiment with grafting, interpolation between different optimizers, and perhaps some weirder combinations like nested momentum.
Each new module takes previous module update and works on it. That way there is no need to reimplement stuff like laplacian smoothing for all optimizers, and it is easy to experiment with grafting, interpolation between different optimizers, and perhaps some weirder combinations like nested momentum.

# How to use

All modules are defined in `torchzero.modules`. You can generally mix and match them however you want. Some pre-made optimizers are available in `torchzero.optim`.

Some optimizers require closure, which should look like this:

```py
def closure(backward = True):
preds = model(inputs)
Expand All @@ -30,20 +35,23 @@ def closure(backward = True):

optimizer.step(closure)
```

This closure will also work with all built in pytorch optimizers, including LBFGS, all optimizers in this library, as well as most custom ones.

# Contents
There will be docs with a more exhaustive list and explanations. A preliminary list of all modules is available here https://torchzero.readthedocs.io/en/latest/autoapi/torchzero/modules/index.html#classes. For now I hope that everything should be reasonably straightforward to use.

There will be docs with a more exhaustive list and explanations. A preliminary list of all modules is available here <https://torchzero.readthedocs.io/en/latest/autoapi/torchzero/modules/index.html#classes>. For now I hope that everything should be reasonably straightforward to use.

- SGD/Rprop/RMSProp/AdaGrad/Adam as composable modules. They are also tested to exactly match built in pytorch versions.
- Cautious Optimizers (https://huggingface.co/papers/2411.16085)
- Optimizer grafting (https://openreview.net/forum?id=FpKgG31Z_i9)
- Laplacian smoothing (https://arxiv.org/abs/1806.06317)
- Cautious Optimizers (<https://huggingface.co/papers/2411.16085>)
- Optimizer grafting (<https://openreview.net/forum?id=FpKgG31Z_i9>)
- Laplacian smoothing (<https://arxiv.org/abs/1806.06317>)
- Polyak momentum, nesterov momentum
- Gradient norm and value clipping, gradient normalization
- Gradient centralization (https://arxiv.org/abs/2004.01461)
- Learning rate droput (https://pubmed.ncbi.nlm.nih.gov/35286266/).
- Forward gradients (https://arxiv.org/abs/2202.08587)
- Gradient approximation via finite difference or randomized finite difference, which includes SPSA, RDSA, FDSA and Gaussian smoothing (https://arxiv.org/abs/2211.13566v3)
- Gradient centralization (<https://arxiv.org/abs/2004.01461>)
- Learning rate droput (<https://pubmed.ncbi.nlm.nih.gov/35286266/>).
- Forward gradient (<https://arxiv.org/abs/2202.08587>)
- Gradient approximation via finite difference or randomized finite difference, which includes SPSA, RDSA, FDSA and Gaussian smoothing (<https://arxiv.org/abs/2211.13566v3>)
- Various line searches
- Exact Newton's method (with Levenberg-Marquardt regularization), newton with hessian approximation via finite difference, subspace finite differences newton.
- Directional newton via one additional forward pass
Expand All @@ -52,18 +60,23 @@ All modules should be quite fast, especially on models with many different param

I am getting to the point where I can start focusing on good docs and tests. As of now, the code should be considered experimental, untested and subject to change, so feel free but be careful if using this for actual project.


# Wrappers

### scipy.optimize.minimize wrapper

scipy.optimize.minimize wrapper with support for both gradient and hessian via batched autograd

```py
from torchzero.optim.wrappers.scipy import ScipyMinimize
opt = ScipyMinimize(model.parameters(), method = 'trust-krylov')
```
Use as any other optimizer (make sure closure accepts `backward` argument like one from **How to use**). Note that it performs full minimization on each step.

Use as any other optimizer (make sure closure accepts `backward` argument like one from **How to use**). Note that it performs full minimization on each step.

### Nevergrad wrapper

```py
opt = NevergradOptimizer(bench.parameters(), ng.optimizers.NGOptBase, budget = 1000)
```

Use as any other optimizer (make sure closure accepts `backward` argument like one from **How to use**).
3 changes: 2 additions & 1 deletion src/torchzero/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from . import tensorlist as tl # this needs to be imported first to avoid circular imports
from . import optim, modules as m
from . import optim, modules as m, core
from .optim import Modular
2 changes: 1 addition & 1 deletion src/torchzero/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .module import OptimizerModule, _get_loss, _ClosureType, OptimizationState, Chain, _Chainable
from .module import OptimizerModule, _get_loss, _ClosureType, OptimizationState, _Chain, _Chainable
from .tensorlist_optimizer import TensorListOptimizer, ParamsT
58 changes: 49 additions & 9 deletions src/torchzero/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,20 @@ def closure(backward = True):
This closure will also work with all built in pytorch optimizers including LBFGS, as well as and most custom ones.
"""

class _WeakDict(dict): pass

def _get_param_groups_to_pass_to_child(optimizer: torch.optim.Optimizer):
"""propagate only per-parameter settings that are not in optimizer.defaults"""
param_groups: list[dict[str, Any]] = []
for g in optimizer.param_groups.copy():
child_g = {"params": g["params"]}
for k,v in g.copy().items():
if k not in optimizer.defaults:
child_g[k] = v
param_groups.append(child_g)

return param_groups

class OptimizationState:
"""Holds optimization state. This is usually automatically created by :any:`torchzero.optim.Modular`."""
def __init__(self, closure: _ClosureType | None, model: torch.nn.Module | None):
Expand Down Expand Up @@ -59,12 +73,15 @@ def __init__(self, closure: _ClosureType | None, model: torch.nn.Module | None):
self.model = model
"""Model (for higher order derivatives)"""

self.post_step_hooks = []
"""callables that get executed after each step. Used by periodic SWA to reset momentum when setting model parameters to SWA."""

def maybe_compute_grad_(self, params: TensorList) -> TensorList:
"""Computes gradient if it hasn't been computed already, and returns it"""
if self.grad is None:

if self.closure is not None:
with torch.enable_grad(): self.fx0 = self.closure(True) # pylint:disable = not-callable (???)
with torch.enable_grad(): self.fx0 = self.closure() # pylint:disable = not-callable (???)
self.grad = params.ensure_grad_().grad

return self.grad
Expand Down Expand Up @@ -115,6 +132,13 @@ def update_attrs_(self, state: "OptimizationState"):
if state.fx0_approx is not None: self.fx0_approx = state.fx0_approx


def add_post_step_hook(self, hook: Callable):
"""add a hook that runs after each step. The hook should look like this:
.. code:: py
def hook(optimizer: tz.optim.Modular, state: tz.core.OptimizationState): ...
"""
self.post_step_hooks.append(hook)

class OptimizerModule(TensorListOptimizer, ABC):
r"""Base class for all modules.
Expand Down Expand Up @@ -154,6 +178,7 @@ def _initialize_(self, params: ParamsT):
"""Initializes this optimizer and all children with the given parameters."""
if isinstance(params, torch.Tensor): raise ValueError("Params must be an iterable of tensors, not torch.Tensor")
params = list(params) # type:ignore

# super().__init__, which is torch.optim.Optimizer.__init__,
# calls self.add_param_group on each param group,
# which in turn calls _update_child_params_,
Expand All @@ -163,7 +188,7 @@ def _initialize_(self, params: ParamsT):

def _set_child_(self, name, child: "_Chainable"):
"""Set a child and initialize it's params."""
if not isinstance(child, OptimizerModule): child = Chain(child)
if not isinstance(child, OptimizerModule): child = _Chain(child)
self.children[name] = child
if self._initialized:
self._update_child_params_(child)
Expand All @@ -185,20 +210,23 @@ def _update_next_module_params_(self, next_module: "OptimizerModule"):

# if child is not initialized, torch.optim.Optimizer.__init__ is called on it by _initialize_ method
if not next_module._initialized:
next_module._initialize_(self._params)

# propagate only per-parameter settings that are not in self.defaults
next_module._initialize_(_get_param_groups_to_pass_to_child(self))

# otherwise to avoid calling __init__ multiple twice, we erase the param groups and readd them
elif not next_module._has_custom_params:
next_module.param_groups = []
for group in self.param_groups:
# it is important not to propagate all the settings
# for example if this module has `lr` setting, and the child has a different `lr` setting,
# we don't want to overwrite the child's `lr` setting
next_module.add_param_group({"params": group["params"]})
# it is important not to propagate all the settings
# for example if this module has `lr` setting, and the child has a different `lr` setting,
# we don't want to overwrite the child's `lr` setting
for group in _get_param_groups_to_pass_to_child(self):
next_module.add_param_group(group)


def add_param_group(self, param_group: dict[str, Any]) -> None:
super().add_param_group(param_group)

if self.next_module is not None: self._update_next_module_params_(self.next_module)
for c in self.children.values():
self._update_child_params_(c)
Expand Down Expand Up @@ -304,6 +332,18 @@ def return_ascent(self, state: OptimizationState, params=None) -> TensorList:
self.next_module = true_next
return ascent

def reset_stats(self):
"""Resets running stats of this optimizer such as momentum. This is meant to be used stop all
momentum when significantly changing model parameters, for example when setting model parameters
to weighted average every once in a while, like periodic SWA does. Pediodic resetting
may also be beneficial for some optimizers.
By default this method completely clears per-parameter state.
Modules may override this to provide different functionality."""
for g in self.param_groups:
for p in g['params']:
state = self.state[p]
for k in state.copy().keys(): del state[k]

class _ReturnAscent:
def __init__(self, params):
self.params = params
Expand All @@ -330,7 +370,7 @@ def step(self, state: OptimizationState):

_Chainable = OptimizerModule | Iterable[OptimizerModule]

class Chain(OptimizerModule):
class _Chain(OptimizerModule):
"""
Utility module that chains multiple modules together, usually used by other modules.
"""
Expand Down
5 changes: 3 additions & 2 deletions src/torchzero/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

from ..core.module import OptimizerModule
from . import experimental
from .adaptive import *
from .gradient_approximation import *
from .line_search import *
Expand All @@ -11,10 +12,10 @@
from .momentum import *
from .operations import *
from .optimizers import *
from .orthogonalization import *
from .quasi_newton import *
from .regularization import *
from .scheduling import *
from .second_order import *
from .smoothing import *
from .orthogonalization import *
from .weight_averaging import *
from . import experimental
2 changes: 1 addition & 1 deletion src/torchzero/modules/experimental/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Optimizers that I haven't tested and various (mostly stupid) ideas go there.
If something works well I will move it outside of experimental folder.
Otherwise all optimizers in this category should be considered unlikely to good for most tasks."""
from .experimental import GradMin, HVPDiagNewton, MinibatchRprop
from .experimental import GradMin, HVPDiagNewton, MinibatchRprop, CyclicSWA
from .subspace import (
Proj2Masks,
ProjAscent,
Expand Down
76 changes: 74 additions & 2 deletions src/torchzero/modules/experimental/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def step(self, state):
# first step
ascent = g1_sign.mul_(magnitudes).mul_(allowed)
params -= ascent
with torch.enable_grad(): state.fx0_approx = state.closure(True)
with torch.enable_grad(): state.fx0_approx = state.closure()
f0 = state.fx0; f1 = state.fx0_approx
assert f0 is not None and f1 is not None

Expand Down Expand Up @@ -191,7 +191,7 @@ def step(self, state):
state.grad = grad_fx0 # set state grad to the cloned version, since it will be overwritten

params += grad_fx0 * eps
with torch.enable_grad(): _ = state.closure(True)
with torch.enable_grad(): _ = state.closure()

params -= grad_fx0 * eps

Expand All @@ -200,3 +200,75 @@ def step(self, state):

state.ascent = newton
return self._update_params_or_step_with_next(state)


def _reset_stats_hook(optimizer, state):
for module in optimizer.modules:
module: OptimizerModule
module.reset_stats()

class CyclicSWA(OptimizerModule):
"""I remember reading about this but I have no idea if I actually red it or if I dreamed it up. I am not able to find the paper.
This is just periodic SWA with cyclic learning rate. So it samples the weights, increases lr to `peak_lr`, samples the weights again,
decreases lr back to `init_lr`, and samples the weights last time. Then model weights are replaced with the average of the three sampled weights,
and next cycle starts.
It is easier to tune than PeriodicSWA and seems to work better too.
"""
def __init__(self, cswa_start: int, cycle_length: int, steps_between: int, init_lr: float = 0, peak_lr: float = 1):

super().__init__({})
self.cswa_start = cswa_start
self.cycle_length = cycle_length
self.init_lr = init_lr
self.peak_lr = peak_lr
self.steps_between = steps_between

self.cur = 0
self.cycle_cur = 0
self.n_models = 0

self.cur_lr = self.init_lr

def step(self, state):
params = self.get_params()

# start first period after `cswa_start` steps
if self.cur >= self.cswa_start:

ascent = state.maybe_use_grad_(params)

# determine the lr
point = self.cycle_cur / self.cycle_length
if point < 0.5:
p2 = point*2
lr = self.init_lr * (1-p2) + self.peak_lr * p2
else:
p2 = (1 - point)*2
lr = self.init_lr * (1-p2) + self.peak_lr * p2

ascent *= lr
ret = self._update_params_or_step_with_next(state, params)

if self.cycle_cur in (0, self.cycle_length, self.cycle_length // 2):
swa = self.get_state_key('swa')
swa.mul_(self.n_models).add_(params).div_(self.n_models + 1)
self.n_models += 1

if self.cycle_cur == self.cycle_length:
assert self.n_models == 3, self.n_models
self.n_models = 0
self.cycle_cur = -1

params.set_(swa)
state.add_post_step_hook(_reset_stats_hook)

self.cycle_cur += 1

else:
ret = self._update_params_or_step_with_next(state, params)

self.cur += 1

return ret
2 changes: 1 addition & 1 deletion src/torchzero/modules/experimental/subspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch

from ... import tensorlist as tl
from ...core import OptimizationState, OptimizerModule, Chain
from ...core import OptimizationState, OptimizerModule, _Chain
# this whole thing can also be implemented via parameter vectors.
# Need to test which one is more efficient...

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def get_forward_gradient(
jvp = fwAD.unpack_dual(loss).tangent

elif mode == 'grad':
with torch.enable_grad(): loss = closure(True)
with torch.enable_grad(): loss = closure()
jvp = tangents.mul(params.ensure_grad_().grad).sum()

elif mode == 'fd':
Expand Down
Loading

0 comments on commit 48c071b

Please sign in to comment.