Skip to content

Commit

Permalink
no ... imports
Browse files Browse the repository at this point in the history
  • Loading branch information
inikishev committed Jan 17, 2025
1 parent 0074276 commit 0d33ec6
Show file tree
Hide file tree
Showing 20 changed files with 49 additions and 41 deletions.
14 changes: 7 additions & 7 deletions docs/source/FAQ.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ If you are intending to use gradient-free methods, :code:`backward` argument is

How to construct modular optimizers?
=====================================
A modular optimizer can be created using the :py:class:`tz.Modular` class. It can be constructed as :code:`tz.Modular(params, *modules)`, or as :code:`tz.Modular(params, [modules])`.
A modular optimizer can be created using the :py:class:`tz.m.Modular<torchzero.module.Modular>` class. It can be constructed as :code:`tz.Modular(params, *modules)`, or as :code:`tz.Modular(params, [modules])`.

All modules are available in :code:`tz.m` namespace, e.g. :py:class:`tz.m.Adam`.

Expand All @@ -79,18 +79,18 @@ All modules are available in :code:`tz.m` namespace, e.g. :py:class:`tz.m.Adam`.
In the example above, :code:`Adam`, being the first module, takes in the gradient, applies the adam update rule, and passes the resulting update the next next module - :code:`LR`. It multiplies the update by the learning rate and passes it to :code:`Cautious`, which applies cautioning and passes it to :code:`WeightDecay`, which adds a weight decay penalty. The resulting update is then subtracted from the model parameters.

It is recommended to always add an :py:class:`tz.m.LR` module to support lr schedulers and per-layer learning rates (see :ref:`how do we handle learning rates?`).
It is recommended to always add an :py:class:`tz.m.LR<torchzero.modules.LR>` module to support lr schedulers and per-layer learning rates (see :ref:`how do we handle learning rates?`).

Most modules perform gradient transformations, so they take in an ascent direction, which is initially the gradient, transform it in some way, and pass to the next module. The first module in the chain usually uses the gradient as the initial ascent direction.

Certain modules, such as gradient-approximation ones or :py:class:`tz.m.ExactNewton`, create an ascent direction "from scratch", so they should be placed first in the chain.
Certain modules, such as gradient-approximation ones or :py:class:`tz.m.ExactNewton<torczhero.modules.ExactNewton>`, create an ascent direction "from scratch", so they should be placed first in the chain.

Any external PyTorch optimizer can also be used as a chainable module by using :py:class:`tz.m.Wrap` and :py:class:`tz.m.WrapClosure` (see :ref:`How to use external PyTorch optimizers as chainable modules?`).
Any external PyTorch optimizer can also be used as a chainable module by using :py:class:`tz.m.Wrap<torchzero.modules.Wrap>` and :py:class:`tz.m.WrapClosure<torchzero.modules.WrapClosure>` (see :ref:`How to use external PyTorch optimizers as chainable modules?`).

How to use learning rate schedulers?
=============================================
There are two primary methods for using learning rate schedulers.
One method is to pass learning rate scheduler class to the :py:class:`tz.m.LR` module like this:
One method is to pass learning rate scheduler class to the :py:class:`tz.m.LR<torchzero.modules.LR>` module like this:

.. code:: python
Expand All @@ -105,7 +105,7 @@ One method is to pass learning rate scheduler class to the :py:class:`tz.m.LR` m
This method also supports cycling momentum, which some schedulers like OneCycleLR do. Momentum will be cycled on all modules that have :code:`momentum` or :code:`beta1` parameters.

Alternatively, learning rate scheduler can be created separately by passing it the LR module, which can be accessed with :code:`get_lr_module` method like this:
Alternatively, learning rate scheduler can be created separately by passing it the LR module, which can be accessed with :py:meth:`torchzero.optim.Modular.get_lr_module<get_lr_module>` method like this:

