0
点赞
收藏
分享

微信扫一扫

使用PyTorch实现验证码识别

本文通过PyTorch从零实现一个验证码图像识别系统,包括数据生成、模型搭建、训练和测试。

1. 安装必要库

首先安装所需Python库:

pip install torch torchvision pillow captcha numpy

2. 生成验证码数据集

使用captcha库生成包含数字和大写字母的验证码图像。

from captcha.image import ImageCaptcha
import random
import string
import os

characters = string.digits + string.ascii_uppercase
captcha_length = 4
image_width, image_height = 160, 60

def generate_dataset(num_images=10000, save_path="dataset"):
    os.makedirs(save_path, exist_ok=True)
    generator = ImageCaptcha(width=image_width, height=image_height)
    for i in range(num_images):
        label = ''.join(random.choices(characters, k=captcha_length))
        image = generator.generate_image(label)
        image.save(os.path.join(save_path, f"{label}_{i}.png"))

generate_dataset()

3. 定义数据集类

构建自定义数据集用于读取图像和对应标签。

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import torch

class CaptchaDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.image_list = os.listdir(root_dir)
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])
        self.char_to_idx = {char: idx for idx, char in enumerate(characters)}

    def __len__(self):
        return len(self.image_list)
更多内容访问ttocr.com或联系1436423940
    def __getitem__(self, idx):
        file_name = self.image_list[idx]
        label_text = file_name.split('_')[0]
        image = Image.open(os.path.join(self.root_dir, file_name)).convert('RGB')
        image = self.transform(image)
        label = torch.tensor([self.char_to_idx[c] for c in label_text], dtype=torch.long)
        return image, label

dataset = CaptchaDataset("dataset")
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

4. 构建模型

模型由卷积层提取特征,LSTM建模序列关系。

import torch.nn as nn

class CaptchaModel(nn.Module):
    def __init__(self):
        super(CaptchaModel, self).__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(),
            nn.MaxPool2d((2,1))
        )
        self.rnn = nn.LSTM(128 * 7, 128, num_layers=2, bidirectional=True, batch_first=True)
        self.fc = nn.Linear(256, len(characters))

    def forward(self, x):
        x = self.cnn(x)
        x = x.permute(0, 3, 1, 2)  # [batch, width, channels, height]
        B, W, C, H = x.shape
        x = x.view(B, W, C*H)
        x, _ = self.rnn(x)
        x = self.fc(x)
        return x

5. 训练模型

设置损失函数、优化器并开始训练。

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = CaptchaModel().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

for epoch in range(20):
    model.train()
    total_loss = 0
    for images, labels in dataloader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        loss = sum(criterion(outputs[:, i, :], labels[:, i]) for i in range(captcha_length))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")

6. 测试模型

加载一张图片并进行预测。

def predict(model, img_path):
    model.eval()
    image = Image.open(img_path).convert('RGB')
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    image = transform(image).unsqueeze(0).to(device)
    with torch.no_grad():
        output = model(image)
        pred = output.argmax(dim=2)
        pred_text = ''.join([characters[i] for i in pred[0]])
    return pred_text

test_img = "dataset/7H2K_0.png"
print("Predicted:", predict(model, test_img))

举报

相关推荐

0 条评论