自编码器(AutoEncoder)
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2019/11/26 13:47
@Author : 我是天才很好
@Blog : https://blog.csdn.net/weixin_43593330
@Email : 1103540209@qq.com
@File : ae.py
@Software: PyCharm
"""
from torch import nn
class AE(nn.Module):
def __init__(self):
super(AE, self).__init__()
# [b, 784] => [b, 20]
self.encoder = nn.Sequential(
nn.Linear(784,256),
nn.ReLU(),
nn.Linear(256,64),
nn.ReLU(),
nn.Linear(64,20),
nn.ReLU()
)
# [b, 20] => [b, 784]
self.decoder = nn.Sequential(
nn.Linear(20,64),
nn.ReLU(),
nn.Linear(64,256),
nn.ReLU(),
nn.Linear(256,784),
nn.Sigmoid()
)
def forward(self, x):
"""
:param self:
:param x: [b, 1, 28, 28]
:return:
"""
batchsz = x.size(0)
# flatten
x = x.view(batchsz, 784)
# encoder
x = self.encoder(x)
# decoder
x = self.decoder(x)
# reshape
x = x.view(batchsz, 1, 28, 28)
return x
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2019/11/26 13:35
@Author : 我是天才很好
@Blog : https://blog.csdn.net/weixin_43593330
@Email : 1103540209@qq.com
@File : AEmain.py
@Software: PyCharm
"""
import torch
from torch.utils.data import DataLoader
from torch import nn, optim
from torchvision import transforms, datasets
from ae import AE
import visdom
def main():
mnist_train = datasets.MNIST('mnist', True, transform=transforms.Compose([transforms.ToTensor()]), download=True)
mnist_train = DataLoader(mnist_train, batch_size=32, shuffle=True)
mnist_test = datasets.MNIST('mnist', False, transform=transforms.Compose([transforms.ToTensor()]), download=True)
mnist_test = DataLoader(mnist_test, batch_size=32, shuffle=True)
x, _ = iter(mnist_train).next()
print('x:',x.shape)
device = torch.device('cuda')
model = AE().to(device)
criteon = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
print(model)
viz = visdom.Visdom()
for epoch in range(10):
for batchidx, (x, _) in enumerate(mnist_train):
# [b,1,28,28]
x = x.to(device)
x_hat = model(x)
loss = criteon(x_hat,x)
# backprop
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(epoch, 'loss:', loss.item())
x, _ = iter(mnist_test).next()
x = x.to(device)
with torch.no_grad():
x_hat = model(x)
viz.images(x, nrow=8, win='original', opts=dict(title='origina'))
viz.images(x_hat, nrow=8, win='AE', opts=dict(title='AE'))
if __name__ == '__main__':
main()
E:\Anaconda3.5\envs\pytorch\python.exe E:/CQUPT/AI/python/pycharm/深度学习与PyTorch入门实战教程/自编码器/main.py
x: torch.Size([32, 1, 28, 28])
AE(
Setting up a new session...
(encoder): Sequential(
(0): Linear(in_features=784, out_features=256, bias=True)
(1): ReLU()
(2): Linear(in_features=256, out_features=64, bias=True)
(3): ReLU()
(4): Linear(in_features=64, out_features=20, bias=True)
(5): ReLU()
)
(decoder): Sequential(
(0): Linear(in_features=20, out_features=64, bias=True)
(1): ReLU()
(2): Linear(in_features=64, out_features=256, bias=True)
(3): ReLU()
(4): Linear(in_features=256, out_features=784, bias=True)
(5): Sigmoid()
)
)
0 loss: 0.018345536664128304
1 loss: 0.02027115412056446
2 loss: 0.014811892993748188
3 loss: 0.014542722143232822
4 loss: 0.015905287116765976
5 loss: 0.013283027336001396
6 loss: 0.01245952770113945
7 loss: 0.010749070905148983
8 loss: 0.01388397254049778
9 loss: 0.013956984505057335
进程已结束,退出代码0