Skip to content

Commit

Permalink
Fix inflection point, add regularize embedding, fix typos. V0.3.1
Browse files Browse the repository at this point in the history
  • Loading branch information
JoergFranke committed Nov 22, 2024
1 parent 82843ea commit d47103f
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 42 deletions.
9 changes: 2 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,6 @@ To replicate the results in the paper, run variations with the following argumen
python examples/train_grokking_task.py --optimizer adamw --weight_decay 0.1
```

#### For Adam + Rescaling:
```bash
python examples/train_grokking_task.py --optimizer adamw --weight_decay 0.0 --rescale 0.8
```

#### For AdamCPR with L2 norm as regularization function:
```bash
python examples/train_grokking_task.py --optimizer adamcpr --kappa_init_method dependent --kappa_init_param 0.8
Expand Down Expand Up @@ -143,9 +138,9 @@ Please cite our paper if you use CPR in your work:
```
@misc{franke2024cpr,
title={Improving Deep Learning Optimization through Constrained Parameter Regularization},
author={Jörg K. H. Franke and Michael Hefenbrock and Gregor Koehler and Frank Hutter},
author={Jörg K. H. Franke and Michael Hefenbrock and Gregor Köhler and Frank Hutter},
journal={Advances in Neural Information Processing Systems},
volume={37},
volume={38},
year={2024},
}
```
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "pytorch-cpr"
version = "0.3.0"
version = "0.3.1"
authors = [
{ name="Jörg Franke", email="[email protected]" },
]
Expand Down
69 changes: 36 additions & 33 deletions pytorch_cpr/adamcpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def single_initilize_kappa(kappa, param, reg_function):


__all__ = ["AdamCPR", "adamcpr"]
HIGHKAPPA=1e6


