0
点赞
收藏
分享

微信扫一扫

params argument given to the optimizer should be an iterable


params argument given to the optimizer should be an iterable_pytorch

net=nn.Linear(100,1)

optimizer_w=torch.optim.SGD(net.weight,lr=0.03,weight_decay=wd)

>出现问题的原因是SGD函数所需的param参数应该是迭代器或者是张量对应的字典
>但是此时net.weight就是个普通张量
>解决办法:
>optimizer_w=torch.optim.SGD([net.weight],lr=0.03,weight_decay=wd)
>将net.weight变成列表,可迭代


举报

相关推荐

0 条评论