From 5a4eef9f0ec7cb757af16476df5ff41c5e61f2d6 Mon Sep 17 00:00:00 2001 From: Joerg Franke Date: Fri, 24 Nov 2023 17:45:57 +0100 Subject: [PATCH] initial commit --- .gitignore | 3 +- LICENSE | 2 +- README.md | 122 ++++++++- examples/requirements.txt | 7 + examples/train_grokking_task.py | 372 +++++++++++++++++++++++++++ examples/train_resnet.py | 436 ++++++++++++++++++++++++++++++++ pyproject.toml | 26 ++ pytorch_cpr/__init__.py | 3 + pytorch_cpr/group_parameter.py | 82 ++++++ pytorch_cpr/optim_cpr.py | 165 ++++++++++++ pytorch_cpr/wrapper.py | 29 +++ setup.py | 12 + tests/__init__.py | 0 tests/model.py | 28 ++ tests/test_group_parameter.py | 58 +++++ tests/test_optim_cpr.py | 39 +++ tests/test_wrapper.py | 40 +++ 17 files changed, 1420 insertions(+), 4 deletions(-) create mode 100644 examples/requirements.txt create mode 100644 examples/train_grokking_task.py create mode 100644 examples/train_resnet.py create mode 100644 pyproject.toml create mode 100644 pytorch_cpr/__init__.py create mode 100644 pytorch_cpr/group_parameter.py create mode 100644 pytorch_cpr/optim_cpr.py create mode 100644 pytorch_cpr/wrapper.py create mode 100644 setup.py create mode 100644 tests/__init__.py create mode 100644 tests/model.py create mode 100644 tests/test_group_parameter.py create mode 100644 tests/test_optim_cpr.py create mode 100644 tests/test_wrapper.py diff --git a/.gitignore b/.gitignore index 68bc17f..8e95a54 100644 --- a/.gitignore +++ b/.gitignore @@ -36,7 +36,7 @@ MANIFEST pip-log.txt pip-delete-this-directory.txt -# Unit test / coverage reports +# Unit tests / coverage reports htmlcov/ .tox/ .nox/ @@ -158,3 +158,4 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ +venv \ No newline at end of file diff --git a/LICENSE b/LICENSE index 261eeb9..542682c 100644 --- a/LICENSE +++ b/LICENSE @@ -186,7 +186,7 @@ same "printed page" as the copyright notice for easier identification within third-party archives. - Copyright [yyyy] [name of copyright owner] + Copyright 2023 Jörg Franke Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/README.md b/README.md index 7eb87f8..9580d3b 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,120 @@ -# CPR -Constraint Parameter Regularization + +# Constrained Parameter Regularization + +This repository contains the PyTorch implementation of **Constrained Parameter Regularization**. + + +## Install + +```bash +pip install pytroch-cpr +``` + +## Getting started + +### Usage of `apply_CPR` Optimizer Wrapper + +The `apply_CPR` function is a wrapper designed to apply CPR (Constrained Parameter Regularization) to a given optimizer by first creating parameter groups and the wrapping the actual optimizer class. + +#### Arguments + +- `model`: The PyTorch model whose parameters are to be optimized. +- `optimizer_cls`: The class of the optimizer to be used (e.g., `torch.optim.Adam`). +- `kappa_init_param`: Initial value for the kappa parameter in CPR depending on tge initialization method. +- `kappa_init_method` (default `'warm_start'`): The method to initialize the kappa parameter. Options include `'warm_start'`, `'uniform'`, and `'dependent'`. +- `reg_function` (default `'l2'`): The regularization function to be applied. Options include `'l2'` or `'std'`. +- `kappa_adapt` (default `False`): Flag to determine if kappa should adapt during training. +- `kappa_update` (default `1.0`): The rate at which kappa is updated in the Lagrangian method. +- `apply_lr` (default `False`): Flag to apply learning rate for the regularization update. +- `normalization_regularization` (default `False`): Flag to apply regularization to normalization layers. +- `bias_regularization` (default `False`): Flag to apply regularization to bias parameters. +- `embedding_regularization` (default `False`): Flag to apply regularization to embedding parameters. +- `**optimizer_args`: Additional optimizer arguments to pass to the optimizer class. + +#### Example usage + +```python +import torch +from pytorch-cpr import apply_CPR + +model = YourModel() +optimizer = apply_CPR(model, torch.optim.Adam, kappa_init_param=1000, kappa_init_method='warm_start', + lr=0.001, betas=(0.9, 0.98)) +``` + + +## Run examples + +We provide scripts to replicate the experiments from our paper. Please use a system with at least 1 GPU. Install the package and the requirements for the example: + +```bash +python3 -m venv venv +source venv/bin/activate +pip install -r examples/requirements.txt +pip install pytorch-cpr +``` + + +### Modular Addition / Grokking Experiment + +The grokking experiment should run within a few minutes. The results will be saved in the `grokking` folder. +To replicate the results in the paper, run variations with the following arguments: + +#### For AdamW: +```bash +python examples/train_grokking_task.py --optimizer adamw --weight_decay 0.1 +``` + +#### For Adam + Rescaling: +```bash +python examples/train_grokking_task.py --optimizer adamw --weight_decay 0.0 --rescale 0.8 +``` + +#### For AdamCPR with L2 norm as regularization function: +```bash +python examples/train_grokking_task.py --optimizer adamcpr --kappa_init_method dependent --kappa_init_param 0.8 +``` + + + +### Image Classification Experiment + +The CIFAR-100 experiment should run within 20-30 minutes. The results will be saved in the `cifar100` folder. + +#### For AdamW: +```bash +python examples/train_resnet.py --optimizer adamw --lr 0.001 --weight_decay 0.001 +``` + +#### For Adam + Rescaling: +```bash +python examples/train_resnet.py --optimizer adamw --lr 0.001 --weight_decay 0 --rescale_alpha 0.8 +``` + +#### For AdamCPR with L2 norm as regularization function and kappa initialization depending on the parameter initialization: +```bash +python examples/train_resnet.py --optimizer adamcpr --lr 0.001 --kappa_init_method dependent --kappa_init_param 0.8 +``` + +#### For AdamCPR with L2 norm as regularization function and kappa initialization with warm start: +```bash +python examples/train_resnet.py --optimizer adamcpr --lr 0.001 --kappa_init_method warm_start --kappa_init_param 1000 +``` + + + +## Citation + +Please cite our paper if you use this code in your own work: + +``` +@misc{franke2023new, + title={New Horizons in Parameter Regularization: A Constraint Approach}, + author={Jörg K. H. Franke and Michael Hefenbrock and Gregor Koehler and Frank Hutter}, + year={2023}, + eprint={2311.09058}, + archivePrefix={arXiv}, + primaryClass={cs.LG} +} +``` + diff --git a/examples/requirements.txt b/examples/requirements.txt new file mode 100644 index 0000000..2113807 --- /dev/null +++ b/examples/requirements.txt @@ -0,0 +1,7 @@ +torch==2.0.1 +torchvision==0.15.2 +numpy>=1.19.0 +tqdm>=4.50.0 +matplotlib>=3.7.2 +pytorch-lightning>=2.0.0 +tensorboard>=2.15.1 diff --git a/examples/train_grokking_task.py b/examples/train_grokking_task.py new file mode 100644 index 0000000..c30a56d --- /dev/null +++ b/examples/train_grokking_task.py @@ -0,0 +1,372 @@ +import os +from argparse import ArgumentParser +from collections import defaultdict +import torch +import torch.nn as nn +import numpy as np +from tqdm.auto import tqdm +import matplotlib.pyplot as plt + +from pytorch_cpr import apply_CPR + +### Data +def modular_addition(p, train_fraction, train_shuffle, device): + equals_token = p + x, y = torch.meshgrid(torch.arange(p), torch.arange(p), indexing='ij') + x = x.flatten() + y = y.flatten() + equals = torch.ones(x.shape, dtype=torch.int64) * equals_token + prompts = torch.stack([x, y, equals], dim=1).to(device) + answers = ((x + y) % p).to(device) + + data = torch.utils.data.TensorDataset(prompts, answers) + train, test = torch.utils.data.random_split(data, + [int(train_fraction * len(data)), + len(data) - int(train_fraction * len(data)) + ]) + + train_loader = torch.utils.data.DataLoader(train, batch_size=512, shuffle=train_shuffle) + test_loader = torch.utils.data.DataLoader(test, batch_size=len(data), shuffle=False) + return train_loader, test_loader + + +### Model +class Block(nn.Module): + def __init__(self, dim, num_heads, use_ln): + super().__init__() + self.use_ln = use_ln + if use_ln: + self.ln_1 = nn.LayerNorm(dim) + self.ln_2 = nn.LayerNorm(dim) + self.attn = nn.MultiheadAttention(dim, num_heads, bias=False) + activation = nn.ReLU() + self.mlp = nn.Sequential(nn.Linear(dim, dim * 4), activation, nn.Linear(dim * 4, dim), ) + + def forward(self, x): + attn_mask = torch.full((len(x), len(x)), -float("Inf"), device=x.device, dtype=x.dtype) + attn_mask = torch.triu(attn_mask, diagonal=1) + if self.use_ln: + x = self.ln_1(x) + a, _ = self.attn(x, x, x, attn_mask=attn_mask, need_weights=False) + x = x + a + if self.use_ln: + x = x + self.mlp(self.ln_2(x)) + else: + x = x + self.mlp(x) + return x + +class TransformerDecoder(nn.Module): + + def __init__(self, dim, num_layers, num_tokens, num_heads=4, seq_len=3, use_ln=False): + super().__init__() + self.token_embeddings = nn.Embedding(num_tokens, dim) + self.position_embeddings = nn.Embedding(seq_len, dim) + self.layers = nn.ModuleList() + for _ in range(num_layers): + self.layers.append(Block(dim, num_heads, use_ln)) + self.use_ln = use_ln + if use_ln: + self.ln_f = nn.LayerNorm(dim) + self.head = nn.Linear(dim, num_tokens, bias=False) + + def forward(self, x): + h = self.token_embeddings(x) + positions = torch.arange(x.shape[0], device=x.device).unsqueeze(-1) + h = h + self.position_embeddings(positions).expand_as(h) + for layer in self.layers: + h = layer(h) + if self.use_ln: + h = self.ln_f(h) + logits = self.head(h) + return logits + + +def init_params(model, model_dim, vocab_dim, init_type='xavier'): + for name, param in model.named_parameters(): + if param.dim() > 1: + if vocab_dim in param.shape: + nn.init.normal_(param, 0, 1 / np.sqrt(vocab_dim)) + else: + if init_type == 'xavier': + nn.init.xavier_normal_(param) + elif init_type == 'sqrt_dim': + nn.init.normal_(param, 0, 1 / np.sqrt(model_dim)) + else: + nn.init.constant_(param, 0) + + +def print_param_groups(param_groups): + for param_group in param_groups: + if 'apply_decay' in param_group: + print(f"### PARAM GROUP #### apply_decay: {param_group['apply_decay']}") + else: + print(f"### PARAM GROUP #### weight_decay: {param_group['weight_decay']}") + for name, param in zip(param_group['names'], param_group['params']): + print( + f"{name:60} {param.shape[0]:4} {param.shape[-1]:4} std {param.std():.3f} l2m {param.square().mean():.3f}") + + +### Main +def train_grokking(config): + torch.manual_seed(config.seed) + torch.cuda.manual_seed(config.seed) + + print("Config", config) + + if config.device is not None: + device = config.device + print("starting on device", device) + else: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + train_loader, test_loader = modular_addition(config.p, train_fraction=config.train_fraction, + train_shuffle=config.train_shuffle, device=device) + + model = TransformerDecoder( + dim=config.model_dim, num_layers=config.num_layers, num_heads=config.num_heads, num_tokens=config.p + 1, + seq_len=3, use_ln=config.use_ln).to(device) + + init_params(model, config.model_dim, config.p, init_type=config.init_type) + + if config.optimizer == 'adamcpr': + optimizer = apply_CPR(model, torch.optim.Adam, config.kappa_init_param, config.kappa_init_method, + config.reg_function, + config.kappa_adapt, config.kappa_update, config.apply_lr, + normalization_regularization=False, bias_regularization=False, + embedding_regularization=True, + lr=config.lr, betas=(config.beta1, config.beta2)) + param_groups = optimizer.state_dict()['param_groups'] + params = list(model.parameters()) + for param_group in param_groups: + for index, param_id in enumerate(param_group['params']): + param_group['params'][index] = params[param_id] + else: + param_dict = {pn: p for pn, p in model.named_parameters() + if p.requires_grad} + if config.exclude_reg is not None: + exclude_reg = config.exclude_reg.split(",") + param_groups = [{"params": [], "names": [], 'weight_decay': config.weight_decay}, { + "params": [], "names": [], 'weight_decay': 0}] + for k, v in param_dict.items(): + print(k) + if any([reg in k for reg in exclude_reg]): + param_groups[1]["params"].append(v) + param_groups[1]["names"].append(k) + else: + param_groups[0]["params"].append(v) + param_groups[0]["names"].append(k) + + else: + param_groups = model.parameters() + optimizer = torch.optim.AdamW(param_groups, lr=config.lr, betas=(config.beta1, config.beta2)) + + if config.print: + print_param_groups(param_groups) + + if config.rescale_alpha > 0: + with torch.no_grad(): + for n, p in model.named_parameters(): + if n.endswith("weight"): + p.data *= config.rescale_alpha + norm = np.sqrt(sum(p.pow(2).sum().item() for n, p in model.named_parameters() if n.endswith("weight"))) + + stats = defaultdict(list) + steps = 0 + + test_x, test_labels = next(iter(test_loader)) # ther is only one tests batch + test_x, test_labels = test_x.H.to(device), test_labels.to(device) + + for epoch in tqdm(range(config.epochs), disable=not config.print): + + for train_x, train_labels in train_loader: + + model.train(True) + train_x, train_labels = train_x.H.to(device), train_labels.to(device) + + train_logits = model(train_x) + train_loss = torch.nn.functional.cross_entropy(train_logits[-1], train_labels) + + model.zero_grad() + train_loss.backward() + optimizer.step() + + if config.rescale_alpha > 0: + with torch.no_grad(): + new_norm = np.sqrt( + sum(p.pow(2).sum().item() for n, p in model.named_parameters() if n.endswith("weight"))) + for n, p in model.named_parameters(): + if n.endswith("weight"): + p.data *= norm / new_norm + + if epoch % config.log_interval == 0: + with torch.no_grad(): + + model.train(False) + test_logits = model(test_x).detach() + test_loss = torch.nn.functional.cross_entropy(test_logits[-1], test_labels) + + train_acc = (train_logits[-1].argmax(-1) == train_labels).float().mean() + test_acc = (test_logits[-1].argmax(-1) == test_labels).float().mean() + + stats['train_loss'].append(train_loss.cpu().numpy()) + stats['val_loss'].append(test_loss.cpu().numpy()) + stats['train_acc'].append(train_acc.cpu().numpy()) + stats['val_acc'].append(test_acc.cpu().numpy()) + stats['total_norm'].append( + torch.sqrt(sum(param.pow(2).sum() for param in model.parameters())).cpu().numpy()) + stats['steps'].append(steps) + + if config.optimizer == "adamcpr": + for group, group_states in zip(optimizer.base_optimizer.param_groups, optimizer.cpr_states): + if 'apply_decay' in group and group['apply_decay'] is True: + for name, state in zip(group['names'], group_states): + lagmul = state['lagmul'] + kappa = state['kappa'] + step = state['step'] + stats[f"cpr/{name}/lambda"].append(lagmul.item()) + stats[f"cpr/{name}/kappa"].append(kappa.item()) + stats[f"cpr/{name}/step"].append(step.item()) + + totalnorm = [] + for param_group in optimizer.param_groups: + for name, param in zip(param_group['names'], param_group['params']): + stats[f"params/{name}/mean"].append(param.mean().item()) + stats[f"params/{name}/std"].append(param.std().item()) + stats[f"params/{name}/l2"].append(param.pow(2).sum().item()) + stats[f"params/{name}/l2m"].append(param.pow(2).mean().item()) + stats[f"params/{name}/l2s"].append(param.pow(2).sum().item()) + totalnorm.append(param.pow(2).sum().item()) + stats[f"params/total_norm"].append(np.sqrt(sum(totalnorm))) + + steps += 1 + + task_name = f"{config.epochs}_{str(int(config.seed))}_p{config.p}_f{config.train_fraction}" + if config.optimizer == "adamcpr": + expt_name = f"{config.optimizer}_p{config.kappa_init_param}_m{config.kappa_init_method}_kf{config.reg_function}_r{config.kappa_update}_l{config.lr}_adapt{config.kappa_adapt}_g{config.apply_lr}" + else: + expt_name = f"{config.optimizer}_w{config.weight_decay}_re{config.rescale_alpha}_l{config.lr}" + + config.output_dir = f"{config.output_dir}/grokking_{task_name}" + os.makedirs(config.output_dir, exist_ok=True) + config_dict = config.__dict__ + if config.print: + print(expt_name, config_dict) + + os.makedirs(config.output_dir + f"/{config.session}_stats", exist_ok=True) + np.save(f"{config.output_dir}/{config.session}_stats/{expt_name}.npy", + {"name": expt_name, 'stats': stats, 'config': config_dict}) + + if config.plot: + os.makedirs(config.output_dir + f"/{config.session}_figures", exist_ok=True) + + if config.plot_norms: + name_constrained_weights = param_groups[0]['names'] + plot_rows = 1 + len(name_constrained_weights) + + fig, ax = plt.subplots(plot_rows, 1, sharex=True, squeeze=True, figsize=(16, 12)) + + ax[0].plot(stats['steps'], stats['train_acc'], color='red', label="train") + ax[0].plot(stats['steps'], stats['val_acc'], color='green', label="val") + ax[0].legend() + ax[0].set_ylabel("Accuracy") + ax[0].set_xlim(8, 2 * config.epochs) + ax[0].set_xscale("log", base=10) + ax[0].set_title(expt_name) + + for idx, name in enumerate(name_constrained_weights): + axr = idx + 1 + ax[axr].plot(stats['steps'], stats[f"params/{name}/std"], color='orange', label=f"std {name}") + ax[axr].set_ylabel("STD") + ax2 = ax[axr].twinx() + if f"cpr/{name}/lambda" in stats.keys(): + ax2.plot(stats['steps'], stats[f"cpr/{name}/lambda"], color='purple', label=f"lambda {name}") + ax2.set_ylabel("Lambda", color='purple') + else: + ax2.plot(stats['steps'], stats[f"params/{name}/l2m"], color='purple', label=f"l2m {name}") + ax2.set_ylabel("Weight Norm", color='purple') + ax[axr].set_xlim(8, 2 * config.epochs) + ax[axr].set_xscale("log", base=10) + ax[axr].legend(loc=(0.015, 0.72)) + ax[axr].legend() + if idx < len(name_constrained_weights) - 1: + plt.setp(ax[axr].get_xticklabels(), visible=False) + ax[axr].set_xlabel("Optimization Steps") + fig.subplots_adjust(0.08, 0.1, 0.95, 0.93, 0, 0) + + else: + ax = plt.subplot(1, 1, 1) + plt.plot(stats['steps'], stats['train_acc'], color='red', label="train") + plt.plot(stats['steps'], stats['val_acc'], color='green', label="val") + plt.legend() + plt.xlabel("Optimization Steps") + plt.ylabel("Accuracy") + plt.xlim(8, 2 * config.epochs) + ax2 = ax.twinx() + if f"cpr/{name}/lambda" in stats.keys(): + ax2.plot(stats['steps'], stats[f"cpr/{name}/lambda"], color='purple', label=f"lambda {name}") + ax2.set_ylabel("Lambda", color='purple') + else: + ax2.plot(stats['steps'], stats[f"params/{name}/l2m"], color='purple', label=f"l2m {name}") + ax2.set_ylabel("Weight Norm", color='purple') + ax2.set_ylim(27, 63) + plt.xscale("log", base=10) + plt.legend(loc=(0.015, 0.72)) + plt.tight_layout() + plt.title(expt_name) + + plt.savefig(f"{config.output_dir}/{config.session}_figures/{expt_name}.png", dpi=150) + + if config.show_plot: + plt.show() + + plt.close() + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("--session", type=str, default='test_grokking') + parser.add_argument("--seed", type=int, default=1) + parser.add_argument("--epochs", type=int, default=4000) + + parser.add_argument("--p", type=int, default=113) + parser.add_argument("--train_shuffle", type=bool, default=True) + parser.add_argument("--train_fraction", type=float, default=0.3) + + parser.add_argument("--model_dim", type=int, default=128) + parser.add_argument("--num_layers", type=int, default=1) + parser.add_argument("--num_heads", type=int, default=4) + parser.add_argument("--use_ln", type=bool, default=False) + parser.add_argument("--init_type", type=str, default='sqrt_dim') + + parser.add_argument("--optimizer", type=str, default='adamcpr') + parser.add_argument("--lr", type=float, default=0.001) + parser.add_argument("--beta1", type=float, default=0.9) + parser.add_argument("--beta2", type=float, default=0.98) + parser.add_argument("--weight_decay", type=float, default=0.1) + parser.add_argument("--exclude_reg", type=str, default='bias,norm') + + parser.add_argument("--rescale_alpha", type=float, default=0) + + parser.add_argument("--kappa_init_param", type=float, default=0.8) + parser.add_argument("--kappa_init_method", type=str, default='dependent') + parser.add_argument("--reg_function", type=str, default='l2') + parser.add_argument("--kappa_update", type=float, default=1.0) + parser.add_argument("--kappa_adapt", type=bool, default=True) + parser.add_argument("--apply_lr", type=bool, default=False) + + parser.add_argument("--log_interval", type=int, default=5) + parser.add_argument("--output_dir", type=str, default='grokking') + parser.add_argument("--plot", type=bool, default=True) + parser.add_argument("--show_plot", type=bool, default=True) + parser.add_argument("--print", type=bool, default=True) + parser.add_argument("--plot_norms", type=bool, default=True) + parser.add_argument("--device", type=int, default=0) + + args = parser.parse_args() + + print(args.__dict__) + + if args.rescale_alpha > 0: + assert args.optimizer == 'adamw' + + train_grokking(args) diff --git a/examples/train_resnet.py b/examples/train_resnet.py new file mode 100644 index 0000000..8a61c61 --- /dev/null +++ b/examples/train_resnet.py @@ -0,0 +1,436 @@ +import pathlib, argparse +from argparse import ArgumentParser +import pytorch_lightning as pl +from pytorch_lightning.loggers import TensorBoardLogger +import torch.nn.functional as F +import math +import numpy as np +import torch +import torch.nn as nn +import torchvision +import torchvision.transforms as transforms +from pytorch_lightning import Callback +from pytorch_lightning.callbacks import LearningRateMonitor + +from pytorch_cpr import apply_CPR + +torch.set_float32_matmul_precision('high') + +class WeightDecayScheduler(Callback): + + def __init__(self, schedule_weight_decay: bool, schedule_type: str, scale: float): + super().__init__() + self.schedule_weight_decay = schedule_weight_decay + + self.schedule_type = schedule_type + + self.decay = scale + + self._step_count = 0 + + @staticmethod + def get_scheduler(schedule_type, num_warmup_steps, decay_factor, num_training_steps): + def fn_scheduler(current_step: int): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + elif schedule_type == 'linear': + return (decay_factor + (1 - decay_factor) * + max(0.0, float(num_training_steps - num_warmup_steps - current_step) / float( + max(1, num_training_steps - num_warmup_steps)))) + elif schedule_type == 'cosine': + return (decay_factor + (1 - decay_factor) * + max(0.0, (1 + math.cos(math.pi * (current_step - num_warmup_steps) / float( + max(1, num_training_steps - num_warmup_steps)))) / 2)) + elif schedule_type == 'const': + return 1.0 + + return fn_scheduler + + def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"): + self.num_training_steps = trainer.max_steps + + self.weight_decay = [] + for optim in trainer.optimizers: + for group_idx, group in enumerate(optim.param_groups): + if 'weight_decay' in group: + self.weight_decay.append(group['weight_decay']) + + num_warmup_steps = 0 + + self.scheduler = self.get_scheduler(self.schedule_type, num_warmup_steps, self.decay, self.num_training_steps) + + def on_before_optimizer_step(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", optimizer): + + if self.schedule_weight_decay: + stats = {} + for group_idx, group in enumerate(optimizer.param_groups): + if 'weight_decay' in group: + group['weight_decay'] = self.weight_decay[group_idx] * self.scheduler(self._step_count) + stats[f"weight_decay/rank_{trainer.local_rank}/group_{group_idx}"] = group['weight_decay'] + + if trainer.loggers is not None: + for logger in trainer.loggers: + logger.log_metrics(stats, step=trainer.global_step) + self._step_count += 1 + + +### Data +def cifar100_task(cache_dir='./data'): + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4, padding_mode='reflect'), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ]) + + transform_test = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ]) + + trainset = torchvision.datasets.CIFAR100(root=cache_dir, train=True, download=True, transform=transform_train) + testset = torchvision.datasets.CIFAR100(root=cache_dir, train=False, download=True, transform=transform_test) + + return trainset, testset + + +def wd_group_named_parameters(model, weight_decay): + apply_decay = set() + apply_no_decay = set() + special = set() + whitelist_weight_modules = (nn.Linear, nn.Conv2d, nn.Embedding) + blacklist_weight_modules = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, + nn.LazyBatchNorm1d, nn.LazyBatchNorm2d, nn.LazyBatchNorm3d, + nn.GroupNorm, nn.SyncBatchNorm, + nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d, + nn.LayerNorm, nn.LocalResponseNorm) + + param_dict = {pn: p for pn, p in model.named_parameters() if p.requires_grad} + for mn, m in model.named_modules(): + for pn, p in m.named_parameters(): + fpn = '%s.%s' % (mn, pn) if mn else pn # full param name + if not p.requires_grad or fpn not in param_dict: + continue # frozen weights + if hasattr(p, '_optim'): + special.add(fpn) + elif pn.endswith('bias'): + apply_no_decay.add(fpn) + elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): + apply_decay.add(fpn) + elif isinstance(m, blacklist_weight_modules): + apply_no_decay.add(fpn) + else: + print("cpr_group_named_parameters: Not using any rule for ", fpn, " in ", type(m)) + + apply_decay |= (param_dict.keys() - apply_no_decay - special) + + # validate that we considered every parameter + inter_params = apply_decay & apply_no_decay + union_params = apply_decay | apply_no_decay + assert len(inter_params) == 0, f"Parameters {str(inter_params)} made it into both apply_decay/apply_no_decay sets!" + assert len( + param_dict.keys() - special - union_params) == 0, f"parameters {str(param_dict.keys() - union_params)} were not separated into either apply_decay/apply_no_decay set!" + + if not apply_no_decay: + param_groups = [{"params": [param_dict[pn] for pn in sorted(list(apply_no_decay | apply_decay))], + "names": [pn for pn in sorted(list(apply_no_decay | apply_decay))], + "weight_decay": weight_decay}] + else: + param_groups = [ + {"params": [param_dict[pn] for pn in sorted(list(apply_decay))], + "names": [pn for pn in sorted(list(apply_decay))], "weight_decay": weight_decay}, + {"params": [param_dict[pn] for pn in sorted(list(apply_no_decay))], + "names": [pn for pn in sorted(list(apply_no_decay))], "weight_decay": 0.0}, + ] + + return param_groups + + +### Model +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, in_planes, planes, stride=1): + super(BasicBlock, self).__init__() + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion * planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(self.expansion * planes) + ) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.bn2(self.conv2(out)) + out += self.shortcut(x) + out = F.relu(out) + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, in_planes, planes, stride=1): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(self.expansion * planes) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion * planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(self.expansion * planes) + ) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = F.relu(self.bn2(self.conv2(out))) + out = self.bn3(self.conv3(out)) + out += self.shortcut(x) + out = F.relu(out) + return out + + +class ResNet(nn.Module): + def __init__(self, block, num_blocks, num_classes): + super(ResNet, self).__init__() + self.in_planes = 64 + + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) + self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) + self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) + self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) + self.linear = nn.Linear(512 * block.expansion, num_classes) + + def _make_layer(self, block, planes, num_blocks, stride): + strides = [stride] + [1] * (num_blocks - 1) + layers = [] + for stride in strides: + layers.append(block(self.in_planes, planes, stride)) + self.in_planes = planes * block.expansion + return nn.Sequential(*layers) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.layer1(out) + out = self.layer2(out) + out = self.layer3(out) + out = self.layer4(out) + out = F.avg_pool2d(out, 4) + out = out.view(out.size(0), -1) + out = self.linear(out) + return out + + +### Lightning Module +class ResNetModule(pl.LightningModule): + + def __init__(self, config): + super().__init__() + + self.cfg = config + + if self.cfg.model_name == "ResNet18": + self.model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=100) + elif self.cfg.model_name == "ResNet34": + self.model = ResNet(BasicBlock, [3, 4, 6, 3], num_classes=100) + elif self.cfg.model_name == "ResNet50": + self.model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=100) + elif self.cfg.model_name == "ResNet101": + self.model = ResNet(Bottleneck, [3, 4, 23, 3], num_classes=100) + elif self.cfg.model_name == "ResNet152": + self.model = ResNet(Bottleneck, [3, 8, 36, 3], num_classes=100) + + self.loss = nn.CrossEntropyLoss(label_smoothing=0.1) + + self.test_stats = [] + + def configure_optimizers(self): + + if self.cfg.optimizer == 'adamw': + param_groups = wd_group_named_parameters(self.model, weight_decay=self.cfg.weight_decay) + optimizer = torch.optim.AdamW(param_groups, lr=self.cfg.lr, betas=(self.cfg.beta1, self.cfg.beta2)) + elif self.cfg.optimizer == 'adamcpr': + optimizer = apply_CPR(self.model, torch.optim.Adam, self.cfg.kappa_init_param, self.cfg.kappa_init_method, + self.cfg.reg_function, + self.cfg.kappa_adapt, self.cfg.kappa_update, self.cfg.apply_lr, + embedding_regularization=True, + lr=self.cfg.lr, betas=(self.cfg.beta1, self.cfg.beta2)) + + if self.cfg.rescale_alpha > 0.0: + with torch.no_grad(): + for n, p in self.model.named_parameters(): + if n.endswith("weight"): + p.data *= self.cfg.rescale_alpha + self.rescale_norm = np.sqrt( + sum(p.pow(2).sum().item() for n, p in self.model.named_parameters() if n.endswith("weight"))) + + lr_decay_factor = self.cfg.lr_decay_factor + num_warmup_steps = self.cfg.lr_warmup_steps + + def lr_lambda(current_step: int): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + else: + return lr_decay_factor + (1 - lr_decay_factor) * max(0.0, (1 + math.cos( + math.pi * (current_step - num_warmup_steps) / float( + max(1, self.cfg.max_train_steps - num_warmup_steps)))) / 2) + + lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=-1) + + return [optimizer], {'scheduler': lr_scheduler, 'interval': 'step'} + + def setup(self, stage: str) -> None: + trainset, testset = cifar100_task() + self.trainset = trainset + self.testset = testset + + def train_dataloader(self): + train_loader = torch.utils.data.DataLoader(self.trainset, batch_size=self.cfg.batch_size, shuffle=True, + num_workers=8) + return train_loader + + def test_dataloader(self): + test_loader = torch.utils.data.DataLoader(self.testset, batch_size=self.cfg.batch_size, shuffle=False, + num_workers=8) + return test_loader + + def _accuracy(self, y_hat, y): + return torch.sum(torch.argmax(y_hat, dim=1) == y).item() / len(y) + + def training_step(self, batch, batch_idx): + + X, y = batch + y_hat = self.model(X) + loss = self.loss(y_hat, y) + + self.log('train_loss', loss) + + if self.cfg.rescale_alpha > 0.0: + with torch.no_grad(): + new_norm = np.sqrt( + sum(p.pow(2).sum().item() for n, p in self.model.named_parameters() if n.endswith("weight"))) + for n, p in self.model.named_parameters(): + if n.endswith("weight"): + p.data *= self.rescale_norm / new_norm + return loss + + def test_step(self, batch, batch_idx): + X, y = batch + y_hat = self.model(X) + loss = self.loss(y_hat, y) + + correct_pred = torch.sum(torch.argmax(y_hat, dim=1) == y).item() + num_samples = len(y) + self.test_stats.append({'loss': loss.item(), 'correct_pred': correct_pred, 'num_samples': num_samples}) + self.log('test_loss', loss) + + return loss + + def on_test_epoch_end(self): + self.log('lr', self.trainer.optimizers[0].param_groups[0]['lr']) + + valid_loss = np.mean([s['loss'] for s in self.test_stats]) + valid_accuracy = np.sum([s['correct_pred'] for s in self.test_stats]) / np.sum( + [s['num_samples'] for s in self.test_stats]) + self.log('test_loss', valid_loss) + self.log('test_accuracy', valid_accuracy) + self.test_stats = [] + + +def train_cifar100_task(config): + task_name = f"{config.model_name}_seed{config.seed}_steps{config.max_train_steps}" + expt_dir = pathlib.Path(config.output_dir) / config.session / task_name + expt_dir.mkdir(parents=True, exist_ok=True) + + if config.optimizer == "adamcpr": + expt_name = f"{config.optimizer}_p{config.kappa_init_param}_m{config.kappa_init_method}_kf{config.reg_function}_r{config.kappa_update}_l{config.lr}_adapt{config.kappa_adapt}_g{config.apply_lr}" + else: + expt_name = f"{config.optimizer}_l{config.lr}_w{config.weight_decay}_re{config.rescale_alpha}_swd{config.schedule_weight_decay}_swds{config.wd_scale}_t{config.wd_schedule_type}" + + (expt_dir / expt_name).mkdir(parents=True, exist_ok=True) + np.save(expt_dir / expt_name / "config.npy", config.__dict__) + logger = TensorBoardLogger(save_dir=expt_dir, name=expt_name) + pl.seed_everything(config.seed) + + if config.device: + devices = [config.device] + else: + devices = [0] + + model = ResNetModule(config) + + callbacks = [ + LearningRateMonitor(logging_interval='step'), + WeightDecayScheduler(config.schedule_weight_decay, schedule_type=config.wd_schedule_type, scale=config.wd_scale) + ] + + trainer = pl.Trainer(devices=devices, accelerator="gpu", max_steps=config.max_train_steps, + log_every_n_steps=config.log_interval, + enable_progress_bar=config.enable_progress_bar, + logger=logger, callbacks=callbacks) + trainer.fit(model) + + # evaluate model + result = trainer.test(model) + np.save(expt_dir / expt_name / "result.npy", result) + print(result) + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("--session", type=str, default='test_resnet') + parser.add_argument("--seed", type=int, default=1) + parser.add_argument("--max_train_steps", type=int, default=20000) + parser.add_argument("--batch_size", type=int, default=128) + parser.add_argument("--model_name", type=str, default="ResNet18") + parser.add_argument("--optimizer", type=str, default='adamcpr') + parser.add_argument("--lr", type=float, default=0.001) + parser.add_argument("--beta1", type=float, default=0.9) + parser.add_argument("--beta2", type=float, default=0.98) + parser.add_argument("--weight_decay", type=float, default=0.1) + + parser.add_argument("--schedule_weight_decay", action=argparse.BooleanOptionalAction) + parser.add_argument("--wd_schedule_type", type=str, default='cosine') + parser.add_argument("--wd_scale", type=float, default=0.1) + + parser.add_argument("--lr_warmup_steps", type=int, default=200) + parser.add_argument("--lr_decay_factor", type=float, default=0.1) + parser.add_argument("--rescale_alpha", type=float, default=0) + + parser.add_argument("--kappa_init_param", type=float, default=1000) + parser.add_argument("--kappa_init_method", type=str, default='warm_start') + parser.add_argument("--reg_function", type=str, default='l2') + parser.add_argument("--kappa_update", type=float, default=1.0) + parser.add_argument("--kappa_adapt", action=argparse.BooleanOptionalAction) + parser.add_argument("--apply_lr", action=argparse.BooleanOptionalAction) + + parser.add_argument("--start_epoch", type=int, default=1) + + parser.add_argument("--log_interval", type=int, default=10) + parser.add_argument("--enable_progress_bar", type=bool, default=True) + parser.add_argument("--output_dir", type=str, default='cifar100') + parser.add_argument("--device", type=int, default=0) + args = parser.parse_args() + + args.schedule_weight_decay = args.schedule_weight_decay == 1 + args.kappa_adapt = args.kappa_adapt == 1 + args.apply_lr = args.apply_lr == 1 + + print(args.__dict__) + + if args.rescale_alpha > 0.0: + assert args.optimizer == 'adamw' + + train_cifar100_task(args) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..4528a7c --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,26 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "pytorch-cpr" +version = "0.1.0" +authors = [ + { name="Jörg Franke", email="frankej@cs.uni-freiburg.de" }, +] +description = "Constrained Parameter Regularization for PyTorch" +readme = "README.md" +requires-python = ">=3.10" +classifiers = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", +] + +[project.optional-dependencies] +full = [ + 'pytorch>=2.0.0', +] + +[project.urls] +Homepage = "https://github.com/automl/CPR" diff --git a/pytorch_cpr/__init__.py b/pytorch_cpr/__init__.py new file mode 100644 index 0000000..843a4d1 --- /dev/null +++ b/pytorch_cpr/__init__.py @@ -0,0 +1,3 @@ +from .group_parameter import cpr_group_named_parameters +from .optim_cpr import CPR +from .wrapper import apply_CPR \ No newline at end of file diff --git a/pytorch_cpr/group_parameter.py b/pytorch_cpr/group_parameter.py new file mode 100644 index 0000000..f37bc5a --- /dev/null +++ b/pytorch_cpr/group_parameter.py @@ -0,0 +1,82 @@ +import torch.nn as nn +import logging + +def cpr_group_named_parameters(model, optim_hps, avoid_keywords=[], + embedding_regularization=False, + bias_regularization=False, + normalization_regularization=False): + if not avoid_keywords: + avoid_keywords = [] + + apply_decay = set() + apply_no_decay = set() + special = set() + whitelist_weight_modules = (nn.Linear, nn.Conv2d) + blacklist_weight_modules = () + if embedding_regularization: + whitelist_weight_modules += (nn.Embedding,) + else: + blacklist_weight_modules += (nn.Embedding,) + + if normalization_regularization: + whitelist_weight_modules += (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, + nn.LazyBatchNorm1d, nn.LazyBatchNorm2d, nn.LazyBatchNorm3d, + nn.GroupNorm, nn.SyncBatchNorm, + nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d, + nn.LayerNorm, nn.LocalResponseNorm) + else: + blacklist_weight_modules += (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, + nn.LazyBatchNorm1d, nn.LazyBatchNorm2d, nn.LazyBatchNorm3d, + nn.GroupNorm, nn.SyncBatchNorm, + nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d, + nn.LayerNorm, nn.LocalResponseNorm) + + + param_dict = {pn: p for pn, p in model.named_parameters() if p.requires_grad} + for mn, m in model.named_modules(): + for pn, p in m.named_parameters(): + fpn = '%s.%s' % (mn, pn) if mn else pn # full param name + # In case of parameter sharing, some parameters show up here but are not in + # param_dict.keys() + if not p.requires_grad or fpn not in param_dict: + continue # frozen weights + if hasattr(p, '_optim'): + special.add(fpn) + elif isinstance(m, blacklist_weight_modules): + apply_no_decay.add(fpn) + elif any([keyword in fpn for keyword in avoid_keywords]): + apply_no_decay.add(fpn) + elif not bias_regularization and pn.endswith('bias'): + apply_no_decay.add(fpn) + elif isinstance(m, whitelist_weight_modules): + apply_decay.add(fpn) + else: + logging.debug(f"cpr_group_named_parameters: Not using any rule for {fpn} in {type(m)}") + + apply_decay |= (param_dict.keys() - apply_no_decay - special) + + # validate that we considered every parameter + inter_params = apply_decay & apply_no_decay + union_params = apply_decay | apply_no_decay + assert len(inter_params) == 0, f"Parameters {str(inter_params)} made it into both apply_decay/apply_no_decay sets!" + assert len(param_dict.keys() - special - union_params) == 0, (f"parameters {str(param_dict.keys() - union_params)} " + f" were not separated into either apply_decay/apply_no_decay set!") + + if not apply_no_decay: + param_groups = [{"params": [param_dict[pn] for pn in sorted(apply_decay)], + "names": [pn for pn in sorted(apply_decay)], "apply_decay": True, **optim_hps}] + else: + param_groups = [ + {"params": [param_dict[pn] for pn in sorted(list(apply_decay))], + "names": [pn for pn in sorted(list(apply_decay))], "apply_decay": True, **optim_hps}, + {"params": [param_dict[pn] for pn in sorted(list(apply_no_decay))], + "names": [pn for pn in sorted(list(apply_no_decay))], "apply_decay": False, **optim_hps}, + ] + # Add parameters with special hyperparameters + # Unique dicts + hps = [dict(s) for s in set(frozenset(param_dict[pn]._optim.items()) for pn in special)] + for hp in hps: + params = [param_dict[pn] for pn in sorted(list(special)) if param_dict[pn]._optim == hp] + param_groups.append({"params": params, **hp}) + + return param_groups diff --git a/pytorch_cpr/optim_cpr.py b/pytorch_cpr/optim_cpr.py new file mode 100644 index 0000000..a7403c4 --- /dev/null +++ b/pytorch_cpr/optim_cpr.py @@ -0,0 +1,165 @@ +import torch + + +class CPR(torch.optim.Optimizer): + def __init__(self, optimizer: torch.optim.Optimizer, kappa_init_param: float, kappa_init_method: str = 'warm_start', + reg_function: str = 'l2', kappa_adapt: bool = False, kappa_update: float = 1.0, apply_lr=False): + """ + Args: + optimizer (torch.optim.Optimizer): The original optimizer (e.g., SGD, Adam). + kappa_init_param (float): The initial value of kappa. + kappa_init_method (str): The method to initialize kappa. Options: 'warm_start', 'uniform', 'dependent' + reg_function (str): The function to regularize the parameters. Options: 'l2', 'std' + kappa_adapt (bool): Whether to adapt kappa during training. + kappa_update (float): The update rate of kappa (mu). + + """ + self.base_optimizer = optimizer + + self.kappa_init_param = kappa_init_param + self.kappa_init_method = kappa_init_method + self.reg_function = reg_function + self.kappa_adapt = kappa_adapt + self.kappa_update = kappa_update + self.apply_lr = apply_lr + + assert self.kappa_init_method in ['warm_start', 'uniform', 'dependent'] + assert self.reg_function in ['l2', 'std'] + + # Ensure internal optimizer's weight decay is set to 0 + for group in self.base_optimizer.param_groups: + if 'weight_decay' in group and group['weight_decay'] != 0: + group['weight_decay'] = 0 + + # Initialize CPR states + self.initilize_CPR_states() + + def initilize_CPR_states(self): + + self.cpr_states = [] + + for group in self.base_optimizer.param_groups: + group_state = [] + if 'weight_decay' in group and group['weight_decay'] != 0: + group['weight_decay'] = 0 + + if 'apply_decay' in group and group['apply_decay'] is True: + + for p in group['params']: + state = {} + state["lagmul"] = torch.tensor(0, dtype=torch.float, device=p.device) + state["step"] = torch.tensor(0, dtype=torch.int32, device=p.device) + + if self.kappa_adapt: + state["adapt_flag"] = torch.tensor(False, dtype=torch.bool, device=p.device) + + if self.kappa_init_method == 'uniform': + state["kappa"] = torch.tensor(self.kappa_init_param, dtype=torch.float, device=p.device) + elif self.kappa_init_method == 'warm_start': + state["kappa"] = torch.tensor(10, dtype=torch.float, device=p.device) + elif self.kappa_init_method == 'dependent': + if self.reg_function == 'std': + state["kappa"] = self.kappa_init_param * torch.std(p).detach() + elif self.reg_function == 'l2': + state["kappa"] = self.kappa_init_param * p.square().mean().detach() + group_state.append(state) + self.cpr_states.append(group_state) + + def zero_grad(self): + """Clears the gradients of all optimized parameters.""" + self.base_optimizer.zero_grad() + + def state_dict(self): + """Returns the state of the optimizer as a dict.""" + state_dict = self.base_optimizer.state_dict() + state_dict['cpr_states'] = self.cpr_states + return state_dict + + def load_state_dict(self, state_dict): + """Loads the optimizer state.""" + if 'cpr_states' in state_dict: + self.cpr_states = state_dict['cpr_states'] + del state_dict['cpr_states'] + self.base_optimizer.load_state_dict(state_dict) + + def __getattr__(self, name): + """Redirect unknown attribute requests to the original optimizer.""" + return getattr(self.base_optimizer, name) + + def step(self, closure=None): + """Performs a single optimization step.""" + self.base_optimizer.step(closure) + + with torch.no_grad(): + + # Apply constrained parameter regularization (CPR) + for group, group_states in zip(self.base_optimizer.param_groups, self.cpr_states): + + if 'apply_decay' in group and group['apply_decay'] is True: + assert len(group['params']) == len(group_states) + for param, state in zip(group['params'], group_states): + + lagmul = state['lagmul'] + kappa = state['kappa'] + step = state['step'] + + if self.reg_function == 'l2': + + n = float(param.numel()) + half_sum_l2norm = param.square().sum() # reg function + + param_specific_lagmul_rate = self.kappa_update / n + param_specific_kappa = kappa * n + + constraint_value = half_sum_l2norm - param_specific_kappa + grad_c = 2 * param + + lagmul.add_(param_specific_lagmul_rate * constraint_value).clip_(min=0.) + if self.apply_lr: + param.add_(-grad_c * lagmul * group['lr']) + else: + param.add_(-grad_c * lagmul) + + elif self.reg_function == 'std': + + n = float(param.numel()) + std_dev = param.std() + + constraint_value = std_dev - kappa + + mean = param.mean() + norm_param = param.sub(mean) + grad_std_dev = norm_param.mul_(2).sub_(2 * norm_param.mean()).div_(n - 1) + grad_std_dev.div_(std_dev.mul_(2)) + grad_c = grad_std_dev + + lagmul.add_(self.kappa_update * constraint_value).clip_(min=0.) + if self.apply_lr: + param.add_(-grad_c * lagmul * group['lr']) + else: + param.add_(-grad_c * lagmul) + + if self.kappa_adapt and not ( + self.kappa_init_method == 'warm_start' and self.kappa_init_param >= step): + adapt_flag = state['adapt_flag'] + + if True == adapt_flag and lagmul == 0: + if self.reg_function == 'l2': + new_kappa = param.square().mean() + kappa.clamp_max_(new_kappa) + + elif self.reg_function == 'std': + new_kappa = param.std() + kappa.clamp_max_(new_kappa) + + if lagmul > 0 and False == adapt_flag: + adapt_flag.add_(True) + + if self.kappa_init_method == 'warm_start' and self.kappa_init_param == step: + if self.reg_function == 'std': + new_kappa = param.std() + elif self.reg_function == 'l2': + new_kappa = param.square().mean() + kappa.clamp_max_(new_kappa) + + state['step'] += 1 diff --git a/pytorch_cpr/wrapper.py b/pytorch_cpr/wrapper.py new file mode 100644 index 0000000..04c48e9 --- /dev/null +++ b/pytorch_cpr/wrapper.py @@ -0,0 +1,29 @@ +import inspect + +from pytorch_cpr.optim_cpr import CPR +from pytorch_cpr.group_parameter import cpr_group_named_parameters + +def apply_CPR(model, optimizer_cls, kappa_init_param, kappa_init_method='warm_start', reg_function='l2', + kappa_adapt=False, kappa_update=1.0, apply_lr=False, + normalization_regularization=False, bias_regularization=False, embedding_regularization=False, + **optimizer_args): + + optimizer_args['weight_decay'] = 0 + avoid_keywords = [] + + param_groups = cpr_group_named_parameters(model=model, optim_hps=optimizer_args, avoid_keywords=avoid_keywords, + embedding_regularization=embedding_regularization, + bias_regularization=bias_regularization, + normalization_regularization=normalization_regularization) + + optimizer_keys = inspect.getfullargspec(optimizer_cls).args + for k, v in optimizer_args.items(): + if k not in optimizer_keys: + raise UserWarning(f"apply_CPR: Unknown optimizer argument {k}") + optimizer = optimizer_cls(param_groups, **optimizer_args) + + optimizer = CPR( optimizer=optimizer, kappa_init_param=kappa_init_param, kappa_init_method=kappa_init_method, + reg_function=reg_function, kappa_adapt=kappa_adapt, kappa_update=kappa_update, apply_lr=apply_lr) + + + return optimizer diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..9adaeff --- /dev/null +++ b/setup.py @@ -0,0 +1,12 @@ +from setuptools import setup + +setup( + name='pytorch-cpr', + version='0.1.0', + description='Constrained Parameter Regularization for PyTorch', + url='https://github.com/automl/CPR', + author='Jörg Franke', + license='Apache License 2.0', + packages=['pytorch_cpr'], + install_requires=['pytorch>=2.0.0'], +) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/model.py b/tests/model.py new file mode 100644 index 0000000..fa0250f --- /dev/null +++ b/tests/model.py @@ -0,0 +1,28 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class MockModel(nn.Module): + def __init__(self): + super(MockModel, self).__init__() + # Convolutional layers + self.conv1 = nn.Conv2d(1, 20, 1, 1) + self.conv2 = nn.Conv2d(20, 50, 1, 1) + # Linear layers + self.fc1 = nn.Linear(4*4*50, 500) + self.fc2 = nn.Linear(500, 10) + # Batch normalization + self.batch_norm = nn.BatchNorm2d(20) + # Embedding layer + self.embedding = nn.Embedding(10, 10) + + def forward(self, x): + # Convolutional layers with ReLU and max pooling + x = F.relu(self.batch_norm(self.conv1(x))) + x = F.relu(self.conv2(x)) + # Flatten the tensor + x = x.view(-1, 4*4*50) + # Linear layers with ReLU + x = F.relu(self.fc1(x)) + x = self.fc2(x) + return F.log_softmax(x, dim=1) diff --git a/tests/test_group_parameter.py b/tests/test_group_parameter.py new file mode 100644 index 0000000..9168328 --- /dev/null +++ b/tests/test_group_parameter.py @@ -0,0 +1,58 @@ +import pytest +import torch +import torch.nn as nn +from pytorch_cpr.group_parameter import cpr_group_named_parameters +from .model import MockModel # Assuming you have a MockModel defined + +def test_group_named_parameters_all_configs(): + model = MockModel() + optim_hps = {'lr': 0.01, 'momentum': 0.9} # Example hyperparameters + + # Define the different configurations + embedding_regularizations = [True, False] + bias_regularizations = [True, False] + normalization_regularizations = [True, False] + avoid_keyword_options = [[], ['conv'], ['fc']] # Example keywords + + for embedding_reg in embedding_regularizations: + for bias_reg in bias_regularizations: + for norm_reg in normalization_regularizations: + for avoid_keywords in avoid_keyword_options: + param_groups = cpr_group_named_parameters( + model, + optim_hps, + avoid_keywords=avoid_keywords, + embedding_regularization=embedding_reg, + bias_regularization=bias_reg, + normalization_regularization=norm_reg + ) + + for param_group in param_groups: + print("bias_reg", bias_reg, "embedding_reg", embedding_reg, "norm_reg", norm_reg, "avoid", avoid_keywords) + print( + f"### PARAM GROUP #### apply_decay: {param_group['apply_decay']}, lr: {param_group['lr']}") + for name, param in zip(param_group['names'], param_group['params']): + print( + f"{name:60} {param.shape[0]:4} {param.shape[-1]:4} std {param.std():.3f} l2m {param.square().mean():.3f}") + + for group in param_groups: + for param_name in group['names']: + module, param_type = param_name.split('.') if '.' in param_name else (param_name, None) + + if any(key in param_name for key in avoid_keywords): + # Parameters with avoid keywords should be in no_decay group + assert not group.get('apply_decay', False), f"Parameter {param_name} with avoid keyword should be in no_decay group." + elif 'bias' in param_type and not bias_reg: + # Bias parameters without bias regularization should be in no_decay group + assert not group.get('apply_decay', False), f"Bias parameter {param_name} should be in no_decay group." + elif isinstance(model._modules.get(module, None), nn.Embedding) and not embedding_reg: + # Embedding layers without embedding regularization should be in no_decay group + assert not group.get('apply_decay', False), f"Embedding parameter {param_name} should be in no_decay group." + elif isinstance(model._modules.get(module, None), (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, + nn.GroupNorm, nn.InstanceNorm1d, nn.InstanceNorm2d, + nn.InstanceNorm3d, nn.LayerNorm, nn.LocalResponseNorm)) and not norm_reg: + # Normalization layers without normalization regularization should be in no_decay group + assert not group.get('apply_decay', False), f"Normalization layer parameter {param_name} should be in no_decay group." + else: + # Other parameters should be in decay group + assert group.get('apply_decay', False), f"Parameter {param_name} should be in decay group." diff --git a/tests/test_optim_cpr.py b/tests/test_optim_cpr.py new file mode 100644 index 0000000..8fb4df2 --- /dev/null +++ b/tests/test_optim_cpr.py @@ -0,0 +1,39 @@ +import pytest +import torch +from pytorch_cpr.optim_cpr import CPR +from torch.optim import SGD, Adam +from .model import MockModel + +def test_cpr_with_sgd(): + model = MockModel() + base_optimizer = SGD(model.parameters(), lr=0.01) + cpr_optimizer = CPR(base_optimizer, kappa_init_param=0.1) + + # Test the CPR optimizer functionality with SGD + # Ensure correct initialization of kappa + assert cpr_optimizer.kappa_init_param == 0.1, "Incorrect initialization of kappa_init_param" + + # Perform a single optimization step and check behavior + output = model(torch.randn(1, 1, 4, 4)) # Assuming input size for MockModel + loss = output.mean() + loss.backward() + cpr_optimizer.step() + + # Add more assertions as necessary to verify the behavior + +def test_cpr_with_adam(): + model = MockModel() + base_optimizer = Adam(model.parameters(), lr=0.01) + cpr_optimizer = CPR(base_optimizer, kappa_init_param=0.1) + + # Test the CPR optimizer functionality with Adam + # Ensure correct initialization of kappa + assert cpr_optimizer.kappa_init_param == 0.1, "Incorrect initialization of kappa_init_param" + + # Perform a single optimization step and check behavior + output = model(torch.randn(1, 1, 4, 4)) + loss = output.mean() + loss.backward() + cpr_optimizer.step() + + # Add more assertions as necessary to verify the behavior diff --git a/tests/test_wrapper.py b/tests/test_wrapper.py new file mode 100644 index 0000000..2999af3 --- /dev/null +++ b/tests/test_wrapper.py @@ -0,0 +1,40 @@ +import pytest +import torch +from pytorch_cpr.wrapper import apply_CPR +from torch.optim import SGD, Adam +from .model import MockModel + +def test_apply_cpr_with_sgd(): + model = MockModel() + optimizer_cls = SGD + optimizer_args = {'lr': 0.01} + + cpr_optimizer = apply_CPR( + model, + optimizer_cls, + kappa_init_param=0.1, + **optimizer_args + ) + + # Test the apply_CPR functionality with SGD + # Ensure CPR optimizer is correctly initialized + assert cpr_optimizer.base_optimizer.__class__ == SGD, "Base optimizer is not SGD" + assert cpr_optimizer.kappa_init_param == 0.1, "Incorrect initialization of kappa_init_param" + + +def test_apply_cpr_with_adam(): + model = MockModel() + optimizer_cls = Adam + optimizer_args = {'lr': 0.01} + + cpr_optimizer = apply_CPR( + model, + optimizer_cls, + kappa_init_param=0.1, + **optimizer_args + ) + + # Test the apply_CPR functionality with Adam + # Ensure CPR optimizer is correctly initialized + assert cpr_optimizer.base_optimizer.__class__ == Adam, "Base optimizer is not Adam" + assert cpr_optimizer.kappa_init_param == 0.1, "Incorrect initialization of kappa_init_param"