diff --git a/src/torchzero/core/module.py b/src/torchzero/core/module.py index cda2642..41622b3 100644 --- a/src/torchzero/core/module.py +++ b/src/torchzero/core/module.py @@ -110,7 +110,6 @@ def update_attrs_(self, state: "OptimizationState"): if state.fx0 is not None: self.fx0 = state.fx0 if state.fx0_approx is not None: self.fx0_approx = state.fx0_approx - class OptimizerModule(TensorListOptimizer, ABC): r"""Base class for all modules. @@ -272,6 +271,14 @@ def _step_update_ascent_direction(self, state: OptimizationState) -> ScalarType # peform an update with the ascent direction, or pass it to the child. return self._update_params_or_step_with_next(state, params=params) + def return_ascent(self, state: OptimizationState, params=None) -> TensorList: + if params is None: params = self.get_params() + true_next = self.next_module + self.next_module = _ReturnAscent(params) # type:ignore + ascent: TensorList = self.step(state) # type:ignore + self.next_module = true_next + return ascent + @torch.no_grad def step( # type:ignore # pylint:disable=signature-differs # pylint:disable = arguments-renamed self, @@ -292,4 +299,13 @@ def _update(self, state: OptimizationState, ascent: TensorList) -> TensorList: After generating a new ascent direction with this `_update` method, if this module has no child, ascent direction will be subtracted from params. Otherwise everything is passed to the child.""" - raise NotImplementedError() \ No newline at end of file + raise NotImplementedError() + +class _ReturnAscent: + def __init__(self, params): + self.params = params + + @torch.no_grad + def step(self, state: OptimizationState) -> TensorList: # type:ignore + update = state.maybe_use_grad_(self.params) # this will execute the closure which might be modified + return update diff --git a/src/torchzero/core/tensorlist_optimizer.py b/src/torchzero/core/tensorlist_optimizer.py index 52fc27b..dd84983 100644 --- a/src/torchzero/core/tensorlist_optimizer.py +++ b/src/torchzero/core/tensorlist_optimizer.py @@ -19,11 +19,12 @@ class TensorListOptimizer(torch.optim.Optimizer, ABC): def __init__(self, params: ParamsT, defaults): super().__init__(params, defaults) self._params: list[torch.Tensor] = [param for group in self.param_groups for param in group['params']] + self.has_complex = any(torch.is_complex(x) for x in self._params) def add_param_group(self, param_group: dict[str, Any]) -> None: super().add_param_group(param_group) self._params: list[torch.Tensor] = [param for group in self.param_groups for param in group['params']] - + self.has_complex = any(torch.is_complex(x) for x in self._params) def get_params[CLS: Any](self, cls: type[CLS] = TensorList) -> CLS: return cls(p for p in self._params if p.requires_grad) diff --git a/src/torchzero/modules/__init__.py b/src/torchzero/modules/__init__.py index fff18ab..9f591d0 100644 --- a/src/torchzero/modules/__init__.py +++ b/src/torchzero/modules/__init__.py @@ -1,7 +1,11 @@ +# pylint: disable = singleton-comparison +# ruff: noqa: E712 r""" This submodule contains composable optimizer "building blocks". """ +from typing import TypedDict, Any, get_type_hints, Literal +from collections.abc import Iterable, Sequence from .gradient_approximation import * from .adaptive import * from .line_search import * @@ -13,4 +17,112 @@ from .second_order import * from .optimizers import * from .smoothing import * -from .subspace import * \ No newline at end of file + +from ..core.module import OptimizerModule +# from .experimental.subspace import * + +Modules = OptimizerModule | Sequence[OptimizerModule] + +def _ismodules(x): + if isinstance(x, OptimizerModule): return True + if isinstance(x, Sequence) and len(x)>0 and isinstance(x[0], OptimizerModule): return True + return False + +class _CommonKwargs(TypedDict, total=False): + # lr: float | Modules | None + decoupled_l1: float | Modules | None + """test if this works""" + decoupled_l2: float | Modules | None + l1: float | Modules | None + l2: float | Modules | None + grad_clip_norm: float | Modules | None + grad_clip_value: float | Modules | None + grad_norm: bool | Literal['global', 'param', 'channel'] | Modules | None + grad_dropout: float | Modules | None + grad_sign: bool | Modules | None + update_clip_norm: float | Modules | None + update_clip_value: float| Modules | None + update_norm: bool | Modules | None + update_sign: bool | Modules | None + lr_dropout: float | Modules | None + cautious: bool | Modules | None + line_search: LineSearches | Modules | None + momentum: float | Modules | None + dampening: float + nesterov: bool + adam: bool | Modules | None + rmsprop: bool | Modules | None + laplacian_smoothing: float | Modules | None + update_laplacian_smoothing: float | Modules | None + grad_estimator: str | Modules | None + grad_modules: Modules | None + update_modules: Modules | None + main_modules: Modules | None + +def _get_module(module:str, arg: Any, all_kwargs: _CommonKwargs): + skip = {"dampening", "nesterov"} + if arg in skip: return None + if arg is None: return None + if _ismodules(arg): return arg + if module == 'lr': return LR(arg) + if module in ("l1", "decoupled_l1"): return WeightDecay(arg, ord = 1) + if module in ("l2", "decoupled_l2"): return WeightDecay(arg) + if module in ('grad_clip_norm', 'update_clip_norm'): return ClipNorm(arg) + if module in ('grad_clip_value', 'update_clip_value'): return ClipValue(arg) + if module in ('grad_norm', 'update_norm'): + if arg == True: return Normalize() + if arg == False: return None + return Normalize(mode=arg) + if module in ('grad_dropout', 'lr_dropouts'): return Dropout(arg) + if module in ('grad_sign', 'update_sign'): return Sign() if arg == True else None + if module == 'cautious': return Cautious() if arg == True else None + if module == 'line_search': return get_line_search(arg) + if module == 'momentum': + dampening = all_kwargs.get('dampening', 0) + nesterov = all_kwargs.get('nesterov', False) + if nesterov: return NesterovMomentum(arg, dampening) + return HeavyBall(arg, dampening) + if module == 'nesterov': return NesterovMomentum(arg) + if module == 'adam': return Adam() if arg == True else None + if module == 'rmsprop': return RMSProp() if arg == True else None + if module in ('laplacian_smoothing', 'update_laplacian_smoothing'): return LaplacianSmoothing(arg) + if module == 'grad_estimator': raise NotImplementedError(module) + raise ValueError(module) + +def _should_decouple_lr(kwargs: _CommonKwargs): + decoupled_modules = {"update_norm", "update_sign", "update_clip_norm", "update_clip_value"} + return any(m in kwargs for m in decoupled_modules) + +def _get_lr_and_lr_module(lr: float, kwargs: _CommonKwargs): + if _should_decouple_lr(kwargs): return 1, lr + return lr, None + +def _make_common_modules(main: OptimizerModule | Iterable[OptimizerModule] | None, lr_module: float | None, kwargs: _CommonKwargs): + """common modules, this is used to add common things to all torchzero.optim optimizers + + l1 and l2 depend on learning rate of the optimizer, and are applied before the update rule. + + Decoupled versions also do not depend on learning rate and are applied after the update rule. + + Update modules, such as update_clip_norm, do not depend on lr either. + """ + from ..python_tools import flatten + order = [ + "grad_estimator", + "l1", "l2", "laplacian_smoothing", "grad_modules", "grad_sign", "grad_clip_norm", "grad_clip_value", "grad_norm", "grad_dropout", + "main", "main_modules", "rmsprop", "adam", "momentum", "cautious", "update_norm", + "update_laplacian_smoothing", "update_sign", "update_clip_norm", "update_clip_value", "lr", + "lr_dropout", "decoupled_l1", "decoupled_l2", "update_modules", "line_search" + ] + + keys = set(get_type_hints(_CommonKwargs).keys()).union({'lr'}).difference({"dampening", "nesterov"}) + order_keys = set(order).difference({'main'}) + assert order_keys == keys, f'missing: {order_keys.difference(keys)}' + + modules_dict = {k: _get_module(k, v, kwargs) for k, v in kwargs.items()} + modules_dict["main"] = main + if lr_module is not None: modules_dict['lr'] = _get_module('lr', lr_module, kwargs) + + modules = [modules_dict[k] for k in order if k in modules_dict] + modules = [i for i in modules if i is not None] + return flatten(modules) \ No newline at end of file diff --git a/src/torchzero/optim/first_order/adamw.py b/src/torchzero/modules/experimental/__init__.py similarity index 100% rename from src/torchzero/optim/first_order/adamw.py rename to src/torchzero/modules/experimental/__init__.py diff --git a/src/torchzero/modules/experimental/gradmin.py b/src/torchzero/modules/experimental/gradmin.py new file mode 100644 index 0000000..243664d --- /dev/null +++ b/src/torchzero/modules/experimental/gradmin.py @@ -0,0 +1,43 @@ +import torch + +from ...core import OptimizerModule +from ...grad.derivatives import jacobian +from ...tensorlist import TensorList + +class GradMin(OptimizerModule): + """An idea. + """ + def __init__(self, add_loss: float = 1, square=False, maximize_grad = False): + super().__init__(dict(add_loss=add_loss)) + self.square = square + self.maximize_grad = maximize_grad + + @torch.no_grad + def step(self, state): + if state.closure is None: raise ValueError() + if state.ascent is not None: + raise ValueError("GradMin doesn't accept ascent_direction") + + params = self.get_params() + add_loss = self.get_group_key('add_loss') + + self.zero_grad() + with torch.enable_grad(): + state.fx0 = state.closure(False) + grads = jacobian([state.fx0], params, create_graph=True, batched=False) # type:ignore + grads = TensorList(grads).squeeze_(0) + if self.square: + grads = grads ** 2 + else: + grads = grads.abs() + + if self.maximize_grad: grads: TensorList = grads - (state.fx0 * add_loss) # type:ignore + else: grads = grads + (state.fx0 * add_loss) + grad_mean = torch.sum(torch.stack(grads.sum())) / grads.total_numel() + grad_mean.backward(retain_graph=False) + + if self.maximize_grad: state.grad = params.ensure_grad_().grad.neg_() + else: state.grad = params.ensure_grad_().grad + + state.maybe_use_grad_(params) + return self._update_params_or_step_with_next(state) diff --git a/src/torchzero/modules/quasi_newton/hv_inv_fdm.py b/src/torchzero/modules/experimental/squared_grad_norm_fdm.py similarity index 69% rename from src/torchzero/modules/quasi_newton/hv_inv_fdm.py rename to src/torchzero/modules/experimental/squared_grad_norm_fdm.py index 6d26669..fa8bf57 100644 --- a/src/torchzero/modules/quasi_newton/hv_inv_fdm.py +++ b/src/torchzero/modules/experimental/squared_grad_norm_fdm.py @@ -2,11 +2,8 @@ from ...core import OptimizerModule -class HvInvFDM(OptimizerModule): +class SquaredGradientNormFDM(OptimizerModule): """Experimental (maybe don't use yet). - This should approximate the hessian via just two backward passes - but it only works if hessian is purely diagonal. - Otherwise I don't really know what happens and I am looking into it. Args: eps (float, optional): finite difference epsilon. Defaults to 1e-3. @@ -18,12 +15,12 @@ def __init__(self, eps=1e-3): def step(self, state): if state.closure is None: raise ValueError() if state.ascent is not None: - raise ValueError("DiagNewtonViaHessianGradientProduct doesn't accept ascent_direction") + raise ValueError("HvInvFDM doesn't accept ascent_direction") params = self.get_params() eps = self.get_group_key('eps') grad_fx0 = state.maybe_compute_grad_(params).clone() - state.grad = grad_fx0 # set state grad to the cloned version, since it will be overwritted + state.grad = grad_fx0 # set state grad to the cloned version, since it will be overwritten params += grad_fx0 * eps with torch.enable_grad(): _ = state.closure(True) diff --git a/src/torchzero/modules/subspace/random_subspace.py b/src/torchzero/modules/experimental/subspace.py similarity index 97% rename from src/torchzero/modules/subspace/random_subspace.py rename to src/torchzero/modules/experimental/subspace.py index a8e59d5..1a09ede 100644 --- a/src/torchzero/modules/subspace/random_subspace.py +++ b/src/torchzero/modules/experimental/subspace.py @@ -4,7 +4,7 @@ import torch -from ... import tl +from ... import tensorlist as tl from ...core import OptimizationState, OptimizerModule from ..meta.chain import Chain # this whole thing can also be implemented via parameter vectors. diff --git a/src/torchzero/modules/line_search/__init__.py b/src/torchzero/modules/line_search/__init__.py index 6c3f3f2..ccd1f78 100644 --- a/src/torchzero/modules/line_search/__init__.py +++ b/src/torchzero/modules/line_search/__init__.py @@ -10,19 +10,20 @@ from .grid_ls import (ArangeLS, BacktrackingLS, GridLS, LinspaceLS, MultiplicativeLS) from .quad_interp import QuadraticInterpolation2Point -from .quadratic_ls import MinimizeQuadratic3PointsLS, MinimizeQuadraticLS +from .directional_newton import DirectionalNewton3Points, DirectionalNewton from .scipy_minimize_scalar import ScipyMinimizeScalarLS -LineSearches = T.Literal['backtracking', 'brent', 'brent-exact', 'brent-norm', 'multiplicative'] | OptimizerModule +LineSearches = T.Literal['backtracking', 'brent', 'brent-exact', 'brent-norm', 'multiplicative', 'newton', 'newton-grad'] | OptimizerModule def get_line_search(name:str | OptimizerModule): if isinstance(name, str): name = name.strip().lower() if name == 'backtracking': return BacktrackingLS() - if name == 'multiplicative': return BacktrackingLS() - elif name == 'brent': return ScipyMinimizeScalarLS(maxiter=8) - elif name == 'brent-exact': return ScipyMinimizeScalarLS() - elif name == 'brent-norm': return Chain([Normalize(), ScipyMinimizeScalarLS(maxiter=16)]) - else: raise ValueError(f"Unknown line search method: {name}") - else: - return name \ No newline at end of file + if name == 'multiplicative': return MultiplicativeLS() + if name == 'brent': return ScipyMinimizeScalarLS(maxiter=8) + if name == 'brent-exact': return ScipyMinimizeScalarLS() + if name == 'brent-norm': return [Normalize(), ScipyMinimizeScalarLS(maxiter=16)] + if name == 'newton': return DirectionalNewton3Points(1) + if name == 'newton-grad': return DirectionalNewton(1) + raise ValueError(f"Unknown line search method: {name}") + return name \ No newline at end of file diff --git a/src/torchzero/modules/line_search/quadratic_ls.py b/src/torchzero/modules/line_search/directional_newton.py similarity index 96% rename from src/torchzero/modules/line_search/quadratic_ls.py rename to src/torchzero/modules/line_search/directional_newton.py index 41b1648..4496f6b 100644 --- a/src/torchzero/modules/line_search/quadratic_ls.py +++ b/src/torchzero/modules/line_search/directional_newton.py @@ -18,7 +18,7 @@ def _ensure_float(x): elif isinstance(x, np.ndarray): return x.item() return float(x) -class MinimizeQuadraticLS(LineSearchBase): +class DirectionalNewton(LineSearchBase): """Minimizes a parabola in the direction of the gradient via one additional forward pass, and uses another forward pass to make sure it didn't overstep. So in total this performs three forward passes and one backward. @@ -49,7 +49,7 @@ class MinimizeQuadraticLS(LineSearchBase): Note: While lr scheduling is supported, this uses lr of the first parameter for all parameters. """ - def __init__(self, lr:float=1e-2, max_dist: float | None = 1e4, validate_step = True, log_lrs = False,): + def __init__(self, lr:float=1e-2, max_dist: float | None = 1e5, validate_step = True, log_lrs = False,): super().__init__({"lr": lr}, make_closure=False, maxiter=None, log_lrs=log_lrs) self.max_dist = max_dist @@ -134,7 +134,7 @@ def _newton_step_3points( # xneg is actually x0 return xneg - dx / ddx, ddx -class MinimizeQuadratic3PointsLS(LineSearchBase): +class DirectionalNewton3Points(LineSearchBase): """Minimizes a parabola in the direction of the update via two additional forward passe, and uses another forward pass to make sure it didn't overstep. So in total this performs four forward passes. diff --git a/src/torchzero/modules/meta/chain.py b/src/torchzero/modules/meta/chain.py index 19c0c56..46333da 100644 --- a/src/torchzero/modules/meta/chain.py +++ b/src/torchzero/modules/meta/chain.py @@ -1,9 +1,9 @@ -import typing as T from collections import abc import torch from ...core import OptimizerModule, OptimizationState +from ...tensorlist import TensorList from ...python_tools import flatten class Chain(OptimizerModule): @@ -17,7 +17,7 @@ def __init__(self, *modules: OptimizerModule | abc.Iterable[OptimizerModule]): # first module is chain's child, second module is first module's child, etc if len(flat_modules) != 0: - self._set_next_module(flat_modules[0]) + self._set_child_('first', flat_modules[0]) if len(flat_modules) > 1: for i, m in enumerate(flat_modules[:-1]): m._set_next_module(flat_modules[i+1]) @@ -26,4 +26,10 @@ def __init__(self, *modules: OptimizerModule | abc.Iterable[OptimizerModule]): @torch.no_grad def step(self, state: OptimizationState): - return self._update_params_or_step_with_next(state) + # no next module, step with the child + if self.next_module is None: + return self.children['first'].step(state) + + # return ascent and pass it to next module + state.ascent = self.children['first'].return_ascent(state) + return self._update_params_or_step_with_next(state) \ No newline at end of file diff --git a/src/torchzero/modules/meta/grafting.py b/src/torchzero/modules/meta/grafting.py index 748e673..103ecbc 100644 --- a/src/torchzero/modules/meta/grafting.py +++ b/src/torchzero/modules/meta/grafting.py @@ -39,12 +39,8 @@ def __init__( # TODO: channelwise ): super().__init__({}) - - if not isinstance(magnitude, Iterable): magnitude = [magnitude] - if not isinstance(direction, Iterable): direction = [direction] - - self._set_child_('magnitude', Chain([*magnitude, ReturnAscent()])) - self._set_child_('direction', Chain([*direction, ReturnAscent()])) + self._set_child_('magnitude', Chain(magnitude)) + self._set_child_('direction', Chain(direction)) self.ord = ord self.eps = eps self.layerwise = layerwise @@ -53,13 +49,13 @@ def __init__( @torch.no_grad def step(self, state): state_copy = state.copy(clone_ascent=True) - magnitude: TensorList = self.children['magnitude'].step(state_copy) # type:ignore + magnitude = self.children['magnitude'].return_ascent(state_copy) if state_copy.grad is not None: state.grad = state_copy.grad if state_copy.fx0 is not None: state.fx0 = state_copy.fx0 if state_copy.fx0_approx is not None: state.fx0_approx = state_copy.fx0_approx - direction: TensorList = self.children['direction'].step(state) # type:ignore + direction = self.children['direction'].return_ascent(state) if self.layerwise: M = magnitude.norm(self.ord) @@ -95,22 +91,19 @@ def __init__( ): super().__init__({}) - if not isinstance(magnitude, Iterable): magnitude = [magnitude] - if not isinstance(sign, Iterable): sign = [sign] - - self._set_child_('magnitude', Chain([*magnitude, ReturnAscent()])) - self._set_child_('sign', Chain([*sign, ReturnAscent()])) + self._set_child_('magnitude', Chain(magnitude)) + self._set_child_('sign', Chain(sign)) @torch.no_grad def step(self, state): state_copy = state.copy(clone_ascent=True) - magnitude: TensorList = self.children['magnitude'].step(state_copy).abs_() # type:ignore + magnitude = self.children['magnitude'].return_ascent(state_copy).abs_() # make sure to store grad and fx0 if it was calculated state.update_attrs_(state_copy) - sign: TensorList = self.children['sign'].step(state).sign_() # type:ignore + sign = self.children['sign'].return_ascent(state).sign_() state.ascent = magnitude.mul_(sign) return self._update_params_or_step_with_next(state) @@ -152,10 +145,10 @@ def __init__( ): super().__init__({}) - self._set_child_('main', Chain(main_module, ReturnAscent())) + self._set_child_('main', Chain(main_module)) if isinstance(compare_module, str): self.compare_mode = compare_module else: - self._set_child_('compare', Chain(compare_module, ReturnAscent())) + self._set_child_('compare', Chain(compare_module)) self.compare_mode = 'module' self.eps = eps self.normalize = normalize @@ -165,10 +158,10 @@ def __init__( def step(self, state): params = None state_copy = state.copy(clone_ascent=True) - ascent: TensorList = self.children['main'].step(state_copy) # type:ignore + ascent = self.children['main'].return_ascent(state_copy) state.update_attrs_(state_copy) - if self.compare_mode == 'module': compare: TensorList = self.children['compare'].step(state) # type:ignore + if self.compare_mode == 'module': compare = self.children['compare'].return_ascent(state) else: params = self.get_params() if self.compare_mode == 'ascent': compare: TensorList = state.maybe_use_grad_(params) diff --git a/src/torchzero/modules/misc/__init__.py b/src/torchzero/modules/misc/__init__.py index 5c495db..e0d8689 100644 --- a/src/torchzero/modules/misc/__init__.py +++ b/src/torchzero/modules/misc/__init__.py @@ -3,7 +3,8 @@ as well as gradient/update clipping and normalization. """ from .basic import (LR, Add, AddMagnitude, Clone, Div, Identity, Lambda, Mul, - NanToNum, Pow, PowMagnitude, Reciprocal, Negate, Sign, sign_grad_) + NanToNum, Pow, PowMagnitude, Reciprocal, Negate, Sign, sign_grad_, Abs, Grad, Zeros, Fill) from .normalization import (Centralize, ClipNorm, ClipValue, Normalize, centralize_grad_, normalize_grad_, clip_grad_norm_, clip_grad_value_) from .on_increase import NegateOnLossIncrease +from .multimodule import Sum, Mean, Product, Subtract, Divide, Interpolate \ No newline at end of file diff --git a/src/torchzero/modules/misc/basic.py b/src/torchzero/modules/misc/basic.py index b65d9bd..6f5b859 100644 --- a/src/torchzero/modules/misc/basic.py +++ b/src/torchzero/modules/misc/basic.py @@ -166,3 +166,39 @@ def __init__(self): def _update(self, state, ascent): ascent.sign_() return ascent + +class Abs(OptimizerModule): + """Takes absolute value of the update `abs(update)`""" + def __init__(self): + super().__init__({}) + + @torch.no_grad + def _update(self, state, ascent): + ascent.abs_() + return ascent + +class Grad(OptimizerModule): + """Uses gradient as the update. This is useful for chains.""" + def __init__(self): + super().__init__({}) + + @torch.no_grad + def _update(self, state, ascent): + ascent = state.ascent = state.maybe_compute_grad_(self.get_params()) + return ascent + +class Zeros(OptimizerModule): + def __init__(self): + super().__init__({}) + + @torch.no_grad + def _update(self, state, ascent): + return ascent.zeros_like() + +class Fill(OptimizerModule): + def __init__(self, value): + super().__init__({"value": value}) + + @torch.no_grad + def _update(self, state, ascent): + return ascent.fill(self.get_group_key('value')) \ No newline at end of file diff --git a/src/torchzero/modules/misc/multimodule.py b/src/torchzero/modules/misc/multimodule.py new file mode 100644 index 0000000..5ea4c31 --- /dev/null +++ b/src/torchzero/modules/misc/multimodule.py @@ -0,0 +1,168 @@ +from collections.abc import Callable, Iterable +from typing import cast +import torch + +from ...tensorlist import TensorList + +from ...core import OptimizerModule +from ..meta.chain import Chain + + +class Sum(OptimizerModule): + def __init__( + self, + modules: Iterable[OptimizerModule | Iterable[OptimizerModule]], + ): + super().__init__({}) + modules = list(modules) + for i,module in enumerate(modules): + self._set_child_(i, Chain(module)) + + @torch.no_grad + def step(self, state): + if len(self.children) == 1: + state.ascent = self.children[0].return_ascent(state) + return self._update_params_or_step_with_next(state) + + sum = None + for i, c in sorted(self.children.items(), key=lambda x: x[0]): + if i == len(self.children) - 1: cur_state = state + else: cur_state = state.copy(clone_ascent = True) + + if sum is None: sum = c.return_ascent(cur_state) + else: sum += c.return_ascent(cur_state) + + if i != len(self.children) - 1: state.update_attrs_(cur_state) + + assert sum is not None + state.ascent = sum + return self._update_params_or_step_with_next(state) + +class Mean(OptimizerModule): + def __init__( + self, + modules: Iterable[OptimizerModule | Iterable[OptimizerModule]], + ): + super().__init__({}) + modules = list(modules) + for i,module in enumerate(modules): + self._set_child_(i, Chain(module)) + + @torch.no_grad + def step(self, state): + if len(self.children) == 1: + state.ascent = self.children[0].return_ascent(state) + return self._update_params_or_step_with_next(state) + + sum = None + for i, c in sorted(self.children.items(), key=lambda x: x[0]): + if i == len(self.children) - 1: cur_state = state + else: cur_state = state.copy(clone_ascent = True) + + if sum is None: sum = c.return_ascent(cur_state) + else: sum += c.return_ascent(cur_state) + + if i != len(self.children) - 1: state.update_attrs_(cur_state) + + assert sum is not None + state.ascent = sum.div_(len(self.children)) + return self._update_params_or_step_with_next(state) + + +class Product(OptimizerModule): + def __init__( + self, + modules: Iterable[OptimizerModule | Iterable[OptimizerModule]], + ): + super().__init__({}) + modules = list(modules) + for i,module in enumerate(modules): + self._set_child_(i, Chain(module)) + + @torch.no_grad + def step(self, state): + if len(self.children) == 1: + state.ascent = self.children[0].return_ascent(state) + return self._update_params_or_step_with_next(state) + + prod = None + for i, c in sorted(self.children.items(), key=lambda x: x[0]): + if i == len(self.children) - 1: cur_state = state + else: cur_state = state.copy(clone_ascent = True) + + if prod is None: prod = c.return_ascent(cur_state) + else: prod *= c.return_ascent(cur_state) + + if i != len(self.children) - 1: state.update_attrs_(cur_state) + + assert prod is not None + state.ascent = prod + return self._update_params_or_step_with_next(state) + + +class Subtract(OptimizerModule): + """a - b""" + def __init__( + self, + a: OptimizerModule | Iterable[OptimizerModule], + b: OptimizerModule | Iterable[OptimizerModule], + ): + super().__init__({}) + self._set_child_('a', Chain(a)) + self._set_child_('b', Chain(b)) + + @torch.no_grad + def step(self, state): + state_copy = state.copy(clone_ascent = True) + a = self.children['a'].return_ascent(state_copy) + state.update_attrs_(state_copy) + b = self.children['b'].return_ascent(state) + + state.ascent = a.sub_(b) + return self._update_params_or_step_with_next(state) + +class Divide(OptimizerModule): + """numerator / denominator""" + def __init__( + self, + numerator: OptimizerModule | Iterable[OptimizerModule], + denominator: OptimizerModule | Iterable[OptimizerModule], + ): + super().__init__({}) + self._set_child_('numerator', Chain(numerator)) + self._set_child_('denominator', Chain(denominator)) + + @torch.no_grad + def step(self, state): + state_copy = state.copy(clone_ascent = True) + numerator = self.children['numerator'].return_ascent(state_copy) + state.update_attrs_(state_copy) + denominator = self.children['denominator'].return_ascent(state) + + state.ascent = numerator.div_(denominator) + return self._update_params_or_step_with_next(state) + + +class Interpolate(OptimizerModule): + """lerp. `out = self + weight * (tensors1 - self)`.""" + def __init__( + self, + input: OptimizerModule | Iterable[OptimizerModule], + end: OptimizerModule | Iterable[OptimizerModule], + weight: float, + ): + super().__init__({}) + self._set_child_('input', Chain(input)) + self._set_child_('end', Chain(end)) + self.weight = weight + + @torch.no_grad + def step(self, state): + state_copy = state.copy(clone_ascent = True) + input = self.children['input'].return_ascent(state_copy) + state.update_attrs_(state_copy) + end = self.children['end'].return_ascent(state) + + state.ascent = input.lerp_(end, weight = self.weight) + return self._update_params_or_step_with_next(state) + diff --git a/src/torchzero/modules/second_order/newton.py b/src/torchzero/modules/second_order/newton.py index a3622b2..663dab6 100644 --- a/src/torchzero/modules/second_order/newton.py +++ b/src/torchzero/modules/second_order/newton.py @@ -46,8 +46,8 @@ def _fallback_safe_diag(hessian:torch.Tensor, grad:torch.Tensor, lr = 1e-2): def regularize_hessian_(hessian: torch.Tensor, value: float | Literal['eig']): """regularize hessian matrix in-place""" - if value == 'eig': - hessian.add_(torch.eye(hessian.shape[0], device=hessian.device, dtype=hessian.dtype), alpha=torch.linalg.eigvalsh(hessian).min()) # pylint:disable=not-callable + if value == 'eig': + value = torch.linalg.eigvalsh(hessian).min().clamp_(max=0).neg_() # pylint:disable=not-callable elif value != 0: hessian.add_(torch.eye(hessian.shape[0], device=hessian.device,dtype=hessian.dtype), alpha = value) @@ -136,8 +136,6 @@ def step(self, state): gvec = grads.to_vec() hessian = hessian_list_to_mat(hessian) - numel = gvec.numel() - # tikhonov regularization regularize_hessian_(hessian, self.tikhonov) diff --git a/src/torchzero/modules/subspace/__init__.py b/src/torchzero/modules/subspace/__init__.py deleted file mode 100644 index e2a953a..0000000 --- a/src/torchzero/modules/subspace/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -r""" -Modules that project the parameters into lower dimensional space, -making Newton-like methods feasible for large-scale problems. -""" -from .random_subspace import (Proj2Masks, ProjAscent, ProjAscentRay, ProjGrad, - ProjGradAscentDifference, ProjGradRay, - ProjLastAscentDifference, ProjLastGradDifference, - ProjNormalize, ProjRandom, Subspace) diff --git a/src/torchzero/optim/__init__.py b/src/torchzero/optim/__init__.py index 77a74c2..9410da6 100644 --- a/src/torchzero/optim/__init__.py +++ b/src/torchzero/optim/__init__.py @@ -14,4 +14,5 @@ from .quasi_newton import * from .zeroth_order import * from .second_order import * -from .first_order import * \ No newline at end of file +from .first_order import * +from .experimental import * \ No newline at end of file diff --git a/src/torchzero/optim/experimental/__init__.py b/src/torchzero/optim/experimental/__init__.py index e69de29..2cb0dd4 100644 --- a/src/torchzero/optim/experimental/__init__.py +++ b/src/torchzero/optim/experimental/__init__.py @@ -0,0 +1 @@ +from .experimental import ReciprocalSGD, RandomCoordinateMomentum, NestedNesterov, GradMin \ No newline at end of file diff --git a/src/torchzero/optim/experimental/experimental.py b/src/torchzero/optim/experimental/experimental.py new file mode 100644 index 0000000..84f03c6 --- /dev/null +++ b/src/torchzero/optim/experimental/experimental.py @@ -0,0 +1,81 @@ +from collections.abc import Iterable +from typing import Unpack + +from ...modules import AddMagnitude, NanToNum, NesterovMomentum, Normalize, Interpolate +from ...modules import RandomCoordinateMomentum as _RandomCoordinateMomentum +from ...modules import (Reciprocal, _CommonKwargs, _get_lr_and_lr_module, + _make_common_modules) +from ...modules.experimental.gradmin import GradMin as _GradMin +from ...modules.experimental.squared_grad_norm_fdm import \ + SquaredGradientNormFDM as _SquaredGradientNormFDM +from ..modular import Modular + + +class SquaredGradientNormFDM(Modular): + """Experimental (maybe don't use yet).""" + def __init__( + self, + params, + lr: float = 1, + eps: float = 1e-2, + **kwargs: Unpack[_CommonKwargs] + ): + modules = _make_common_modules(_SquaredGradientNormFDM(eps = eps), lr_module = lr, kwargs=kwargs) + super().__init__(params, modules) + + +class ReciprocalSGD(Modular): + def __init__( + self, + params, + lr: float = 1e-2, + eps: float = 1e-2, + **kwargs: Unpack[_CommonKwargs] + ): + lr, lr_module = _get_lr_and_lr_module(lr, kwargs) + main = [AddMagnitude(eps), Reciprocal(), NanToNum(0,0,0), Normalize(lr)] + modules = _make_common_modules(main, lr_module = lr_module, kwargs=kwargs) + super().__init__(params, modules) + + +class NestedNesterov(Modular): + def __init__( + self, + params, + lr: float = 1e-3, + momentums: Iterable[float] = (0.5, 0.5, 0.5), + dampening: float | Iterable[float] = 0, + **kwargs: Unpack[_CommonKwargs] # type:ignore + ): + momentums = list(momentums) + if isinstance(dampening, (int, float)): dampening = [dampening for _ in momentums] + main = [NesterovMomentum(m, d) for m, d in zip(momentums, dampening)] + modules = _make_common_modules(main, lr_module = lr, kwargs=kwargs) + super().__init__(params, modules) + +class RandomCoordinateMomentum(Modular): + def __init__( + self, + params, + lr: float = 1e-3, + p: float = 0.1, + nesterov: bool = True, + **kwargs: Unpack[_CommonKwargs] # type:ignore + ): + main = _RandomCoordinateMomentum(p, nesterov) + modules = _make_common_modules(main, lr_module = lr, kwargs=kwargs) + super().__init__(params, modules) + +class GradMin(Modular): + def __init__( + self, + params, + lr: float = 1e-2, + add_loss: float = 1, + square: bool = False, + maximize_grad: bool = False, + **kwargs: Unpack[_CommonKwargs], + ): + main = _GradMin(add_loss, square, maximize_grad) + modules = _make_common_modules(main, lr_module = lr, kwargs=kwargs) + super().__init__(params, modules) diff --git a/src/torchzero/optim/quasi_newton/newton_ray_search.py b/src/torchzero/optim/experimental/ray_search.py similarity index 84% rename from src/torchzero/optim/quasi_newton/newton_ray_search.py rename to src/torchzero/optim/experimental/ray_search.py index 96c847b..c9e27e1 100644 --- a/src/torchzero/optim/quasi_newton/newton_ray_search.py +++ b/src/torchzero/optim/experimental/ray_search.py @@ -1,11 +1,11 @@ -import typing as T -from collections import abc +from typing import Literal, Any import torch from ...core import OptimizerModule -from ...modules import (SGD, LineSearches, NewtonFDM, Subspace, - get_line_search, ProjAscentRay, ProjNormalize, LR, OptimizerWrapper) +from ...modules import (SGD, LineSearches, NewtonFDM, + get_line_search, LR, OptimizerWrapper) +from ...modules.experimental.subspace import Subspace, ProjNormalize, ProjAscentRay from ..modular import Modular @@ -24,7 +24,7 @@ def __init__( line_search: LineSearches | None = 'brent' ): """This is is an experiment and might not work well, maybe don't use yet""" - modules: list[OptimizerModule] = [ + modules: list[Any] = [ SGD(1, momentum=momentum, weight_decay=weight_decay, dampening=dampening, nesterov=nesterov), Subspace(NewtonFDM(eps = eps), ProjNormalize(ProjAscentRay(ray_width, n = n_rays))), ] @@ -53,7 +53,7 @@ def __init__( tolerance_grad: float = 1e-7, tolerance_change: float = 1e-9, history_size: int = 100, - line_search_fn: str | T.Literal['strong_wolfe'] | None = None, + line_search_fn: str | Literal['strong_wolfe'] | None = None, ): """This is is an experiment and might not work well, maybe don't use yet""" lbfgs = OptimizerWrapper( diff --git a/src/torchzero/optim/first_order/__init__.py b/src/torchzero/optim/first_order/__init__.py index 5c29387..6415379 100644 --- a/src/torchzero/optim/first_order/__init__.py +++ b/src/torchzero/optim/first_order/__init__.py @@ -1 +1,2 @@ -from .cautious import CautiousAdam, CautiousSGD \ No newline at end of file +from .cautious import CautiousAdam, CautiousSGD +from .optimizers import GD, SGD, Adagrad, Adam, AdamW, RMSProp, Rprop, SignSGD, NormSGD \ No newline at end of file diff --git a/src/torchzero/optim/first_order/cautious.py b/src/torchzero/optim/first_order/cautious.py index c852c30..c4cd2cd 100644 --- a/src/torchzero/optim/first_order/cautious.py +++ b/src/torchzero/optim/first_order/cautious.py @@ -1,10 +1,7 @@ -import typing -from collections import abc - -import torch +from typing import Literal, Unpack from ...core import OptimizerModule -from ...modules import Cautious, Adam, SGD, LR +from ...modules import Cautious, Adam, SGD, _make_common_modules, _CommonKwargs, _get_lr_and_lr_module from ..modular import Modular @@ -19,13 +16,16 @@ def __init__( amsgrad=False, c_eps = 1e-6, normalize = False, - mode: typing.Literal['zero', 'grad', 'backtrack'] = 'zero' + mode: Literal['zero', 'grad', 'backtrack'] = 'zero', + **kwargs: Unpack[_CommonKwargs], ): - modules: list[OptimizerModule] = [ + lr, lr_module = _get_lr_and_lr_module(lr, kwargs) + + main: list[OptimizerModule] = [ Adam(lr = lr, beta1 = beta1, beta2 = beta2, eps = eps, amsgrad = amsgrad), Cautious(normalize = normalize, eps = c_eps, mode = mode), ] - + modules = _make_common_modules(main, lr_module, kwargs) super().__init__(params, modules) @@ -40,12 +40,16 @@ def __init__( nesterov: bool = True, c_eps = 1e-6, normalize = True, - mode: typing.Literal['zero', 'grad', 'backtrack'] = 'zero' + mode: Literal['zero', 'grad', 'backtrack'] = 'zero', + **kwargs: Unpack[_CommonKwargs], # type:ignore ): - modules: list[OptimizerModule] = [ + lr, lr_module = _get_lr_and_lr_module(lr, kwargs) + + main: list[OptimizerModule] = [ SGD(lr = lr, momentum = momentum, dampening = dampening, weight_decay = weight_decay, nesterov = nesterov), Cautious(normalize = normalize, eps = c_eps, mode = mode), ] - super().__init__(params, modules) + modules = _make_common_modules(main, lr_module, kwargs) + super().__init__(params, modules) diff --git a/src/torchzero/optim/first_order/optimizers.py b/src/torchzero/optim/first_order/optimizers.py new file mode 100644 index 0000000..e066ca9 --- /dev/null +++ b/src/torchzero/optim/first_order/optimizers.py @@ -0,0 +1,149 @@ +from typing import Unpack + +from ...modules import SGD as _SGD +from ...modules import Adagrad as _Adagrad +from ...modules import Adam as _Adam +from ...modules import LineSearches, Normalize +from ...modules import RMSProp as _RMSProp +from ...modules import Rprop as _Rprop +from ...modules import (Sign, _CommonKwargs, _get_lr_and_lr_module, + _make_common_modules) +from ..modular import Modular + + +class GD(Modular): + def __init__( + self, + params, + lr: float = 1, + line_search: LineSearches | None = 'backtracking', + **kwargs: Unpack[_CommonKwargs], # type:ignore + ): + kwargs['line_search'] = line_search + modules = _make_common_modules(None, lr, kwargs) + super().__init__(params, modules) + +class SGD(Modular): + def __init__( + self, + params, + lr: float = 0.001, + momentum: float = 0, + dampening: float = 0, + weight_decay: float = 0, + nesterov: bool = False, + **kwargs: Unpack[_CommonKwargs], # type:ignore + ): + lr, lr_module = _get_lr_and_lr_module(lr, kwargs) + main = _SGD(lr = lr, momentum = momentum, dampening = dampening, weight_decay = weight_decay, nesterov = nesterov) + modules = _make_common_modules(main, lr_module, kwargs) + super().__init__(params, modules) + + +class SignSGD(Modular): + def __init__( + self, + params, + lr: float = 1e-3, + **kwargs: Unpack[_CommonKwargs], # type:ignore + ): + modules = _make_common_modules(Sign(), lr, kwargs) + super().__init__(params, modules) + + +class NormSGD(Modular): + def __init__( + self, + params, + lr: float = 1e-3, + **kwargs: Unpack[_CommonKwargs], # type:ignore + ): + lr, lr_module = _get_lr_and_lr_module(lr, kwargs) + modules = _make_common_modules(Normalize(lr), lr_module, kwargs) + super().__init__(params, modules) + + +class Adagrad(Modular): + def __init__( + self, + params, + lr: float = 1, lr_decay: float = 0, initial_accumulator_value: float = 0, eps: float = 1e-10, + **kwargs: Unpack[_CommonKwargs], + ): + lr, lr_module = _get_lr_and_lr_module(lr, kwargs) + main = _Adagrad(lr = lr, lr_decay = lr_decay, initial_accumulator_value = initial_accumulator_value, eps = eps) + modules = _make_common_modules(main, lr_module, kwargs) + super().__init__(params, modules) + +class Rprop(Modular): + def __init__( + self, + params, + lr: float = 1, + nplus: float = 1.2, + nminus: float = 0.5, + lb: float | None = 1e-6, + ub: float | None = 50, + backtrack=True, + **kwargs: Unpack[_CommonKwargs], + ): + lr, lr_module = _get_lr_and_lr_module(lr, kwargs) + main = _Rprop(lr = lr, nplus = nplus, nminus = nminus, lb=lb, ub = ub, backtrack=backtrack) + modules = _make_common_modules(main, lr_module, kwargs) + super().__init__(params, modules) + +class RMSProp(Modular): + def __init__( + self, + params, + lr: float = 1, alpha: float = 0.99, eps: float = 1e-8, centered: bool = False, + **kwargs: Unpack[_CommonKwargs], # type:ignore + ): + main = _RMSProp(alpha = alpha, eps = eps, centered = centered,) + modules = _make_common_modules(main, lr, kwargs) + super().__init__(params, modules) + +class Adam(Modular): + def __init__( + self, + params, + lr: float = 1, + beta1: float = 0.9, + beta2: float = 0.999, + eps: float = 1e-8, + amsgrad=False, + **kwargs: Unpack[_CommonKwargs], + ): + lr, lr_module = _get_lr_and_lr_module(lr, kwargs) + main = _Adam(lr = lr, beta1 = beta1, beta2 = beta2, eps = eps, amsgrad = amsgrad) + modules = _make_common_modules(main, lr_module, kwargs) + super().__init__(params, modules) + +class AdamW(Modular): + def __init__( + self, + params, + lr: float = 1, + beta1: float = 0.9, + beta2: float = 0.999, + eps: float = 1e-8, + weight_decay: float = 0.01, + amsgrad=False, + **kwargs: Unpack[_CommonKwargs], + ): + """AdamW with truly decoupled weight decay. + + Args: + params (_type_): _description_ + lr (float, optional): _description_. Defaults to 1. + beta1 (float, optional): _description_. Defaults to 0.9. + beta2 (float, optional): _description_. Defaults to 0.999. + eps (float, optional): _description_. Defaults to 1e-8. + weight_decay (float, optional): _description_. Defaults to 0.01. + amsgrad (bool, optional): _description_. Defaults to False. + """ + kwargs['decoupled_l2'] = weight_decay + lr, lr_module = _get_lr_and_lr_module(lr, kwargs) + main = _Adam(lr = lr, beta1 = beta1, beta2 = beta2, eps = eps, amsgrad = amsgrad) + modules = _make_common_modules(main, lr_module, kwargs) + super().__init__(params, modules) \ No newline at end of file diff --git a/src/torchzero/optim/first_order/rprop.py b/src/torchzero/optim/first_order/rprop.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/torchzero/optim/first_order/signsgd.py b/src/torchzero/optim/first_order/signsgd.py deleted file mode 100644 index f3b830d..0000000 --- a/src/torchzero/optim/first_order/signsgd.py +++ /dev/null @@ -1,27 +0,0 @@ -import typing as T -from collections import abc - -import torch - -from ...core import OptimizerModule -from ...modules import SGD, Sign, Adam -from ..modular import Modular - - -class SignSGD(Modular): - def __init__( - self, - params, - lr: float = 1e-3, - momentum: float = 0, - dampening: float = 0, - weight_decay: float = 0, - nesterov: bool = True, - ): - modules: list[OptimizerModule] = [ - Sign(), - SGD(lr, momentum = momentum, dampening = dampening, weight_decay = weight_decay, nesterov = nesterov), - ] - - super().__init__(params, modules) - diff --git a/src/torchzero/optim/quasi_newton/__init__.py b/src/torchzero/optim/quasi_newton/__init__.py index 63c5252..c0e4199 100644 --- a/src/torchzero/optim/quasi_newton/__init__.py +++ b/src/torchzero/optim/quasi_newton/__init__.py @@ -1,2 +1 @@ -# from .hv_inv_fdm import HvInvFDM -# from .newton_ray_search import NewtonFDMRaySearch, LBFGSRaySearch \ No newline at end of file +from .directional_newton import DirectionalNewton \ No newline at end of file diff --git a/src/torchzero/optim/quasi_newton/directional_newton.py b/src/torchzero/optim/quasi_newton/directional_newton.py new file mode 100644 index 0000000..5afb2a7 --- /dev/null +++ b/src/torchzero/optim/quasi_newton/directional_newton.py @@ -0,0 +1,53 @@ +from typing import Literal, Unpack + +from ...core import OptimizerModule +from ...modules import DirectionalNewton as _DirectionalNewton +from ...modules import (_CommonKwargs, _get_lr_and_lr_module, + _make_common_modules) +from ..modular import Modular + + +class DirectionalNewton(Modular): + """Minimizes a parabola in the direction of the gradient via one additional forward pass, + and uses another forward pass to make sure it didn't overstep. + So in total this performs three forward passes and one backward. + + This can only be used either as the first module or after FDM, as it requires ascent to + be the gradient. + + First forward and backward pass is used to calculate the value and gradient at initial parameters. + Then a gradient descent step is performed with `lr` learning rate, and loss is recalculated + with new parameters. A quadratic is fitted to two points and gradient, + if it has positive curvature, this makes a step towards the minimum, and checks if lr decreased + with an additional forward pass. + + Args: + params: iterable of parameters to optimize or dicts defining parameter groups. + lr (float, optional): + learning rate. Since you shouldn't put this module after LR(), you have to specify + the learning rate in this argument. Defaults to 1e-2. + max_dist (float | None, optional): + maximum distance to step when minimizing quadratic. + If minimum is further than this distance, minimization is not performed. Defaults to 1e4. + validate_step (bool, optional): + uses an additional forward pass to check + if step towards the minimum actually decreased the loss. Defaults to True. + log_lrs (bool, optional): + saves lrs and losses with them into optimizer._lrs (for debugging). + Defaults to False. + + Note: + While lr scheduling is supported, this uses lr of the first parameter for all parameters. + """ + def __init__( + self, + params, + lr: float = 1e-4, + max_dist: float | None = 1e5, + validate_step: bool = True, + + ): + + modules = _DirectionalNewton(lr, max_dist, validate_step) + super().__init__(params, modules) + diff --git a/src/torchzero/optim/quasi_newton/hv_inv_fdm.py b/src/torchzero/optim/quasi_newton/hv_inv_fdm.py deleted file mode 100644 index de359bd..0000000 --- a/src/torchzero/optim/quasi_newton/hv_inv_fdm.py +++ /dev/null @@ -1,32 +0,0 @@ -import typing as T -from collections import abc - -import torch - -from ...core import OptimizerModule -from ...modules.quasi_newton.hv_inv_fdm import HvInvFDM as _HvInvFDM -from ...modules import get_line_search, LineSearches, LR -from ..modular import Modular - - -class HvInvFDM(Modular): - """Experimental (maybe don't use yet).""" - def __init__( - self, - params, - lr: float = 1, - eps: float = 1e-2, - line_search: LineSearches | None = None, - ): - modules: list[OptimizerModule] = [ - _HvInvFDM(eps = eps), - ] - - if lr != 1: - modules.append(LR(lr)) - - if line_search is not None: - modules.append(get_line_search(line_search)) - - super().__init__(params, modules) - diff --git a/src/torchzero/optim/wrappers/scipy.py b/src/torchzero/optim/wrappers/scipy.py index 00c06e6..8a5c2d2 100644 --- a/src/torchzero/optim/wrappers/scipy.py +++ b/src/torchzero/optim/wrappers/scipy.py @@ -9,9 +9,8 @@ from ...core import ClosureType, TensorListOptimizer from ...grad.derivatives import jacobian, jacobian_list_to_vec, hessian, hessian_list_to_mat, jacobian_and_hessian -from ...modules import (OptimizerWrapper, Proj2Masks, ProjGrad, ProjNormalize, - Subspace) -from ...modules.subspace.random_subspace import Projection +from ...modules import OptimizerWrapper +from ...modules.experimental.subspace import Projection, Proj2Masks, ProjGrad, ProjNormalize, Subspace from ...modules.second_order.newton import regularize_hessian_ from ...tensorlist import TensorList from ..modular import Modular @@ -61,7 +60,8 @@ def __init__( self, params, method: str | None = None, - bounds = None, + lb = None, + ub = None, constraints = (), tol: float | None = None, callback = None, @@ -70,9 +70,9 @@ def __init__( hess: Literal['2-point', '3-point', 'cs', 'autograd'] | scipy.optimize.HessianUpdateStrategy = 'autograd', tikhonov: float | Literal['eig'] = 0, ): - super().__init__(params, {}) + defaults = dict(lb=lb, ub=ub) + super().__init__(params, defaults) self.method = method - self.bounds = bounds self.constraints = constraints self.tol = tol self.callback = callback @@ -126,6 +126,12 @@ def step(self, closure: ClosureType): # type:ignore # pylint:disable = signature x0 = params.to_vec().detach().cpu().numpy() + # make bounds + lb, ub = self.get_group_keys(['lb', 'ub'], cls=list) + bounds = [] + for p, l, u in zip(params, lb, ub): + bounds.extend([(l, u)] * p.numel()) + if self.method is not None and (self.method.lower() == 'tnc' or self.method.lower() == 'slsqp'): x0 = x0.astype(np.float64) # those methods error without this @@ -133,7 +139,7 @@ def step(self, closure: ClosureType): # type:ignore # pylint:disable = signature partial(self._objective, params = params, closure = closure), x0 = x0, method=self.method, - bounds=self.bounds, + bounds=bounds, constraints=self.constraints, tol=self.tol, callback=self.callback, @@ -211,7 +217,7 @@ def step(self, closure: ClosureType): # type:ignore # pylint:disable = signature class ScipyRootOptimization(TensorListOptimizer): - """Optimization via finding roots of the gradient with `scipy.optimize.root`. + """Optimization via finding roots of the gradient with `scipy.optimize.root` (experimental). Args: params: iterable of parameters to optimize or dicts defining parameter groups. diff --git a/src/torchzero/optim/zeroth_order/fdm.py b/src/torchzero/optim/zeroth_order/fdm.py index 2dd5232..200f3f3 100644 --- a/src/torchzero/optim/zeroth_order/fdm.py +++ b/src/torchzero/optim/zeroth_order/fdm.py @@ -1,9 +1,8 @@ -import typing as T -from collections import abc +from typing import Literal, Unpack import torch -from ...modules import FDM as _FDM, SGD, OptimizerWrapper +from ...modules import FDM as _FDM, OptimizerWrapper, _make_common_modules, _CommonKwargs from ...modules.gradient_approximation._fd_formulas import _FD_Formulas from ..modular import Modular @@ -21,10 +20,6 @@ class FDM(Modular): params: iterable of parameters to optimize or dicts defining parameter groups. lr (float, optional): learning rate. Defaults to 1e-3. eps (float, optional): finite difference epsilon. Defaults to 1e-3. - momentum (float, optional): momentum factor. Defaults to 0. - weight_decay (float, optional): weight decay (L2 penalty). Defaults to 0. - dampening (float, optional): dampening for momentum. Defaults to 0. - nesterov (bool, optional): enables Nesterov momentum (supports dampening). Defaults to False. formula (_FD_Formulas, optional): finite difference formula. Defaults to "forward". n_points (T.Literal[2, 3], optional): number of points for finite difference formula, 2 or 3. Defaults to 2. """ @@ -33,17 +28,12 @@ def __init__( params, lr: float = 1e-3, eps: float = 1e-3, - momentum:float = 0, - weight_decay:float = 0, - dampening: float = 0, - nesterov:bool = False, formula: _FD_Formulas = "forward", - n_points: T.Literal[2, 3] = 2, + n_points: Literal[2, 3] = 2, + **kwargs: Unpack[_CommonKwargs], ): - modules = [ - _FDM(eps = eps, formula=formula, n_points=n_points), - SGD(lr = lr, momentum = momentum, weight_decay = weight_decay, dampening = dampening, nesterov = nesterov) - ] + main = _FDM(eps = eps, formula=formula, n_points=n_points) + modules = _make_common_modules(main, lr, kwargs) super().__init__(params, modules) @@ -72,7 +62,7 @@ def __init__( optimizer: torch.optim.Optimizer, eps: float = 1e-3, formula: _FD_Formulas = "forward", - n_points: T.Literal[2, 3] = 2, + n_points: Literal[2, 3] = 2, ): modules = [ _FDM(eps = eps, formula=formula, n_points=n_points, make_closure=True), diff --git a/src/torchzero/optim/zeroth_order/newton_fdm.py b/src/torchzero/optim/zeroth_order/newton_fdm.py index e18a23a..ad165ae 100644 --- a/src/torchzero/optim/zeroth_order/newton_fdm.py +++ b/src/torchzero/optim/zeroth_order/newton_fdm.py @@ -6,8 +6,8 @@ from ...core import OptimizerModule from ...modules import (LR, BacktrackingLS, FallbackLinearSystemSolvers, LinearSystemSolvers, LineSearches, ClipNorm) -from ...modules import NewtonFDM as _NewtonFDM -from ...modules import Proj2Masks, ProjRandom, Subspace, get_line_search +from ...modules import NewtonFDM as _NewtonFDM, get_line_search +from ...modules.experimental.subspace import Proj2Masks, ProjRandom, Subspace from ..modular import Modular diff --git a/src/torchzero/optim/zeroth_order/rfdm.py b/src/torchzero/optim/zeroth_order/rfdm.py index 0a30ed9..02e6726 100644 --- a/src/torchzero/optim/zeroth_order/rfdm.py +++ b/src/torchzero/optim/zeroth_order/rfdm.py @@ -1,12 +1,12 @@ -import typing as T -from collections import abc +from typing import Literal, Unpack import torch -from ...tensorlist import Distributions from ...modules import SGD, OptimizerWrapper from ...modules import RandomizedFDM as _RandomizedFDM +from ...modules import _CommonKwargs, _make_common_modules from ...modules.gradient_approximation._fd_formulas import _FD_Formulas +from ...tensorlist import Distributions from ..modular import Modular @@ -20,10 +20,6 @@ class RandomizedFDM(Modular): params: iterable of parameters to optimize or dicts defining parameter groups. lr (float, optional): learning rate. Defaults to 1e-3. eps (float, optional): finite difference epsilon. Defaults to 1e-3. - momentum (float, optional): momentum factor. Defaults to 0. - weight_decay (float, optional): weight decay (L2 penalty). Defaults to 0. - dampening (float, optional): dampening for momentum. Defaults to 0. - nesterov (bool, optional): enables Nesterov momentum (supports dampening). Defaults to False. formula (_FD_Formulas, optional): finite difference formula. Defaults to "forward". n_samples (int, optional): number of random gradient approximations that will be averaged. Defaults to 1. distribution (Distributions, optional): distribution for random perturbations. Defaults to "normal". @@ -34,25 +30,20 @@ def __init__( params, lr: float = 1e-3, eps: float = 1e-3, - momentum:float = 0, - weight_decay:float = 0, - dampening: float = 0, - nesterov:bool = False, formula: _FD_Formulas = "forward", n_samples: int = 1, distribution: Distributions = "normal", randomize_every: int = 1, + **kwargs: Unpack[_CommonKwargs], ): - modules = [ - _RandomizedFDM( - eps=eps, - formula=formula, - n_samples=n_samples, - distribution=distribution, - randomize_every=randomize_every, - ), - SGD(lr = lr, momentum = momentum, weight_decay = weight_decay, dampening = dampening, nesterov = nesterov) - ] + main = _RandomizedFDM( + eps=eps, + formula=formula, + n_samples=n_samples, + distribution=distribution, + randomize_every=randomize_every, + ) + modules = _make_common_modules(main, lr, kwargs) super().__init__(params, modules) @@ -85,27 +76,21 @@ def __init__( params, lr: float = 1e-3, eps: float = 1e-3, - momentum:float = 0, - weight_decay:float = 0, - dampening: float = 0, - nesterov:bool = False, formula: _FD_Formulas = "central", n_samples: int = 1, distribution: Distributions = 'rademacher', randomize_every: int = 1, + **kwargs: Unpack[_CommonKwargs], ): super().__init__( params = params, lr = lr, eps = eps, - momentum = momentum, - weight_decay = weight_decay, - dampening = dampening, - nesterov = nesterov, formula = formula, n_samples = n_samples, distribution = distribution, randomize_every=randomize_every, + **kwargs, ) @@ -139,27 +124,21 @@ def __init__( params, lr: float = 1e-2, eps: float = 1e-2, - momentum:float = 0, - weight_decay:float = 0, - dampening: float = 0, - nesterov:bool = False, formula: _FD_Formulas = "forward", n_samples: int = 10, distribution: Distributions = 'normal', randomize_every: int = 1, + **kwargs: Unpack[_CommonKwargs], ): super().__init__( params = params, lr = lr, eps = eps, - momentum = momentum, - weight_decay = weight_decay, - dampening = dampening, - nesterov = nesterov, formula = formula, n_samples = n_samples, distribution = distribution, randomize_every=randomize_every, + **kwargs, ) class RandomizedFDMWrapper(Modular): diff --git a/src/torchzero/tensorlist.py b/src/torchzero/tensorlist.py index a8fcc62..6ed08d9 100644 --- a/src/torchzero/tensorlist.py +++ b/src/torchzero/tensorlist.py @@ -102,6 +102,26 @@ def real(self): return self.__class__(i.real for i in self) @property def imag(self): return self.__class__(i.imag for i in self) + def view_as_real(self): return self.__class__(torch.view_as_real(i) for i in self) + def view_as_complex(self): return self.__class__(torch.view_as_complex(i) for i in self) + + def to_real_views(self): + """Turns all complex tensors into real views, ignoring non-complex tensors, and sets an attribute `_tl_is_complex` to True or False, + which `from_real_views` method can use to convert real views back into complex tensors""" + tl = TensorList() + for p in self: + if torch.is_complex(p): + p._tl_is_complex = True # type:ignore + tl.append(torch.view_as_real(p)) + else: + p._tl_is_complex = False # type:ignore + tl.append(p) + return tl + + def from_real_views(self): + """undoes `to_real_views`.""" + return self.__class__(torch.view_as_complex(p) if p._tl_is_complex else p for p in self) # type:ignore + def get_existing_grads(self): """Returns all gradients that are not None.""" return self.__class__(i.grad for i in self if i is not None) diff --git a/tests/modules.py b/tests/modules.py index a6459f7..6233772 100644 --- a/tests/modules.py +++ b/tests/modules.py @@ -6,4 +6,3 @@ def test_cautious(): p = torch.tensor([1., 1.], requires_grad = True) p.grad = torch.tensor([0.2, 0.2]) - \ No newline at end of file