Skip to content

Commit

Permalink
better lr scheduling + docs
Browse files Browse the repository at this point in the history
  • Loading branch information
inikishev committed Jan 16, 2025
1 parent bfbf3a2 commit 7cb158e
Show file tree
Hide file tree
Showing 15 changed files with 563 additions and 119 deletions.
254 changes: 252 additions & 2 deletions docs/source/FAQ.rst
Original file line number Diff line number Diff line change
@@ -1,4 +1,254 @@
FAQ
==================
###########

TODO!

How to perform optimization?
============================
Most torchzero optimizers can be used in the same way as built in pytorch optimizers:

.. code:: python
import torchzero as tz
opt = tz.Modular(
model.parameters(),
[tz.m.Adam(), tz.m.LR(1e-3), tz.m.WeightDecay()]
)
for inputs, targets in dataloader:
preds = model(inputs)
loss = loss_fn(preds, targets)
loss.backward()
opt.step()
opt.zero_grad()
A few modules and optimizers require closure, similar to :code:`torch.optim.LBFGS` but with an additional :code:`backward` argument, which, if True, calls :code:`opt.zero_grad()` and :code:`loss.backward()`. The name of the argument doesn't matter, but I will refer to it as :code:`backward`.

All line-searches and gradient approximation modules, as well as a few other ones, require a closure. Training loop with a closure looks like this:

.. code:: python
import torchzero as tz
opt = tz.Modular(
model.parameters(),
[tz.m.Adam(), tz.m.LR(1e-3), tz.m.WeightDecay()]
)
for inputs, targets in dataloader:
def closure(backward=True):
preds = model(inputs)
loss = loss_fn(preds, targets)
if backward:
opt.zero_grad()
loss.backward()
return loss
loss = opt.step(closure)
Note that all built-in pytorch optimizers, as well as most custom ones, support closure too! So the training loop above will work with all other optimizers out of the box, and switching to it prevents having to rewrite training loop when changing optimizers.

If you are intending to use gradient-free methods, :code:`backward` argument is still required in the closure. Simply leave it unused. Gradient-free and gradient approximation methods always call closure with :code:`backward=False`.

How to construct modular optimizers?
=====================================
A modular optimizer can be created using the :py:class:`tz.Modular` class. It can be constructed as :code:`tz.Modular(params, *modules)`, or as :code:`tz.Modular(params, [modules])`.

All modules are available in :code:`tz.m` namespace, e.g. :py:class:`tz.m.Adam`.

.. code:: python
import torchzero as tz
# construct it like this
opt = tz.Modular(
model.parameters(),
[tz.m.Adam(), tz.m.LR(1e-3), tz.m.Cautious(), tz.m.WeightDecay()]
)
# or like this
opt = tz.Modular(
model.parameters(),
tz.m.Adam(),
tz.m.LR(1e-3),
tz.m.Cautious(),
tz.m.WeightDecay(),
)
In the example above, :code:`Adam`, being the first module, takes in the gradient, applies the adam update rule, and passes the resulting update the next next module - :code:`LR`. It multiplies the update by the learning rate and passes it to :code:`Cautious`, which applies cautioning and passes it to :code:`WeightDecay`, which adds a weight decay penalty. The resulting update is then subtracted from the model parameters.

It is recommended to always add an :py:class:`tz.m.LR` module to support lr schedulers and per-layer learning rates (see :ref:`how do we handle learning rates?`).

Most modules perform gradient transformations, so they take in an ascent direction, which is initially the gradient, transform it in some way, and pass to the next module. The first module in the chain usually uses the gradient as the initial ascent direction.

Certain modules, such as gradient-approximation ones or :py:class:`tz.m.ExactNewton`, create an ascent direction "from scratch", so they should be placed first in the chain.

Any external PyTorch optimizer can also be used as a chainable module by using :py:class:`tz.m.Wrap` and :py:class:`tz.m.WrapClosure` (see :ref:`How to use external PyTorch optimizers as chainable modules?`).

