0
点赞
收藏
分享

微信扫一扫

pytorch余弦退火学习率和warmup实现

狗啃月亮_Rachel 2022-03-12 阅读 117

不说废话,直接上代码warmup_lr_scheduler.py:

from torch.optim.lr_scheduler import _LRScheduler
import warnings
import math
class CosineAnnealingLRWarmup(_LRScheduler):
    def __init__(self, optimizer, T_max, eta_min=1.0e-5, last_epoch=-1, verbose=False,
                 warmup_steps=2, warmup_start_lr=1.0e-5):
        self.T_max = T_max
        self.eta_min = eta_min
        super(CosineAnnealingLRWarmup, self).__init__(optimizer, last_epoch, verbose)
        self.warmup_steps=warmup_steps
        self.warmup_start_lr = warmup_start_lr
        if warmup_steps>0:
            self.base_warup_factors = [
                (base_lr/warmup_start_lr)**(1.0/self.warmup_steps)
                for base_lr in self.base_lrs
            ]

    def get_lr(self):
        if not self._get_lr_called_within_step:
            warnings.warn("To get the last learning rate computed by the scheduler, "
                          "please use `get_last_lr()`.", UserWarning)
        return self._get_closed_form_lr()

    def _get_closed_form_lr(self):
        if hasattr(self,'warmup_steps'):
            if self.last_epoch<self.warmup_steps:
                return [self.warmup_start_lr*(warmup_factor**self.last_epoch)
                        for warmup_factor in self.base_warup_factors]
            else:
                return [self.eta_min + (base_lr - self.eta_min) *
                        (1 + math.cos(math.pi * (self.last_epoch - self.warmup_steps) / (self.T_max - self.warmup_steps)))*0.5
                        for base_lr in self.base_lrs]
        else:
            return [self.eta_min + (base_lr - self.eta_min) *
                    (1 + math.cos(math.pi * self.last_epoch / self.T_max)) / 2
                    for base_lr in self.base_lrs]

使用方式, optimizer中,对每个group设定的初始学习率:

    lr_scheduler_warmup = CosineAnnealingLRWarmup(optimizer,
                                                  T_max=100,
                                                  eta_min=1.0e-4,
                                                  last_epoch=-1,
                                                  warmup_steps=10,
                                                  warmup_start_lr=1.0e-5)
    for i in range(args.epochs):
        lr_scheduler_warmup.step()
        print(i, 'lr: ', lr_scheduler_warmup.get_last_lr())

打印的日志:

举报

相关推荐

0 条评论