0
点赞
收藏
分享

微信扫一扫

pytorchAPI实现回归

蓝哆啦呀 2022-01-04 阅读 63
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

#准备数据
x=torch.rand([500,1])
y_true=x*3+0.8

#定义模型
class MyLinear(nn.Module):
    def __init__(self):
        #继承父类
        super(MyLinear,self).__init__()
        self.linear=nn.Linear(1,1) #输入特征数,输出特征数   因为这里x的列数为1

    def forward(self,x):
        out=self.linear(x)
        return out

#实例化,优化器实例化,loss实例化
my_linear=MyLinear()
optimiter=torch.optim.SGD(my_linear.parameters(),lr=0.001) #这里存储着更新的参数
loss_fn=nn.MSELoss()

#循环,梯度下降,参数更新
for i in range(4000):
    #向前传播
    my_predict=my_linear(x)
    #计算损失函数
    loss=loss_fn(my_predict,y_true)
    #每次将梯度置为0
    optimiter.zero_grad()#每次将梯度置为0
    #损失函数反向传播求梯度
    loss.backward()
    #更新参数
    optimiter.step()

    if i%50==0:
        params=list(my_linear.parameters())
        print(loss.item(),params[0].item(),params[1].item())

my_linear.eval()  # 设置评估模式
predict = my_linear(x)
plt.scatter(x.data.numpy(), y_true.data.numpy(), c='r')
plt.plot(x.data.numpy(), predict.data.numpy())
plt.show()

在这里插入图片描述
重点看一下nn.Model模块的详解:
PyTorch 源码解读之 nn.Module:核心网络模块接口详解
【PyTorch 源码阅读】 torch.nn.Module 篇

举报

相关推荐

0 条评论