0
点赞
收藏
分享

微信扫一扫

微调vgg16预训练模型


微调代码

只训练全连接层

model = torch.load(
'../model/20220509-pretrain-vgg16-数据增强-5e-05.pth')
num_ftrs = model.classifier[6].in_features
model.classifier[6] = nn.Linear(num_ftrs, 8)

for name, child in model.named_children():
if name in ['avgpool','classifier']:
for param in child.parameters():
param.requires_grad = True
else:
for param in child.parameters():
param.requires_grad = False

完整代码

from glob import glob
import os
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import shutil
from torchvision import transforms
from torchvision import models
import torch
from torch.autograd import Variable
import torch.nn as nn
from torch.optim import lr_scheduler
from torch import optim
from torchvision.datasets import ImageFolder
from torchvision.utils import make_grid
from sklearn.model_selection import KFold
from mpl_toolkits.axes_grid1 import host_subplot
import time
%matplotlib inline


def try_gpu(i=0):
"""如果存在,则返回gpu(i),否则返回cpu()"""
if torch.cuda.device_count() >= i + 1:
return torch.device(f'cuda:{i}')
return torch.device('cpu')

path = '../data/data-8-train-valid-test/train/'
files = glob(os.path.join(path, '*/*.png'))
print(f'Total train of images {len(files)}')
path = '../data/data-8-train-valid-test/valid/'
files = glob(os.path.join(path, '*/*.png'))
print(f'Total valid of images {len(files)}')
path = '../data/data-8-train-valid-test/test/'
files = glob(os.path.join(path, '*/*.png'))
print(f'Total test of images {len(files)}')

imag_size = 224
batch_size = 16
transform = transforms.Compose([transforms.Resize((imag_size, imag_size)),
transforms.ToTensor()
])
train_imgs = ImageFolder('../data/data-8-train-valid-test/train', transform)
valid_imgs = ImageFolder('../data/data-8-train-valid-test/valid', transform)
test_imgs = ImageFolder('../data/data-8-train-valid-test/test', transform)

train_data = torch.utils.data.DataLoader(
train_imgs, shuffle=True, batch_size=batch_size)
valid_data = torch.utils.data.DataLoader(
valid_imgs, shuffle=False, batch_size=batch_size)

def train(data, isTrain=True):

if isTrain:
model.train()
else:
model.eval()

running_loss = 0.0
running_corrects = 0
for inputs, labels in data:

if isTrain:
optimizer.zero_grad()

inputs, labels = Variable(
inputs.to(device)), Variable(labels.to(device))
outputs = model(inputs)
_, preds = torch.max(outputs.data, 1)
loss = criterion(outputs, labels)

if isTrain:
loss.backward()
optimizer.step()

running_loss += loss.data
running_corrects += torch.sum(preds == labels.data)

loss = running_loss / len(data) / batch_size
acc = running_corrects / len(data) / batch_size

return loss, acc

def test(data):
real_lables,pred_lables = [],[]
model.eval()
running_corrects = 0
for inputs, labels in data:
inputs, labels = Variable(
inputs.to(device)), Variable(labels.to(device))
outputs = model(inputs)
_, preds = torch.max(outputs.data, 1)
for y in labels:
real_lables.append(y.item())
for y in preds:
pred_lables.append(y.item())
running_corrects += torch.sum(preds == labels.data)

acc = running_corrects / len(data) / batch_size
return acc, running_corrects, real_lables, pred_lables


model = torch.load(
'../model/20220509-pretrain-vgg16-数据增强-5e-05.pth')
num_ftrs = model.classifier[6].in_features
model.classifier[6] = nn.Linear(num_ftrs, 8)

for name, child in model.named_children():
if name in ['avgpool','classifier']:
for param in child.parameters():
param.requires_grad = True
else:
for param in child.parameters():
param.requires_grad = False

# 设置超参数
train_iterations, train_loss, test_accuracy = [], [], []

lr, num_epochs = 5e-4, 30
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
device = try_gpu(0)

# 开始训练
print('训练开始 on', device)
model.to(device)
ssum = time.time()
for epoch in range(num_epochs):
s = time.time()

train_losses, train_acc = train(train_data)
loss, acc = train(valid_data, False)

train_iterations.append(epoch)
train_loss.append(loss.to('cpu').item())
test_accuracy.append(acc.to('cpu').item())
print('Epoch {}/{} avgLoss: {:.8f} Acc: {:.8f} Time:{:.1f}s'.format(
epoch,num_epochs, loss, acc, time.time()-s))
print('训练结束 TotalTime:{:.1f}s'.format(time.time()-ssum))
torch.save(model, '../model/20220512-vgg16pre' + str(lr) + '.pth')


print('测试 on', device)
model.to(device)
test_data = torch.utils.data.DataLoader(test_imgs, shuffle=True, batch_size=batch_size)
acc, corrects, real, pre = test(test_data)
print('准确率: {:.4f} 正确预测个数: {}'.format(acc, corrects))

举报

相关推荐

0 条评论