0
点赞
收藏
分享

微信扫一扫

【Pytorch】保存和加载模型

颜娘娘的碎碎念 2022-05-04 阅读 103

官方:https://pytorch.org/tutorials/beginner/basics/saveloadrun_tutorial.html

【参考:PyTorch保存和加载模型_正则化的博客-CSDN博客】

保存和加载权重参数

PyTorch 模型将学到的参数存储在内部状态字典中,称为 state _ dict。可以通过 torch.save 方法持久化这些内容:

#----把模型中的参数保存成字典的形式, 不保存网络模型的结构, 官方推荐----
torch.save(model.state_dict(), 'params_name.pth') #保存的文件名后缀一般是.pt或.pth

要加载模型权重,您需要首先创建相同模型的实例,然后使用 load _ state _ dict ()方法加载参数。

#----加载----
model=Model() #定义模型结构
model.load_state_dict(torch.load('params_name.pth'))  #加载模型参数

保存和加载带权重参数的模型

#----保存----
torch.save(model, 'model_name.pth')
#----加载----
model = torch.load('model_name.pth')
举报

相关推荐

0 条评论