import tensorflow as tf
from tensorflow.keras.layers import Dense, Reshape, Flatten, Conv2D, Conv2DTranspose, LeakyReLU
from tensorflow.keras.models import Sequential
import numpy as np
import matplotlib.pyplot as plt
def build_generator():
model = Sequential([
Dense(128 * 7 * 7, input_dim=100),
LeakyReLU(alpha=0.2),
Reshape((7, 7, 128)),
Conv2DTranspose(128, kernel_size=4, strides=2, padding='same'),
LeakyReLU(alpha=0.2),
Conv2DTranspose(64, kernel_size=4, strides=2, padding='same'),
LeakyReLU(alpha=0.2),
Conv2D(1, kernel_size=7, padding='same', activation='tanh')
])
return model
def build_discriminator():
model = Sequential([
Conv2D(64, kernel_size=3, strides=2, padding='same', input_shape=(28, 28, 1)),
LeakyReLU(alpha=0.2),
Conv2D(128, kernel_size=3, strides=2, padding='same'),
LeakyReLU(alpha=0.2),
Flatten(),
Dense(1, activation='sigmoid')
])
return model
def build_gan(generator, discriminator):
discriminator.trainable = False
model = Sequential([generator, discriminator])
return model
# Hyperparameters
lr = 0.0002
batch_size = 64
epochs = 10
# Load MNIST dataset
(x_train, _), (_, _) = tf.keras.datasets.mnist.load_data()
x_train = np.expand_dims(x_train, axis=-1).astype('float32')
x_train = (x_train - 127.5) / 127.5
# Build and compile models
generator = build_generator()
discriminator = build_discriminator()
gan = build_gan(generator, discriminator)
discriminator.compile(optimizer=tf.keras.optimizers.Adam(lr), loss='binary_crossentropy', metrics=['accuracy'])
gan.compile(optimizer=tf.keras.optimizers.Adam(lr), loss='binary_crossentropy')
# Training loop
for epoch in range(epochs):
for _ in range(len(x_train) // batch_size):
noise = np.random.randn(batch_size, 100)
fake_images = generator.predict(noise)
real_images = x_train[np.random.randint(0, x_train.shape[0], batch_size)]
labels_real = np.ones((batch_size, 1))
labels_fake = np.zeros((batch_size, 1))
d_loss_real = discriminator.train_on_batch(real_images, labels_real)
d_loss_fake = discriminator.train_on_batch(fake_images, labels_fake)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
noise = np.random.randn(batch_size, 100)
g_loss = gan.train_on_batch(noise, labels_real)
print(f'Epoch [{epoch+1}/{epochs}], D Loss: {d_loss[0]}, D Acc: {d_loss[1]}, G Loss: {g_loss}')
if (epoch + 1) % 5 == 0:
gen_images = generator.predict(np.random.randn(25, 100))
gen_images = (gen_images + 1) / 2.0
plt.figure(figsize=(5, 5))
for i in range(25):
plt.subplot(5, 5, i+1)
plt.imshow(gen_images[i].reshape(28, 28), cmap='gray')
plt.axis('off')
plt.savefig(f'stylegan_images_{epoch+1}.png')
plt.close()