Skip to content

Commit

Permalink
premade optimizers, operations, bunch of stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
inikishev committed Dec 25, 2024
1 parent bf8d2c9 commit e8579a8
Show file tree
Hide file tree
Showing 34 changed files with 791 additions and 203 deletions.
20 changes: 18 additions & 2 deletions src/torchzero/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ def update_attrs_(self, state: "OptimizationState"):
if state.fx0 is not None: self.fx0 = state.fx0
if state.fx0_approx is not None: self.fx0_approx = state.fx0_approx


class OptimizerModule(TensorListOptimizer, ABC):
r"""Base class for all modules.
Expand Down Expand Up @@ -272,6 +271,14 @@ def _step_update_ascent_direction(self, state: OptimizationState) -> ScalarType
# peform an update with the ascent direction, or pass it to the child.
return self._update_params_or_step_with_next(state, params=params)

def return_ascent(self, state: OptimizationState, params=None) -> TensorList:
if params is None: params = self.get_params()
true_next = self.next_module
self.next_module = _ReturnAscent(params) # type:ignore
ascent: TensorList = self.step(state) # type:ignore
self.next_module = true_next
return ascent

@torch.no_grad
def step( # type:ignore # pylint:disable=signature-differs # pylint:disable = arguments-renamed
self,
Expand All @@ -292,4 +299,13 @@ def _update(self, state: OptimizationState, ascent: TensorList) -> TensorList:
After generating a new ascent direction with this `_update` method,
if this module has no child, ascent direction will be subtracted from params.
Otherwise everything is passed to the child."""
raise NotImplementedError()
raise NotImplementedError()

class _ReturnAscent:
def __init__(self, params):
self.params = params

@torch.no_grad
def step(self, state: OptimizationState) -> TensorList: # type:ignore
update = state.maybe_use_grad_(self.params) # this will execute the closure which might be modified
return update
3 changes: 2 additions & 1 deletion src/torchzero/core/tensorlist_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@ class TensorListOptimizer(torch.optim.Optimizer, ABC):
def __init__(self, params: ParamsT, defaults):
super().__init__(params, defaults)
self._params: list[torch.Tensor] = [param for group in self.param_groups for param in group['params']]
self.has_complex = any(torch.is_complex(x) for x in self._params)

def add_param_group(self, param_group: dict[str, Any]) -> None:
super().add_param_group(param_group)
self._params: list[torch.Tensor] = [param for group in self.param_groups for param in group['params']]

self.has_complex = any(torch.is_complex(x) for x in self._params)

def get_params[CLS: Any](self, cls: type[CLS] = TensorList) -> CLS:
return cls(p for p in self._params if p.requires_grad)
Expand Down
114 changes: 113 additions & 1 deletion src/torchzero/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
# pylint: disable = singleton-comparison
# ruff: noqa: E712
r"""
This submodule contains composable optimizer "building blocks".
"""
from typing import TypedDict, Any, get_type_hints, Literal

from collections.abc import Iterable, Sequence
from .gradient_approximation import *
from .adaptive import *
from .line_search import *
Expand All @@ -13,4 +17,112 @@
from .second_order import *
from .optimizers import *
from .smoothing import *
from .subspace import *

from ..core.module import OptimizerModule
# from .experimental.subspace import *

Modules = OptimizerModule | Sequence[OptimizerModule]

def _ismodules(x):
if isinstance(x, OptimizerModule): return True
if isinstance(x, Sequence) and len(x)>0 and isinstance(x[0], OptimizerModule): return True
return False