How to use learning rate schedulers?
=============================================
There are two primary methods for using learning rate schedulers.
One method is to pass learning rate scheduler class to the :py:class:`tz.m.LR` module like this:

.. code:: python
from torch.optim.lr_scheduler import OneCycleLR
opt = tz.Modular(
model.parameters(),
tz.m.Adam(),
tz.m.LR(1e-1, scheduler_cls = lambda opt: OneCycleLR(opt, max_lr = 1e-1, total_steps = 60_000)),
tz.m.WeightDecay(),
)
This method also supports cycling momentum, which some schedulers like OneCycleLR do. Momentum will be cycled on all modules that have :code:`momentum` or :code:`beta1` parameters.

Alternatively, learning rate scheduler can be created separately by passing it the LR module, which can be accessed with :code:`get_lr_module` method like this:

.. code:: python
opt = tz.Modular(
model.parameters(),
[tz.m.Adam(), tz.m.LR(1e-3), tz.m.WeightDecay()]
)
scheduler = OneCycleLR(opt.get_lr_module(), max_lr = 1e-1, total_steps=60_000)
Here :code:`get_lr_module` returns the :py:class:`tz.m.LR`, even if it is nested somewhere. You can then call :code:`scheduler.step()` as usual. This method does not support cycling momentum.


How to specify per-parameter options?
=============================================
In pytorch it is possible to specify per-layer options, such as learning rate, using parameter groups. In torchzero those are specified in almost the same way (although there is a catch):

.. code:: python
param_groups = [
{'params': model.encoder.parameters(), 'lr': 1e-2, 'eps': 1e-5},
{'params': model.decoder.parameters()}
]
optimizer = tz.Modular(
param_groups,
[tz.m.Adam(), tz.m.LR(1e-3), tz.m.WeightDecay()]
)
In the example above, :code:`model.encoder` will use a custom learning rate of 1e-2, and custom adam epsilon of 1e-5, while :code:`model.decoder` will stick to the default learning rate of 1e-3 and the default epsilon value.

The catch is that when you specify a setting such as `eps`, it will be applied to ALL modules that have that setting, which may lead to unexpected behavior. For example, both :py:class:`tz.m.Adam` and :py:class:`tz.m.RandomizedFDM` have an `eps` parameter, which has completely different function and value range. To avoid this, per-parameter settings can be specified for specific modules by using the `set_params` method:

.. code:: python
adam_param_groups = [
{'params': model.encoder.parameters(), 'lr': 1e-2, 'eps': 1e-5},
{'params': model.decoder.parameters()}
]
# 1. create adam
adam = tz.m.Adam()
# 2. pass custom parameter groups to adam
adam.set_params(adam_param_groups)
# 3. create modular optimizer after passing custom parameter groups
optimizer = tz.Modular(
param_groups,
[adam, tz.m.LR(1e-3), tz.m.WeightDecay()]
)
You don't have to worry about this if you are only setting per-layer lr, because the only module that has an :code:`lr` setting is :py:class:`tz.m.LR` (see :ref:`How do we handle learning rates?`).

How do we handle learning rates?
=================================
Certain optimisers, like Adam, have learning rate built into the update rule. Using multiple such modules can result in unintended compounding of learning rate modifications.

To avoid this, learning rate should be applied by a singular :py:class:`tz.m.LR` module. All other modules with a learning rate, such as :py:class:`tz.m.Adam`, have `lr` renamed to `alpha` with the default value of 1 to avoid rescaling the update.

For example:

.. code:: python
tz.Modular(
model.parameters(),
[tz.m.Adam(), tz.m.LR(1e-3), tz.m.WeightDecay()]
)
Here, instead of using Adam's `alpha` setting, we added an :code:`LR` module. This allows this modular optimizer to support per-parameter `lr` setting and learning rate schedulers, without having to worry about learning rate compounding.

See also:

* :ref:`how to use learning rate schedulers?`
* :ref:`How to specify per-parameter options?`

