微调代码
只训练全连接层
model = torch.load(
'../model/20220509-pretrain-vgg16-数据增强-5e-05.pth')
num_ftrs = model.classifier[6].in_features
model.classifier[6] = nn.Linear(num_ftrs, 8)
for name, child in model.named_children():
if name in ['avgpool','classifier']:
for param in child.parameters():
param.requires_grad = True
else:
for param in child.parameters():
param.requires_grad = False
完整代码
from glob import glob
import os
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import shutil
from torchvision import transforms
from torchvision import models
import torch
from torch.autograd import Variable
import torch.nn as nn
from torch.optim import lr_scheduler
from torch import optim
from torchvision.datasets import ImageFolder
from torchvision.utils import make_grid
from sklearn.model_selection import KFold
from mpl_toolkits.axes_grid1 import host_subplot
import time
%matplotlib inline
def try_gpu(i=0):
"""如果存在,则返回gpu(i),否则返回cpu()"""
if torch.cuda.device_count() >= i + 1:
return torch.device(f'cuda:{i}')
return torch.device('cpu')
path = '../data/data-8-train-valid-test/train/'
files = glob(os.path.join(path, '*/*.png'))
print(f'Total train of images {len(files)}')
path = '../data/data-8-train-valid-test/valid/'
files = glob(os.path.join(path, '*/*.png'))
print(f'Total valid of images {len(files)}')
path = '../data/data-8-train-valid-test/test/'
files = glob(os.path.join(path, '*/*.png'))
print(f'Total test of images {len(files)}')
imag_size = 224
batch_size = 16
transform = transforms.Compose([transforms.Resize((imag_size, imag_size)),
transforms.ToTensor()
])
train_imgs = ImageFolder('../data/data-8-train-valid-test/train', transform)
valid_imgs = ImageFolder('../data/data-8-train-valid-test/valid', transform)
test_imgs = ImageFolder('../data/data-8-train-valid-test/test', transform)
train_data = torch.utils.data.DataLoader(
train_imgs, shuffle=True, batch_size=batch_size)
valid_data = torch.utils.data.DataLoader(
valid_imgs, shuffle=False, batch_size=batch_size)
def train(data, isTrain=True):
if isTrain:
model.train()
else:
model.eval()
running_loss = 0.0
running_corrects = 0
for inputs, labels in data:
if isTrain:
optimizer.zero_grad()
inputs, labels = Variable(
inputs.to(device)), Variable(labels.to(device))
outputs = model(inputs)
_, preds = torch.max(outputs.data, 1)
loss = criterion(outputs, labels)
if isTrain:
loss.backward()
optimizer.step()
running_loss += loss.data
running_corrects += torch.sum(preds == labels.data)
loss = running_loss / len(data) / batch_size
acc = running_corrects / len(data) / batch_size
return loss, acc
def test(data):
real_lables,pred_lables = [],[]
model.eval()
running_corrects = 0
for inputs, labels in data:
inputs, labels = Variable(
inputs.to(device)), Variable(labels.to(device))
outputs = model(inputs)
_, preds = torch.max(outputs.data, 1)
for y in labels:
real_lables.append(y.item())
for y in preds:
pred_lables.append(y.item())
running_corrects += torch.sum(preds == labels.data)
acc = running_corrects / len(data) / batch_size
return acc, running_corrects, real_lables, pred_lables
model = torch.load(
'../model/20220509-pretrain-vgg16-数据增强-5e-05.pth')
num_ftrs = model.classifier[6].in_features
model.classifier[6] = nn.Linear(num_ftrs, 8)
for name, child in model.named_children():
if name in ['avgpool','classifier']:
for param in child.parameters():
param.requires_grad = True
else:
for param in child.parameters():
param.requires_grad = False
# 设置超参数
train_iterations, train_loss, test_accuracy = [], [], []
lr, num_epochs = 5e-4, 30
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
device = try_gpu(0)
# 开始训练
print('训练开始 on', device)
model.to(device)
ssum = time.time()
for epoch in range(num_epochs):
s = time.time()
train_losses, train_acc = train(train_data)
loss, acc = train(valid_data, False)
train_iterations.append(epoch)
train_loss.append(loss.to('cpu').item())
test_accuracy.append(acc.to('cpu').item())
print('Epoch {}/{} avgLoss: {:.8f} Acc: {:.8f} Time:{:.1f}s'.format(
epoch,num_epochs, loss, acc, time.time()-s))
print('训练结束 TotalTime:{:.1f}s'.format(time.time()-ssum))
torch.save(model, '../model/20220512-vgg16pre' + str(lr) + '.pth')
print('测试 on', device)
model.to(device)
test_data = torch.utils.data.DataLoader(test_imgs, shuffle=True, batch_size=batch_size)
acc, corrects, real, pre = test(test_data)
print('准确率: {:.4f} 正确预测个数: {}'.format(acc, corrects))