Path: blob/master/examples/generative/ipynb/dcgan_overriding_train_step.ipynb
3508 views
Kernel: Python 3
DCGAN to generate face images
Author: fchollet
Date created: 2019/04/29
Last modified: 2023/12/21
Description: A simple DCGAN trained using fit()
by overriding train_step
on CelebA images.
Setup
In [0]:
import keras import tensorflow as tf from keras import layers from keras import ops import matplotlib.pyplot as plt import os import gdown from zipfile import ZipFile
Prepare CelebA data
We'll use face images from the CelebA dataset, resized to 64x64.
In [0]:
os.makedirs("celeba_gan") url = "https://drive.google.com/uc?id=1O7m1010EJjLE5QxLZiM9Fpjs7Oj6e684" output = "celeba_gan/data.zip" gdown.download(url, output, quiet=True) with ZipFile("celeba_gan/data.zip", "r") as zipobj: zipobj.extractall("celeba_gan")
Create a dataset from our folder, and rescale the images to the [0-1] range:
In [0]:
dataset = keras.utils.image_dataset_from_directory( "celeba_gan", label_mode=None, image_size=(64, 64), batch_size=32 ) dataset = dataset.map(lambda x: x / 255.0)
Let's display a sample image:
In [0]:
for x in dataset: plt.axis("off") plt.imshow((x.numpy() * 255).astype("int32")[0]) break
Create the discriminator
It maps a 64x64 image to a binary classification score.
In [0]:
discriminator = keras.Sequential( [ keras.Input(shape=(64, 64, 3)), layers.Conv2D(64, kernel_size=4, strides=2, padding="same"), layers.LeakyReLU(negative_slope=0.2), layers.Conv2D(128, kernel_size=4, strides=2, padding="same"), layers.LeakyReLU(negative_slope=0.2), layers.Conv2D(128, kernel_size=4, strides=2, padding="same"), layers.LeakyReLU(negative_slope=0.2), layers.Flatten(), layers.Dropout(0.2), layers.Dense(1, activation="sigmoid"), ], name="discriminator", ) discriminator.summary()
Create the generator
It mirrors the discriminator, replacing Conv2D
layers with Conv2DTranspose
layers.
In [0]:
latent_dim = 128 generator = keras.Sequential( [ keras.Input(shape=(latent_dim,)), layers.Dense(8 * 8 * 128), layers.Reshape((8, 8, 128)), layers.Conv2DTranspose(128, kernel_size=4, strides=2, padding="same"), layers.LeakyReLU(negative_slope=0.2), layers.Conv2DTranspose(256, kernel_size=4, strides=2, padding="same"), layers.LeakyReLU(negative_slope=0.2), layers.Conv2DTranspose(512, kernel_size=4, strides=2, padding="same"), layers.LeakyReLU(negative_slope=0.2), layers.Conv2D(3, kernel_size=5, padding="same", activation="sigmoid"), ], name="generator", ) generator.summary()
Override train_step
In [0]:
class GAN(keras.Model): def __init__(self, discriminator, generator, latent_dim): super().__init__() self.discriminator = discriminator self.generator = generator self.latent_dim = latent_dim self.seed_generator = keras.random.SeedGenerator(1337) def compile(self, d_optimizer, g_optimizer, loss_fn): super().compile() self.d_optimizer = d_optimizer self.g_optimizer = g_optimizer self.loss_fn = loss_fn self.d_loss_metric = keras.metrics.Mean(name="d_loss") self.g_loss_metric = keras.metrics.Mean(name="g_loss") @property def metrics(self): return [self.d_loss_metric, self.g_loss_metric] def train_step(self, real_images): # Sample random points in the latent space batch_size = ops.shape(real_images)[0] random_latent_vectors = keras.random.normal( shape=(batch_size, self.latent_dim), seed=self.seed_generator ) # Decode them to fake images generated_images = self.generator(random_latent_vectors) # Combine them with real images combined_images = ops.concatenate([generated_images, real_images], axis=0) # Assemble labels discriminating real from fake images labels = ops.concatenate( [ops.ones((batch_size, 1)), ops.zeros((batch_size, 1))], axis=0 ) # Add random noise to the labels - important trick! labels += 0.05 * tf.random.uniform(tf.shape(labels)) # Train the discriminator with tf.GradientTape() as tape: predictions = self.discriminator(combined_images) d_loss = self.loss_fn(labels, predictions) grads = tape.gradient(d_loss, self.discriminator.trainable_weights) self.d_optimizer.apply_gradients( zip(grads, self.discriminator.trainable_weights) ) # Sample random points in the latent space random_latent_vectors = keras.random.normal( shape=(batch_size, self.latent_dim), seed=self.seed_generator ) # Assemble labels that say "all real images" misleading_labels = ops.zeros((batch_size, 1)) # Train the generator (note that we should *not* update the weights # of the discriminator)! with tf.GradientTape() as tape: predictions = self.discriminator(self.generator(random_latent_vectors)) g_loss = self.loss_fn(misleading_labels, predictions) grads = tape.gradient(g_loss, self.generator.trainable_weights) self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights)) # Update metrics self.d_loss_metric.update_state(d_loss) self.g_loss_metric.update_state(g_loss) return { "d_loss": self.d_loss_metric.result(), "g_loss": self.g_loss_metric.result(), }
Create a callback that periodically saves generated images
In [0]:
class GANMonitor(keras.callbacks.Callback): def __init__(self, num_img=3, latent_dim=128): self.num_img = num_img self.latent_dim = latent_dim self.seed_generator = keras.random.SeedGenerator(42) def on_epoch_end(self, epoch, logs=None): random_latent_vectors = keras.random.normal( shape=(self.num_img, self.latent_dim), seed=self.seed_generator ) generated_images = self.model.generator(random_latent_vectors) generated_images *= 255 generated_images.numpy() for i in range(self.num_img): img = keras.utils.array_to_img(generated_images[i]) img.save("generated_img_%03d_%d.png" % (epoch, i))
Train the end-to-end model
In [0]:
epochs = 1 # In practice, use ~100 epochs gan = GAN(discriminator=discriminator, generator=generator, latent_dim=latent_dim) gan.compile( d_optimizer=keras.optimizers.Adam(learning_rate=0.0001), g_optimizer=keras.optimizers.Adam(learning_rate=0.0001), loss_fn=keras.losses.BinaryCrossentropy(), ) gan.fit( dataset, epochs=epochs, callbacks=[GANMonitor(num_img=10, latent_dim=latent_dim)] )
Some of the last generated images around epoch 30 (results keep improving after that):