PyTorch如何训练完后在另外一个数据集训练
在深度学习中,我们经常需要在一个数据集上训练模型,并希望在另外一个数据集上进行微调或继续训练。这篇文章将介绍如何使用PyTorch在一个数据集上训练完后,在另外一个数据集上进行训练。
问题描述
假设我们有一个预训练好的图像分类模型,该模型在ImageNet数据集上进行了训练。我们现在想将该模型应用于一个新的任务,该任务是将手写数字进行分类。我们希望能够利用预训练模型的知识,加速在手写数字数据集上的训练,并提高模型的准确率。
方案
我们将采用迁移学习的思想,即利用预训练模型的特征提取能力和权重初始化,然后在新的数据集上进行微调。
步骤一:加载预训练模型
首先,我们需要加载预训练模型,并将其拆分为特征提取器和分类器。
import torch
import torchvision.models as models
# 加载预训练模型
model = models.resnet18(pretrained=True)
# 拆分模型为特征提取器和分类器
feature_extractor = torch.nn.Sequential(*list(model.children())[:-1])
classifier = torch.nn.Linear(model.fc.in_features, num_classes)
# 冻结特征提取器的参数
for param in feature_extractor.parameters():
param.requires_grad = False
步骤二:准备新数据集
接下来,我们需要准备新的手写数字数据集,包括训练集和验证集。这些数据集应该与预训练模型所用的数据集具有相似的数据分布。
import torchvision.datasets as datasets
import torchvision.transforms as transforms
# 定义数据转换
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# 加载手写数字数据集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
val_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
# 定义数据加载器
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
步骤三:微调模型
现在,我们可以使用新的数据集训练模型,同时保持预训练模型的特征提取器部分不变,并只训练分类器部分。
import torch.optim as optim
# 定义优化器和损失函数
optimizer = optim.Adam(classifier.parameters(), lr=learning_rate)
criterion = torch.nn.CrossEntropyLoss()
# 训练模型
for epoch in range(num_epochs):
for images, labels in train_loader:
# 前向传播
features = feature_extractor(images)
features = features.view(features.size(0), -1)
outputs = classifier(features)
# 计算损失
loss = criterion(outputs, labels)
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 在验证集上计算准确率
correct = 0
total = 0
with torch.no_grad():
for images, labels in val_loader:
features = feature_extractor(images)
features = features.view(features.size(0), -1)
outputs = classifier(features)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = 100 * correct / total
# 打印训练过程中的准确率
print(f'Epoch [{epoch+1}/{num_epochs}], Accuracy: {accuracy:.2f}%')
序列图
下面是一个使用迁移学习进行微调的训练过程的序列图。
sequenceDiagram
participant Model
participant FeatureExtractor
participant Classifier
participant DataPreparation
participant Optimizer
participant DataLoader
participant LossFunction
participant Validation
Model->>FeatureExtractor: 提取特征