Skip to content

Commit

Permalink
Update the code base with AdamCPR and for PyTorch v2.3.1+. Add figure…
Browse files Browse the repository at this point in the history
…s and more detailed descriptions in README. Tests are not completed.
  • Loading branch information
JoergFranke committed Nov 9, 2024
1 parent 5a9092e commit ecedcd7
Show file tree
Hide file tree
Showing 19 changed files with 1,452 additions and 575 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -158,4 +158,4 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
venv
venv
109 changes: 70 additions & 39 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,23 @@

# Constrained Parameter Regularization
# Improving Deep Learning Optimization through Constrained Parameter Regularization

This repository contains the PyTorch implementation of [**Constrained Parameter Regularization**](https://arxiv.org/abs/2311.09058).
This repository contains the PyTorch implementation of [**Constrained Parameter Regularization**](https://arxiv.org/abs/2311.09058)(CPR) with the Adam optimizer.
CPR is an alternative to traditional weight decay. Unlike the uniform application of a single penalty, CPR enforces an upper bound on a statistical measure, such as the L$_2$-norm, of individual parameter matrices. CPR introduces only a minor runtime overhead and only requires setting an upper bound (or does it automatically with an inflection point detection).

AdamCPR outperforms AdamW on various tasks, such as imagenet (CIFAR100 and ImageNet) or language modeling (GPT2/OpenWebText) as in the figure below.

<img src="figures/gpt2s_adamw200_300_cprIP.jpg width="390" height="240">

We see a GPT2s model trained on OpenWebText with AdamW for 200k steps (blue) and 300k steps (purple) vs. AdamCPR with inflection point detection (green). The CPR model converges more linear and achieves a lower validation perplexity, equivalent to training 50% longer with AdamW.
Please find more experiments in our [paper](https://arxiv.org/abs/2311.09058).

## How does it work?

With CPR, learning becomes a constraint optimization problem, which we tackle using an adaptation of the augmented Lagrangian method.
We implement this by adding a Lagrange multiplier $\lambda$ (scalar) and an upper bound $\kappa$ (scalar) for each parameter matrix $W$ in the model and update them each optimization step. We introduce four techniques for initializing the upper bound: `'uniform'` with a fixed value ,`'depended'` on the initial parameter norm, `'warm_start'` based on the norm after X training steps and a `'inflection_point'` detection-based method which doesn't require any additional hyperparameter for the regularization.
We implement this Lagrange optimization directly in the Adam optimizer, which we call AdamCPR:

<img src="figures/adamcpr.jpg width="852" height="439">



Expand All @@ -14,34 +29,37 @@ pip install pytorch-cpr

## Getting started

### Usage of `apply_CPR` Optimizer Wrapper
We implemented CPR with Adam optimizer in PyTorch (v2.3.1+). To use CPR, you can simply replace the optimizer in your training script with the AdamCPR optimizer.

The `apply_CPR` function is a wrapper designed to apply CPR (Constrained Parameter Regularization) to a given optimizer by first creating parameter groups and the wrapping the actual optimizer class.
### Example usage

#### Arguments
```python
from pytorch_cpr import AdamCPR

- `model`: The PyTorch model whose parameters are to be optimized.
- `optimizer_cls`: The class of the optimizer to be used (e.g., `torch.optim.Adam`).
- `kappa_init_param`: Initial value for the kappa parameter in CPR depending on tge initialization method.
- `kappa_init_method` (default `'warm_start'`): The method to initialize the kappa parameter. Options include `'warm_start'`, `'uniform'`, and `'dependent'`.
- `reg_function` (default `'l2'`): The regularization function to be applied. Options include `'l2'` or `'std'`.
- `kappa_adapt` (default `False`): Flag to determine if kappa should adapt during training.
- `kappa_update` (default `1.0`): The rate at which kappa is updated in the Lagrangian method.
- `normalization_regularization` (default `False`): Flag to apply regularization to normalization layers.
- `bias_regularization` (default `False`): Flag to apply regularization to bias parameters.
- `embedding_regularization` (default `False`): Flag to apply regularization to embedding parameters.
- `**optimizer_args`: Additional optimizer arguments to pass to the optimizer class.
# for AdamCPR with warm start initialization
optimizer = AdamCPR(model, lr=0.001, kappa_init_param=1000, kappa_init_method='warm_start')
# for AdamCPR with inflection point initialization (no other regularization hyperparameter needed)
optimizer = AdamCPR(model, lr=0.001, kappa_init_method='inflection_point')
```

#### Example usage
### Arguments of AdamCPR

```python
import torch
from pytorch-cpr import apply_CPR
#### Basic Optimizer Parameters
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `params` | iterable | required | Iterable of parameters to optimize or dicts defining parameter groups |
| `lr` | float | 1e-3 | Learning rate. Note: Tensor lr is only supported with `capturable=True` |
| `betas` | tuple(float, float) | (0.9, 0.999) | Coefficients for computing running averages of gradient and its square |
| `eps` | float | 1e-8 | Term added to denominator for numerical stability |
| `amsgrad` | bool | False | Whether to use the AMSGrad variant from ["On the Convergence of Adam and Beyond"](https://openreview.net/forum?id=ryQu7f-RZ) |

#### CPR-Specific Parameters
| Parameter | Type | Default | Description |
|-----------|------|---------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| `kappa_init_method` | str | 'inflection_point' | Method to initialize regularization bound. Options:<br>• `'uniform'`: Fixed value initialization<br>• `'warm_start'`: Delayed initialization<br>• `'dependent'`: Parameter-dependent initialization<br>• `'inflection_point'`: Automated inflection point detection-based initialization. |
| `kappa_init_param` | float | 1000.0 | Initial value for the regularization bound, the meaning depends on the initialization method: <br>• `'uniform'`: The value of the upper bound.<br>• `'warm_start'`: The number of steps before setting the upper bount to the current regularization value. <br>• `'dependent'`: The factor of the reg. value after initialization.<br>• `'inflection_point'`: No param. requiered. |
| `reg_function` | str | 'l2' | Regularization function type. Options:<br>• `'l2'`: L2 norm regularization<br>• `'l1'`: L1 norm regularization<br>• `'std'`: Standard deviation regularization<br>• `'huber'`: Huber norm regularization |

model = YourModel()
optimizer = apply_CPR(model, torch.optim.Adam, kappa_init_param=1000, kappa_init_method='warm_start',
lr=0.001, betas=(0.9, 0.98))
```


## Run examples
Expand Down Expand Up @@ -82,40 +100,53 @@ python examples/train_grokking_task.py --optimizer adamcpr --kappa_init_method d

The CIFAR-100 experiment should run within 20-30 minutes. The results will be saved in the `cifar100` folder.

#### For AdamW:

#### For AdamCPR with L2 norm as regularization function and kappa initialization depending on the parameter initialization:
```bash
python examples/train_resnet.py --optimizer adamw --lr 0.001 --weight_decay 0.001
python examples/train_cifar100_task.py --optimizer adamcpr --lr 0.001 --kappa_init_method dependent --kappa_init_param 1.0
```

#### For Adam + Rescaling:
#### For AdamCPR with L2 norm as regularization function and kappa initialization with warm start:
```bash
python examples/train_resnet.py --optimizer adamw --lr 0.001 --weight_decay 0 --rescale_alpha 0.8
python examples/train_cifar100_task.py --optimizer adamcpr --lr 0.001 --kappa_init_method warm_start --kappa_init_param 1000
```

#### For AdamCPR with L2 norm as regularization function and kappa initialization depending on the parameter initialization:
#### For AdamAdaCPR with L2 norm as regularization function and kappa initialization with inflection point:
```bash
python examples/train_resnet.py --optimizer adamcpr --lr 0.001 --kappa_init_method dependent --kappa_init_param 0.8
python examples/train_cifar100_task.py --optimizer adamcpr --lr 0.001 --kappa_init_method inflection_point
```

#### For AdamCPR with L2 norm as regularization function and kappa initialization with warm start:
#### For AdamW:
```bash
python examples/train_resnet.py --optimizer adamcpr --lr 0.001 --kappa_init_method warm_start --kappa_init_param 1000
python examples/train_cifar100_task.py --optimizer adamw --lr 0.001 --weight_decay 0.001
```

#### For Adam + Rescaling:
```bash
python examples/train_cifar100_task.py --optimizer adamw --lr 0.001 --weight_decay 0 --rescale_alpha 0.8
```

#### For Adam + AWD:
```bash
python examples/train_cifar100_task.py --optimizer adam_awd --lr 0.001 --weight_decay 0.1
```

#### For Adam + AdaDecay:
```bash
python examples/train_cifar100_task.py --optimizer adam_adadecay --lr 0.001 --weight_decay 0.1
```

## Citation

Please cite our paper if you use this code in your work:
Please cite our paper if you use CPR in your work:

```
@misc{franke2023cpr,
title={Constrained Parameter Regularization},
@misc{franke2024cpr,
title={Improving Deep Learning Optimization through Constrained Parameter Regularization},
author={Jörg K. H. Franke and Michael Hefenbrock and Gregor Koehler and Frank Hutter},
year={2023},
eprint={2311.09058},
archivePrefix={arXiv},
primaryClass={cs.LG}
journal={Advances in Neural Information Processing Systems},
volume={37},
year={2024},
}
```

8 changes: 4 additions & 4 deletions examples/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
torch==2.0.1
torchvision==0.15.2
numpy>=1.19.0
torch>=2.3.1
torchvision>=0.18.1
numpy<2
tqdm>=4.50.0
matplotlib>=3.7.2
pytorch-lightning>=2.0.0
pytorch-lightning==2.0.0
tensorboard>=2.15.1
94 changes: 24 additions & 70 deletions examples/train_resnet.py → examples/train_cifar100_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,72 +9,15 @@
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from pytorch_lightning import Callback
from pytorch_lightning.callbacks import LearningRateMonitor

from pytorch_cpr import apply_CPR

torch.set_float32_matmul_precision('high')

class WeightDecayScheduler(Callback):

def __init__(self, schedule_weight_decay: bool, schedule_type: str, scale: float):
super().__init__()
self.schedule_weight_decay = schedule_weight_decay

self.schedule_type = schedule_type

self.decay = scale

self._step_count = 0

@staticmethod
def get_scheduler(schedule_type, num_warmup_steps, decay_factor, num_training_steps):
def fn_scheduler(current_step: int):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
elif schedule_type == 'linear':
return (decay_factor + (1 - decay_factor) *
max(0.0, float(num_training_steps - num_warmup_steps - current_step) / float(
max(1, num_training_steps - num_warmup_steps))))
elif schedule_type == 'cosine':
return (decay_factor + (1 - decay_factor) *
max(0.0, (1 + math.cos(math.pi * (current_step - num_warmup_steps) / float(
max(1, num_training_steps - num_warmup_steps)))) / 2))
elif schedule_type == 'const':
return 1.0

return fn_scheduler
from pytorch_cpr import AdamCPR

def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
self.num_training_steps = trainer.max_steps

self.weight_decay = []
for optim in trainer.optimizers:
for group_idx, group in enumerate(optim.param_groups):
if 'weight_decay' in group:
self.weight_decay.append(group['weight_decay'])

num_warmup_steps = 0

self.scheduler = self.get_scheduler(self.schedule_type, num_warmup_steps, self.decay, self.num_training_steps)

def on_before_optimizer_step(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", optimizer):

if self.schedule_weight_decay:
stats = {}
for group_idx, group in enumerate(optimizer.param_groups):
if 'weight_decay' in group:
group['weight_decay'] = self.weight_decay[group_idx] * self.scheduler(self._step_count)
stats[f"weight_decay/rank_{trainer.local_rank}/group_{group_idx}"] = group['weight_decay']

if trainer.loggers is not None:
for logger in trainer.loggers:
logger.log_metrics(stats, step=trainer.global_step)
self._step_count += 1
torch.set_float32_matmul_precision('high')

### Dataset

### Data
def cifar100_task(cache_dir='./data'):
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4, padding_mode='reflect'),
Expand Down Expand Up @@ -262,11 +205,10 @@ def configure_optimizers(self):
param_groups = wd_group_named_parameters(self.model, weight_decay=self.cfg.weight_decay)
optimizer = torch.optim.AdamW(param_groups, lr=self.cfg.lr, betas=(self.cfg.beta1, self.cfg.beta2))
elif self.cfg.optimizer == 'adamcpr':
optimizer = apply_CPR(self.model, torch.optim.Adam, self.cfg.kappa_init_param, self.cfg.kappa_init_method,
self.cfg.reg_function,
self.cfg.kappa_adapt, self.cfg.kappa_update,
embedding_regularization=True,
lr=self.cfg.lr, betas=(self.cfg.beta1, self.cfg.beta2))
optimizer = AdamCPR(self.model, lr=self.cfg.lr, betas=(self.cfg.beta1, self.cfg.beta2),
kappa_init_param=self.cfg.kappa_init_param,
kappa_init_method=self.cfg.kappa_init_method, reg_function=self.cfg.reg_function,
kappa_update=self.cfg.kappa_update, reg_embedding=True)

if self.cfg.rescale_alpha > 0.0:
with torch.no_grad():
Expand Down Expand Up @@ -324,6 +266,19 @@ def training_step(self, batch, batch_idx):
for n, p in self.model.named_parameters():
if n.endswith("weight"):
p.data *= self.rescale_norm / new_norm



for name, param in self.model.named_parameters():
if param.requires_grad and len(param.shape) >= 2:
self.log(f"params/{name}/mean", param.mean().item())
self.log(f"params/{name}/std", param.std().item())
self.log(f"params/{name}/l2", param.pow(2).sum().item())
if self.cfg.optimizer == 'adamcpr':
state = self.trainer.optimizers[0].state[param]
if 'kappa' in state:
self.log(f"cpr/{name}/lagmul", state['lagmul'].item())
self.log(f"cpr/{name}/kappa", state['kappa'].item())
return loss

def test_step(self, batch, batch_idx):
Expand Down Expand Up @@ -373,7 +328,6 @@ def train_cifar100_task(config):

callbacks = [
LearningRateMonitor(logging_interval='step'),
WeightDecayScheduler(config.schedule_weight_decay, schedule_type=config.wd_schedule_type, scale=config.wd_scale)
]

trainer = pl.Trainer(devices=devices, accelerator="gpu", max_steps=config.max_train_steps,
Expand Down Expand Up @@ -405,21 +359,21 @@ def train_cifar100_task(config):
parser.add_argument("--wd_schedule_type", type=str, default='cosine')
parser.add_argument("--wd_scale", type=float, default=0.1)

parser.add_argument("--lr_warmup_steps", type=int, default=200)
parser.add_argument("--lr_warmup_steps", type=int, default=500)
parser.add_argument("--lr_decay_factor", type=float, default=0.1)
parser.add_argument("--rescale_alpha", type=float, default=0)

parser.add_argument("--kappa_init_param", type=float, default=1000)
parser.add_argument("--kappa_init_method", type=str, default='warm_start')
parser.add_argument("--kappa_init_method", type=str, default='inflection_point')
parser.add_argument("--reg_function", type=str, default='l2')
parser.add_argument("--kappa_update", type=float, default=1.0)
parser.add_argument("--kappa_adapt", action=argparse.BooleanOptionalAction)

parser.add_argument("--start_epoch", type=int, default=1)

parser.add_argument("--log_interval", type=int, default=10)
parser.add_argument("--log_interval", type=int, default=20)
parser.add_argument("--enable_progress_bar", type=bool, default=True)
parser.add_argument("--output_dir", type=str, default='cifar100')
parser.add_argument("--output_dir", type=str, default='experiments/cifar100')
parser.add_argument("--device", type=int, default=0)
args = parser.parse_args()

Expand Down
Loading

0 comments on commit ecedcd7

Please sign in to comment.