0
点赞
收藏
分享

微信扫一扫

加载网络权重,去除全连接层的权重

我是小瘦子哟 2022-04-14 阅读 108
python

仅作为记录,大佬请跳过。

感谢老师的示范。

fc_keys = [k for k in state_dict.keys() if "fc" in k]
for k in fc_keys:
    del state_dict[k]

查看设计的网络加载的网络权重的有没有不同的层

def load_from_pretrained(self, ckpt_path):
    print(f"==============> Loading weight {ckpt_path} for fine-tuning......")
    ckpt = torch.load(ckpt_path, map_location='cpu')
    state_dict = ckpt

    fc_keys = [k for k in state_dict.keys() if "fc" in k]
    for k in fc_keys:
        del state_dict[k]

    from pprint import pprint
    missing_keys, unexpected_keys = self.load_state_dict(state_dict, strict=False)
    print('missing_keys = ')
    pprint(missing_keys)
    print('unexpected_keys = ')
    pprint(unexpected_keys)
    print(f"=> loaded successfully '{ckpt_path}'")
    print('ok')

其中,self指设计的网络

举报

相关推荐

0 条评论