0
点赞
收藏
分享

微信扫一扫

Pytorch搭建ALexNet网络

IT程序员 2022-04-25 阅读 65

本讲目标:
  介绍Pytorch搭建AlexNet网络的流程。参考

Pytorch搭建ALexNet网络

1.AlexNet网络介绍

  AlexNet网络诞生于2012年,当年ImageNet竞赛的冠军,Top5错误率为16.4%
  借鉴点:使用了Relu激活函数,提升训练速度;使用了Dropout, 防止过拟合。使用了图像增强,使得样本量增加。  AlexNet的总体结构和LeNet5有相似之处,但是有一些很重要的改进:

  1. 由五层卷积、三层全连接层组成,输入图像尺寸为224x224x3,网络规模远大于LeNet5;
  2. 使用了Relu激活函数,而不是Sigmoid或者Tanh;
  3. 进行了舍弃(Dropout)操作,防止模型过拟合,提升鲁棒性;
  4. 增加了一些训练上的技巧,包括数据增强、学习率衰减、权重衰减(L2正则化)等;
  5. 首次使用GPU进行网络加速训练;

  AlexNet的网络结构如下图所示:
在这里插入图片描述
  卷积就是特征提取器:CBAPB

Conv2D、 BatchNormalization、 Activation、 Pooling、 Dropout

  特征提取器(卷积层):
C1:
  C(核:96x11x11,步长:4,填充:valid)
  B(LRN/BN)原文用LRN进行标准化,我们用BN标准化
  A(relu)
  P(max,核:3x3,步长:2)
  D(None)
C2:
  C(核:256x5x5,步长:1,填充:valid)
  B (LRN/BN)原文用LRN进行标准化,我们用BN标准化
  A(relu)
  P(max,核:3x3,步长:2)
  D(None)
C3:
  C(核:384x3x3,步长:1,填充:same)
  B(None)
  A(relu)
  P(None)
  D(None)
C4:
  C(核:384x3x3,步长:1,填充:same)
  B(None)
  A(relu)
  P(None)
  D(None)
C5:
  C(核:256x3x3,步长:1,填充:same)
  B(None)
  A(relu)
  P(max,核:3x3,步长:2)
  D(None)
  分类器(全连接层):
D1:
  Dense(神经元:2048,激活:relu,Dropout:0.5)
D2:
  Dense(神经元:2048,激活:relu,Dropout:0.5)
D3:
  Dense(神经元: 10,激活:softmax)

  可以看到,与结构类似的 LeNet5 相比,AlexNet 模型的参数量有了非常明显的提升,卷积运算的层数也更多了,这有利于更好地提取特征;Relu 激活函数的使用加快了模型的训练速度;Dropout 的使用提升了模型的鲁棒性,这些优势使得 AlexNet 的性能大大提升。

2.Pytorch框架搭建AlexNet

2.1搭建模型 model.py

import torch
import torch.nn as nn
import torch.functional as F

class AlexNet(nn.Module):
    def __init__(self,num_classes=1000):
        super(AlexNet,self).__init__()
        # N=(W-F+2P)/S+1
        self.c1=nn.Conv2d(in_channels=3,out_channels=96,kernel_size=11,stride=4,padding=2) #input(3,224,224)->(96,55,55)
        self.b1=nn.BatchNorm2d(96)
        self.a1=nn.ReLU(inplace=True)
        self.p1=nn.MaxPool2d(kernel_size=3,stride=2)   # (96,55,55)->(96,27,27)

        self.c2=nn.Conv2d(96,256,5,stride=1,padding=2) #(96,27,27)->(256,27, 27)
        self.b2=nn.BatchNorm2d(256)
        self.a2=nn.ReLU(inplace=True)
        self.p2=nn.MaxPool2d(kernel_size=3,stride=2)   #(256,27, 27)->(256,13, 13)

        self.c3=nn.Conv2d(256,384,3,stride=1,padding=1)#(256,13, 13)->(384,13, 13)
        self.a3=nn.ReLU(inplace=True)

        self.c4=nn.Conv2d(384,384,3,stride=1,padding=1)#(256,13, 13)->(384,13, 13)
        self.a4=nn.ReLU(inplace=True)

        self.c5=nn.Conv2d(384,256,3,stride=1,padding=1)  #(256,13, 13)->(256,13, 13)
        self.a5=nn.ReLU(inplace=True)
        self.p5 = nn.MaxPool2d(kernel_size=3, stride=2)  # (256,13, 13)->(256,6, 6)

        self.fc1_d=nn.Dropout(p=0.5)
        self.fc1=nn.Linear(256*6*6,2048)
        self.fc1_a=nn.ReLU(inplace=True)

        self.fc2_d=nn.Dropout(p=0.5)
        self.fc2=nn.Linear(2048,2048)
        self.fc2_a=nn.ReLU(inplace=True)

        self.fc3=nn.Linear(2048,num_classes)

    def forward(self,x):
        x = self.c1(x)
        x = self.b1(x)
        x = self.a1(x)
        x = self.p1(x)

        x = self.c2(x)
        x = self.b2(x)
        x = self.a2(x)
        x = self.p2(x)

        x = self.c3(x)
        x = self.a3(x)

        x = self.c4(x)
        x = self.a4(x)

        x = self.c5(x)
        x = self.a5(x)
        x = self.p5(x)

        x=torch.flatten(x,start_dim=1)

        x=self.fc1_d(x)
        x = self.fc1(x)
        x = self.fc1_a(x)

        x=self.fc2_d(x)
        x = self.fc2(x)
        x = self.fc2_a(x)

        x=self.fc3(x)
        return  x

