0
点赞
收藏
分享

微信扫一扫

torch 的 forward 和 backward


Criterions有其forward和backward函数
​​​https://github.com/torch/nn/blob/master/doc/criterion.md​​​
Module也有其forward和backward函数
​​​https://github.com/torch/nn/blob/master/doc/module.md​​

Module的forward函数最简单,就是输入input得到output

Module的backward看下这个线性回归的例子

require 'torch'
require 'nn'
require 'gnuplot'

month = torch.range(1,10)
price = torch.Tensor{28993,29110,29436,30791,33384,36762,39900,39972,40230,40146}

model = nn.Linear(1, 1)
criterion = nn.MSECriterion()

month_train = month:reshape(10,1)
price_train = price:reshape(10,1)

for i=1,1000 do
price_predict = model:forward(month_train) -- 输入 -> 输出
err = criterion:forward(price_predict, price_train) -- 输出,正确 -> loss值
print(i, err)
model:zeroGradParameters()
gradient = criterion:backward(price_predict, price_train) -- 输出,正确 -> 梯度
model:backward(month_train, gradient) -- 输入,梯度
model:updateParameters(0.01)
end

month_predict = torch.range(1,12)
local price_predict = model:forward(month_predict:reshape(12,1))
print(price_predict)

gnuplot.pngfigure('plot.png')
gnuplot.plot({month, price}, {month_predict, price_predict})
gnuplot.plotflush()


举报

相关推荐

0 条评论