0
点赞
收藏
分享

微信扫一扫

RuntimeError: Error(s) in loading state_dic ,Missing key(s) in state_dict , Unexpected key(s)

残北 2022-04-13 阅读 119
  1. pytorch加载模型错误信息:
RuntimeError: Error(s) in loading state_dict for SimpleDLA:
	Missing key(s) in state_dict: "base.0.weight", "base.1.weight", "base.1.bias", "base.1.running_mean", "base.1.running_var", "layer1.0.weight", "layer1.1.weight", "layer1.1.bias", "layer1.1.running_mean", "layer1.1.running_var", "layer2.0.weight", "layer2.1.weight", "layer2.1.bias", "layer2.1.running_mean", "layer2.1.running_var", "layer3.root.conv.weight", "layer3.root.bn.weight", "layer3.root.bn.bias", "layer3.root.bn.running_mean", "layer3.root.bn.running_var", "layer3.left_tree.conv1.weight", "layer3.left_tree.bn1.weight", "layer3.left_tree.bn1.bias", "layer3.left_tree.bn1.running_mean", "layer3.left_tree.bn1.running_var", "layer3.left_tree.conv2.weight", "layer3.left_tree.bn2.weight", "layer3.left_tree.bn2.bias", "layer3.left_tree.bn2.running_mean", "layer3.left_tree.bn2.running_var", "layer3.left_tree.shortcut.0.weight", "layer3.left_tree.shortcut.1.weight", "layer3.left_tree.shortcut.1.bias", "layer3.left_tree.shortcut.1.running_mean", "layer3.left_tree.shortcut.1.running_var", "layer3.right_tree.conv1.weight", "layer3.right_tree.bn1.weight", "layer3.right_tree.bn1.bias", "layer3.right_tree.bn1.running_mean", "layer3.right_tree.bn1.running_var", "layer3.right_tree.conv2.weight", "layer3.right_tree.bn2.weight", "layer3.right_tree.bn2.bias", "layer3.right_tree.bn2.running_mean", "layer3.right_tree.bn2.running_var", "layer4.root.conv.weight", "layer4.root.bn.weight", "layer4.root.bn.bias", "layer4.root.bn.running_mean", "layer4.root.bn.running_var", "layer4.left_tree.root.conv.weight", "layer4.left_tree.root.bn.weight", "layer4.left_tree.root.bn.bias", "layer4.left_tree.root.bn.running_mean", "layer4.left_tree.root.bn.running_var", "layer4.left_tree.left_tree.conv1.weight", "layer4.left_tree.left_tree.bn1.weight", "layer4.left_tree.left_tree.bn1.bias", "layer4.left_tree.left_tree.bn1.running_mean", "layer4.left_tree.left_tree.bn1.running_var", "layer4.left_tree.left_tree.conv2.weight", "layer4.left_tree.left_tree.bn2.weight", "layer4.left_tree.left_tree.bn2.bias", "layer4.left_tree.left_tree.bn2.running_mean", "layer4.left_tree.left_tree.bn2.running_var", "layer4.left_tree.left_tree.shortcut.0.weight", "layer4.left_tree.left_tree.shortcut.1.weight", "layer4.left_tree.left_tree.shortcut.1.bias", "layer4.left_tree.left_tree.shortcut.1.running_mean", "layer4.left_tree.left_tree.shortcut.1.running_var", "layer4.left_tree.right_tree.conv1.weight", "layer4.left_tree.right_tree.bn1.weight", "layer4.left_tree.right_tree.bn1.bias", "layer4.left_tree.right_tree.bn1.running_mean", "layer4.left_tree.right_tree.bn1.running_var", "layer4.left_tree.right_tree.conv2.weight", "layer4.left_tree.right_tree.bn2.weight", "layer4.left_tree.right_tree.bn2.bias", "layer4.left_tree.right_tree.bn2.running_mean", "layer4.left_tree.right_tree.bn2.running_var", "layer4.right_tree.root.conv.weight", "layer4.right_tree.root.bn.weight", "layer4.right_tree.root.bn.bias", "layer4.right_tree.root.bn.running_mean", "layer4.right_tree.root.bn.running_var", "layer4.right_tree.left_tree.conv1.weight", "layer4.right_tree.left_tree.bn1.weight", "layer4.right_tree.left_tree.bn1.bias", "layer4.right_tree.left_tree.bn1.running_mean", "layer4.right_tree.left_tree.bn1.running_var", "layer4.right_tree.left_tree.conv2.weight", "layer4.right_tree.left_tree.bn2.weight", "layer4.right_tree.left_tree.bn2.bias", "layer4.right_tree.left_tree.bn2.running_mean", "layer4.right_tree.left_tree.bn2.running_var", "layer4.right_tree.right_tree.conv1.weight", "layer4.right_tree.right_tree.bn1.weight", "layer4.right_tree.right_tree.bn1.bias", "layer4.right_tree.right_tree.bn1.running_mean", "layer4.right_tree.right_tree.bn1.running_var", "layer4.right_tree.right_tree.conv2.weight", "layer4.right_tree.right_tree.bn2.weight", "layer4.right_tree.right_tree.bn2.bias", "layer4.right_tree.right_tree.bn2.running_mean", "layer4.right_tree.right_tree.bn2.running_var", "layer5.root.conv.weight", "layer5.root.bn.weight", "layer5.root.bn.bias", "layer5.root.bn.running_mean", "layer5.root.bn.running_var", "layer5.left_tree.root.conv.weight", "layer5.left_tree.root.bn.weight", "layer5.left_tree.root.bn.bias", "layer5.left_tree.root.bn.running_mean", "layer5.left_tree.root.bn.running_var", "layer5.left_tree.left_tree.conv1.weight", "layer5.left_tree.left_tree.bn1.weight", "layer5.left_tree.left_tree.bn1.bias", "layer5.left_tree.left_tree.bn1.running_mean", "layer5.left_tree.left_tree.bn1.running_var", "layer5.left_tree.left_tree.conv2.weight", "layer5.left_tree.left_tree.bn2.weight", "layer5.left_tree.left_tree.bn2.bias", "layer5.left_tree.left_tree.bn2.running_mean", "layer5.left_tree.left_tree.bn2.running_var", "layer5.left_tree.left_tree.shortcut.0.weight", "layer5.left_tree.left_tree.shortcut.1.weight", "layer5.left_tree.left_tree.shortcut.1.bias", "layer5.left_tree.left_tree.shortcut.1.running_mean", "layer5.left_tree.left_tree.shortcut.1.running_var", "layer5.left_tree.right_tree.conv1.weight", "layer5.left_tree.right_tree.bn1.weight", "layer5.left_tree.right_tree.bn1.bias", "layer5.left_tree.right_tree.bn1.running_mean", "layer5.left_tree.right_tree.bn1.running_var", "layer5.left_tree.right_tree.conv2.weight", "layer5.left_tree.right_tree.bn2.weight", "layer5.left_tree.right_tree.bn2.bias", "layer5.left_tree.right_tree.bn2.running_mean", "layer5.left_tree.right_tree.bn2.running_var", "layer5.right_tree.root.conv.weight", "layer5.right_tree.root.bn.weight", "layer5.right_tree.root.bn.bias", "layer5.right_tree.root.bn.running_mean", "layer5.right_tree.root.bn.running_var", "layer5.right_tree.left_tree.conv1.weight", "layer5.right_tree.left_tree.bn1.weight", "layer5.right_tree.left_tree.bn1.bias", "layer5.right_tree.left_tree.bn1.running_mean", "layer5.right_tree.left_tree.bn1.running_var", "layer5.right_tree.left_tree.conv2.weight", "layer5.right_tree.left_tree.bn2.weight", "layer5.right_tree.left_tree.bn2.bias", "layer5.right_tree.left_tree.bn2.running_mean", "layer5.right_tree.left_tree.bn2.running_var", "layer5.right_tree.right_tree.conv1.weight", "layer5.right_tree.right_tree.bn1.weight", "layer5.right_tree.right_tree.bn1.bias", "layer5.right_tree.right_tree.bn1.running_mean", "layer5.right_tree.right_tree.bn1.running_var", "layer5.right_tree.right_tree.conv2.weight", "layer5.right_tree.right_tree.bn2.weight", "layer5.right_tree.right_tree.bn2.bias", "layer5.right_tree.right_tree.bn2.running_mean", "layer5.right_tree.right_tree.bn2.running_var", "layer6.root.conv.weight", "layer6.root.bn.weight", "layer6.root.bn.bias", "layer6.root.bn.running_mean", "layer6.root.bn.running_var", "layer6.left_tree.conv1.weight", "layer6.left_tree.bn1.weight", "layer6.left_tree.bn1.bias", "layer6.left_tree.bn1.running_mean", "layer6.left_tree.bn1.running_var", "layer6.left_tree.conv2.weight", "layer6.left_tree.bn2.weight", "layer6.left_tree.bn2.bias", "layer6.left_tree.bn2.running_mean", "layer6.left_tree.bn2.running_var", "layer6.left_tree.shortcut.0.weight", "layer6.left_tree.shortcut.1.weight", "layer6.left_tree.shortcut.1.bias", "layer6.left_tree.shortcut.1.running_mean", "layer6.left_tree.shortcut.1.running_var", "layer6.right_tree.conv1.weight", "layer6.right_tree.bn1.weight", "layer6.right_tree.bn1.bias", "layer6.right_tree.bn1.running_mean", "layer6.right_tree.bn1.running_var", "layer6.right_tree.conv2.weight", "layer6.right_tree.bn2.weight", "layer6.right_tree.bn2.bias", "layer6.right_tree.bn2.running_mean", "layer6.right_tree.bn2.running_var", "linear.weight", "linear.bias". 
	Unexpected key(s) in state_dict: "module.base.0.weight", "module.base.1.weight", "module.base.1.bias", "module.base.1.running_mean", "module.base.1.running_var", "module.base.1