.. code:: python
Expand All @@ -116,7 +116,7 @@ Alternatively, learning rate scheduler can be created separately by passing it t
scheduler = OneCycleLR(opt.get_lr_module(), max_lr = 1e-1, total_steps=60_000)
Here :code:`get_lr_module` returns the :py:class:`tz.m.LR`, even if it is nested somewhere. You can then call :code:`scheduler.step()` as usual. This method does not support cycling momentum.
Here :code:`get_lr_module` returns the :py:class:`tz.m.LR<torchzero.modules.LR>`, even if it is nested somewhere. You can then call :code:`scheduler.step()` as usual. This method does not support cycling momentum.


How to specify per-parameter options?
Expand Down
4 changes: 4 additions & 0 deletions docs/source/implementing.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Implementing new modules
=========================

.. Optimizer modules are similar to standard PyTorch optimizers that inherit from :code:`torch.optim.Optimizer`, in fact the :py:class:`tz.core.OptimizerModule<torchzero.core.OptimizerModule>` inherits from it as well.
2 changes: 2 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@ Welcome to torchzero documentation!

installation
introduction

FAQ
.. automodule:: torchzero
:caption: API reference:
:members:
:show-inheritance:
* :ref:`search`

implementing


.. idk what this does
Expand Down
2 changes: 1 addition & 1 deletion docs/source/installation.rst
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Installation
==================

TODO!
TODO
4 changes: 2 additions & 2 deletions docs/source/introduction.rst
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
introduction
Introduction
==================

TODO!
TODO
4 changes: 3 additions & 1 deletion src/torchzero/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

from ..utils.python_tools import _ScalarLoss, flatten
from ..tensorlist import TensorList
from .tensorlist_optimizer import TensorListOptimizer, ParamsT
from .tensorlist_optimizer import TensorListOptimizer
from torch.optim.optimizer import ParamsT


def _get_loss(fx0, fx0_approx):
Expand Down Expand Up @@ -492,6 +493,7 @@ def step(self, state: OptimizationState):
return self.children['first'].step(state)

# return ascent and pass it to next module
# we do this because updating parameters directly is often more efficient
params = self.get_params()
self._last_module.next_module = _ReturnAscent(params) # type:ignore
state.ascent: TensorList = self.children['first'].step(state) # type:ignore
Expand Down
2 changes: 1 addition & 1 deletion src/torchzero/core/tensorlist_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch.optim.optimizer
from torch.optim.optimizer import ParamsT

from torchzero.tensorlist import TensorList, NumberList
from ..tensorlist import TensorList, NumberList

_StateInit = Literal['params', 'grad'] | Callable | TensorList
class TensorListOptimizer(torch.optim.Optimizer, ABC):
Expand Down
4 changes: 2 additions & 2 deletions src/torchzero/modules/experimental/quad_interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import torch

from ... import tl
from ...tensorlist import TensorList
from ...core import OptimizationState
from ..line_search.base_ls import LineSearchBase

Expand Down Expand Up @@ -47,7 +47,7 @@ def __init__(self, lr=1e-2, log_lrs = False, max_evals = 2, min_dist = 1e-2,):
self.min_dist = min_dist

