0
点赞
收藏
分享

微信扫一扫

paddlepaddle 13 RIFLE:Backpropagation in Depth for Deep Transfer Learning through Re-Initia

云朵里的佛光 2022-02-12 阅读 46

预训练模型作为当下迁移学习的一种主流方法,让“从零开始训练一个新模型”的时代一去不复返。这类在大型数据集上训练好的模型进行微调后,便能够较好地适配类似的新目标任务,可以极大地节约时间和成本。不过,天下没有免费的午餐,这种方法看似“一劳永逸”,也需要付出一定的代价,其中的一大问题便是,由于反向传播带给深层卷积层的更新较小,微调得到的模型往往被“吸引”在预训练模型附近,无法得到充分更新

对此,百度的一篇 ICML 2020工作《RIFLE: Backpropagation in Depth for Deep Transfer Learning through Re-Initializing the Fully-connected LayEr》,提出了一种简单有效的策略RIFLE,通过周期性的重新初始化全连接层,给深层的特征提取网络带来有意义的更新,提升低层次特征的学习,从而显著提升迁移学习的效果。其项目开源地址为https://github.com/AgentMaker/RIFLE_Module/

虽然百度官方开源了应用于迁移学习的RIFLE,但是其对内置模型的迁移学习是不友好的。因为RIFLE的使用需要将微调层进行多次初始化,也就是是训练时的对象必须一直都是paddle.nn.Layer。然而,训练过程中保持的模型是无法再次被加载为paddle.nn.Layer对象的,只能被加载为TranslatedLayer对象,这样的对象是无法获取其输出layer,并用RIFLE重新初始化参数的。

RIFLE其核心思想就是对全连接层进行周期性的初始化,需要克服的是模型参数保存与加载,其次才是全连接层的周期性初始化。

1、模型参数的保存与加载

 模型的初始化:模型的初始化有多种方法,我们可以通过以下方式多次对模型进行初始化,然后加载权重。

1、内置的模型(分类模型)

#===================构造模型和相关训练参数==========================
print('飞桨框架内置模型:', paddle.vision.models.__all__)
model=paddle.vision.resnet18()

2、hub内置的模型(分类模型)

import paddle
from paddle import hub
# PaddleClas
models = hub.list('PaddlePaddle/PaddleClas:develop', source='github', force_reload=False)
print("PaddleClas models:",models)
 
#model = hub.load('PaddlePaddle/PaddleClas:develop', 'densenet121', source='local', force_reload=False, pretrained=True)
model = hub.load('PaddlePaddle/PaddleClas:develop', 'densenet121', source='github' , pretrained=True)

使用以下代码可以获取模型输出层的name 

fc_param=[name for name,atom_layer in model.named_sublayers()]
print(fc_param[-1])#model.parameters()

然后,根据name重新修改模型的输出层

resnet18.fc=nn.Sequential(
    nn.Linear(512, 128),
    nn.ReLU(),
    nn.Dropout(0.2),
    nn.Linear(128, 17)
)

参数的保存: 

paddle.save(model.state_dict(), "m.parameters", protocol=4)

参数的加载:加载参数前其实有一个前提就是模型的初始化

state_dict = paddle.load("m.parameters")
model.set_state_dict(state_dict)

2、全连接层的周期性初始化

周期性初始化全连接层,就是就是替换掉原理的全连接层,代码很简单。下列代码就是每10次重新初始化一次全连接层。

for i in range(Epochs):
    if i%10==0:
        resnet18.fc=nn.Sequential(
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 17)
        )
举报

相关推荐

0 条评论