Path: blob/master/examples/generative/conditional_gan.py
3507 views
"""1Title: Conditional GAN2Author: [Sayak Paul](https://twitter.com/RisingSayak)3Date created: 2021/07/134Last modified: 2024/01/025Description: Training a GAN conditioned on class labels to generate handwritten digits.6Accelerator: GPU7"""89"""10Generative Adversarial Networks (GANs) let us generate novel image data, video data,11or audio data from a random input. Typically, the random input is sampled12from a normal distribution, before going through a series of transformations that turn13it into something plausible (image, video, audio, etc.).1415However, a simple [DCGAN](https://arxiv.org/abs/1511.06434) doesn't let us control16the appearance (e.g. class) of the samples we're generating. For instance,17with a GAN that generates MNIST handwritten digits, a simple DCGAN wouldn't let us18choose the class of digits we're generating.19To be able to control what we generate, we need to _condition_ the GAN output20on a semantic input, such as the class of an image.2122In this example, we'll build a **Conditional GAN** that can generate MNIST handwritten23digits conditioned on a given class. Such a model can have various useful applications:2425* let's say you are dealing with an26[imbalanced image dataset](https://developers.google.com/machine-learning/data-prep/construct/sampling-splitting/imbalanced-data),27and you'd like to gather more examples for the skewed class to balance the dataset.28Data collection can be a costly process on its own. You could instead train a Conditional GAN and use29it to generate novel images for the class that needs balancing.30* Since the generator learns to associate the generated samples with the class labels,31its representations can also be used for [other downstream tasks](https://arxiv.org/abs/1809.11096).3233Following are the references used for developing this example:3435* [Conditional Generative Adversarial Nets](https://arxiv.org/abs/1411.1784)36* [Lecture on Conditional Generation from Coursera](https://www.coursera.org/lecture/build-basic-generative-adversarial-networks-gans/conditional-generation-inputs-2OPrG)3738If you need a refresher on GANs, you can refer to the "Generative adversarial networks"39section of40[this resource](https://livebook.manning.com/book/deep-learning-with-python-second-edition/chapter-12/r-3/232).4142This example requires TensorFlow 2.5 or higher, as well as TensorFlow Docs, which can be43installed using the following command:44"""4546"""shell47pip install -q git+https://github.com/tensorflow/docs48"""4950"""51## Imports52"""5354import keras5556from keras import layers57from keras import ops58from tensorflow_docs.vis import embed59import tensorflow as tf60import numpy as np61import imageio6263"""64## Constants and hyperparameters65"""6667batch_size = 6468num_channels = 169num_classes = 1070image_size = 2871latent_dim = 1287273"""74## Loading the MNIST dataset and preprocessing it75"""7677# We'll use all the available examples from both the training and test78# sets.79(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()80all_digits = np.concatenate([x_train, x_test])81all_labels = np.concatenate([y_train, y_test])8283# Scale the pixel values to [0, 1] range, add a channel dimension to84# the images, and one-hot encode the labels.85all_digits = all_digits.astype("float32") / 255.086all_digits = np.reshape(all_digits, (-1, 28, 28, 1))87all_labels = keras.utils.to_categorical(all_labels, 10)8889# Create tf.data.Dataset.90dataset = tf.data.Dataset.from_tensor_slices((all_digits, all_labels))91dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)9293print(f"Shape of training images: {all_digits.shape}")94print(f"Shape of training labels: {all_labels.shape}")9596"""97## Calculating the number of input channel for the generator and discriminator9899In a regular (unconditional) GAN, we start by sampling noise (of some fixed100dimension) from a normal distribution. In our case, we also need to account101for the class labels. We will have to add the number of classes to102the input channels of the generator (noise input) as well as the discriminator103(generated image input).104"""105106generator_in_channels = latent_dim + num_classes107discriminator_in_channels = num_channels + num_classes108print(generator_in_channels, discriminator_in_channels)109110"""111## Creating the discriminator and generator112113The model definitions (`discriminator`, `generator`, and `ConditionalGAN`) have been114adapted from [this example](https://keras.io/guides/customizing_what_happens_in_fit/).115"""116117# Create the discriminator.118discriminator = keras.Sequential(119[120keras.layers.InputLayer((28, 28, discriminator_in_channels)),121layers.Conv2D(64, (3, 3), strides=(2, 2), padding="same"),122layers.LeakyReLU(negative_slope=0.2),123layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same"),124layers.LeakyReLU(negative_slope=0.2),125layers.GlobalMaxPooling2D(),126layers.Dense(1),127],128name="discriminator",129)130131# Create the generator.132generator = keras.Sequential(133[134keras.layers.InputLayer((generator_in_channels,)),135# We want to generate 128 + num_classes coefficients to reshape into a136# 7x7x(128 + num_classes) map.137layers.Dense(7 * 7 * generator_in_channels),138layers.LeakyReLU(negative_slope=0.2),139layers.Reshape((7, 7, generator_in_channels)),140layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),141layers.LeakyReLU(negative_slope=0.2),142layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),143layers.LeakyReLU(negative_slope=0.2),144layers.Conv2D(1, (7, 7), padding="same", activation="sigmoid"),145],146name="generator",147)148149"""150## Creating a `ConditionalGAN` model151"""152153154class ConditionalGAN(keras.Model):155def __init__(self, discriminator, generator, latent_dim):156super().__init__()157self.discriminator = discriminator158self.generator = generator159self.latent_dim = latent_dim160self.seed_generator = keras.random.SeedGenerator(1337)161self.gen_loss_tracker = keras.metrics.Mean(name="generator_loss")162self.disc_loss_tracker = keras.metrics.Mean(name="discriminator_loss")163164@property165def metrics(self):166return [self.gen_loss_tracker, self.disc_loss_tracker]167168def compile(self, d_optimizer, g_optimizer, loss_fn):169super().compile()170self.d_optimizer = d_optimizer171self.g_optimizer = g_optimizer172self.loss_fn = loss_fn173174def train_step(self, data):175# Unpack the data.176real_images, one_hot_labels = data177178# Add dummy dimensions to the labels so that they can be concatenated with179# the images. This is for the discriminator.180image_one_hot_labels = one_hot_labels[:, :, None, None]181image_one_hot_labels = ops.repeat(182image_one_hot_labels, repeats=[image_size * image_size]183)184image_one_hot_labels = ops.reshape(185image_one_hot_labels, (-1, image_size, image_size, num_classes)186)187188# Sample random points in the latent space and concatenate the labels.189# This is for the generator.190batch_size = ops.shape(real_images)[0]191random_latent_vectors = keras.random.normal(192shape=(batch_size, self.latent_dim), seed=self.seed_generator193)194random_vector_labels = ops.concatenate(195[random_latent_vectors, one_hot_labels], axis=1196)197198# Decode the noise (guided by labels) to fake images.199generated_images = self.generator(random_vector_labels)200201# Combine them with real images. Note that we are concatenating the labels202# with these images here.203fake_image_and_labels = ops.concatenate(204[generated_images, image_one_hot_labels], -1205)206real_image_and_labels = ops.concatenate([real_images, image_one_hot_labels], -1)207combined_images = ops.concatenate(208[fake_image_and_labels, real_image_and_labels], axis=0209)210211# Assemble labels discriminating real from fake images.212labels = ops.concatenate(213[ops.ones((batch_size, 1)), ops.zeros((batch_size, 1))], axis=0214)215216# Train the discriminator.217with tf.GradientTape() as tape:218predictions = self.discriminator(combined_images)219d_loss = self.loss_fn(labels, predictions)220grads = tape.gradient(d_loss, self.discriminator.trainable_weights)221self.d_optimizer.apply_gradients(222zip(grads, self.discriminator.trainable_weights)223)224225# Sample random points in the latent space.226random_latent_vectors = keras.random.normal(227shape=(batch_size, self.latent_dim), seed=self.seed_generator228)229random_vector_labels = ops.concatenate(230[random_latent_vectors, one_hot_labels], axis=1231)232233# Assemble labels that say "all real images".234misleading_labels = ops.zeros((batch_size, 1))235236# Train the generator (note that we should *not* update the weights237# of the discriminator)!238with tf.GradientTape() as tape:239fake_images = self.generator(random_vector_labels)240fake_image_and_labels = ops.concatenate(241[fake_images, image_one_hot_labels], -1242)243predictions = self.discriminator(fake_image_and_labels)244g_loss = self.loss_fn(misleading_labels, predictions)245grads = tape.gradient(g_loss, self.generator.trainable_weights)246self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))247248# Monitor loss.249self.gen_loss_tracker.update_state(g_loss)250self.disc_loss_tracker.update_state(d_loss)251return {252"g_loss": self.gen_loss_tracker.result(),253"d_loss": self.disc_loss_tracker.result(),254}255256257"""258## Training the Conditional GAN259"""260261cond_gan = ConditionalGAN(262discriminator=discriminator, generator=generator, latent_dim=latent_dim263)264cond_gan.compile(265d_optimizer=keras.optimizers.Adam(learning_rate=0.0003),266g_optimizer=keras.optimizers.Adam(learning_rate=0.0003),267loss_fn=keras.losses.BinaryCrossentropy(from_logits=True),268)269270cond_gan.fit(dataset, epochs=20)271272"""273## Interpolating between classes with the trained generator274"""275276# We first extract the trained generator from our Conditional GAN.277trained_gen = cond_gan.generator278279# Choose the number of intermediate images that would be generated in280# between the interpolation + 2 (start and last images).281num_interpolation = 9 # @param {type:"integer"}282283# Sample noise for the interpolation.284interpolation_noise = keras.random.normal(shape=(1, latent_dim))285interpolation_noise = ops.repeat(interpolation_noise, repeats=num_interpolation)286interpolation_noise = ops.reshape(interpolation_noise, (num_interpolation, latent_dim))287288289def interpolate_class(first_number, second_number):290# Convert the start and end labels to one-hot encoded vectors.291first_label = keras.utils.to_categorical([first_number], num_classes)292second_label = keras.utils.to_categorical([second_number], num_classes)293first_label = ops.cast(first_label, "float32")294second_label = ops.cast(second_label, "float32")295296# Calculate the interpolation vector between the two labels.297percent_second_label = ops.linspace(0, 1, num_interpolation)[:, None]298percent_second_label = ops.cast(percent_second_label, "float32")299interpolation_labels = (300first_label * (1 - percent_second_label) + second_label * percent_second_label301)302303# Combine the noise and the labels and run inference with the generator.304noise_and_labels = ops.concatenate([interpolation_noise, interpolation_labels], 1)305fake = trained_gen.predict(noise_and_labels)306return fake307308309start_class = 2 # @param {type:"slider", min:0, max:9, step:1}310end_class = 6 # @param {type:"slider", min:0, max:9, step:1}311312fake_images = interpolate_class(start_class, end_class)313314"""315Here, we first sample noise from a normal distribution and then we repeat that for316`num_interpolation` times and reshape the result accordingly.317We then distribute it uniformly for `num_interpolation`318with the label identities being present in some proportion.319"""320321fake_images *= 255.0322converted_images = fake_images.astype(np.uint8)323converted_images = ops.image.resize(converted_images, (96, 96)).numpy().astype(np.uint8)324imageio.mimsave("animation.gif", converted_images[:, :, :, 0], fps=1)325embed.embed_file("animation.gif")326327"""328We can further improve the performance of this model with recipes like329[WGAN-GP](https://keras.io/examples/generative/wgan_gp).330Conditional generation is also widely used in many modern image generation architectures like331[VQ-GANs](https://arxiv.org/abs/2012.09841), [DALL-E](https://openai.com/blog/dall-e/),332etc.333334You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/conditional-gan) and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/conditional-GAN).335"""336337338