class _CommonKwargs(TypedDict, total=False):
# lr: float | Modules | None
decoupled_l1: float | Modules | None
"""test if this works"""
decoupled_l2: float | Modules | None
l1: float | Modules | None
l2: float | Modules | None
grad_clip_norm: float | Modules | None
grad_clip_value: float | Modules | None
grad_norm: bool | Literal['global', 'param', 'channel'] | Modules | None
grad_dropout: float | Modules | None
grad_sign: bool | Modules | None
update_clip_norm: float | Modules | None
update_clip_value: float| Modules | None
update_norm: bool | Modules | None
update_sign: bool | Modules | None
lr_dropout: float | Modules | None
cautious: bool | Modules | None
line_search: LineSearches | Modules | None
momentum: float | Modules | None
dampening: float
nesterov: bool
adam: bool | Modules | None
rmsprop: bool | Modules | None
laplacian_smoothing: float | Modules | None
update_laplacian_smoothing: float | Modules | None
grad_estimator: str | Modules | None
grad_modules: Modules | None
update_modules: Modules | None
main_modules: Modules | None

def _get_module(module:str, arg: Any, all_kwargs: _CommonKwargs):
skip = {"dampening", "nesterov"}
if arg in skip: return None
if arg is None: return None
if _ismodules(arg): return arg
if module == 'lr': return LR(arg)
if module in ("l1", "decoupled_l1"): return WeightDecay(arg, ord = 1)
if module in ("l2", "decoupled_l2"): return WeightDecay(arg)
if module in ('grad_clip_norm', 'update_clip_norm'): return ClipNorm(arg)
if module in ('grad_clip_value', 'update_clip_value'): return ClipValue(arg)
if module in ('grad_norm', 'update_norm'):
if arg == True: return Normalize()
if arg == False: return None
return Normalize(mode=arg)
if module in ('grad_dropout', 'lr_dropouts'): return Dropout(arg)
if module in ('grad_sign', 'update_sign'): return Sign() if arg == True else None
if module == 'cautious': return Cautious() if arg == True else None
if module == 'line_search': return get_line_search(arg)
if module == 'momentum':
dampening = all_kwargs.get('dampening', 0)
nesterov = all_kwargs.get('nesterov', False)
if nesterov: return NesterovMomentum(arg, dampening)
return HeavyBall(arg, dampening)
if module == 'nesterov': return NesterovMomentum(arg)
if module == 'adam': return Adam() if arg == True else None
if module == 'rmsprop': return RMSProp() if arg == True else None
if module in ('laplacian_smoothing', 'update_laplacian_smoothing'): return LaplacianSmoothing(arg)
if module == 'grad_estimator': raise NotImplementedError(module)
raise ValueError(module)

def _should_decouple_lr(kwargs: _CommonKwargs):
decoupled_modules = {"update_norm", "update_sign", "update_clip_norm", "update_clip_value"}
return any(m in kwargs for m in decoupled_modules)

def _get_lr_and_lr_module(lr: float, kwargs: _CommonKwargs):
if _should_decouple_lr(kwargs): return 1, lr
return lr, None

def _make_common_modules(main: OptimizerModule | Iterable[OptimizerModule] | None, lr_module: float | None, kwargs: _CommonKwargs):
"""common modules, this is used to add common things to all torchzero.optim optimizers
l1 and l2 depend on learning rate of the optimizer, and are applied before the update rule.
Decoupled versions also do not depend on learning rate and are applied after the update rule.
Update modules, such as update_clip_norm, do not depend on lr either.
"""
from ..python_tools import flatten
order = [
"grad_estimator",
"l1", "l2", "laplacian_smoothing", "grad_modules", "grad_sign", "grad_clip_norm", "grad_clip_value", "grad_norm", "grad_dropout",
"main", "main_modules", "rmsprop", "adam", "momentum", "cautious", "update_norm",
"update_laplacian_smoothing", "update_sign", "update_clip_norm", "update_clip_value", "lr",
"lr_dropout", "decoupled_l1", "decoupled_l2", "update_modules", "line_search"
]

keys = set(get_type_hints(_CommonKwargs).keys()).union({'lr'}).difference({"dampening", "nesterov"})
order_keys = set(order).difference({'main'})
assert order_keys == keys, f'missing: {order_keys.difference(keys)}'

