Skip to content

Commit

Permalink
operations revamp
Browse files Browse the repository at this point in the history
  • Loading branch information
inikishev committed Jan 10, 2025
1 parent d83688e commit 0adf91f
Show file tree
Hide file tree
Showing 31 changed files with 994 additions and 411 deletions.
1 change: 1 addition & 0 deletions docs/source/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ torch
numpy
scipy
nevergrad
nlopt
git+https://github.com/inikishev/torchzero.git
furo
sphinx-autoapi
24 changes: 14 additions & 10 deletions src/torchzero/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ def _get_loss(fx0, fx0_approx):
"""
Closure example:
.. code-block:: python
def closure(backward = True, **k):
def closure(backward = True):
loss = model(inputs)
if backward:
optimizer.zero_grad()
loss.backward(**k)
loss.backward()
return loss
This closure will also work with all built in pytorch optimizers including LBFGS, as well as and most custom ones.
Expand All @@ -36,7 +36,7 @@ def __init__(self, closure: ClosureType | None, model: torch.nn.Module | None):
"""A closure that reevaluates the model and returns the loss.
The closure should accept `backward` boolean argument that is True by default, which,
if True, sets `.grad` attributes of all learnable params, for example via `loss.backward()`.
Closure can be None for some first order optimizers."""
Closure can be None for most first order optimizers."""

self.ascent: TensorList | None = None
"""Ascent direction, for example the gradients.
Expand All @@ -52,13 +52,15 @@ def __init__(self, closure: ClosureType | None, model: torch.nn.Module | None):
This is mainly used as the return value of the step method when fx0 is None."""

self.grad: TensorList | None = None
"""Gradient if it has been computed, otherwise None."""
"""Gradient if it has been computed, otherwise None.
Gradient must be evaluated strictly with initial parameters of the current step"""

self.model = model
"""Model (for higher order derivatives)"""

def maybe_compute_grad_(self, params: TensorList) -> TensorList:
"""Computes gradient if it hasn't been computed already."""
"""Computes gradient if it hasn't been computed already, and returns it"""
if self.grad is None:

if self.closure is not None:
Expand All @@ -68,8 +70,8 @@ def maybe_compute_grad_(self, params: TensorList) -> TensorList:
return self.grad

def maybe_use_grad_(self, params: TensorList | None) -> TensorList:
"""If ascent direction is None, use cloned gradient as ascent direction.
otherwise returns existing ascent direction.
"""If ascent direction is None, use cloned gradient as ascent direction and returns it.
Otherwise does nothing and returns existing ascent direction.
If gradient hasn't been computed, this also sets `fx0`."""
if self.ascent is None:
if params is None: raise ValueError()
Expand Down Expand Up @@ -105,7 +107,9 @@ def copy(self, clone_ascent = False):
return state

def update_attrs_(self, state: "OptimizationState"):
"""Updates attributes of this state with attributes of another state."""
"""Updates attributes of this state with attributes of another state.
This updates `grad`, `fx0` and `fx0_approx`."""
if state.grad is not None: self.grad = state.grad
if state.fx0 is not None: self.fx0 = state.fx0
if state.fx0_approx is not None: self.fx0_approx = state.fx0_approx
Expand Down Expand Up @@ -165,7 +169,7 @@ def _set_child_(self, name, child: "OptimizerModule"):
def _update_child_params_(self, child: "OptimizerModule"):
"""Initializes or updates child params with parameters of this module."""
return self._update_next_module_params_(child)

def _set_next_module(self, next_module: "OptimizerModule"):
"""Set next module and initialize it's params."""
self.next_module = next_module
Expand Down Expand Up @@ -297,7 +301,7 @@ def return_ascent(self, state: OptimizationState, params=None) -> TensorList:
ascent: TensorList = self.step(state) # type:ignore
self.next_module = true_next
return ascent

class _ReturnAscent:
def __init__(self, params):
self.params = params
Expand Down
10 changes: 5 additions & 5 deletions src/torchzero/core/tensorlist_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@ def __init__(self, params: ParamsT, defaults):
super().__init__(params, defaults)
self._params: list[torch.Tensor] = [param for group in self.param_groups for param in group['params']]
self.has_complex = any(torch.is_complex(x) for x in self._params)
"""True if any of the params are complex"""

def add_param_group(self, param_group: dict[str, Any]) -> None:
super().add_param_group(param_group)
self._params: list[torch.Tensor] = [param for group in self.param_groups for param in group['params']]
self.has_complex = any(torch.is_complex(x) for x in self._params)

def get_params[CLS: Any](self, cls: type[CLS] = TensorList) -> CLS:
"""returns all params with `requires_grad = True` as a TensorList."""
return cls(p for p in self._params if p.requires_grad)

def ensure_grad_(self):
Expand All @@ -35,16 +37,14 @@ def ensure_grad_(self):
if p.requires_grad and p.grad is None: p.grad = torch.zeros_like(p)

