Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
JoergFranke committed Nov 24, 2023
1 parent 0ff4c74 commit 5a4eef9
Show file tree
Hide file tree
Showing 17 changed files with 1,420 additions and 4 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ MANIFEST
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
# Unit tests / coverage reports
htmlcov/
.tox/
.nox/
Expand Down Expand Up @@ -158,3 +158,4 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
venv
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@
same "printed page" as the copyright notice for easier
identification within third-party archives.

Copyright [yyyy] [name of copyright owner]
Copyright 2023 Jörg Franke

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down
122 changes: 120 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,120 @@
# CPR
Constraint Parameter Regularization

# Constrained Parameter Regularization

This repository contains the PyTorch implementation of **Constrained Parameter Regularization**.


## Install

```bash
pip install pytroch-cpr
```

## Getting started

### Usage of `apply_CPR` Optimizer Wrapper

The `apply_CPR` function is a wrapper designed to apply CPR (Constrained Parameter Regularization) to a given optimizer by first creating parameter groups and the wrapping the actual optimizer class.

#### Arguments

- `model`: The PyTorch model whose parameters are to be optimized.
- `optimizer_cls`: The class of the optimizer to be used (e.g., `torch.optim.Adam`).
- `kappa_init_param`: Initial value for the kappa parameter in CPR depending on tge initialization method.
- `kappa_init_method` (default `'warm_start'`): The method to initialize the kappa parameter. Options include `'warm_start'`, `'uniform'`, and `'dependent'`.
- `reg_function` (default `'l2'`): The regularization function to be applied. Options include `'l2'` or `'std'`.
- `kappa_adapt` (default `False`): Flag to determine if kappa should adapt during training.
- `kappa_update` (default `1.0`): The rate at which kappa is updated in the Lagrangian method.
- `apply_lr` (default `False`): Flag to apply learning rate for the regularization update.
- `normalization_regularization` (default `False`): Flag to apply regularization to normalization layers.
- `bias_regularization` (default `False`): Flag to apply regularization to bias parameters.
- `embedding_regularization` (default `False`): Flag to apply regularization to embedding parameters.
- `**optimizer_args`: Additional optimizer arguments to pass to the optimizer class.

#### Example usage

```python
import torch
from pytorch-cpr import apply_CPR

model = YourModel()
optimizer = apply_CPR(model, torch.optim.Adam, kappa_init_param=1000, kappa_init_method='warm_start',
lr=0.001, betas=(0.9, 0.98))
```


## Run examples

We provide scripts to replicate the experiments from our paper. Please use a system with at least 1 GPU. Install the package and the requirements for the example:

```bash
python3 -m venv venv
source venv/bin/activate
pip install -r examples/requirements.txt
pip install pytorch-cpr
```


### Modular Addition / Grokking Experiment

The grokking experiment should run within a few minutes. The results will be saved in the `grokking` folder.
To replicate the results in the paper, run variations with the following arguments:

#### For AdamW:
```bash
python examples/train_grokking_task.py --optimizer adamw --weight_decay 0.1
```

#### For Adam + Rescaling:
```bash
python examples/train_grokking_task.py --optimizer adamw --weight_decay 0.0 --rescale 0.8
```

#### For AdamCPR with L2 norm as regularization function:
```bash
python examples/train_grokking_task.py --optimizer adamcpr --kappa_init_method dependent --kappa_init_param 0.8
```



### Image Classification Experiment

The CIFAR-100 experiment should run within 20-30 minutes. The results will be saved in the `cifar100` folder.

#### For AdamW:
```bash
python examples/train_resnet.py --optimizer adamw --lr 0.001 --weight_decay 0.001
```

#### For Adam + Rescaling:
```bash
python examples/train_resnet.py --optimizer adamw --lr 0.001 --weight_decay 0 --rescale_alpha 0.8
```

#### For AdamCPR with L2 norm as regularization function and kappa initialization depending on the parameter initialization:
```bash
python examples/train_resnet.py --optimizer adamcpr --lr 0.001 --kappa_init_method dependent --kappa_init_param 0.8
```

#### For AdamCPR with L2 norm as regularization function and kappa initialization with warm start:
```bash
python examples/train_resnet.py --optimizer adamcpr --lr 0.001 --kappa_init_method warm_start --kappa_init_param 1000
```



## Citation

Please cite our paper if you use this code in your own work:

```
@misc{franke2023new,
title={New Horizons in Parameter Regularization: A Constraint Approach},
author={Jörg K. H. Franke and Michael Hefenbrock and Gregor Koehler and Frank Hutter},
year={2023},
eprint={2311.09058},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
```

7 changes: 7 additions & 0 deletions examples/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
torch==2.0.1
torchvision==0.15.2
numpy>=1.19.0
tqdm>=4.50.0
matplotlib>=3.7.2
pytorch-lightning>=2.0.0
tensorboard>=2.15.1
Loading

0 comments on commit 5a4eef9

Please sign in to comment.