0
点赞
收藏
分享

微信扫一扫

(一)LeNet复现


文章目录

  • ​​1.model​​
  • ​​2.train​​
  • ​​3.predict​​


LeNet是最早的神经网络之一

1.model

(一)LeNet复现_pytorch

# Coding by ajupyter
# 日期:2022/6/2 17:18

import torch.nn as nn
import torch.nn.functional as F


class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__() # [N,C,H,W]
self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5)
self.pool1 = nn.MaxPool2d(kernel_size=(2, 2), stride=2)
self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)
# (14-k) / s + 1 = 10 ==> k=5 s取默认1

self.pool2 = nn.MaxPool2d(2, 2)
# (16-2) / 2 + 1 = 5
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

def forward(self, x):
x = F.relu(self.conv1(x)) # input 3*32*32 output: 6*28*28
x = self.pool1(x) # output: 6*14*14
x = F.relu(self.conv2(x)) # output: 16*10*10
x = self.pool2(x) # output: 16*5*5
x = x.view(-1, 16 * 5 * 5) # 展平操作
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x

2.train

CIFA10 一共包括10个类别RGB色彩图片,图片的尺寸32*32,共有50000张训练图片,10000张测试图片
(一)LeNet复现_pytorch_02

# Coding by ajupyter
# 日期:2022/6/2 18:03

import torch
import torchvision
from torch.utils.data import DataLoader
import torch.nn as nn
from model import LeNet
import torch.optim as optim
import torchvision.transforms as transforms
import matplotlib.pyplot as plt


def main():
# 数据预处理
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

transforms_tool = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((.5, .5, .5), (.5, .5, .5))
]
)
# 50000张训练图片
train_set = torchvision.datasets.CIFAR10(root='./data/', train=True,
download=False, transform=transforms_tool)
train_loader = DataLoader(dataset=train_set, batch_size=36, shuffle=True, num_workers=0)

# 10000张测试图片
test_set = torchvision.datasets.CIFAR10(root='./data/', train=False,
download=False, transform=transforms_tool)
test_loader = DataLoader(dataset=test_set, batch_size=10000, shuffle=False)

# 可视化数据集
test_data_iter = iter(test_loader)
test_image, test_label = test_data_iter.next()

# classes = ('plane', 'car', 'bird', 'cat',
# 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

net = LeNet()
net.to(device)

loss_function = nn.CrossEntropyLoss() # 指定损失函数
loss_function.to(device)

optimizer = optim.Adam(net.parameters(), lr=0.001) # 优化器

for epoch in range(5):
loss_all = 0.0
for step, data in enumerate(train_loader, start=0):
inputs, lables = data
inputs, lables = inputs.to(device), lables.to(device)
# 梯度清零
optimizer.zero_grad()
# 前向传播+反向传播+优化
outputs = net(inputs)
loss = loss_function(outputs, lables)
loss.backward()
optimizer.step()

loss_all += loss.item()
if step % 500 == 0:
with torch.no_grad():
test_image = test_image.to(device)
outputs = net(test_image)
# print(outputs)
test_label = test_label.to(device)
predict_y = torch.max(outputs, dim=1)[1]
accuracy = torch.eq(predict_y, test_label).sum().item() / test_label.size(0)

print('[%d, %5d] train_loss: %.3f test_accuracy: %.3f' %
(epoch + 1, step + 1, loss_all / 500, accuracy))
loss_all = 0.0
print('Finished Training')

save_path = './LeNet.pth'
torch.save(net.state_dict(), save_path)


if __name__ == '__main__':
main()

3.predict

# Coding by ajupyter
# 日期:2022/6/2 19:50

import torch
import torchvision.transforms as transforms
from PIL import Image
from model import LeNet


def main():
transforms_tool = transforms.Compose(
[
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5,), (0.5, 0.5, 0.5))
]
)
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
net = LeNet()
net.load_state_dict(torch.load('LeNet.pth'))

img = Image.open('./bird.jpg')
img = transforms_tool(img) # [c,h,w]
img = torch.unsqueeze(img, dim=0) # 增加1维 [N, C, H, W]

with torch.no_grad():
outputs = net(img)
'''
torch.max返回值和对应的下标
print(torch.max(outputs, dim=1)) # values=tensor([1.7637]),indices=tensor([2]))

torch.argmax返回下表
print(torch.argmax(outputs,dim=1)) # tensor([2])
'''

predict = torch.max(outputs, dim=1)[1].numpy()
print(classes[int(predict)])


if __name__ == '__main__':
main()


举报

相关推荐

0 条评论