import torch
import torch.nn as nn
import matplotlib.pyplot as plt
#准备数据
x=torch.rand([500,1])
y_true=x*3+0.8
#定义模型
class MyLinear(nn.Module):
def __init__(self):
#继承父类
super(MyLinear,self).__init__()
self.linear=nn.Linear(1,1) #输入特征数,输出特征数 因为这里x的列数为1
def forward(self,x):
out=self.linear(x)
return out
#实例化,优化器实例化,loss实例化
my_linear=MyLinear()
optimiter=torch.optim.SGD(my_linear.parameters(),lr=0.001) #这里存储着更新的参数
loss_fn=nn.MSELoss()
#循环,梯度下降,参数更新
for i in range(4000):
#向前传播
my_predict=my_linear(x)
#计算损失函数
loss=loss_fn(my_predict,y_true)
#每次将梯度置为0
optimiter.zero_grad()#每次将梯度置为0
#损失函数反向传播求梯度
loss.backward()
#更新参数
optimiter.step()
if i%50==0:
params=list(my_linear.parameters())
print(loss.item(),params[0].item(),params[1].item())
my_linear.eval() # 设置评估模式
predict = my_linear(x)
plt.scatter(x.data.numpy(), y_true.data.numpy(), c='r')
plt.plot(x.data.numpy(), predict.data.numpy())
plt.show()
重点看一下nn.Model模块的详解:
PyTorch 源码解读之 nn.Module:核心网络模块接口详解
【PyTorch 源码阅读】 torch.nn.Module 篇