From e694859a5d565578851cd431b9dea75db8566bd6 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Mon, 13 Jan 2025 16:59:34 +0000 Subject: [PATCH] Add Cailey SGD optimizer with test --- src/brevitas/optim/cailey_sgd.py | 198 ++++++++++++++++++++++++ tests/brevitas/optim/test_cailey_sgd.py | 127 +++++++++++++++ 2 files changed, 325 insertions(+) create mode 100644 src/brevitas/optim/cailey_sgd.py create mode 100644 tests/brevitas/optim/test_cailey_sgd.py diff --git a/src/brevitas/optim/cailey_sgd.py b/src/brevitas/optim/cailey_sgd.py new file mode 100644 index 000000000..2e2426fee --- /dev/null +++ b/src/brevitas/optim/cailey_sgd.py @@ -0,0 +1,198 @@ +# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# This file was adapted from a file in a repository +# licensed under the Creative Commons Attribution-NonCommercial 4.0 International (CC BY-NC 4.0). +# Original author: (c) Meta Platforms, Inc. and affiliates. +# Source: https://github.com/facebookresearch/SpinQuant/blob/main/train_utils/optimizer.py + +# This code is originally from: https://github.com/JunLi-Galios/Optimization-on-Stiefel-Manifold-via-Cayley-Transform/blob/master/stiefel_optimizer.py + +import random + +import torch +from torch.optim.optimizer import Optimizer + + +def unit(v, dim: int = 1, eps: float = 1e-8): + vnorm = norm(v, dim) + return v / vnorm.add(eps), vnorm + + +def norm(v, dim: int = 1): + assert len(v.size()) == 2 + return v.norm(p=2, dim=dim, keepdim=True) + + +def matrix_norm_one(W): + out = torch.abs(W) + out = torch.sum(out, dim=0) + out = torch.max(out) + return out + + +def Cayley_loop(X, W, tan_vec, t): # + [n, p] = X.size() + Y = X + t * tan_vec + for i in range(5): + Y = X + t * torch.matmul(W, 0.5 * (X + Y)) + + return Y.t() + + +def qr_retraction(tan_vec): # tan_vec, p-by-n, p <= n + [p, n] = tan_vec.size() + tan_vec.t_() + q, r = torch.linalg.qr(tan_vec) + d = torch.diag(r, 0) + ph = d.sign() + q *= ph.expand_as(q) + q.t_() + + return q + + +episilon = 1e-8 + + +class CaileySGD(Optimizer): + r"""This optimizer updates variables with two different routines + based on the boolean variable 'stiefel'. + + If stiefel is True, the variables will be updated by SGD-G proposed + as decorrelated weight matrix. + + If stiefel is False, the variables will be updated by SGD. + This routine was taken from https://github.com/pytorch/pytorch/blob/master/torch/optim/sgd.py. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + + -- common parameters + lr (float): learning rate + momentum (float, optional): momentum factor (default: 0) + stiefel (bool, optional): whether to use SGD-G (default: False) + + -- parameters in case stiefel is False + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + dampening (float, optional): dampening for momentum (default: 0) + nesterov (bool, optional): enables Nesterov momentum (default: False) + + -- parameters in case stiefel is True + omega (float, optional): orthogonality regularization factor (default: 0) + grad_clip (float, optional): threshold for gradient norm clipping (default: None) + """ + + def __init__( + self, + params, + lr: float = 1e-3, + momentum: int = 0, + dampening: int = 0, + weight_decay: int = 0, + nesterov: bool = False, + stiefel: bool = False, + omega: int = 0, + grad_clip=None, + ) -> None: + defaults = dict( + lr=lr, + momentum=momentum, + dampening=dampening, + weight_decay=weight_decay, + nesterov=nesterov, + stiefel=stiefel, + omega=0, + grad_clip=grad_clip, + ) + if nesterov and (momentum <= 0 or dampening != 0): + raise ValueError("Nesterov momentum requires a momentum and zero dampening") + super(CaileySGD, self).__init__(params, defaults) + + def __setstate__(self, state) -> None: + super(CaileySGD, self).__setstate__(state) + for group in self.param_groups: + group.setdefault("nesterov", False) + + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + momentum = group["momentum"] + stiefel = group["stiefel"] + + for p in group["params"]: + if p.grad is None: + continue + + unity, _ = unit(p.data.view(p.size()[0], -1)) + if stiefel and unity.size()[0] <= unity.size()[1]: + weight_decay = group["weight_decay"] + dampening = group["dampening"] + nesterov = group["nesterov"] + + rand_num = random.randint(1, 101) + if rand_num == 1: + unity = qr_retraction(unity) + + g = p.grad.data.view(p.size()[0], -1) + + lr = group["lr"] + + param_state = self.state[p] + if "momentum_buffer" not in param_state: + param_state["momentum_buffer"] = torch.zeros(g.t().size()) + if p.is_cuda: + param_state["momentum_buffer"] = param_state["momentum_buffer"].cuda() + + V = param_state["momentum_buffer"] + V = momentum * V - g.t() + MX = torch.mm(V, unity) + XMX = torch.mm(unity, MX) + XXMX = torch.mm(unity.t(), XMX) + W_hat = MX - 0.5 * XXMX + W = W_hat - W_hat.t() + t = 0.5 * 2 / (matrix_norm_one(W) + episilon) + alpha = min(t, lr) + + p_new = Cayley_loop(unity.t(), W, V, alpha) + V_new = torch.mm(W, unity.t()) # n-by-p + # check_identity(p_new.t()) + p.data.copy_(p_new.view(p.size())) + V.copy_(V_new) + + else: + d_p = p.grad.data + # defined. + try: + if weight_decay != 0: + # defined. + d_p.add_(weight_decay, p.data) + except: + pass + if momentum != 0: + param_state = self.state[p] + if "momentum_buffer" not in param_state: + buf = param_state["momentum_buffer"] = d_p.clone() + else: + buf = param_state["momentum_buffer"] + # always defined. + buf.mul_(momentum).add_(1 - dampening, d_p) + # defined. + if nesterov: + d_p = d_p.add(momentum, buf) + else: + d_p = buf + + p.data.add_(-group["lr"], d_p) + + return loss diff --git a/tests/brevitas/optim/test_cailey_sgd.py b/tests/brevitas/optim/test_cailey_sgd.py new file mode 100644 index 000000000..92de8ae5a --- /dev/null +++ b/tests/brevitas/optim/test_cailey_sgd.py @@ -0,0 +1,127 @@ +""" +Copyright (C) 2024, Advanced Micro Devices, Inc. +Copyright (c) 2016- Facebook, Inc (Adam Paszke) +Copyright (c) 2014- Facebook, Inc (Soumith Chintala) +Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) +Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) +Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) +Copyright (c) 2011-2013 NYU (Clement Farabet) +Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) +Copyright (c) 2006 Idiap Research Institute (Samy Bengio) +Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) + +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +3. Neither the names of AMD, Facebook, Deepmind Technologies, NYU, + NEC Laboratories America and IDIAP Research Institute nor the names + of its contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. +""" + +from copy import deepcopy +from itertools import product +import math +import sys +from typing import List, Union +import unittest + +from hypothesis import given +import numpy as np +import pytest +import pytest_cases +from pytest_cases import fixture +from scipy.stats import ortho_group +import torch +from torch.nn import Parameter +import torch.nn as nn +from torch.optim.lr_scheduler import LinearLR + +from brevitas.optim.cailey_sgd import CaileySGD +from tests.conftest import SEED + +torch.manual_seed(SEED) + +OPTIMIZER_KWARGS = [{ + "stiefel": True}, { + "stiefel": True, "lr": 1e-2}, { + "stiefel": True, "lr": torch.tensor(0.001)}] +LR_SCHEDULER_ARGS = [ + None, + (LinearLR, { + "start_factor": 1.0, "end_factor": 0.0, "total_iters": 20}),] +DEVICES = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"] +DTYPES = [torch.float32] + +device_dtype_parametrize = pytest_cases.parametrize("device, dtype", list(product(DEVICES, DTYPES))) + + +class TestCaileySGD: + + @device_dtype_parametrize + @pytest_cases.parametrize("optimizer_kwargs", OPTIMIZER_KWARGS) + @pytest_cases.parametrize("lr_scheduler_args", LR_SCHEDULER_ARGS) + def test_forloop_goes_right_direction(self, device, dtype, optimizer_kwargs, lr_scheduler_args): + optim_cls = CaileySGD + # Generate a random orthogonal matrix of size NxN. Columns represent orthonormal vector in R^{N} + N = 5 + P = 3 + weight_orthogonal = ortho_group(dim=N, seed=SEED).rvs() + weight_orthonormal = weight_orthogonal / np.linalg.norm(weight_orthogonal, ord=2, axis=0) + # Verify that the matrix is orthonormal + assert np.allclose(np.matmul(weight_orthonormal.T, weight_orthonormal), np.eye(N)) + # Initialize weights, the Cailey SGD optimizer expects a matrix of size PxN, given the + # condition unity.size()[0] <= unity.size()[1] + weight = Parameter( + torch.from_numpy(weight_orthonormal[:, :P].T).to(device=device, dtype=dtype)) + + optimizer = optim_cls([weight], **deepcopy(optimizer_kwargs)) + scheduler = None if lr_scheduler_args is None else lr_scheduler_args[0]( + optimizer, **lr_scheduler_args[1]) + + def closure(): + optimizer.zero_grad() + # MSE between the weights and a set of orthonormal vectors + loss = (weight - torch.eye(N, P, device=device, dtype=dtype).t()).pow(2).sum() + loss.backward() + return loss + + initial_value = closure().item() + for _ in range(20): + closure() + optimizer.step() + if scheduler is not None: + scheduler.step() + + # Verify that iterates stay within the Stiefel manifold + assert torch.allclose( + weight.detach().cpu() @ weight.detach().cpu().t(), + torch.eye(P, P, device=device, dtype=dtype).detach().cpu(), + atol=1e-5, + rtol=1e-6) + + if optimizer_kwargs.get("maximize", False): + assert closure().item() > initial_value + else: + assert closure().item() < initial_value