0
点赞
收藏
分享

微信扫一扫

图像验证码识别:使用 PyTorch Lightning 实现 CRNN 模型

Gascognya 05-01 18:00 阅读 13

本教程展示如何使用 PyTorch Lightning 实现一个图像验证码识别系统,具备清晰结构、易于扩展、便于训练的优点。

1. 安装依赖

pip install pytorch-lightning torch torchvision pillow captcha

2. 生成验证码图片

from captcha.image import ImageCaptcha
import string, random, os
from PIL import Image

characters = string.digits + string.ascii_uppercase
width, height, captcha_length = 160, 60, 4

def generate_captcha(output_dir="pl_captcha", num=5000):
    os.makedirs(output_dir, exist_ok=True)
    gen = ImageCaptcha(width, height)
    for i in range(num):
        text = ''.join(random.choices(characters, k=captcha_length))
        img = gen.generate_image(text)
        img.save(f"{output_dir}/{text}_{i}.png")
更多内容访问ttocr.com或联系1436423940
generate_captcha()

3. 定义数据集

from torch.utils.data import Dataset
from torchvision import transforms
import torch

class CaptchaDataset(Dataset):
    def __init__(self, data_dir):
        self.data_dir = data_dir
        self.files = [f for f in os.listdir(data_dir) if f.endswith('.png')]
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])
        self.char_to_idx = {c: i for i, c in enumerate(characters)}
    
    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, idx):
        filename = self.files[idx]
        path = os.path.join(self.data_dir, filename)
        image = Image.open(path).convert('RGB')
        label_text = filename.split('_')[0]
        label = torch.tensor([self.char_to_idx[c] for c in label_text], dtype=torch.long)
        return self.transform(image), label

4. 构建模型(LightningModule)

import pytorch_lightning as pl
import torch.nn as nn
import torch

class CRNNModel(pl.LightningModule):
    def __init__(self, num_classes=len(characters), captcha_len=4):
        super().__init__()
        self.save_hyperparameters()
        self.captcha_len = captcha_len
        self.model = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(),
            nn.MaxPool2d((2, 1))
        )
        self.rnn = nn.LSTM(128 * 7, 128, num_layers=2, bidirectional=True, batch_first=True)
        self.fc = nn.Linear(256, num_classes)

    def forward(self, x):
        x = self.model(x)             # [B, C, H, W]
        x = x.permute(0, 3, 1, 2)     # [B, W, C, H]
        b, w, c, h = x.shape
        x = x.reshape(b, w, c*h)      # [B, W, C*H]
        x, _ = self.rnn(x)
        x = self.fc(x)                # [B, W, num_classes]
        return x

    def training_step(self, batch, batch_idx):
        images, labels = batch
        output = self(images)  # [B, W, C]
        loss = sum(nn.CrossEntropyLoss()(output[:, i, :], labels[:, i]) for i in range(self.captcha_len))
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

5. 准备训练入口

from torch.utils.data import DataLoader, random_split

dataset = CaptchaDataset("pl_captcha")
train_size = int(0.9 * len(dataset))
train_ds, val_ds = random_split(dataset, [train_size, len(dataset) - train_size])

train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=64)

model = CRNNModel()

trainer = pl.Trainer(max_epochs=10, accelerator="auto")
trainer.fit(model, train_loader, val_loader)

6. 推理函数

def predict(model, image_path):
    image = Image.open(image_path).convert('RGB')
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    x = transform(image).unsqueeze(0)
    model.eval()
    with torch.no_grad():
        output = model(x)  # [1, W, C]
        pred = torch.argmax(output, dim=2)[0]
        idx_to_char = {i: c for i, c in enumerate(characters)}
        return ''.join([idx_to_char[i.item()] for i in pred])

print(predict(model, "pl_captcha/A8Z5_1.png"))

举报

相关推荐

0 条评论