class AdamCPR(Optimizer):
Expand All @@ -102,7 +103,7 @@ def __init__(
reg_function: str = 'l2',
kappa_update: float = 1.0,
reg_step_size: int = 200,
reg_ema_decay: float = 0.9,
reg_ema_decay: float = 0.99,
reg_embedding: bool = False,
reg_by_lr: bool = False,
amsgrad: bool = False,
Expand Down Expand Up @@ -286,23 +287,23 @@ def _init_group(
state['prev_reg'] = torch.tensor([0.0], dtype=torch.float, device=p.device)
state['prev_reg_gradient'] = torch.tensor([0.0], dtype=torch.float, device=p.device)
state['inflection_point_emas'] = torch.tensor([0.0], dtype=torch.float, device=p.device)
state['lagmul'] = torch.tensor([0.0], dtype=torch.float, device=p.device)

if self.reg_function == 'std':
state['kappa_update'] = torch.tensor([self.kappa_update], dtype=torch.float, device=p.device)
state['kappa_update'] = torch.tensor(self.kappa_update, dtype=torch.float, device=p.device)
else:
state['kappa_update'] = torch.tensor([self.kappa_update / p.numel()], dtype=torch.float,
state['kappa_update'] = torch.tensor(self.kappa_update / p.numel(), dtype=torch.float,
device=p.device)

if self.kappa_init_method == 'uniform':
state["kappa"] = torch.tensor([self.kappa_init_param], dtype=torch.float, device=p.device)
state["kappa"] = torch.tensor(self.kappa_init_param, dtype=torch.float, device=p.device)
elif self.kappa_init_method == 'warm_start':
state["kappa"] = torch.tensor([0.0], dtype=torch.float, device=p.device)
state["kappa"] = torch.tensor(0.0, dtype=torch.float, device=p.device)
elif self.kappa_init_method == 'inflection_point':
# Initialize kappa with 1000, it doesn't have any effect and 1000 bill be used to identify un-set kappa bounds
state["kappa"] = torch.tensor([1000], dtype=torch.float, device=p.device)
# Initialize kappa with HIGHKAPPA, it doesn't have any effect to the training before the inflection
# point and HIGHKAPPA will be used to identify un-set kappa bounds
state["kappa"] = torch.tensor(HIGHKAPPA, dtype=torch.float, device=p.device)
elif self.kappa_init_method == 'dependent':
kappa = torch.tensor([0.0], dtype=torch.float, device=p.device)
kappa = torch.tensor(0.0, dtype=torch.float, device=p.device)
single_initilize_kappa(kappa, p, self.reg_function)
state["kappa"] = self.kappa_init_param * kappa.detach()

Expand Down Expand Up @@ -421,7 +422,7 @@ def step(self, closure=None):


@torch.jit.script
def l2_update(param, lagmul, kappa, kappa_update, reg_by_lr, lr):
def l2_update(param, lagmul, kappa, kappa_update, reg_by_lr: bool, lr: float):
sum_l2norm = param.square().sum()
constraint_value = sum_l2norm - kappa
lagmul.add_(kappa_update * constraint_value).clip_(min=0.)
Expand All @@ -432,7 +433,7 @@ def l2_update(param, lagmul, kappa, kappa_update, reg_by_lr, lr):


@torch.jit.script
def l1_update(param, lagmul, kappa, kappa_update, reg_by_lr, lr):
def l1_update(param, lagmul, kappa, kappa_update, reg_by_lr: bool, lr: float):
sum_l1norm = param.abs().sum()
constraint_value = sum_l1norm - kappa
lagmul.add_(kappa_update * constraint_value).clip_(min=0.)
Expand All @@ -443,7 +444,7 @@ def l1_update(param, lagmul, kappa, kappa_update, reg_by_lr, lr):


@torch.jit.script
def std_update(param, lagmul, kappa, kappa_update, reg_by_lr, lr):
def std_update(param, lagmul, kappa, kappa_update, reg_by_lr: bool, lr: float):
n = param.numel()
std_dev = param.std()
constraint_value = std_dev - kappa
Expand All @@ -460,7 +461,7 @@ def std_update(param, lagmul, kappa, kappa_update, reg_by_lr, lr):


@torch.jit.script
def huber_update(param, lagmul, kappa, kappa_update, reg_by_lr, lr):
def huber_update(param, lagmul, kappa, kappa_update, reg_by_lr: bool, lr: float):
param_abs = param.abs()
huber_idx = param_abs < 1
huber_loss = torch.where(huber_idx, 0.5 * param.square(), param_abs - 0.5)
Expand Down Expand Up @@ -606,8 +607,8 @@ def _single_tensor_adamcpr(
if regularize:
if kappa_init_method == 'inflection_point' and kappa == 1000:
current_l2m = param.square().sum()
inflection_point_ema = reg_ema_decay * inflection_point_ema + (1 - reg_ema_decay) * current_l2m
if step > reg_step_size * 1:
inflection_point_ema.mul_(reg_ema_decay).add_(current_l2m, alpha=1 - reg_ema_decay)
if step > reg_step_size * 1 and step % reg_step_size == 0:
current_reg_gradient = inflection_point_ema - prev_reg
# Peak detection for gradient
if step > reg_step_size * 3 and prev_reg_gradient > current_reg_gradient:
Expand Down Expand Up @@ -857,10 +858,7 @@ def _multi_tensor_adamcpr(
if regularize:
if device_state_steps[0] > warm_start:
if reg_function == 'l2':
square_params = torch._foreach_pow(device_params, 2)
square_sum_params = []
for square_param in square_params:
square_sum_params.append(square_param.sum().unsqueeze(0))
square_sum_params = torch._foreach_pow(torch._foreach_norm(device_params), 2)
torch._foreach_sub_(square_sum_params, device_kappas)
torch._foreach_mul_(square_sum_params, device_kappa_updates)
torch._foreach_add_(device_lagmuls, square_sum_params)
Expand All @@ -874,7 +872,7 @@ def _multi_tensor_adamcpr(
elif reg_function == 'std':
std_params, ns = [], []
for device_param in device_params:
std_params.append(device_param.str().unsqueeze(0))
std_params.append(device_param.std().unsqueeze(0))
ns.append(device_param.numel() - 1)
mean_params = [device_param.mean() for device_param in device_params]
norm_params = torch._foreach_sub(device_params, mean_params)
Expand Down Expand Up @@ -914,7 +912,7 @@ def _multi_tensor_adamcpr(
abs_params = torch._foreach_abs(device_params)
square_params = torch._foreach_pow(device_params, 2)
huber_loss_params, huber_loss_grads = [], []
for param_abs, square_params, device_param in zip(abs_params, square_params, device_params):
for param_abs, square_param, device_param in zip(abs_params, square_params, device_params):
huber_loss_params.append(torch.where(param_abs < 1, 0.5 * square_param, param_abs - 0.5).sum())
huber_loss_grads.append(torch.where(param_abs < 1, device_param, device_param.sign()))
torch._foreach_sub_(huber_loss_params, device_kappas)
Expand All @@ -931,33 +929,38 @@ def _multi_tensor_adamcpr(
raise ValueError(f"Unsupported regularization function: {reg_function}")

if (kappa_init_method == 'inflection_point'
and any([device_kappa == 1000 for device_kappa in device_kappas])
and device_state_steps[0] % reg_step_size == 0):
and any([device_kappa == HIGHKAPPA for device_kappa in device_kappas])):

square_sum_params = torch._foreach_norm(device_params)

for i in range(len(device_params)):
device_inflection_point_emas[i] = reg_ema_decay * device_inflection_point_emas[i] + (
1 - reg_ema_decay) * square_sum_params[i]
if device_state_steps[0] >= reg_step_size:
torch._foreach_mul_(device_inflection_point_emas, reg_ema_decay)
torch._foreach_add_(device_inflection_point_emas, square_sum_params, alpha=(1 - reg_ema_decay))
else:
torch._foreach_copy_(device_inflection_point_emas, square_sum_params)

if device_state_steps[0] > reg_step_size * 1:
if device_state_steps[0] > reg_step_size * 1 and device_state_steps[0] % reg_step_size == 0:

current_gradients = torch._foreach_sub(device_inflection_point_emas, device_prev_regs)
current_gradients = torch._foreach_div(current_gradients, device_inflection_point_emas)
# Update previous values for next iteration
torch._foreach_copy_(device_prev_regs, device_inflection_point_emas)

# Peak detection for gradient
if device_state_steps[0] > reg_step_size * 3:
for i in range(len(device_params)):
if device_prev_reg_gradients[i] > current_gradients[i] and device_kappas[i] == 1000:
if device_prev_reg_gradients[i] > current_gradients[i] > 0.01 \
and device_kappas[i] == HIGHKAPPA:
single_initilize_kappa(device_kappas[i], device_params[i], reg_function)

torch._foreach_copy_(device_prev_regs, device_inflection_point_emas)

if device_state_steps[0] > reg_step_size * 2:
torch._foreach_copy_(device_prev_reg_gradients, current_gradients)

elif kappa_init_method == 'warm_start' and device_state_steps[0] == warm_start:

if reg_function == 'l2':
square_params = torch._foreach_pow(device_params, 2)
new_kappas = [square_param.sum() for square_param in square_params]
torch._foreach_add_(device_kappas, new_kappas)
square_sum_params = torch._foreach_pow(torch._foreach_norm(device_params), 2)
torch._foreach_add_(device_kappas, square_sum_params)

elif reg_function == 'std':
new_kappas = [device_param.std() for device_param in device_params]
Expand Down
2 changes: 2 additions & 0 deletions pytorch_cpr/group_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ def group_parameters_for_cpr_optimizer(model, regularize_embedding=False):
no_decay.add(fpn)
elif pn.endswith('bias'):
no_decay.add(fpn)
elif not regularize_embedding and "embed" in fpn:
no_decay.add(fpn)
elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
decay.add(fpn)
elif isinstance(m, blacklist_weight_modules):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name='pytorch-cpr',
version='0.3.0',
version='0.3.1',
description='Constrained Parameter Regularization for PyTorch',
url='https://github.com/automl/CPR',
author='Joerg Franke',
Expand Down

0 comments on commit d47103f

Please sign in to comment.