0
点赞
收藏
分享

微信扫一扫

BERT书籍阅读笔记(二)PyTorch编程基础(2)

hoohack 2022-02-26 阅读 69

文章目录

优化器

(1)示例

optimizer = toech.optim.Adam([{'params': weight_p, 'weight_decay': 0.001},
		{'params': bias_p, 'weight_decay': 0}], lr= 0.01)

  其中,字典中的params指的是模型中的权重。将具体的权重张量放入优化器,再为参数weight_decay赋值指定权值衰减率,之后便可以对参数进行正则化处理了。
(2)得到参数
  字典中的权重张量weight_p和bias_p可以通过实例化后的模型对象得到,具体代码如下:

weight_p, bias_p = [], []
for name,p in model.named_parameters():#获取模型中所有的参数及参数名字
	if 'bias' in name:#将偏置参数收集起来
		bias_p += [p]
	else:#将权重参数收集起来
		weight_p += [p]

  通过上述代码可将模型中的权重参数和偏置参数分别收集到列表对象weight_p和bias_p中。

保存与载入模型

(1)保存模型
  用state_dict()方法获取模型的全部参数后,用如下代码即可实现保存参数的功能。

torch.save(model.state_dict(), './model.pth')

  执行完该命令以后,会在本地目录生成一个model.pth文件,用来保存模型的参数
(2)载入模型
  使用模型对象的load_state_dict()方法,可以将保存好的模型文件载入模型model里,代码如下:

model.load_state_dict(torch.load('./model.pth'))

  在执行该命令之后,model模型中的值将于model.pth文件中的值保持同步。
注明:这里与Keras不用,Keras保存和读取均是直接为模型本身,而PyTorch保存和读取的为模型参数。

举报

相关推荐

0 条评论