不说废话,直接上代码warmup_lr_scheduler.py:
from torch.optim.lr_scheduler import _LRScheduler
import warnings
import math
class CosineAnnealingLRWarmup(_LRScheduler):
def __init__(self, optimizer, T_max, eta_min=1.0e-5, last_epoch=-1, verbose=False,
warmup_steps=2, warmup_start_lr=1.0e-5):
self.T_max = T_max
self.eta_min = eta_min
super(CosineAnnealingLRWarmup, self).__init__(optimizer, last_epoch, verbose)
self.warmup_steps=warmup_steps
self.warmup_start_lr = warmup_start_lr
if warmup_steps>0:
self.base_warup_factors = [
(base_lr/warmup_start_lr)**(1.0/self.warmup_steps)
for base_lr in self.base_lrs
]
def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn("To get the last learning rate computed by the scheduler, "
"please use `get_last_lr()`.", UserWarning)
return self._get_closed_form_lr()
def _get_closed_form_lr(self):
if hasattr(self,'warmup_steps'):
if self.last_epoch<self.warmup_steps:
return [self.warmup_start_lr*(warmup_factor**self.last_epoch)
for warmup_factor in self.base_warup_factors]
else:
return [self.eta_min + (base_lr - self.eta_min) *
(1 + math.cos(math.pi * (self.last_epoch - self.warmup_steps) / (self.T_max - self.warmup_steps)))*0.5
for base_lr in self.base_lrs]
else:
return [self.eta_min + (base_lr - self.eta_min) *
(1 + math.cos(math.pi * self.last_epoch / self.T_max)) / 2
for base_lr in self.base_lrs]
使用方式, optimizer中,对每个group设定的初始学习率:
lr_scheduler_warmup = CosineAnnealingLRWarmup(optimizer,
T_max=100,
eta_min=1.0e-4,
last_epoch=-1,
warmup_steps=10,
warmup_start_lr=1.0e-5)
for i in range(args.epochs):
lr_scheduler_warmup.step()
print(i, 'lr: ', lr_scheduler_warmup.get_last_lr())
打印的日志: