-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Custom scheduler and optimizer classes
- Loading branch information
Showing
6 changed files
with
159 additions
and
104 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import torch | ||
|
||
|
||
class TripleLRSGD: | ||
def __init__(self, model: torch.nn.Module, params: dict) -> None: | ||
"""TripleLRSGD is a custom optimizer that separates weights into | ||
batch norm weights, regular weights, and biases. | ||
@type model: torch.nn.Module | ||
@param model: The model to be used | ||
@type params: dict | ||
@param params: The parameters to be used for the optimizer | ||
""" | ||
self.model = model | ||
self.params = params | ||
|
||
def create_optimizer(self): | ||
batch_norm_weights, regular_weights, biases = [], [], [] | ||
|
||
for module in self.model.modules(): | ||
if hasattr(module, "bias") and isinstance( | ||
module.bias, torch.nn.Parameter | ||
): | ||
biases.append(module.bias) | ||
if isinstance(module, torch.nn.BatchNorm2d): | ||
batch_norm_weights.append(module.weight) | ||
elif hasattr(module, "weight") and isinstance( | ||
module.weight, torch.nn.Parameter | ||
): | ||
regular_weights.append(module.weight) | ||
|
||
optimizer = torch.optim.SGD( | ||
[ | ||
{ | ||
"params": batch_norm_weights, | ||
"lr": self.params["lr"], | ||
"momentum": self.params["momentum"], | ||
"nesterov": self.params["nesterov"], | ||
}, | ||
{ | ||
"params": regular_weights, | ||
"weight_decay": self.params["weight_decay"], | ||
}, | ||
{"params": biases}, | ||
], | ||
lr=self.params["lr"], | ||
momentum=self.params["momentum"], | ||
nesterov=self.params["nesterov"], | ||
) | ||
|
||
return optimizer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
import math | ||
|
||
import numpy as np | ||
import torch | ||
|
||
|
||
class TripleLRScheduler: | ||
def __init__( | ||
self, | ||
optimizer: torch.optim.Optimizer, | ||
params: dict, | ||
epochs: int, | ||
max_stepnum: int, | ||
) -> None: | ||
"""TripleLRScheduler is a custom learning rate scheduler that | ||
combines a cosine annealing. | ||
@type optimizer: torch.optim.Optimizer | ||
@param optimizer: The optimizer to be used | ||
@type parmas: dict | ||
@param parmas: The parameters to be used for the scheduler | ||
@type epochs: int | ||
@param epochs: The number of epochs to train for | ||
@type max_stepnum: int | ||
@param max_stepnum: The maximum number of steps to train for | ||
""" | ||
self.optimizer = optimizer | ||
self.params = params | ||
self.max_stepnum = max_stepnum | ||
self.warmup_stepnum = max( | ||
round(self.params["warmup_epochs"] * self.max_stepnum), 1000 | ||
) | ||
self.step = 0 | ||
self.lrf = self.params["lre"] / self.optimizer.defaults["lr"] | ||
self.lf = ( | ||
lambda x: ((1 - math.cos(x * math.pi / epochs)) / 2) | ||
* (self.lrf - 1) | ||
+ 1 | ||
) | ||
|
||
def create_scheduler(self): | ||
scheduler = torch.optim.lr_scheduler.LambdaLR( | ||
self.optimizer, lr_lambda=self.lf | ||
) | ||
return scheduler | ||
|
||
def update_learning_rate(self, current_epoch: int) -> None: | ||
"""Update the learning rate based on the current epoch. | ||
@type current_epoch: int | ||
@param current_epoch: The current epoch | ||
""" | ||
self.step = self.step % self.max_stepnum | ||
curr_step = self.step + self.max_stepnum * current_epoch | ||
|
||
if curr_step <= self.warmup_stepnum: | ||
for k, param in enumerate(self.optimizer.param_groups): | ||
warmup_bias_lr = ( | ||
self.params["warmup_bias_lr"] if k == 2 else 0.0 | ||
) | ||
param["lr"] = np.interp( | ||
curr_step, | ||
[0, self.warmup_stepnum], | ||
[ | ||
warmup_bias_lr, | ||
self.optimizer.defaults["lr"] * self.lf(current_epoch), | ||
], | ||
) | ||
if "momentum" in param: | ||
self.optimizer.defaults["momentum"] = np.interp( | ||
curr_step, | ||
[0, self.warmup_stepnum], | ||
[ | ||
self.params["warmup_momentum"], | ||
self.optimizer.defaults["momentum"], | ||
], | ||
) | ||
|
||
self.step += 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters