PyTorch神经网络
神经网络可以通过torch.nn包构建
 pytorch神经网络上基于自动梯度(autograd)来定义模型:
一个nn.Module构建神经网络层
 一个方法forward(input)它会返回输出(output)
数字图片识别网络:
典型的神经网络训练过程包括以下几点:
- 定义一个包含可训练参数的神经网络
 - 迭代整个输入
 - 通过神经网络处理输入
 - 计算损失loss
 - 反向传播梯度到神经网络的参数
 - 更新网络的参数,典型的用一个简单方法:*weight = weight - learning_rate gradient
 
代码实现
定义神经网络:
import torch
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
    
    def __init__(self):
        super(Net,self).__init__()
# 1 input image channel, 6 output channels, 5x5 square convolution
        self.conv1=nn.Conv2d(1,6,5)
        self.conv2=nn.Conv2d(6, 16, 5)
        # an affine operation: y = Wx + b
        self.fc1=nn.Linear(16*5*5, 120)
        self.fc2=nn.Linear(120,84)
        self.fc3=nn.Linear(84, 10)
        
    def forward(self,x):
         # Max pooling over a (2, 2) window
         x=F.max_pool2d(F.relu(self.conv1(x)),(2,2))
         x=F.max_pool2d(F.relu(self.conv2(x)),2)
         x=x.view(-1,self.num_flat_features(x))
         x=F.relu(self.fc1(x))
         x=F.relu(self.fc2(x))
         x=self.fc3(x)
         
         return x
    def num_flat_features(self,x):
         size=x.size()[1:]
         num_features=1
         for s in size:
             num_features *=s 
         return num_features
     
net=Net()
print(net)迭代输入、通过神经网络处理输入
#前向传播:
input=torch.randn(1,1,32,32)
out=net(input)
print(out)
#反向传播:
#将网络参数重置为0
net.zero_grad()
out.backward(torch.randn(1,10))计算损失loss
output=net(input)
target=torch.randn(10)
target=target.view(1,-1)
criterion=nn.MSELoss()
loss=criterion(output,target)
print(loss)反向传播梯度到神经网络参数
net.zero_grad()
#conv1之前的参数
print(net.conv1.bias.grad)
#反向传播后参数
loss.backward()
print(net.conv1.bias.grad)更新网络的参数
 使用optimizer优化器,可以使用不同的更新规则,类似于SGD,Nesterov-SGD, Adam,RMSProp, 等。
learning_rate = 0.01
for f in net.parameters():
f.data.sub_(f.grad.data * learning_rate)
#使用优化器optimizer实现更新
import torch.optim as optim
optimizer=optim.SGD(net.parameters(),lr=0.01)
optimizer.zero_grad()
output=net(input)
loss=criterion(output,target)
loss.backward()
optimizer.step()                
                










