- pytorch加载模型错误信息:
RuntimeError: Error(s) in loading state_dict for SimpleDLA:
Missing key(s) in state_dict: "base.0.weight", "base.1.weight", "base.1.bias", "base.1.running_mean", "base.1.running_var", "layer1.0.weight", "layer1.1.weight", "layer1.1.bias", "layer1.1.running_mean", "layer1.1.running_var", "layer2.0.weight", "layer2.1.weight", "layer2.1.bias", "layer2.1.running_mean", "layer2.1.running_var", "layer3.root.conv.weight", "layer3.root.bn.weight", "layer3.root.bn.bias", "layer3.root.bn.running_mean", "layer3.root.bn.running_var", "layer3.left_tree.conv1.weight", "layer3.left_tree.bn1.weight", "layer3.left_tree.bn1.bias", "layer3.left_tree.bn1.running_mean", "layer3.left_tree.bn1.running_var", "layer3.left_tree.conv2.weight", "layer3.left_tree.bn2.weight", "layer3.left_tree.bn2.bias", "layer3.left_tree.bn2.running_mean", "layer3.left_tree.bn2.running_var", "layer3.left_tree.shortcut.0.weight", "layer3.left_tree.shortcut.1.weight", "layer3.left_tree.shortcut.1.bias", "layer3.left_tree.shortcut.1.running_mean", "layer3.left_tree.shortcut.1.running_var", "layer3.right_tree.conv1.weight", "layer3.right_tree.bn1.weight", "layer3.right_tree.bn1.bias", "layer3.right_tree.bn1.running_mean", "layer3.right_tree.bn1.running_var", "layer3.right_tree.conv2.weight", "layer3.right_tree.bn2.weight", "layer3.right_tree.bn2.bias", "layer3.right_tree.bn2.running_mean", "layer3.right_tree.bn2.running_var", "layer4.root.conv.weight", "layer4.root.bn.weight", "layer4.root.bn.bias", "layer4.root.bn.running_mean", "layer4.root.bn.running_var", "layer4.left_tree.root.conv.weight", "layer4.left_tree.root.bn.weight", "layer4.left_tree.root.bn.bias", "layer4.left_tree.root.bn.running_mean", "layer4.left_tree.root.bn.running_var", "layer4.left_tree.left_tree.conv1.weight", "layer4.left_tree.left_tree.bn1.weight", "layer4.left_tree.left_tree.bn1.bias", "layer4.left_tree.left_tree.bn1.running_mean", "layer4.left_tree.left_tree.bn1.running_var", "layer4.left_tree.left_tree.conv2.weight", "layer4.left_tree.left_tree.bn2.weight", "layer4.left_tree.left_tree.bn2.bias", "layer4.left_tree.left_tree.bn2.running_mean", "layer4.left_tree.left_tree.bn2.running_var", "layer4.left_tree.left_tree.shortcut.0.weight", "layer4.left_tree.left_tree.shortcut.1.weight", "layer4.left_tree.left_tree.shortcut.1.bias", "layer4.left_tree.left_tree.shortcut.1.running_mean", "layer4.left_tree.left_tree.shortcut.1.running_var", "layer4.left_tree.right_tree.conv1.weight", "layer4.left_tree.right_tree.bn1.weight", "layer4.left_tree.right_tree.bn1.bias", "layer4.left_tree.right_tree.bn1.running_mean", "layer4.left_tree.right_tree.bn1.running_var", "layer4.left_tree.right_tree.conv2.weight", "layer4.left_tree.right_tree.bn2.weight", "layer4.left_tree.right_tree.bn2.bias", "layer4.left_tree.right_tree.bn2.running_mean", "layer4.left_tree.right_tree.bn2.running_var", "layer4.right_tree.root.conv.weight", "layer4.right_tree.root.bn.weight", "layer4.right_tree.root.bn.bias", "layer4.right_tree.root.bn.running_mean", "layer4.right_tree.root.bn.running_var", "layer4.right_tree.left_tree.conv1.weight", "layer4.right_tree.left_tree.bn1.weight", "layer4.right_tree.left_tree.bn1.bias", "layer4.right_tree.left_tree.bn1.running_mean", "layer4.right_tree.left_tree.bn1.running_var", "layer4.right_tree.left_tree.conv2.weight", "layer4.right_tree.left_tree.bn2.weight", "layer4.right_tree.left_tree.bn2.bias", "layer4.right_tree.left_tree.bn2.running_mean", "layer4.right_tree.left_tree.bn2.running_var", "layer4.right_tree.right_tree.conv1.weight", "layer4.right_tree.right_tree.bn1.weight", "layer4.right_tree.right_tree.bn1.bias", "layer4.right_tree.right_tree.bn1.running_mean", "layer4.right_tree.right_tree.bn1.running_var", "layer4.right_tree.right_tree.conv2.weight", "layer4.right_tree.right_tree.bn2.weight", "layer4.right_tree.right_tree.bn2.bias", "layer4.right_tree.right_tree.bn2.running_mean", "layer4.right_tree.right_tree.bn2.running_var", "layer5.root.conv.weight", "layer5.root.bn.weight", "layer5.root.bn.bias", "layer5.root.bn.running_mean", "layer5.root.bn.running_var", "layer5.left_tree.root.conv.weight", "layer5.left_tree.root.bn.weight", "layer5.left_tree.root.bn.bias", "layer5.left_tree.root.bn.running_mean", "layer5.left_tree.root.bn.running_var", "layer5.left_tree.left_tree.conv1.weight", "layer5.left_tree.left_tree.bn1.weight", "layer5.left_tree.left_tree.bn1.bias", "layer5.left_tree.left_tree.bn1.running_mean", "layer5.left_tree.left_tree.bn1.running_var", "layer5.left_tree.left_tree.conv2.weight", "layer5.left_tree.left_tree.bn2.weight", "layer5.left_tree.left_tree.bn2.bias", "layer5.left_tree.left_tree.bn2.running_mean", "layer5.left_tree.left_tree.bn2.running_var", "layer5.left_tree.left_tree.shortcut.0.weight", "layer5.left_tree.left_tree.shortcut.1.weight", "layer5.left_tree.left_tree.shortcut.1.bias", "layer5.left_tree.left_tree.shortcut.1.running_mean", "layer5.left_tree.left_tree.shortcut.1.running_var", "layer5.left_tree.right_tree.conv1.weight", "layer5.left_tree.right_tree.bn1.weight", "layer5.left_tree.right_tree.bn1.bias", "layer5.left_tree.right_tree.bn1.running_mean", "layer5.left_tree.right_tree.bn1.running_var", "layer5.left_tree.right_tree.conv2.weight", "layer5.left_tree.right_tree.bn2.weight", "layer5.left_tree.right_tree.bn2.bias", "layer5.left_tree.right_tree.bn2.running_mean", "layer5.left_tree.right_tree.bn2.running_var", "layer5.right_tree.root.conv.weight", "layer5.right_tree.root.bn.weight", "layer5.right_tree.root.bn.bias", "layer5.right_tree.root.bn.running_mean", "layer5.right_tree.root.bn.running_var", "layer5.right_tree.left_tree.conv1.weight", "layer5.right_tree.left_tree.bn1.weight", "layer5.right_tree.left_tree.bn1.bias", "layer5.right_tree.left_tree.bn1.running_mean", "layer5.right_tree.left_tree.bn1.running_var", "layer5.right_tree.left_tree.conv2.weight", "layer5.right_tree.left_tree.bn2.weight", "layer5.right_tree.left_tree.bn2.bias", "layer5.right_tree.left_tree.bn2.running_mean", "layer5.right_tree.left_tree.bn2.running_var", "layer5.right_tree.right_tree.conv1.weight", "layer5.right_tree.right_tree.bn1.weight", "layer5.right_tree.right_tree.bn1.bias", "layer5.right_tree.right_tree.bn1.running_mean", "layer5.right_tree.right_tree.bn1.running_var", "layer5.right_tree.right_tree.conv2.weight", "layer5.right_tree.right_tree.bn2.weight", "layer5.right_tree.right_tree.bn2.bias", "layer5.right_tree.right_tree.bn2.running_mean", "layer5.right_tree.right_tree.bn2.running_var", "layer6.root.conv.weight", "layer6.root.bn.weight", "layer6.root.bn.bias", "layer6.root.bn.running_mean", "layer6.root.bn.running_var", "layer6.left_tree.conv1.weight", "layer6.left_tree.bn1.weight", "layer6.left_tree.bn1.bias", "layer6.left_tree.bn1.running_mean", "layer6.left_tree.bn1.running_var", "layer6.left_tree.conv2.weight", "layer6.left_tree.bn2.weight", "layer6.left_tree.bn2.bias", "layer6.left_tree.bn2.running_mean", "layer6.left_tree.bn2.running_var", "layer6.left_tree.shortcut.0.weight", "layer6.left_tree.shortcut.1.weight", "layer6.left_tree.shortcut.1.bias", "layer6.left_tree.shortcut.1.running_mean", "layer6.left_tree.shortcut.1.running_var", "layer6.right_tree.conv1.weight", "layer6.right_tree.bn1.weight", "layer6.right_tree.bn1.bias", "layer6.right_tree.bn1.running_mean", "layer6.right_tree.bn1.running_var", "layer6.right_tree.conv2.weight", "layer6.right_tree.bn2.weight", "layer6.right_tree.bn2.bias", "layer6.right_tree.bn2.running_mean", "layer6.right_tree.bn2.running_var", "linear.weight", "linear.bias".
Unexpected key(s) in state_dict: "module.base.0.weight", "module.base.1.weight", "module.base.1.bias", "module.base.1.running_mean", "module.base.1.running_var", "module.base.1
如下图所示:
2. 错误意思指:
表明加载模型时参数字典中state_dict[]缺失了一些键,如"base.0.weight", “base.1.weight”, “base.1.bias"等键,出现了一些不必要的键,如"module.base.0.weight”
3. 原因:
模型训练时使用了多张GPU并行训练,出现下面几条语句:
model = torch.nn.DataParallel(model)
cudnn.benchmark = True
从而使训练好后保存的模型参数键值对中键开头多出现了"module."字符串,
4.解决方法:将不希望出现的键删除,将缺失的键添加进来,也即是将dict[key,value]键值对中的key全部去掉“module.”前缀,
具体代码如下:
model_cifar = SimpleDLA()
checkpoint = torch.load("pytorch_model.pth", map_location="cpu")['net']
print("key:",checkpoint.keys())
for key in list(checkpoint.keys()):
if 'modul.' in key:
checkpoint[key[7:]] = checkpoint[key] #全部key去掉“module.”前缀
del checkpoint[key]
print("key2:",checkpoint.keys())
model_cifar.load_state_dict(checkpoint)
- 注意:模型训练好后只保存了模型的参数,保存格式如下:因此上面加载模型参数时需要根据"net"键使用state_dict[‘net’] 获取模型训练好的参数
state_dict = {
'net': net.state_dict(),
'acc': acc,
'epoch': epoch,
}
torch.save(state_dict, 'pytorch_model.pth')
- 另外CSDN上面对这个报错千篇一律说将load_state_dict(state_dict) 改成 model.load_state_dict(state_dict, False),完全不起作用,误人子弟
- 重要最后是在pytorch上面找到了答案:
https://discuss.pytorch.org/t/missing-keys-unexpected-keys-in-state-dict-when-loading-self-trained-model/22379/13
如下图所示