本教程展示如何使用 PyTorch Lightning 实现一个图像验证码识别系统,具备清晰结构、易于扩展、便于训练的优点。
1. 安装依赖
pip install pytorch-lightning torch torchvision pillow captcha
2. 生成验证码图片
from captcha.image import ImageCaptcha
import string, random, os
from PIL import Image
characters = string.digits + string.ascii_uppercase
width, height, captcha_length = 160, 60, 4
def generate_captcha(output_dir="pl_captcha", num=5000):
os.makedirs(output_dir, exist_ok=True)
gen = ImageCaptcha(width, height)
for i in range(num):
text = ''.join(random.choices(characters, k=captcha_length))
img = gen.generate_image(text)
img.save(f"{output_dir}/{text}_{i}.png")
更多内容访问ttocr.com或联系1436423940
generate_captcha()
3. 定义数据集
from torch.utils.data import Dataset
from torchvision import transforms
import torch
class CaptchaDataset(Dataset):
def __init__(self, data_dir):
self.data_dir = data_dir
self.files = [f for f in os.listdir(data_dir) if f.endswith('.png')]
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
self.char_to_idx = {c: i for i, c in enumerate(characters)}
def __len__(self):
return len(self.files)
def __getitem__(self, idx):
filename = self.files[idx]
path = os.path.join(self.data_dir, filename)
image = Image.open(path).convert('RGB')
label_text = filename.split('_')[0]
label = torch.tensor([self.char_to_idx[c] for c in label_text], dtype=torch.long)
return self.transform(image), label
4. 构建模型(LightningModule)
import pytorch_lightning as pl
import torch.nn as nn
import torch
class CRNNModel(pl.LightningModule):
def __init__(self, num_classes=len(characters), captcha_len=4):
super().__init__()
self.save_hyperparameters()
self.captcha_len = captcha_len
self.model = nn.Sequential(
nn.Conv2d(3, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(),
nn.MaxPool2d((2, 1))
)
self.rnn = nn.LSTM(128 * 7, 128, num_layers=2, bidirectional=True, batch_first=True)
self.fc = nn.Linear(256, num_classes)
def forward(self, x):
x = self.model(x) # [B, C, H, W]
x = x.permute(0, 3, 1, 2) # [B, W, C, H]
b, w, c, h = x.shape
x = x.reshape(b, w, c*h) # [B, W, C*H]
x, _ = self.rnn(x)
x = self.fc(x) # [B, W, num_classes]
return x
def training_step(self, batch, batch_idx):
images, labels = batch
output = self(images) # [B, W, C]
loss = sum(nn.CrossEntropyLoss()(output[:, i, :], labels[:, i]) for i in range(self.captcha_len))
self.log("train_loss", loss)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3)
5. 准备训练入口
from torch.utils.data import DataLoader, random_split
dataset = CaptchaDataset("pl_captcha")
train_size = int(0.9 * len(dataset))
train_ds, val_ds = random_split(dataset, [train_size, len(dataset) - train_size])
train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=64)
model = CRNNModel()
trainer = pl.Trainer(max_epochs=10, accelerator="auto")
trainer.fit(model, train_loader, val_loader)
6. 推理函数
def predict(model, image_path):
image = Image.open(image_path).convert('RGB')
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
x = transform(image).unsqueeze(0)
model.eval()
with torch.no_grad():
output = model(x) # [1, W, C]
pred = torch.argmax(output, dim=2)[0]
idx_to_char = {i: c for i, c in enumerate(characters)}
return ''.join([idx_to_char[i.item()] for i in pred])
print(predict(model, "pl_captcha/A8Z5_1.png"))