def get_state_key[CLS: MutableSequence](self, key: str, init: _StateInit = torch.zeros_like, params=None, cls: type[CLS] = TensorList) -> CLS:
"""Returns a TensorList with the `key` states of all `params` that currently have grad or require grad,
depending on `mode` passed on `__init__`. Creates the states if they don't exist.
This guarantees that the returned TensorList is the same shape as params, so
models where some parameters don't have a gradient sometimes will still work.
"""Returns a tensorlist of all `key` states of all params with `requires_grad = True`.
Args:
key (str): key to create/access.
init: Initial value if key doesn't exist. Can be `params`, `grad`, or callable such as `torch.zeros_like`.
Defaults to torch.zeros_like.
params (_type_, optional): optionally pass params if you already created them. Defaults to None.
params (optional): optionally pass params if you already created them. Defaults to None.
cls (optional): optionally specify any other MutableSequence subclass to use instead of TensorList.
Returns:
TensorList: TensorList with the `key` state. Those tensors are stored in the optimizer, so modify them in-place.
Expand Down
15 changes: 8 additions & 7 deletions src/torchzero/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,23 @@
r"""
This submodule contains composable optimizer "building blocks".
"""
from typing import TypedDict, Any, get_type_hints, Literal

from collections.abc import Iterable, Sequence
from .gradient_approximation import *
from typing import Any, Literal, TypedDict, get_type_hints

from ..core.module import OptimizerModule
from .adaptive import *
from .gradient_approximation import *
from .line_search import *
from .meta import *
from .momentum import *
from .misc import *
from .momentum import *
from .operations import *
from .optimizers import *
from .quasi_newton import *
from .regularization import *
from .second_order import *
from .optimizers import *
from .smoothing import *

from ..core.module import OptimizerModule
# from .experimental.subspace import *

Modules = OptimizerModule | Sequence[OptimizerModule]
Expand Down Expand Up @@ -97,7 +98,7 @@ def _get_baked_in_and_module_lr(lr: float, kwargs: _CommonKwargs):
"""some optimizers like adam have `lr` baked in because it is slightly more efficient than using `LR(lr)` module.
But some modules like update norm require lr to be 1, so an LR(lr) needs to be put after them. Using this basically checks
if any of those modules are being used and if they are, it sets lr to 1 and appends an LR(lr) module.
.. code:: py
lr, lr_module = _get_lr_and_lr_module(lr, kwargs)
main: list[OptimizerModule] = [
Expand Down
2 changes: 1 addition & 1 deletion src/torchzero/modules/line_search/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from ...core import OptimizerModule
from ..meta.chain import Chain
from ..misc import Normalize
from ..regularization import Normalize
from .grid_ls import (ArangeLS, BacktrackingLS, GridLS, LinspaceLS,
MultiplicativeLS)
from .quad_interp import QuadraticInterpolation2Point
Expand Down
6 changes: 3 additions & 3 deletions src/torchzero/modules/line_search/armijo.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,16 @@ def __init__(
@torch.no_grad
def _find_best_lr(self, state: OptimizationState, params: tl.TensorList) -> float:
if state.closure is None: raise ValueError("closure is not set")
if state.ascent is None: raise ValueError("ascent_direction is not set")
ascent = state.maybe_use_grad_()
grad = state.maybe_compute_grad_(params)
lr = self.get_first_group_key('lr')
if state.fx0 is None: state.fx0 = state.closure(False)

# loss decrease per lr=1 if function was linear
decrease_per_lr = (grad*state.ascent).total_sum()
decrease_per_lr = (grad*ascent).total_sum()

for _ in range(self.max_iter):
loss = self._evaluate_lr_(lr, state.closure, state.ascent, params)
loss = self._evaluate_lr_(lr, state.closure, ascent, params)

# expected decrease
expected_decrease = decrease_per_lr * lr
Expand Down
2 changes: 1 addition & 1 deletion src/torchzero/modules/meta/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Modules that use other modules."""
# from .chain import Chain, ChainReturn
from .optimizer_wrapper import OptimizerWrapper
from .optimizer_wrapper import Wrap
from .return_overrides import SetGrad, ReturnAscent, ReturnClosure
from .grafting import Grafting, SignGrafting, IntermoduleCautious
4 changes: 2 additions & 2 deletions src/torchzero/modules/meta/optimizer_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

import torch

from ...core import OptimizerModule, _get_loss, OptimizationState
from ...core import OptimizerModule

class OptimizerWrapper(OptimizerModule):
class Wrap(OptimizerModule):
"""
Wraps any torch.optim.Optimizer.
Expand Down
7 changes: 2 additions & 5 deletions src/torchzero/modules/misc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@
This module includes various basic operators, notable LR for setting the learning rate,
as well as gradient/update clipping and normalization.
"""
from .basic import (LR, Add, AddMagnitude, Clone, Div, Identity, Lambda, Mul,
NanToNum, Pow, PowMagnitude, Reciprocal, Negate, Sign, sign_grad_, Abs, Grad, Zeros, Fill)
from .normalization import (Centralize, ClipNorm, ClipValue, Normalize,
centralize_grad_, normalize_grad_, clip_grad_norm_, clip_grad_value_)

