0
点赞
收藏
分享

微信扫一扫

自编码器重建 Fashion_mnist 数据集


自编码器

from PIL import Image
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import Sequential, layers
import numpy as np
from matplotlib import pyplot as plt

加载数据集

(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train, x_test = x_train.astype(np.float32) / 255., x_test.astype(np.float32) / 255.
# we do not need label
train_db = tf.data.Dataset.from_tensor_slices(x_train)
train_db = train_db.shuffle(buffer_size=512).batch(512)
test_db = tf.data.Dataset.from_tensor_slices(x_test)
test_db = test_db.batch(512)

print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)

(60000, 28, 28) (60000,)
(10000, 28, 28) (10000,)

构建网络

class AutoEncoder(keras.Model):

def __init__(self):
super(AutoEncoder, self).__init__()

# Encoders
self.encoder = Sequential([
layers.Dense(256, activation=tf.nn.relu),
layers.Dense(128, activation=tf.nn.relu),
layers.Dense(20)
])

# Decoders
self.decoder = Sequential([
layers.Dense(128, activation=tf.nn.relu),
layers.Dense(256, activation=tf.nn.relu),
layers.Dense(784)
])

# 前向计算
def call(self, inputs, training=None):
# [b, 784] => [b, 10]
h = self.encoder(inputs)
# [b, 10] => [b, 784]
x_hat = self.decoder(h)

return x_hat

网络训练

def save_images(imgs, name):
new_im = Image.new('L', (280, 280))

index = 0
for i in range(0, 280, 28):
for j in range(0, 280, 28):
im = imgs[index]
im = Image.fromarray(im, mode='L')
new_im.paste(im, (i, j))
index += 1

new_im.save(name)

model = AutoEncoder()
model.build(input_shape=(None, 28 * 28))
model.summary()

optimizer = tf.optimizers.Adam(lr=1e-3)

Model: "auto_encoder"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
sequential (Sequential) multiple 236436
_________________________________________________________________
sequential_1 (Sequential) multiple 237200
=================================================================
Total params: 473,636
Trainable params: 473,636
Non-trainable params: 0
_________________________________________________________________

开始训练

for epoch in range(20):

for step, x in enumerate(train_db):

#[b, 28, 28] => [b, 784]
x = tf.reshape(x, [-1, 28 * 28])
# 构建梯度记录器
with tf.GradientTape() as tape:
# 前向计算
x_rec_logits = model(x)
# 计算损失函数
rec_loss = tf.losses.binary_crossentropy(x, x_rec_logits, from_logits=True)
rec_loss = tf.reduce_mean(rec_loss)
# 自动求导
grads = tape.gradient(rec_loss, model.trainable_variables)
# 更新网络
optimizer.apply_gradients(zip(grads, model.trainable_variables))

# 打印训练误差
print("epoch: ", epoch, "loss: ", float(rec_loss))


# 从测试集采集图片
x = next(iter(test_db))
logits = model(tf.reshape(x, [-1, 784]))
# 讲输出值转化为像素值
x_hat = tf.sigmoid(logits)
# [b, 784] => [b, 28, 28] 恢复原始数据格式
x_hat = tf.reshape(x_hat, [-1, 28, 28])

# [b, 28, 28] => [2b, 28, 28]
# 输入的前 50 张+重建的前 50 张图片合并
x_concat = tf.concat([x[:50], x_hat[:50]], axis=0)
# 恢复为 0-255 的范围
x_concat = x_concat.numpy() * 255.
# 转换为整型
x_concat = x_concat.astype(np.uint8)
save_images(x_concat, 'ae_images/mnist_%d.png'%epoch)

epoch:  0 loss:  0.1876431256532669
epoch: 1 loss: 0.14163847267627716
epoch: 2 loss: 0.12352141737937927
epoch: 3 loss: 0.11942803859710693
epoch: 4 loss: 0.11525192111730576
epoch: 5 loss: 0.10021436214447021
epoch: 6 loss: 0.10526927560567856
epoch: 7 loss: 0.10288294404745102
epoch: 8 loss: 0.10139968246221542
epoch: 9 loss: 0.10215207189321518
epoch: 10 loss: 0.0961870551109314
epoch: 11 loss: 0.091026671230793
epoch: 12 loss: 0.09655070304870605
epoch: 13 loss: 0.09417414665222168
epoch: 14 loss: 0.08978977054357529
epoch: 15 loss: 0.08931374549865723
epoch: 16 loss: 0.08951258659362793
epoch: 17 loss: 0.08937102556228638
epoch: 18 loss: 0.09456444531679153
epoch: 19 loss: 0.08556753396987915

def printImage(images):
plt.figure(figsize=(10, 10))
for i in range(20):
plt.subplot(5,5,i+1)
plt.xticks([])
plt.yticks([])
plt.grid(False)
plt.imshow(images[i], cmap=plt.cm.binary)

x = next(iter(test_db))
logits = model(tf.reshape(x, [-1, 784]))
# 讲输出值转化为像素值
x_hat = tf.sigmoid(logits)
# [b, 784] => [b, 28, 28] 恢复原始数据格式
x_hat = tf.reshape(x_hat, [-1, 28, 28])

# [b, 28, 28] => [2b, 28, 28]
# 输入的前 50 张+重建的前 50 张图片合并
x_concat = tf.concat([x[:10], x_hat[:10]], axis=0)
# 恢复为 0-255 的范围
x_concat = x_concat.numpy() * 255.
# 转换为整型
x_concat = x_concat.astype(np.uint8)
printImage(x_concat)

  • 上面 5 行是原始图片, 下面 5 行是 重建后的图片

保存本地的图片:

第一次 epoch

左边 5 列是原图片,右边 5 列是经过重建后的。可以看到,此时还不是很清楚

自编码器重建 Fashion_mnist 数据集_cv

第十次 epoch

自编码器重建 Fashion_mnist 数据集_深度学习_02


第二十次 epoch

自编码器重建 Fashion_mnist 数据集_神经网络_03


举报

相关推荐

0 条评论