-
Notifications
You must be signed in to change notification settings - Fork 19
/
Copy pathsched_del.py
37 lines (31 loc) · 1.27 KB
/
sched_del.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from torch.optim.lr_scheduler import _LRScheduler, CosineAnnealingLR
class DelayerScheduler(_LRScheduler):
""" Starts with a flat lr schedule until it reaches N epochs the applies a scheduler
Args:
optimizer (Optimizer): Wrapped optimizer.
delay_epochs: number of epochs to keep the initial lr until starting aplying the scheduler
after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
"""
def __init__(self, optimizer, delay_epochs, after_scheduler):
self.delay_epochs = delay_epochs
self.after_scheduler = after_scheduler
self.finished = False
super().__init__(optimizer)
def get_lr(self):
if self.last_epoch >= self.delay_epochs:
if not self.finished:
self.after_scheduler.base_lrs = self.base_lrs
self.finished = True
return self.after_scheduler.get_lr()
return self.base_lrs
def step(self, epoch=None):
if self.finished:
if epoch is None:
self.after_scheduler.step(None)
else:
self.after_scheduler.step(epoch - self.delay_epochs)
else:
return super(DelayerScheduler, self).step(epoch)
def DelayedCosineAnnealingLR(optimizer, delay_epochs, cosine_annealing_epochs):
base_scheduler = CosineAnnealingLR(optimizer, cosine_annealing_epochs)
return DelayerScheduler(optimizer, delay_epochs, base_scheduler)