How to use external PyTorch optimizers as chainable modules?
============================================================
In addition to torchzero modules, any PyTorch optimizer can be used as a module using :py:class:`tz.m.Wrap`.

There are two slightly different ways to construct a :code:`Wrap` module. Here I will convert :code:`LaProp` optimizer from `pytorch_optimizer <https://pytorch-optimizers.readthedocs.io/en/latest/optimizer/#pytorch_optimizer.LaProp>`_ library into a module and chain it with :py:class:`tz.m.Cautious`

.. code:: py
from pytorch_optimizer import LaProp
# first way
tz.Modular(
model.parameters(),
tz.m.ClipNorm(1),
tz.m.Wrap(LaProp, lr = 1, betas = (0.9, 0.99)),
tz.m.LR(1e-3),
tz.m.Cautious(),
)
# second way (identical but more verbose)
tz.Modular(
model.parameters(),
tz.m.ClipNorm(1),
tz.m.Wrap(LaProp(model.parameters(), lr = 1, betas = (0.9, 0.99))),
tz.m.LR(1e-3),
tz.m.Cautious(),
)
Most pytorch optimizers update model parameters by using their :code:`.grad` attibute. Wrap puts the current update into the :code:`.grad`, making the wrapped optimizer use it instead.

Note that since the wrapped optimizer updates model parameters directly, if :class:`tz.m.Wrap` is not the last module, it stores model parameters before the step, then performs a step with the wrapped optimizer, calculates the update as difference between model parameters before and after the step, and undoes the step. That may introduce additional overhead compared to using modules.

However when :py:class:`tz.m.Wrap` is the last module in the chain, it simply makes a step with the wrapped optimizer, so no overhead is introduced.

Also notice how I set `lr` to 1 in LaProp, and instead used an :py:class:`tz.m.LR` module. As usual, to make the optimizer support lr scheduling and per-layer learning rates, use the :py:class:`tz.m.LR` module to set the learning rate.

There is also a :py:class:`tz.m.WrapClosure` for optimizers that require closure, such as :code:`torch.optim.LBFGS`. It modifies the closure to set :code:`.grad` attribute on each closure evaluation. So you can use LBFGS with FDM or gradient smoothing methods.

How to save/serialize a modular optimizer?
============================================
TODO

How much overhead does a torchzero modular optimizer have compared to a normal optimizer?
==========================================================================================
A thorough benchmark will be posted to this section very soon. There is no overhead other than what is described below.

Since some optimizers, like Adam, have learning rate baked into the update rule, but we use LR module instead, that requires an extra add operation. Currently if :py:class:`tz.m.Adam` is directly followed by a :py:class:`tz.m.LR`, they will be automatically fused. However adding LR fusing to all modules with a learning rate is not a priority, unless I find that it makes a non-negligible difference to performance.

Whenever possible I used :code:`_foreach_xxx` operations. Those make the optimizers way quicker, especially with a lot of different parameter tensors. Also all modules change the update in-place whenever possible.

Is there support for complex-valued parameters?
=================================================
Currently no, as I have not made the modules with complex-valued parameters in mind, although some might still work. I do use complex-valued networks so I am looking into adding support. There may actually be a way to support them automatically.

Is there support for optimized parameters being on different devices?
======================================================================
TODO

Is there support for FSDP (FullyShardedDataParallel)?
======================================================
There is no support for FDSP. It may be possible to add some FDSP module, I will look into it at some point. Currently I don't think I can even use FDSP because I only have one laptop.

Is there support for differentiable optimizers?
======================================================
There is no support for differentiable optimizers.

In PyTorch most optimizers have a :code:`differentiable` argument runs autograd through optimizer step, for example :code:`torch.optim.Adam(params, 1e-3, differentiable=True)`.

I have not looked into this yet, adding support may or may not be as easy as switching :code:`@torch.no_grad` decorator to :code:`@_use_grad_for_differentiable`.
31 changes: 28 additions & 3 deletions src/torchzero/core/module.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, TypeAlias
from typing import Any, TypeAlias, Self
from abc import ABC, abstractmethod
from collections.abc import Callable, Sequence, Iterable
import warnings
Expand All @@ -17,8 +17,11 @@ def _get_loss(fx0, fx0_approx):

