对于大的数据集,小模型往往很难获得较高的精度;知识蒸馏则是使用大模型指导小模型,使小模型学到大模型包含的知识,从而得到更高的精度。原理网上很多,主要是给softmax加温度实现,这里不赘述,直接讲简单实现。
1. 训练教师模型
教师模型用于仅用于指导小模型,不参与部署,因此在条件允许的情况下,可以选择很大的模型,举个例子:
from torch import nn
from torchvision import models
from torchsummary import summary
# 使用预训练的resnet152, 加载预训练权重
model = models.resnet152(pretrained=True)
# 修改全连接层,改为自己的预测类别数
model.fc = nn.Linear(model.fc.in_features, 10)
summary(model, (3, 224, 224), device="cpu")
然后使用上述模型在自己的数据集上进行微调即可。
2. 训练学生模型
此时的学生模型可以选择很小的模型,使用知识蒸馏比常规训练主要多了以下两个步骤:
2.1 定义损失函数
criteon = nn.CrossEntropyLoss()
t_loss = nn.KLDivLoss(reduction="batchmean")
2.2 加载教师模型
if distilled:
t_net = models.resnet152().to(device)
try:
t_net.load_state_dict(torch.load("./model/resnet152.pth"))
print("successful")
except:
print("failed")
2.3 计算loss
# T是温度,用于削弱softmax,可以自己调
if distilled:
loss = 0.7 * criteon(logits, target.to(device)) + 0.3 * t_loss(F.log_softmax(logits/T, dim=1),
F.softmax(t_logits/T, dim=1))
else:
loss = criteon(logits, target.to(device))
虽然看原理觉得很有趣,但用mnist、cifar10、cifar100测试并没有感觉有什么用啊。