Path: blob/master/examples/generative/ipynb/conditional_gan.ipynb
3508 views
Conditional GAN
Author: Sayak Paul
Date created: 2021/07/13
Last modified: 2024/01/02
Description: Training a GAN conditioned on class labels to generate handwritten digits.
Generative Adversarial Networks (GANs) let us generate novel image data, video data, or audio data from a random input. Typically, the random input is sampled from a normal distribution, before going through a series of transformations that turn it into something plausible (image, video, audio, etc.).
However, a simple DCGAN doesn't let us control the appearance (e.g. class) of the samples we're generating. For instance, with a GAN that generates MNIST handwritten digits, a simple DCGAN wouldn't let us choose the class of digits we're generating. To be able to control what we generate, we need to condition the GAN output on a semantic input, such as the class of an image.
In this example, we'll build a Conditional GAN that can generate MNIST handwritten digits conditioned on a given class. Such a model can have various useful applications:
let's say you are dealing with an imbalanced image dataset, and you'd like to gather more examples for the skewed class to balance the dataset. Data collection can be a costly process on its own. You could instead train a Conditional GAN and use it to generate novel images for the class that needs balancing.
Since the generator learns to associate the generated samples with the class labels, its representations can also be used for other downstream tasks.
Following are the references used for developing this example:
If you need a refresher on GANs, you can refer to the "Generative adversarial networks" section of this resource.
This example requires TensorFlow 2.5 or higher, as well as TensorFlow Docs, which can be installed using the following command:
Imports
Constants and hyperparameters
Loading the MNIST dataset and preprocessing it
Calculating the number of input channel for the generator and discriminator
In a regular (unconditional) GAN, we start by sampling noise (of some fixed dimension) from a normal distribution. In our case, we also need to account for the class labels. We will have to add the number of classes to the input channels of the generator (noise input) as well as the discriminator (generated image input).
Creating the discriminator and generator
The model definitions (discriminator
, generator
, and ConditionalGAN
) have been adapted from this example.
Creating a ConditionalGAN
model
Training the Conditional GAN
Interpolating between classes with the trained generator
Here, we first sample noise from a normal distribution and then we repeat that for num_interpolation
times and reshape the result accordingly. We then distribute it uniformly for num_interpolation
with the label identities being present in some proportion.
We can further improve the performance of this model with recipes like WGAN-GP. Conditional generation is also widely used in many modern image generation architectures like VQ-GANs, DALL-E, etc.
You can use the trained model hosted on Hugging Face Hub and try the demo on Hugging Face Spaces.