Path: blob/master/examples/generative/dcgan_overriding_train_step.py
3507 views
"""1Title: DCGAN to generate face images2Author: [fchollet](https://twitter.com/fchollet)3Date created: 2019/04/294Last modified: 2023/12/215Description: A simple DCGAN trained using `fit()` by overriding `train_step` on CelebA images.6Accelerator: GPU7"""89"""10## Setup11"""1213import keras14import tensorflow as tf1516from keras import layers17from keras import ops18import matplotlib.pyplot as plt19import os20import gdown21from zipfile import ZipFile222324"""25## Prepare CelebA data2627We'll use face images from the CelebA dataset, resized to 64x64.28"""2930os.makedirs("celeba_gan")3132url = "https://drive.google.com/uc?id=1O7m1010EJjLE5QxLZiM9Fpjs7Oj6e684"33output = "celeba_gan/data.zip"34gdown.download(url, output, quiet=True)3536with ZipFile("celeba_gan/data.zip", "r") as zipobj:37zipobj.extractall("celeba_gan")3839"""40Create a dataset from our folder, and rescale the images to the [0-1] range:41"""4243dataset = keras.utils.image_dataset_from_directory(44"celeba_gan", label_mode=None, image_size=(64, 64), batch_size=3245)46dataset = dataset.map(lambda x: x / 255.0)474849"""50Let's display a sample image:51"""525354for x in dataset:55plt.axis("off")56plt.imshow((x.numpy() * 255).astype("int32")[0])57break585960"""61## Create the discriminator6263It maps a 64x64 image to a binary classification score.64"""6566discriminator = keras.Sequential(67[68keras.Input(shape=(64, 64, 3)),69layers.Conv2D(64, kernel_size=4, strides=2, padding="same"),70layers.LeakyReLU(negative_slope=0.2),71layers.Conv2D(128, kernel_size=4, strides=2, padding="same"),72layers.LeakyReLU(negative_slope=0.2),73layers.Conv2D(128, kernel_size=4, strides=2, padding="same"),74layers.LeakyReLU(negative_slope=0.2),75layers.Flatten(),76layers.Dropout(0.2),77layers.Dense(1, activation="sigmoid"),78],79name="discriminator",80)81discriminator.summary()8283"""84## Create the generator8586It mirrors the discriminator, replacing `Conv2D` layers with `Conv2DTranspose` layers.87"""8889latent_dim = 1289091generator = keras.Sequential(92[93keras.Input(shape=(latent_dim,)),94layers.Dense(8 * 8 * 128),95layers.Reshape((8, 8, 128)),96layers.Conv2DTranspose(128, kernel_size=4, strides=2, padding="same"),97layers.LeakyReLU(negative_slope=0.2),98layers.Conv2DTranspose(256, kernel_size=4, strides=2, padding="same"),99layers.LeakyReLU(negative_slope=0.2),100layers.Conv2DTranspose(512, kernel_size=4, strides=2, padding="same"),101layers.LeakyReLU(negative_slope=0.2),102layers.Conv2D(3, kernel_size=5, padding="same", activation="sigmoid"),103],104name="generator",105)106generator.summary()107108"""109## Override `train_step`110"""111112113class GAN(keras.Model):114def __init__(self, discriminator, generator, latent_dim):115super().__init__()116self.discriminator = discriminator117self.generator = generator118self.latent_dim = latent_dim119self.seed_generator = keras.random.SeedGenerator(1337)120121def compile(self, d_optimizer, g_optimizer, loss_fn):122super().compile()123self.d_optimizer = d_optimizer124self.g_optimizer = g_optimizer125self.loss_fn = loss_fn126self.d_loss_metric = keras.metrics.Mean(name="d_loss")127self.g_loss_metric = keras.metrics.Mean(name="g_loss")128129@property130def metrics(self):131return [self.d_loss_metric, self.g_loss_metric]132133def train_step(self, real_images):134# Sample random points in the latent space135batch_size = ops.shape(real_images)[0]136random_latent_vectors = keras.random.normal(137shape=(batch_size, self.latent_dim), seed=self.seed_generator138)139140# Decode them to fake images141generated_images = self.generator(random_latent_vectors)142143# Combine them with real images144combined_images = ops.concatenate([generated_images, real_images], axis=0)145146# Assemble labels discriminating real from fake images147labels = ops.concatenate(148[ops.ones((batch_size, 1)), ops.zeros((batch_size, 1))], axis=0149)150# Add random noise to the labels - important trick!151labels += 0.05 * tf.random.uniform(tf.shape(labels))152153# Train the discriminator154with tf.GradientTape() as tape:155predictions = self.discriminator(combined_images)156d_loss = self.loss_fn(labels, predictions)157grads = tape.gradient(d_loss, self.discriminator.trainable_weights)158self.d_optimizer.apply_gradients(159zip(grads, self.discriminator.trainable_weights)160)161162# Sample random points in the latent space163random_latent_vectors = keras.random.normal(164shape=(batch_size, self.latent_dim), seed=self.seed_generator165)166167# Assemble labels that say "all real images"168misleading_labels = ops.zeros((batch_size, 1))169170# Train the generator (note that we should *not* update the weights171# of the discriminator)!172with tf.GradientTape() as tape:173predictions = self.discriminator(self.generator(random_latent_vectors))174g_loss = self.loss_fn(misleading_labels, predictions)175grads = tape.gradient(g_loss, self.generator.trainable_weights)176self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))177178# Update metrics179self.d_loss_metric.update_state(d_loss)180self.g_loss_metric.update_state(g_loss)181return {182"d_loss": self.d_loss_metric.result(),183"g_loss": self.g_loss_metric.result(),184}185186187"""188## Create a callback that periodically saves generated images189"""190191192class GANMonitor(keras.callbacks.Callback):193def __init__(self, num_img=3, latent_dim=128):194self.num_img = num_img195self.latent_dim = latent_dim196self.seed_generator = keras.random.SeedGenerator(42)197198def on_epoch_end(self, epoch, logs=None):199random_latent_vectors = keras.random.normal(200shape=(self.num_img, self.latent_dim), seed=self.seed_generator201)202generated_images = self.model.generator(random_latent_vectors)203generated_images *= 255204generated_images.numpy()205for i in range(self.num_img):206img = keras.utils.array_to_img(generated_images[i])207img.save("generated_img_%03d_%d.png" % (epoch, i))208209210"""211## Train the end-to-end model212"""213214epochs = 1 # In practice, use ~100 epochs215216gan = GAN(discriminator=discriminator, generator=generator, latent_dim=latent_dim)217gan.compile(218d_optimizer=keras.optimizers.Adam(learning_rate=0.0001),219g_optimizer=keras.optimizers.Adam(learning_rate=0.0001),220loss_fn=keras.losses.BinaryCrossentropy(),221)222223gan.fit(224dataset, epochs=epochs, callbacks=[GANMonitor(num_img=10, latent_dim=latent_dim)]225)226227"""228Some of the last generated images around epoch 30229(results keep improving after that):230231232"""233234235