modules_dict = {k: _get_module(k, v, kwargs) for k, v in kwargs.items()}
modules_dict["main"] = main
if lr_module is not None: modules_dict['lr'] = _get_module('lr', lr_module, kwargs)

modules = [modules_dict[k] for k in order if k in modules_dict]
modules = [i for i in modules if i is not None]
return flatten(modules)
File renamed without changes.
43 changes: 43 additions & 0 deletions src/torchzero/modules/experimental/gradmin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import torch

from ...core import OptimizerModule
from ...grad.derivatives import jacobian
from ...tensorlist import TensorList

class GradMin(OptimizerModule):
"""An idea.
"""
def __init__(self, add_loss: float = 1, square=False, maximize_grad = False):
super().__init__(dict(add_loss=add_loss))
self.square = square
self.maximize_grad = maximize_grad

@torch.no_grad
def step(self, state):
if state.closure is None: raise ValueError()
if state.ascent is not None:
raise ValueError("GradMin doesn't accept ascent_direction")

params = self.get_params()
add_loss = self.get_group_key('add_loss')

self.zero_grad()
with torch.enable_grad():
state.fx0 = state.closure(False)
grads = jacobian([state.fx0], params, create_graph=True, batched=False) # type:ignore
grads = TensorList(grads).squeeze_(0)
if self.square:
grads = grads ** 2
else:
grads = grads.abs()

if self.maximize_grad: grads: TensorList = grads - (state.fx0 * add_loss) # type:ignore
else: grads = grads + (state.fx0 * add_loss)
grad_mean = torch.sum(torch.stack(grads.sum())) / grads.total_numel()
grad_mean.backward(retain_graph=False)

if self.maximize_grad: state.grad = params.ensure_grad_().grad.neg_()
else: state.grad = params.ensure_grad_().grad

state.maybe_use_grad_(params)
return self._update_params_or_step_with_next(state)
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,8 @@

from ...core import OptimizerModule