from .basic import LR, Clone, Fill, Grad, Identity, Lambda, Zeros
from .on_increase import NegateOnLossIncrease
from .multimodule import Sum, Mean, Product, Subtract, Divide, Interpolate
126 changes: 0 additions & 126 deletions src/torchzero/modules/misc/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,132 +51,6 @@ def __init__(self, f: Callable[[TensorList], TensorList]):
@torch.no_grad()
def _update(self, state, ascent): return self.f(ascent)

class Reciprocal(OptimizerModule):
"""Calculates reciprocal of the update (1 / update)."""
def __init__(self,):
super().__init__({})

@torch.no_grad()
def _update(self, state, ascent): return ascent.reciprocal_()

class Negate(OptimizerModule):
"""Negates the update (-update)"""
def __init__(self,):
super().__init__({})

@torch.no_grad()
def _update(self, state, ascent): return ascent.neg_()

class Add(OptimizerModule):
"""Adds `value` to the update."""
def __init__(self, value):
super().__init__({})
self.value = value
@torch.no_grad()
def _update(self, state, ascent): return ascent.add_(self.value)

class AddMagnitude(OptimizerModule):
"""Add `value` multiplied by sign of the ascent, i.e. this adds `value` to the magnitude of the update.
Args:
value (_type_): value to add to magnitude.
add_to_zero (bool, optional): if True, adds `value` to 0s. Otherwise, zeros remain zero. Defaults to True.
"""
def __init__(self, value, add_to_zero=True):
super().__init__({})
self.value = value
self.add_to_zero = add_to_zero
@torch.no_grad()
def _update(self, state, ascent):
if self.add_to_zero: return ascent.add_(ascent.clamp_magnitude(min=1).sign_().mul_(self.value))
return ascent.add_(ascent.sign_().mul_(self.value))

class Mul(OptimizerModule):
"""Multiplies the update by `value`."""
def __init__(self, value):
super().__init__({})
self.value = value
@torch.no_grad()
def _update(self, state, ascent) -> TensorList:return ascent.mul_(self.value)

class Div(OptimizerModule):
"""Divides update by `value`."""
def __init__(self, value):
super().__init__({})
self.value = value
@torch.no_grad()
def _update(self, state, ascent) -> TensorList: return ascent.div_(self.value)


class Pow(OptimizerModule):
"""Raises update to the `value` power."""
def __init__(self, value):
super().__init__({})
self.value = value
@torch.no_grad()
def _update(self, state, ascent): return ascent.pow_(self.value)

class PowMagnitude(OptimizerModule):
"""Raises update to the `value` power, but preserves the sign."""
def __init__(self, value):
super().__init__({})
self.value = value
@torch.no_grad()
def _update(self, state, ascent):
if self.value % 2 == 1: return ascent.pow_(self.value)
return ascent.abs().pow_(self.value) * ascent.sign()

class NanToNum(OptimizerModule):
"""Convert `nan`, `inf` and `-inf` to numbers.
Args:
nan (optional): the value to replace NaNs with. Default is zero.
posinf (optional): if a Number, the value to replace positive infinity values with.
If None, positive infinity values are replaced with the greatest finite value
representable by input's dtype. Default is None.
neginf (optional): if a Number, the value to replace negative infinity values with.
If None, negative infinity values are replaced with the lowest finite value
representable by input's dtype. Default is None.
"""
def __init__(self, nan=None, posinf=None, neginf=None):
super().__init__({})
self.nan = nan
self.posinf = posinf
self.neginf = neginf

@torch.no_grad()
def _update(self, state, ascent): return ascent.nan_to_num_(self.nan, self.posinf, self.neginf)



def sign_grad_(params: Iterable[torch.Tensor]):
"""Apply sign function to gradients of an iterable of parameters.
Args:
params (abc.Iterable[torch.Tensor]): an iterable of Tensors or a single Tensor.
"""
TensorList(params).get_existing_grads().sign_()

class Sign(OptimizerModule):
"""Applies sign function to the update"""
def __init__(self):
super().__init__({})

@torch.no_grad
def _update(self, state, ascent):
ascent.sign_()
return ascent

class Abs(OptimizerModule):
"""Takes absolute value of the update `abs(update)`"""
def __init__(self):
super().__init__({})

@torch.no_grad
def _update(self, state, ascent):
ascent.abs_()
return ascent

class Grad(OptimizerModule):
"""Uses gradient as the update. This is useful for chains."""
def __init__(self):
Expand Down
29 changes: 29 additions & 0 deletions src/torchzero/modules/operations/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from .multi import (
Add,
AddMagnitude,
Div,
Divide,
Interpolate,
Lerp,
Mul,
Pow,
Power,
RDiv,
RPow,
RSub,
Sub,
Subtract,
)
from .reduction import Mean, Product, Sum
from .singular import (
Abs,
Cos,
MagnitudePower,
NanToNum,
Negate,
Operation,
Reciprocal,
Sign,
Sin,
sign_grad_,
)
Loading

0 comments on commit 0adf91f

Please sign in to comment.