PyTorch从1.6.0版本以后开始支持Stochastic Weight Averaging。
That is, after the conventional training of an object detector with the initial learning rate and the ending learning rate
, train it for an extra 12 epochs using the cyclical learning rates (
,
) for each epoch, and then average these 12 checkpoints as the final detection model.
SWA理论认为平均多个SGD优化轨迹上的多个模型,最终模型泛化性能更好。如下图
SGD倾向于收敛到loss的平稳区域,由于权重空间的维度比较高,平稳区域的大部分都处于边界,SGD通常只会走到这些平稳区域的边界。SWA通过平均多个SGD的权重参数,使其能够达到平稳区域的中心。
Object Detection
SWA Object Detection在目标检测任务上尝试了不同的epoch和固定学习率或者循环余弦退火学习率,最后发现使用12个epoch和循环余弦退火学习率效果最好。
PyTorch示例代码:
loader, optimizer, model, loss_fn = ...
swa_model = torch.optim.swa_utils.AveragedModel(model)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=300)
swa_start = 160
swa_scheduler = SWALR(optimizer, swa_lr=0.05)
for epoch in range(300):
for input, target in loader:
optimizer.zero_grad()
loss_fn(model(input), target).backward()
optimizer.step()
if i > swa_start:
swa_model.update_parameters(model)
swa_scheduler.step()
else:
scheduler.step()
# Update bn statistics for the swa_model at the end
torch.optim.swa_utils.update_bn(loader, swa_model)
# Use swa_model to make predictions on test data
preds = swa_model(test_input)
参考资料
1 Stochastic Weight Averaging blog
2 Stochastic Weight Averaging in PyTorch
3 Stochastic Weight Averaging docs
4 SWA Object Detection
5 Averaging Weights Leads to Wider Optima and Better Generalization