@torch.no_grad
def _find_best_lr(self, state: OptimizationState, params: tl.TensorList) -> float:
def _find_best_lr(self, state: OptimizationState, params: TensorList) -> float:
if state.closure is None: raise ValueError('QuardaticLS requires closure')
closure = state.closure
if state.fx0 is None: state.fx0 = state.closure(False)
Expand Down
4 changes: 2 additions & 2 deletions src/torchzero/modules/line_search/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Line searches.
"""

import typing as T
from typing import Literal

from ...core import OptimizerModule
from ..regularization import Normalize
Expand All @@ -13,7 +13,7 @@
from .scipy_minimize_scalar import ScipyMinimizeScalarLS
from .armijo import ArmijoLS

LineSearches = T.Literal['backtracking', 'brent', 'brent-exact', 'brent-norm', 'multiplicative', 'newton', 'newton3', 'armijo'] | OptimizerModule
LineSearches = Literal['backtracking', 'brent', 'brent-exact', 'brent-norm', 'multiplicative', 'newton', 'newton3', 'armijo'] | OptimizerModule

def get_line_search(name:str | OptimizerModule) -> OptimizerModule | list[OptimizerModule]:
if isinstance(name, str):
Expand Down
4 changes: 2 additions & 2 deletions src/torchzero/modules/line_search/armijo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch

from ... import tl
from ...tensorlist import TensorList
from ...core import OptimizationState
from .base_ls import LineSearchBase

Expand Down Expand Up @@ -32,7 +32,7 @@ def __init__(
self.max_iter = max_iter

@torch.no_grad
def _find_best_lr(self, state: OptimizationState, params: tl.TensorList) -> float:
def _find_best_lr(self, state: OptimizationState, params: TensorList) -> float:
if state.closure is None: raise ValueError("closure is not set")
ascent = state.maybe_use_grad_(params)
grad = state.maybe_compute_grad_(params)
Expand Down
12 changes: 6 additions & 6 deletions src/torchzero/modules/line_search/base_ls.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch

from ... import tl
from ...tensorlist import TensorList
from ...core import _ClosureType, OptimizationState, OptimizerModule, _maybe_pass_backward
from ...utils.python_tools import _ScalarLoss

Expand Down Expand Up @@ -58,13 +58,13 @@ def _reset(self):
self._fx0_approx = None
self._current_iter = 0

def _set_lr_(self, lr: float, ascent_direction: tl.TensorList, params: tl.TensorList, ):
def _set_lr_(self, lr: float, ascent_direction: TensorList, params: TensorList, ):
alpha = self._last_lr - lr
if alpha != 0: params.add_(ascent_direction, alpha = alpha)
self._last_lr = lr

# lr is first here so that we can use a partial
def _evaluate_lr_(self, lr: float, closure: _ClosureType, ascent: tl.TensorList, params: tl.TensorList, backward=False):
def _evaluate_lr_(self, lr: float, closure: _ClosureType, ascent: TensorList, params: TensorList, backward=False):
"""Evaluate `lr`, if loss is better than current lowest loss,
overrides `self._lowest_loss` and `self._best_lr`.
Expand Down Expand Up @@ -100,15 +100,15 @@ def _evaluate_lr_ensure_float(
self,
lr: float,
closure: _ClosureType,
ascent: tl.TensorList,
params: tl.TensorList,
ascent: TensorList,
params: TensorList,
) -> float:
"""Same as _evaluate_lr_ but ensures that the loss value is float."""
v = self._evaluate_lr_(lr, closure, ascent, params)
if isinstance(v, torch.Tensor): return v.detach().cpu().item()
return float(v)

def _find_best_lr(self, state: OptimizationState, params: tl.TensorList) -> float:
def _find_best_lr(self, state: OptimizationState, params: TensorList) -> float:
"""This should return the best lr."""
... # pylint:disable=unnecessary-ellipsis

Expand Down
6 changes: 3 additions & 3 deletions src/torchzero/modules/line_search/directional_newton.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import torch

from ... import tl
from ...tensorlist import TensorList
from ...core import OptimizationState
from .base_ls import LineSearchBase

