0
点赞
收藏
分享

微信扫一扫

CNN手写数字识别——使用MSE误差函数

m逆光生长 2022-04-18 阅读 89
python

使用MSE误差函数进行手写字体识别时,需要保证label为one-hot形式

可以通过target-transform参数调整label数据为one-hot,也可以在学习过程中调整

本文采用后者,代码如下

import torch
from torch import nn, optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import mnist
import numpy as np

#网络模型
class CNN(nn.Module):
  def __init__(self):
    super(CNN, self).__init__()
    self.layer1 = nn.Sequential(nn.Conv2d(1, 16, kernel_size = 3, stride = 1),   #16 26 26
                                nn.BatchNorm2d(16),
                                nn.ReLU(True))
    self.layer2 = nn.Sequential(nn.Conv2d(16, 32, kernel_size = 3, stride = 1),   #16 24 24
                                nn.BatchNorm2d(32),
                                nn.ReLU(True),
                                nn.MaxPool2d(2, 2))       #32 12 12
    self.layer3 = nn.Sequential(nn.Conv2d(32, 64, kernel_size = 3, stride = 1),   #64 10 10
                                nn.BatchNorm2d(64),
                                nn.ReLU(True))
    self.layer4 = nn.Sequential(nn.Conv2d(64, 128, kernel_size = 3, stride = 1),   #128 8 8
                                nn.BatchNorm2d(128),
                                nn.ReLU(True),
                                nn.MaxPool2d(2, 2))  #128 4 4
    self.fc = nn.Sequential(nn.Linear(128*4*4, 1024), 
                            nn.ReLU(True), 
                            nn.Linear(1024, 128), 
                            nn.ReLU(True),
                            nn.Linear(128, 10), 
                            nn.Softmax(dim = 1)) 
    
  def forward(self, x):
    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)
    x = self.layer4(x)
    x = x.view(x.size(0), -1)
    out = self.fc(x)
    return out


#数据预处理
data_tf = transforms.Compose(
                [transforms.ToTensor(),
                 transforms.Normalize([0.5],[0.5])])

#超参数
learning_rate = 0.01
num_epoches = 20

#读取数据集
train_set = mnist.MNIST('./data', train = True, transform = data_tf, download = True)
test_set = mnist.MNIST('./data', train = False, transform = data_tf, download = True)
train_loader = DataLoader(train_set, batch_size = 64, shuffle = True)
test_loader = DataLoader(test_set, batch_size = 128, shuffle = True)


#开始学习
model = CNN()
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr = learning_rate)

losses = []
acces = []
eval_losses = []
eval_ac = []

for e in range(num_epoches):
  train_loss = 0
  train_acc = 0
  model.train()
  for image, label in train_loader:
    image = Variable(image)
    label = Variable(label)
    out = model(image)
    label_onehot = nn.functional.one_hot(label, num_classes= 10)                 
    label_onehot = label_onehot.float()
    loss = criterion(out, label_onehot)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    train_loss = train_loss + loss.item()
    _, pred = out.max(1)
    num_correct = (pred == label).sum().item()
    acc = num_correct / image.shape[0]       
    train_acc += acc

  losses.append(train_loss / len(train_loader))
  acces.append(train_acc / len(train_loader))   
  eval_loss = 0
  eval_acc = 0
  model.eval() 
  for im, label in test_loader:
    im = Variable(im)
    label = Variable(label)
    out = model(im)
    label_onehot = nn.functional.one_hot(label, num_classes= 10)                  
    label_onehot = label_onehot.float()
    loss = criterion(out, label_onehot)
    eval_loss += loss.item()
    _, pred = out.max(1)
    num_correct = (pred == label).sum().item()
    acc = num_correct / im.shape[0]
    eval_acc += acc
        
  eval_losses.append(eval_loss / len(test_loader))
  eval_ac.append(eval_acc / len(test_loader))
  print('epoch: {}, Train Loss: {:.6f}, Train Acc: {:.6f}, Eval Loss: {:.6f}, Eval Acc: {:.6f}'
          .format(e+1, train_loss / len(train_loader), train_acc / len(train_loader), 
                     eval_loss / len(test_loader), eval_acc / len(test_loader)))
举报

相关推荐

0 条评论