0
点赞
收藏
分享

微信扫一扫

Pytorch入门十三 || 对抗神经网络

kmoon_b426 2022-03-10 阅读 62

Pytorch入门十三 || 对抗神经网络

直接上代码

#coding=utf-8
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets,transforms,models
import matplotlib.pyplot as plt
import os

data_transform = transforms.Compose([
        transforms.Resize((96,96)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
    ])


trainset = datasets.ImageFolder('./faces', data_transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=5,shuffle=True, num_workers=4)


def imshow(inputs,picname):
    plt.ion()
    
    inputs = inputs / 2 + 0.5
    inputs = inputs.numpy().transpose((1, 2, 0))
    plt.imshow(inputs)
    plt.pause(0.01)

    plt.savefig(os.path.join('faces', '0',picname+".jpg"))
    plt.close()

def fun1():
    inputs,__ = next(iter(trainloader))
    imshow(torchvision.utils.make_grid(inputs),"RealDataSample")

#%%
class D(nn.Module):
    def __init__(self,nc,ndf):
        super(D, self).__init__()
        self.layer1 = nn.Sequential(nn.Conv2d(nc,ndf,kernel_size=4,stride=2,padding=1),
                                 nn.BatchNorm2d(ndf),
                                 nn.LeakyReLU(0.2,inplace=True))
        self.layer2 = nn.Sequential(nn.Conv2d(ndf,ndf*2,kernel_size=4,stride=2,padding=1),
                                 nn.BatchNorm2d(ndf*2),
                                 nn.LeakyReLU(0.2,inplace=True))
        self.layer3 = nn.Sequential(nn.Conv2d(ndf*2,ndf*4,kernel_size=4,stride=2,padding=1),
                                 nn.BatchNorm2d(ndf*4),
                                 nn.LeakyReLU(0.2,inplace=True))
        self.layer4 = nn.Sequential(nn.Conv2d(ndf*4,ndf*8,kernel_size=4,stride=2,padding=1),
                                 nn.BatchNorm2d(ndf*8),
                                 nn.LeakyReLU(0.2,inplace=True))
        self.fc = nn.Sequential(nn.Linear(256*6*6,1),nn.Sigmoid())

    def forward(self,x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = out.view(-1,256*6*6)
        out = self.fc(out)
        return out

#d = D(3,32)

#print(d(inputs))

class G(nn.Module):
    def __init__(self,nc, ngf,nz,feature_size):
        super(G,self).__init__()
        self.prj = nn.Linear(feature_size,nz*6*6)
        self.layer1 = nn.Sequential(nn.ConvTranspose2d(nz,ngf*4,kernel_size=4,stride=2,padding=1),
                                 nn.BatchNorm2d(ngf*4),
                                 nn.ReLU())
        self.layer2 = nn.Sequential(nn.ConvTranspose2d(ngf*4,ngf*2,kernel_size=4,stride=2,padding=1),
                                 nn.BatchNorm2d(ngf*2),
                                 nn.ReLU())
        self.layer3 = nn.Sequential(nn.ConvTranspose2d(ngf*2,ngf,kernel_size=4,stride=2,padding=1),
                                 nn.BatchNorm2d(ngf),
                                 nn.ReLU())
        self.layer4 = nn.Sequential(nn.ConvTranspose2d(ngf,nc,kernel_size=4,stride=2,padding=1),
                                 nn.Tanh())
    
    def forward(self,x):
        out = self.prj(x)
        out = out.view(-1,1024,6,6)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        return out

d = D(3,32)
g = G(3,128,1024,100)

criterion = nn.BCELoss()

d_optimizer = torch.optim.Adam(d.parameters(),lr=0.0003)
g_optimizer = torch.optim.Adam(g.parameters(),lr=0.0003)


def train(d,g,criterion,d_optimizer,g_optimizer,epochs=1,show_every=1000,print_every=10):
    iter_count = 0
    for epoch in range(epochs):
        
        for inputs,_ in trainloader:

            real_inputs = inputs

            fake_inputs = g(torch.randn(5,100))     #生成假图片
            
            real_labels = torch.ones(real_inputs.size(0))   #生成真图片的标签
            fake_labels = torch.zeros(5)                    #生成假图片的标签
            
            real_outputs = d(real_inputs)            # 真图片经鉴别后的标签
            d_loss_real = criterion(real_outputs.squeeze(),real_labels)     #真图片鉴别后的标签与真标签的损失
            # real_scores = real_outputs

            fake_outputs = d(fake_inputs)

            d_loss_fake = criterion(fake_outputs.squeeze(),fake_labels)     #假图片鉴别后的标签与假标签的损失
            # fake_scores = fake_outputs

            d_loss = d_loss_real+d_loss_fake
            d_optimizer.zero_grad()
            d_loss.backward()
            d_optimizer.step()      #更新鉴别器参数
        
            fake_inputs = g(torch.randn(5,100))     #生成假图片
            outputs = d(fake_inputs)                #用升级后的鉴别器生成假图片标签,然后求出这个标签与真实的损失
            real_labels = torch.ones(outputs.size(0))
            g_loss = criterion(outputs.squeeze(),real_labels)
            
            g_optimizer.zero_grad()
            g_loss.backward()
            g_optimizer.step()

            if (iter_count % show_every == 0):
                print('Epoch:{},Iter: {}, D: {:.4}, G:{:.4}'.format(epoch,iter_count, d_loss.item(), g_loss.item()))
                picname = "Epoch_"+str(epoch)+"Iter_"+str(iter_count)
                imshow(torchvision.utils.make_grid(fake_inputs.data),picname)
                # save_param(d,'d_model.pkl')
                # save_param(g,'g_model.pkl')

            if (iter_count%print_every == 0):
                print('Epoch:{},Iter: {}, D: {:.4}, G:{:.4}'.format(epoch,iter_count, d_loss.item(), g_loss.item()))
            iter_count += 1
        
    print('Finished Training')

def load_param(model,path):
    if os.path.exists(path):
        model.load_state_dict(torch.load(path))

def save_param(model,path):
    torch.save(model.state_dict(),path)

if __name__ == '__main__':
    # if os.path.exists("d_model.pkl"):
    #     load_param(d,'d_model.pkl')
    #     load_param(g,'g_model.pkl')

    train(d,g,criterion,d_optimizer,g_optimizer,epochs=300)

    # save_param(d,'d_model.pkl')
    # save_param(g,'g_model.pkl')

image-20220301130931957

数据连接

(21条消息) 二次元头像数据集,来自konachan.net-机器学习文档类资源-CSDN文库

举报

相关推荐

0 条评论