Expand Down Expand Up @@ -57,7 +57,7 @@ def __init__(self, max_dist: float | None = 1e5, validate_step = True, alpha:flo
self.validate_step = validate_step

@torch.no_grad
def _find_best_lr(self, state: OptimizationState, params: tl.TensorList) -> float:
def _find_best_lr(self, state: OptimizationState, params: TensorList) -> float:
if state.closure is None: raise ValueError('QuardaticLS requires closure')
closure = state.closure

Expand Down Expand Up @@ -172,7 +172,7 @@ def __init__(self, max_dist: float | None = 1e4, validate_step = True, alpha: fl
self.validate_step = validate_step

@torch.no_grad
def _find_best_lr(self, state: OptimizationState, params: tl.TensorList) -> float:
def _find_best_lr(self, state: OptimizationState, params: TensorList) -> float:
if state.closure is None: raise ValueError('QuardaticLS requires closure')
closure = state.closure
ascent_direction = state.ascent
Expand Down
4 changes: 2 additions & 2 deletions src/torchzero/modules/line_search/grid_ls.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
import torch

from ... import tl
from ...tensorlist import TensorList
from ...core import _ClosureType, OptimizationState
from .base_ls import LineSearchBase

Expand Down Expand Up @@ -34,7 +34,7 @@ def __init__(
self.stop_on_worsened = stop_on_worsened

@torch.no_grad
def _find_best_lr(self, state: OptimizationState, params: tl.TensorList) -> float:
def _find_best_lr(self, state: OptimizationState, params: TensorList) -> float:
if state.closure is None: raise ValueError("closure is not set")
if state.ascent is None: raise ValueError("ascent_direction is not set")

Expand Down
4 changes: 2 additions & 2 deletions src/torchzero/modules/line_search/scipy_minimize_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
except ModuleNotFoundError:
scopt = typing.cast(typing.Any, None)

from ... import tl
from ...tensorlist import TensorList
from ...core import OptimizationState

from .base_ls import LineSearchBase, MaxIterReached
Expand Down Expand Up @@ -45,7 +45,7 @@ def __init__(
self.options = options

@torch.no_grad
def _find_best_lr(self, state: OptimizationState, params: tl.TensorList) -> float:
def _find_best_lr(self, state: OptimizationState, params: TensorList) -> float:
try:
res = scopt.minimize_scalar(
self._evaluate_lr_ensure_float,
Expand Down
2 changes: 1 addition & 1 deletion src/torchzero/modules/misc/accumulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch

from torchzero.tensorlist import TensorList
from ...tensorlist import TensorList

from ...core import OptimizerModule

Expand Down
2 changes: 1 addition & 1 deletion src/torchzero/modules/misc/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch

from torchzero.tensorlist import TensorList
from ...tensorlist import TensorList

from ...core import OptimizerModule, _Chainable

Expand Down
2 changes: 1 addition & 1 deletion src/torchzero/modules/misc/lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import torch

from torchzero.tensorlist import TensorList
from ...tensorlist import TensorList

from ...core import OptimizerModule

Expand Down
2 changes: 1 addition & 1 deletion src/torchzero/modules/misc/multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch

from torchzero.tensorlist import TensorList
from ...tensorlist import TensorList

from ...core import OptimizerModule, _Chainable

Expand Down
4 changes: 2 additions & 2 deletions src/torchzero/modules/regularization/ortho_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import torch

from ... import tl
from ...tensorlist import TensorList
from ...core import OptimizerModule, _Targets


Expand All @@ -20,7 +20,7 @@ def orthograd_(params: Iterable[torch.Tensor], eps: float = 1e-30):
reference
https://arxiv.org/abs/2501.04697
"""
if not isinstance(params, tl.TensorList): params = tl.TensorList(params)
if not isinstance(params, TensorList): params = TensorList(params)
params = params.with_grad()
grad = params.grad
grad -= (((params*grad).total_sum())/(params*params).total_sum() + eps) * params
Expand Down
8 changes: 4 additions & 4 deletions src/torchzero/modules/smoothing/gaussian_smoothing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import torch

from ... import tl
from ...tensorlist import TensorList, Distributions, mean as tlmean
from ...utils.python_tools import _ScalarLoss
from ...core import _ClosureType, OptimizationState, OptimizerModule, _maybe_pass_backward

Expand All @@ -29,14 +29,14 @@ def __init__(
self,
n_samples: int = 4,
sigma: float = 0.1,
distribution: tl.Distributions = "normal",
distribution: Distributions = "normal",
sample_x0 = False,
randomize_every: int | None = 1,
):
defaults = dict(sigma = sigma)
super().__init__(defaults)
self.n_samples = n_samples
self.distribution: tl.Distributions = distribution
self.distribution: Distributions = distribution
self.randomize_every = randomize_every
self.current_step = 0
self.perturbations = None
Expand Down Expand Up @@ -78,7 +78,7 @@ def smooth_closure(backward = True):
params.sub_(p)

# set the new averaged grads and return average loss
if backward: params.set_grad_(tl.mean(grads))
if backward: params.set_grad_(tlmean(grads))
return _numpy_or_torch_mean(losses)


Expand Down

0 comments on commit 0d33ec6

Please sign in to comment.