验证网络是否正确,用如下代码:

input=torch.rand([1,3,224,224])
model=AlexNet(num_classes=10)
output=model(input)
print(model)

可以发现输出是10分类的问题。

2.2搭建训练过程 train.py

import os
import sys
import json

import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from tqdm import tqdm

from model import AlexNet



device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("using {} device.".format(device))

# 对数据做统一变换
data_transform = {
    "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                 transforms.RandomHorizontalFlip(),
                                 transforms.ToTensor(),
                                 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
    "val": transforms.Compose([transforms.Resize((224, 224)),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}

data_root = os.path.abspath(os.path.join(os.getcwd(), "./data_set"))
image_path = os.path.join(data_root,  "flower_data")
assert os.path.exists(image_path), "{} path does not exist.".format(image_path)

train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                     transform=data_transform["train"])
train_num = len(train_dataset)

# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
flower_list = train_dataset.class_to_idx
cla_dict = dict((val, key) for key, val in flower_list.items())
# write dict into json file
json_str = json.dumps(cla_dict, indent=4)

with open('class_indices.json', 'w') as json_file:
    json_file.write(json_str)


batch_size = 32
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
print('Using {} dataloader workers every process'.format(nw))

train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True,
                                           num_workers=nw)

validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                        transform=data_transform["val"])
val_num = len(validate_dataset)
validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                              batch_size=4,
                                              shuffle=False,
                                              num_workers=nw)

print("using {} images for training, {} images for validation.".format(train_num,
                                                                       val_num))
# test_data_iter = iter(validate_loader)
# test_image, test_label = test_data_iter.next()
#
# def imshow(img):
#     img = img / 2 + 0.5  # unnormalize
#     npimg = img.numpy()
#     plt.imshow(np.transpose(npimg, (1, 2, 0)))
#     plt.show()
#
# print(' '.join('%5s' % cla_dict[test_label[j].item()] for j in range(4)))
# imshow(utils.make_grid(test_image))

net = AlexNet(num_classes=5)

net.to(device)
loss_function = nn.CrossEntropyLoss()
# pata = list(net.parameters())
optimizer = optim.Adam(net.parameters(), lr=0.0002)

epochs = 20
save_path = './AlexNet.pth'
best_acc = 0.0
train_steps = len(train_loader)
for epoch in range(epochs):
    # train
    net.train()
    running_loss = 0.0
    train_bar = tqdm(train_loader, file=sys.stdout)
    for step, data in enumerate(train_bar):
        images, labels = data
        optimizer.zero_grad()
        outputs = net(images.to(device))
        loss = loss_function(outputs, labels.to(device))
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()

        train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                 epochs,
                                                                 loss)

    # validate
    net.eval()
    acc = 0.0  # accumulate accurate number / epoch
    with torch.no_grad():
        val_bar = tqdm(validate_loader, file=sys.stdout)
        for val_data in val_bar:
            val_images, val_labels = val_data
            outputs = net(val_images.to(device))
            predict_y = torch.max(outputs, dim=1)[1]
            acc += torch.eq(predict_y, val_labels.to(device)).sum().item()

    val_accurate = acc / val_num
    print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
          (epoch + 1, running_loss / train_steps, val_accurate))

    if val_accurate > best_acc:
        best_acc = val_accurate
        torch.save(net.state_dict(), save_path)

print('Finished Training')

2.3搭建预测过程 predict.py

import os
import json

import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt

from model import AlexNet


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

data_transform = transforms.Compose(
    [transforms.Resize((224, 224)),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# load image
img_path = "./tulip.jpg"
assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
img = Image.open(img_path)

plt.imshow(img)
# [N, C, H, W]
img = data_transform(img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0)

# read class_indict
json_path = './class_indices.json'
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)

with open(json_path, "r") as f:
    class_indict = json.load(f)

# create model
model = AlexNet(num_classes=5).to(device)

# load model weights
weights_path = "./AlexNet.pth"
assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
model.load_state_dict(torch.load(weights_path))

model.eval()
with torch.no_grad():
    # predict class
    output = torch.squeeze(model(img.to(device))).cpu()
    predict = torch.softmax(output, dim=0)
    predict_cla = torch.argmax(predict).numpy()

print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],
                                             predict[predict_cla].numpy())
plt.title(print_res)
for i in range(len(predict)):
    print("class: {:10}   prob: {:.3}".format(class_indict[str(i)],
                                              predict[i].numpy()))
plt.show()
举报

相关推荐

0 条评论