_ClosureType = Callable[..., _ScalarLoss] #
"""
Closure example:
.. code-block:: python
def closure(backward = True):
loss = model(inputs)
if backward:
Expand Down Expand Up @@ -72,10 +75,18 @@ def __init__(self, closure: _ClosureType | None, model: torch.nn.Module | None):
Gradient must be evaluated strictly with initial parameters of the current step"""

self.model = model
"""Model (for higher order derivatives)"""
"""model itself (torch.nn.Module) if it was passed, otherwise None."""

self.post_step_hooks = []
"""callables that get executed after each step. Used by periodic SWA to reset momentum when setting model parameters to SWA."""
"""callables that get executed after each step. Used by periodic SWA to reset momentum when setting model parameters to SWA.
Signature:
.. code:: py
def hook(optimizer: ModularOptimizer, state: OptimizationState) -> None:
...
"""

def maybe_compute_grad_(self, params: TensorList) -> TensorList:
"""Computes gradient if it hasn't been computed already, and returns it"""
Expand Down Expand Up @@ -183,6 +194,19 @@ def __init__(self, defaults: dict[str, Any], make_closure = False): # pylint:dis
self._passed_params: list[torch.Tensor] | list[dict[str, Any]] | None = None
"""list of parameters or parameter groups that were passed to this module and will get passed to child modules."""

self.post_init_hooks: list[Callable[[Any, Self], Any]] = []
"""Hooks that run once after a ModularOptimizer is initialized with this module.
Signature:
.. code:: py
def hook(optimizer: ModularOptimizer, module: OptimizerModule) -> None:
...
where `module` is this module.
"""

def __repr__(self):
if self._initialized: return super().__repr__()
return f"uninitialized {self.__class__.__name__}()"
Expand Down Expand Up @@ -373,6 +397,7 @@ def reset_stats(self):
for k in state.copy().keys(): del state[k]

class _ReturnAscent:
IS_LR_MODULE = False
def __init__(self, params):
self.params = params

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ class ForwardGradient(OptimizerModule):
n_samples (int): number of forward gradients to evaluate and average.
distribution (Distributions): distribution for random tangent vector.
mode (str):
"jvp" - uses forward mode AD, usually slightly slower than backward mode AD but uses significantly less memory.
"jvp" - uses forward mode AD, usually slightly slower than backward mode AD but uses significantly less memory,
because it doesn't have to store intermediate activations.
"grad" - evaluates gradient with `loss.backward()` which may be faster but uses all the memory, mainly useful for
benchmarking as there is probably no point in forward gradient if full gradient is available.
Expand Down
3 changes: 2 additions & 1 deletion src/torchzero/modules/misc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
as well as gradient/update clipping and normalization.
"""

from .basic import LR, Clone, Fill, Grad, Identity, Lambda, Zeros
from .basic import Clone, Fill, Grad, Identity, Lambda, Zeros, Alpha
from .lr import LR
from .on_increase import NegateOnLossIncrease
from .multistep import Multistep
from .accumulate import Accumulate
13 changes: 0 additions & 13 deletions src/torchzero/modules/misc/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,6 @@

from ...core import OptimizerModule

class LR(OptimizerModule):
"""Multiplies update by the learning rate."""
IS_LR_MODULE = True
def __init__(self, lr = 1e-3):
defaults = dict(lr = lr)
super().__init__(defaults)

@torch.no_grad
def _update(self, state, ascent):
# multiply ascent direction by lr in-place
lr = self.get_group_key('lr')
ascent *= lr
return ascent

class Alpha(OptimizerModule):
"""Multiplies update by the learning rate, won't get picked up by learning rate schedulers."""
Expand Down
Loading

0 comments on commit 7cb158e

Please sign in to comment.