官方:https://pytorch.org/tutorials/beginner/basics/saveloadrun_tutorial.html
【参考:PyTorch保存和加载模型_正则化的博客-CSDN博客】
保存和加载权重参数
PyTorch 模型将学到的参数存储在内部状态字典中,称为 state _ dict。可以通过 torch.save 方法持久化这些内容:
#----把模型中的参数保存成字典的形式, 不保存网络模型的结构, 官方推荐----
torch.save(model.state_dict(), 'params_name.pth') #保存的文件名后缀一般是.pt或.pth
要加载模型权重,您需要首先创建相同模型的实例,然后使用 load _ state _ dict ()方法加载参数。
#----加载----
model=Model() #定义模型结构
model.load_state_dict(torch.load('params_name.pth')) #加载模型参数
保存和加载带权重参数的模型
#----保存----
torch.save(model, 'model_name.pth')
#----加载----
model = torch.load('model_name.pth')