如下图所示:
在这里插入图片描述
2. 错误意思指:
表明加载模型时参数字典中state_dict[]缺失了一些键,如"base.0.weight", “base.1.weight”, “base.1.bias"等键,出现了一些不必要的键,如"module.base.0.weight”
3. 原因:
模型训练时使用了多张GPU并行训练,出现下面几条语句:

    model = torch.nn.DataParallel(model)
    cudnn.benchmark = True
从而使训练好后保存的模型参数键值对中键开头多出现了"module."字符串,

4.解决方法:将不希望出现的键删除,将缺失的键添加进来,也即是将dict[key,value]键值对中的key全部去掉“module.”前缀,
具体代码如下:

model_cifar = SimpleDLA()
checkpoint = torch.load("pytorch_model.pth", map_location="cpu")['net']
print("key:",checkpoint.keys())
for key in list(checkpoint.keys()):
    if 'modul.' in key:
        checkpoint[key[7:]] = checkpoint[key] #全部key去掉“module.”前缀
        del checkpoint[key]
print("key2:",checkpoint.keys())
model_cifar.load_state_dict(checkpoint)
  1. 注意:模型训练好后只保存了模型的参数,保存格式如下:因此上面加载模型参数时需要根据"net"键使用state_dict[‘net’] 获取模型训练好的参数
   state_dict = {
       'net': net.state_dict(),
        'acc': acc,
         'epoch': epoch,
   }
   torch.save(state_dict, 'pytorch_model.pth')
  1. 另外CSDN上面对这个报错千篇一律说将load_state_dict(state_dict) 改成 model.load_state_dict(state_dict, False),完全不起作用,误人子弟
    在这里插入图片描述
  2. 重要最后是在pytorch上面找到了答案:
    https://discuss.pytorch.org/t/missing-keys-unexpected-keys-in-state-dict-when-loading-self-trained-model/22379/13
    如下图所示
    在这里插入图片描述
举报

相关推荐

0 条评论