使用MSE误差函数进行手写字体识别时,需要保证label为one-hot形式
可以通过target-transform参数调整label数据为one-hot,也可以在学习过程中调整
本文采用后者,代码如下
import torch
from torch import nn, optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import mnist
import numpy as np
#网络模型
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.layer1 = nn.Sequential(nn.Conv2d(1, 16, kernel_size = 3, stride = 1), #16 26 26
nn.BatchNorm2d(16),
nn.ReLU(True))
self.layer2 = nn.Sequential(nn.Conv2d(16, 32, kernel_size = 3, stride = 1), #16 24 24
nn.BatchNorm2d(32),
nn.ReLU(True),
nn.MaxPool2d(2, 2)) #32 12 12
self.layer3 = nn.Sequential(nn.Conv2d(32, 64, kernel_size = 3, stride = 1), #64 10 10
nn.BatchNorm2d(64),
nn.ReLU(True))
self.layer4 = nn.Sequential(nn.Conv2d(64, 128, kernel_size = 3, stride = 1), #128 8 8
nn.BatchNorm2d(128),
nn.ReLU(True),
nn.MaxPool2d(2, 2)) #128 4 4
self.fc = nn.Sequential(nn.Linear(128*4*4, 1024),
nn.ReLU(True),
nn.Linear(1024, 128),
nn.ReLU(True),
nn.Linear(128, 10),
nn.Softmax(dim = 1))
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = x.view(x.size(0), -1)
out = self.fc(x)
return out
#数据预处理
data_tf = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize([0.5],[0.5])])
#超参数
learning_rate = 0.01
num_epoches = 20
#读取数据集
train_set = mnist.MNIST('./data', train = True, transform = data_tf, download = True)
test_set = mnist.MNIST('./data', train = False, transform = data_tf, download = True)
train_loader = DataLoader(train_set, batch_size = 64, shuffle = True)
test_loader = DataLoader(test_set, batch_size = 128, shuffle = True)
#开始学习
model = CNN()
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr = learning_rate)
losses = []
acces = []
eval_losses = []
eval_ac = []
for e in range(num_epoches):
train_loss = 0
train_acc = 0
model.train()
for image, label in train_loader:
image = Variable(image)
label = Variable(label)
out = model(image)
label_onehot = nn.functional.one_hot(label, num_classes= 10)
label_onehot = label_onehot.float()
loss = criterion(out, label_onehot)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss = train_loss + loss.item()
_, pred = out.max(1)
num_correct = (pred == label).sum().item()
acc = num_correct / image.shape[0]
train_acc += acc
losses.append(train_loss / len(train_loader))
acces.append(train_acc / len(train_loader))
eval_loss = 0
eval_acc = 0
model.eval()
for im, label in test_loader:
im = Variable(im)
label = Variable(label)
out = model(im)
label_onehot = nn.functional.one_hot(label, num_classes= 10)
label_onehot = label_onehot.float()
loss = criterion(out, label_onehot)
eval_loss += loss.item()
_, pred = out.max(1)
num_correct = (pred == label).sum().item()
acc = num_correct / im.shape[0]
eval_acc += acc
eval_losses.append(eval_loss / len(test_loader))
eval_ac.append(eval_acc / len(test_loader))
print('epoch: {}, Train Loss: {:.6f}, Train Acc: {:.6f}, Eval Loss: {:.6f}, Eval Acc: {:.6f}'
.format(e+1, train_loss / len(train_loader), train_acc / len(train_loader),
eval_loss / len(test_loader), eval_acc / len(test_loader)))