class HvInvFDM(OptimizerModule):
class SquaredGradientNormFDM(OptimizerModule):
"""Experimental (maybe don't use yet).
This should approximate the hessian via just two backward passes
but it only works if hessian is purely diagonal.
Otherwise I don't really know what happens and I am looking into it.
Args:
eps (float, optional): finite difference epsilon. Defaults to 1e-3.
Expand All @@ -18,12 +15,12 @@ def __init__(self, eps=1e-3):
def step(self, state):
if state.closure is None: raise ValueError()
if state.ascent is not None:
raise ValueError("DiagNewtonViaHessianGradientProduct doesn't accept ascent_direction")
raise ValueError("HvInvFDM doesn't accept ascent_direction")

params = self.get_params()
eps = self.get_group_key('eps')
grad_fx0 = state.maybe_compute_grad_(params).clone()
state.grad = grad_fx0 # set state grad to the cloned version, since it will be overwritted
state.grad = grad_fx0 # set state grad to the cloned version, since it will be overwritten

params += grad_fx0 * eps
with torch.enable_grad(): _ = state.closure(True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch

from ... import tl
from ... import tensorlist as tl
from ...core import OptimizationState, OptimizerModule
from ..meta.chain import Chain
# this whole thing can also be implemented via parameter vectors.
Expand Down
19 changes: 10 additions & 9 deletions src/torchzero/modules/line_search/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,20 @@
from .grid_ls import (ArangeLS, BacktrackingLS, GridLS, LinspaceLS,
MultiplicativeLS)
from .quad_interp import QuadraticInterpolation2Point
from .quadratic_ls import MinimizeQuadratic3PointsLS, MinimizeQuadraticLS
from .directional_newton import DirectionalNewton3Points, DirectionalNewton
from .scipy_minimize_scalar import ScipyMinimizeScalarLS

LineSearches = T.Literal['backtracking', 'brent', 'brent-exact', 'brent-norm', 'multiplicative'] | OptimizerModule
LineSearches = T.Literal['backtracking', 'brent', 'brent-exact', 'brent-norm', 'multiplicative', 'newton', 'newton-grad'] | OptimizerModule

def get_line_search(name:str | OptimizerModule):
if isinstance(name, str):
name = name.strip().lower()
if name == 'backtracking': return BacktrackingLS()
if name == 'multiplicative': return BacktrackingLS()
elif name == 'brent': return ScipyMinimizeScalarLS(maxiter=8)
elif name == 'brent-exact': return ScipyMinimizeScalarLS()
elif name == 'brent-norm': return Chain([Normalize(), ScipyMinimizeScalarLS(maxiter=16)])
else: raise ValueError(f"Unknown line search method: {name}")
else:
return name
if name == 'multiplicative': return MultiplicativeLS()
if name == 'brent': return ScipyMinimizeScalarLS(maxiter=8)
if name == 'brent-exact': return ScipyMinimizeScalarLS()
if name == 'brent-norm': return [Normalize(), ScipyMinimizeScalarLS(maxiter=16)]
if name == 'newton': return DirectionalNewton3Points(1)
if name == 'newton-grad': return DirectionalNewton(1)
raise ValueError(f"Unknown line search method: {name}")
return name
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def _ensure_float(x):
elif isinstance(x, np.ndarray): return x.item()
return float(x)

class MinimizeQuadraticLS(LineSearchBase):
class DirectionalNewton(LineSearchBase):
"""Minimizes a parabola in the direction of the gradient via one additional forward pass,
and uses another forward pass to make sure it didn't overstep.
So in total this performs three forward passes and one backward.
Expand Down Expand Up @@ -49,7 +49,7 @@ class MinimizeQuadraticLS(LineSearchBase):
Note:
While lr scheduling is supported, this uses lr of the first parameter for all parameters.
"""
def __init__(self, lr:float=1e-2, max_dist: float | None = 1e4, validate_step = True, log_lrs = False,):
def __init__(self, lr:float=1e-2, max_dist: float | None = 1e5, validate_step = True, log_lrs = False,):
super().__init__({"lr": lr}, make_closure=False, maxiter=None, log_lrs=log_lrs)

self.max_dist = max_dist
Expand Down Expand Up @@ -134,7 +134,7 @@ def _newton_step_3points(
# xneg is actually x0
return xneg - dx / ddx, ddx

class MinimizeQuadratic3PointsLS(LineSearchBase):
class DirectionalNewton3Points(LineSearchBase):
"""Minimizes a parabola in the direction of the update via two additional forward passe,
and uses another forward pass to make sure it didn't overstep.
So in total this performs four forward passes.
Expand Down
12 changes: 9 additions & 3 deletions src/torchzero/modules/meta/chain.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import typing as T
from collections import abc

import torch

from ...core import OptimizerModule, OptimizationState
from ...tensorlist import TensorList
from ...python_tools import flatten

class Chain(OptimizerModule):
Expand All @@ -17,7 +17,7 @@ def __init__(self, *modules: OptimizerModule | abc.Iterable[OptimizerModule]):

# first module is chain's child, second module is first module's child, etc
if len(flat_modules) != 0:
self._set_next_module(flat_modules[0])
self._set_child_('first', flat_modules[0])
if len(flat_modules) > 1:
for i, m in enumerate(flat_modules[:-1]):
m._set_next_module(flat_modules[i+1])
Expand All @@ -26,4 +26,10 @@ def __init__(self, *modules: OptimizerModule | abc.Iterable[OptimizerModule]):

@torch.no_grad
def step(self, state: OptimizationState):
return self._update_params_or_step_with_next(state)
# no next module, step with the child
if self.next_module is None:
return self.children['first'].step(state)

# return ascent and pass it to next module
state.ascent = self.children['first'].return_ascent(state)
return self._update_params_or_step_with_next(state)
Loading

0 comments on commit e8579a8

Please sign in to comment.