import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader, Dataset from torchvision.utils import save_image from torchvision import transforms from PIL import Image import os class SelfAttention(nn.Module): def __init__(self, in_dim): super(SelfAttention, self).__init__() self.query = nn.utils.spectral_norm(nn.Conv2d(in_dim, in_dim // 8, kernel_size=1)) self.key = nn.utils.spectral_norm(nn.Conv2d(in_dim, in_dim // 8, kernel_size=1)) self.value = nn.utils.spectral_norm(nn.Conv2d(in_dim, in_dim, kernel_size=1)) self.gamma = nn.Parameter(torch.zeros(1)) def forward(self, x): batch_size, C, width, height = x.size() proj_query = self.query(x).view(batch_size, -1, width * height).permute(0, 2, 1) proj_key = self.key(x).view(batch_size, -1, width * height) energy = torch.bmm(proj_query, proj_key) attention = F.softmax(energy, dim=-1) proj_value = self.value(x).view(batch_size, -1, width * height) out = torch.bmm(proj_value, attention.permute(0, 2, 1)) out = out.view(batch_size, C, width, height) out = self.gamma * out + x return out class Generator(nn.Module): def __init__(self, noise_dim, label_dim): super(Generator, self).__init__() self.label_dim = label_dim self.fc = nn.Sequential( nn.Linear(noise_dim + label_dim, 1024 * 4 * 4), nn.BatchNorm1d(1024 * 4 * 4), nn.ReLU(True) ) self.deconv_layers = nn.Sequential( nn.utils.spectral_norm(nn.ConvTranspose2d(1024, 512, 4, 2, 1)), # 4x4 -> 8x8 nn.BatchNorm2d(512), nn.ReLU(True), SelfAttention(512), nn.utils.spectral_norm(nn.ConvTranspose2d(512, 256, 4, 2, 1)), # 8x8 -> 16x16 nn.BatchNorm2d(256), nn.ReLU(True), nn.utils.spectral_norm(nn.ConvTranspose2d(256, 128, 4, 2, 1)), # 16x16 -> 32x32 nn.BatchNorm2d(128), nn.ReLU(True), nn.utils.spectral_norm(nn.ConvTranspose2d(128, 64, 4, 2, 1)), # 32x32 -> 64x64 nn.BatchNorm2d(64), nn.ReLU(True), SelfAttention(64), nn.utils.spectral_norm(nn.ConvTranspose2d(64, 3, 4, 2, 1)), # 64x64 -> 128x128 nn.Tanh() ) def forward(self, noise, labels): x = torch.cat((noise, labels), dim=1) x = self.fc(x).view(-1, 1024, 4, 4) x = self.deconv_layers(x) return x class Discriminator(nn.Module): def __init__(self, input_channels, label_dim): super(Discriminator, self).__init__() self.label_dim = label_dim self.conv1 = nn.Sequential( nn.utils.spectral_norm(nn.Conv2d(input_channels + label_dim, 64, 4, 2, 1)), nn.LeakyReLU(0.2, inplace=True) ) self.conv2 = nn.Sequential( nn.utils.spectral_norm(nn.Conv2d(64, 128, 4, 2, 1)), nn.LeakyReLU(0.2, inplace=True) ) self.conv3 = nn.Sequential( nn.utils.spectral_norm(nn.Conv2d(128, 256, 4, 2, 1)), nn.LeakyReLU(0.2, inplace=True) ) self.self_attn = SelfAttention(256) self.conv4 = nn.Sequential( nn.utils.spectral_norm(nn.Conv2d(256, 512, 4, 2, 1)), nn.LeakyReLU(0.2, inplace=True) ) self.fc = nn.utils.spectral_norm(nn.Linear(512 * 8 * 8, 1)) def forward(self, x, labels): batch_size = x.size(0) img_size = x.size(2) labels = labels.view(batch_size, self.label_dim, 1, 1) labels = labels.expand(batch_size, self.label_dim, img_size, img_size) x = torch.cat([x, labels], dim=1) x = self.conv1(x) x = self.conv2(x) x = self.conv3(x) x = self.self_attn(x) x = self.conv4(x) x = x.view(batch_size, -1) x = self.fc(x) return x class TrafficSignDataset(Dataset): def __init__(self, root_dir, labels_file, transform=None): self.root_dir = root_dir self.transform = transform self.image_paths = [] self.labels = [] with open(labels_file, 'r') as f: lines = f.readlines() for line in lines: img_name, label = line.strip().split() img_path = os.path.join(root_dir, img_name) self.image_paths.append(img_path) self.labels.append(int(label)) def __len__(self): return len(self.image_paths) def __getitem__(self, idx): img_path = self.image_paths[idx] image = Image.open(img_path).convert('RGB') label = self.labels[idx] if self.transform: image = self.transform(image) return image, label # 设置超参数 noise_dim = 100 # 噪声维度 label_dim = 58 # 标签维度 batch_size =64 # 批大小 lr = 2e-4 num_epochs = 500 n_critic = 5 lambda_gp = 10 output_dir = "D:/PycharmProjects/GAN/traffic sign datasets/CTSDB/resized_images/MMSGAN" # 生成图像保存路径 if not os.path.exists(output_dir): os.makedirs(output_dir) G = Generator(noise_dim=noise_dim, label_dim=label_dim).to('cuda') D = Discriminator(input_channels=3, label_dim=label_dim).to('cuda') beta1 = 0.0 beta2 = 0.9 optimizer_G = optim.Adam(G.parameters(), lr=lr, betas=(beta1, beta2)) optimizer_D = optim.Adam(D.parameters(), lr=lr, betas=(beta1, beta2)) scheduler_G = optim.lr_scheduler.StepLR(optimizer_G, step_size=50, gamma=0.5) scheduler_D = optim.lr_scheduler.StepLR(optimizer_D, step_size=50, gamma=0.5) transform = transforms.Compose([ transforms.Resize((128, 128)), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) root_dir = "D:/PycharmProjects/GAN/traffic sign datasets/CTSDB/resized_images/BorderlineSMOTE -insepct" labels_file = "D:/PycharmProjects/GAN/traffic sign datasets/CTSDB/resized_images/BorderlineSMOTE -insepct/labels.txt" # 标签文件路径 dataset = TrafficSignDataset(root_dir=root_dir, labels_file=labels_file, transform=transform) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) def discriminator_hinge_loss(real_outputs, fake_outputs): real_loss = torch.mean(F.relu(1.0 - real_outputs)) fake_loss = torch.mean(F.relu(1.0 + fake_outputs)) return real_loss + fake_loss def generator_hinge_loss(fake_outputs): return -torch.mean(fake_outputs) def compute_gradient_penalty(D, real_samples, fake_samples, labels): alpha = torch.rand(real_samples.size(0), 1, 1, 1).to(real_samples.device) interpolates = (alpha * real_samples + (1 - alpha) * fake_samples).requires_grad_(True) d_interpolates = D(interpolates, labels) fake = torch.ones(d_interpolates.size()).to(real_samples.device) gradients = torch.autograd.grad( outputs=d_interpolates, inputs=interpolates, grad_outputs=fake, create_graph=True, retain_graph=True, only_inputs=True )[0] gradients = gradients.view(gradients.size(0), -1) gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() return gradient_penalty fixed_noise = torch.randn(64, noise_dim).to('cuda') fixed_labels_idx = torch.arange(0, label_dim).repeat(64 // label_dim + 1)[:64].to('cuda') fixed_labels_one_hot = torch.zeros(64, label_dim).to('cuda') fixed_labels_one_hot.scatter_(1, fixed_labels_idx.view(-1, 1), 1) for epoch in range(num_epochs): for i, (real_images, real_labels_idx) in enumerate(dataloader): real_images = real_images.to('cuda') real_labels_idx = real_labels_idx.to('cuda') batch_size_current = real_images.size(0) real_labels_one_hot = torch.zeros(batch_size_current, label_dim).to('cuda') real_labels_one_hot.scatter_(1, real_labels_idx.view(-1, 1), 1) optimizer_D.zero_grad() noise = torch.randn(batch_size_current, noise_dim).to('cuda') fake_labels_idx = torch.randint(0, label_dim, (batch_size_current,)).to('cuda') fake_labels_one_hot = torch.zeros(batch_size_current, label_dim).to('cuda') fake_labels_one_hot.scatter_(1, fake_labels_idx.view(-1, 1), 1) fake_images = G(noise, fake_labels_one_hot) real_outputs = D(real_images, real_labels_one_hot) fake_outputs = D(fake_images.detach(), fake_labels_one_hot) d_loss = discriminator_hinge_loss(real_outputs, fake_outputs) gradient_penalty = compute_gradient_penalty(D, real_images, fake_images.detach(), real_labels_one_hot) d_loss += lambda_gp * gradient_penalty d_loss.backward() optimizer_D.step() if i % n_critic == 0: optimizer_G.zero_grad() fake_outputs = D(fake_images, fake_labels_one_hot) g_loss = generator_hinge_loss(fake_outputs) g_loss.backward() optimizer_G.step() scheduler_G.step() scheduler_D.step() print(f"Epoch [{epoch + 1}/{num_epochs}], D Loss: {d_loss.item():.4f}, G Loss: {g_loss.item():.4f}") with torch.no_grad(): fake_images = G(fixed_noise, fixed_labels_one_hot) save_image(fake_images, os.path.join(output_dir, f"epoch_{epoch + 1}.png"), nrow=8, normalize=True) torch.save(G.state_dict(), 'generator.pth') torch.save(D.state_dict(), 'discriminator.pth')