0
点赞
收藏
分享

微信扫一扫

Pytorch只更新预训练模型的部分参数


Pytorch只更新预训练模型的部分参数

假设有一个训练好的模型,并且我们只想微调部分参数。
比如,这里我们只想更新最后一部分的参数:
可以看到,这里的模块叫b4。

Pytorch只更新预训练模型的部分参数_python


我们可以直接通过获取模块的名字来进行更新:

方法1

def update(model,flag=True):
    for name,p in model.named_parameters():
        if "b4" in name:
            print("update only",name)
            p.requires_grad = flag

也就是说 只要模块名字包含b4 就会让他跟新网络。
对应的optimizer 的设置如下:

optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr_)

然后直接训练就行。

方法二

也可以直接 把这些符合条件的 parameters 加入 list中,并传给 optimizer

def update(model,flag=True):
    paras = []
    for name,p in model.named_parameters():
        if "b4" in name:
            print("update only",name)
            p.requires_grad = flag
            paras.append(p)
    return paras

optimizer = torch.optim.Adam(paras, lr=lr_)

直接训练就行。##


举报

相关推荐

0 条评论