Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/generative/conditional_gan.py
3507 views
1
"""
2
Title: Conditional GAN
3
Author: [Sayak Paul](https://twitter.com/RisingSayak)
4
Date created: 2021/07/13
5
Last modified: 2024/01/02
6
Description: Training a GAN conditioned on class labels to generate handwritten digits.
7
Accelerator: GPU
8
"""
9
10
"""
11
Generative Adversarial Networks (GANs) let us generate novel image data, video data,
12
or audio data from a random input. Typically, the random input is sampled
13
from a normal distribution, before going through a series of transformations that turn
14
it into something plausible (image, video, audio, etc.).
15
16
However, a simple [DCGAN](https://arxiv.org/abs/1511.06434) doesn't let us control
17
the appearance (e.g. class) of the samples we're generating. For instance,
18
with a GAN that generates MNIST handwritten digits, a simple DCGAN wouldn't let us
19
choose the class of digits we're generating.
20
To be able to control what we generate, we need to _condition_ the GAN output
21
on a semantic input, such as the class of an image.
22
23
In this example, we'll build a **Conditional GAN** that can generate MNIST handwritten
24
digits conditioned on a given class. Such a model can have various useful applications:
25
26
* let's say you are dealing with an
27
[imbalanced image dataset](https://developers.google.com/machine-learning/data-prep/construct/sampling-splitting/imbalanced-data),
28
and you'd like to gather more examples for the skewed class to balance the dataset.
29
Data collection can be a costly process on its own. You could instead train a Conditional GAN and use
30
it to generate novel images for the class that needs balancing.
31
* Since the generator learns to associate the generated samples with the class labels,
32
its representations can also be used for [other downstream tasks](https://arxiv.org/abs/1809.11096).
33
34
Following are the references used for developing this example:
35
36
* [Conditional Generative Adversarial Nets](https://arxiv.org/abs/1411.1784)
37
* [Lecture on Conditional Generation from Coursera](https://www.coursera.org/lecture/build-basic-generative-adversarial-networks-gans/conditional-generation-inputs-2OPrG)
38
39
If you need a refresher on GANs, you can refer to the "Generative adversarial networks"
40
section of
41
[this resource](https://livebook.manning.com/book/deep-learning-with-python-second-edition/chapter-12/r-3/232).
42
43
This example requires TensorFlow 2.5 or higher, as well as TensorFlow Docs, which can be
44
installed using the following command:
45
"""
46
47
"""shell
48
pip install -q git+https://github.com/tensorflow/docs
49
"""
50
51
"""
52
## Imports
53
"""
54
55
import keras
56
57
from keras import layers
58
from keras import ops
59
from tensorflow_docs.vis import embed
60
import tensorflow as tf
61
import numpy as np
62
import imageio
63
64
"""
65
## Constants and hyperparameters
66
"""
67
68
batch_size = 64
69
num_channels = 1
70
num_classes = 10
71
image_size = 28
72
latent_dim = 128
73
74
"""
75
## Loading the MNIST dataset and preprocessing it
76
"""
77
78
# We'll use all the available examples from both the training and test
79
# sets.
80
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
81
all_digits = np.concatenate([x_train, x_test])
82
all_labels = np.concatenate([y_train, y_test])
83
84
# Scale the pixel values to [0, 1] range, add a channel dimension to
85
# the images, and one-hot encode the labels.
86
all_digits = all_digits.astype("float32") / 255.0
87
all_digits = np.reshape(all_digits, (-1, 28, 28, 1))
88
all_labels = keras.utils.to_categorical(all_labels, 10)
89
90
# Create tf.data.Dataset.
91
dataset = tf.data.Dataset.from_tensor_slices((all_digits, all_labels))
92
dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)
93
94
print(f"Shape of training images: {all_digits.shape}")
95
print(f"Shape of training labels: {all_labels.shape}")
96
97
"""
98
## Calculating the number of input channel for the generator and discriminator
99
100
In a regular (unconditional) GAN, we start by sampling noise (of some fixed
101
dimension) from a normal distribution. In our case, we also need to account
102
for the class labels. We will have to add the number of classes to
103
the input channels of the generator (noise input) as well as the discriminator
104
(generated image input).
105
"""
106
107
generator_in_channels = latent_dim + num_classes
108
discriminator_in_channels = num_channels + num_classes
109
print(generator_in_channels, discriminator_in_channels)
110
111
"""
112
## Creating the discriminator and generator
113
114
The model definitions (`discriminator`, `generator`, and `ConditionalGAN`) have been
115
adapted from [this example](https://keras.io/guides/customizing_what_happens_in_fit/).
116
"""
117
118
# Create the discriminator.
119
discriminator = keras.Sequential(
120
[
121
keras.layers.InputLayer((28, 28, discriminator_in_channels)),
122
layers.Conv2D(64, (3, 3), strides=(2, 2), padding="same"),
123
layers.LeakyReLU(negative_slope=0.2),
124
layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same"),
125
layers.LeakyReLU(negative_slope=0.2),
126
layers.GlobalMaxPooling2D(),
127
layers.Dense(1),
128
],
129
name="discriminator",
130
)
131
132
# Create the generator.
133
generator = keras.Sequential(
134
[
135
keras.layers.InputLayer((generator_in_channels,)),
136
# We want to generate 128 + num_classes coefficients to reshape into a
137
# 7x7x(128 + num_classes) map.
138
layers.Dense(7 * 7 * generator_in_channels),
139
layers.LeakyReLU(negative_slope=0.2),
140
layers.Reshape((7, 7, generator_in_channels)),
141
layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
142
layers.LeakyReLU(negative_slope=0.2),
143
layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
144
layers.LeakyReLU(negative_slope=0.2),
145
layers.Conv2D(1, (7, 7), padding="same", activation="sigmoid"),
146
],
147
name="generator",
148
)
149
150
"""
151
## Creating a `ConditionalGAN` model
152
"""
153
154
155
class ConditionalGAN(keras.Model):
156
def __init__(self, discriminator, generator, latent_dim):
157
super().__init__()
158
self.discriminator = discriminator
159
self.generator = generator
160
self.latent_dim = latent_dim
161
self.seed_generator = keras.random.SeedGenerator(1337)
162
self.gen_loss_tracker = keras.metrics.Mean(name="generator_loss")
163
self.disc_loss_tracker = keras.metrics.Mean(name="discriminator_loss")
164
165
@property
166
def metrics(self):
167
return [self.gen_loss_tracker, self.disc_loss_tracker]
168
169
def compile(self, d_optimizer, g_optimizer, loss_fn):
170
super().compile()
171
self.d_optimizer = d_optimizer
172
self.g_optimizer = g_optimizer
173
self.loss_fn = loss_fn
174
175
def train_step(self, data):
176
# Unpack the data.
177
real_images, one_hot_labels = data
178
179
# Add dummy dimensions to the labels so that they can be concatenated with
180
# the images. This is for the discriminator.
181
image_one_hot_labels = one_hot_labels[:, :, None, None]
182
image_one_hot_labels = ops.repeat(
183
image_one_hot_labels, repeats=[image_size * image_size]
184
)
185
image_one_hot_labels = ops.reshape(
186
image_one_hot_labels, (-1, image_size, image_size, num_classes)
187
)
188
189
# Sample random points in the latent space and concatenate the labels.
190
# This is for the generator.
191
batch_size = ops.shape(real_images)[0]
192
random_latent_vectors = keras.random.normal(
193
shape=(batch_size, self.latent_dim), seed=self.seed_generator
194
)
195
random_vector_labels = ops.concatenate(
196
[random_latent_vectors, one_hot_labels], axis=1
197
)
198
199
# Decode the noise (guided by labels) to fake images.
200
generated_images = self.generator(random_vector_labels)
201
202
# Combine them with real images. Note that we are concatenating the labels
203
# with these images here.
204
fake_image_and_labels = ops.concatenate(
205
[generated_images, image_one_hot_labels], -1
206
)
207
real_image_and_labels = ops.concatenate([real_images, image_one_hot_labels], -1)
208
combined_images = ops.concatenate(
209
[fake_image_and_labels, real_image_and_labels], axis=0
210
)
211
212
# Assemble labels discriminating real from fake images.
213
labels = ops.concatenate(
214
[ops.ones((batch_size, 1)), ops.zeros((batch_size, 1))], axis=0
215
)
216
217
# Train the discriminator.
218
with tf.GradientTape() as tape:
219
predictions = self.discriminator(combined_images)
220
d_loss = self.loss_fn(labels, predictions)
221
grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
222
self.d_optimizer.apply_gradients(
223
zip(grads, self.discriminator.trainable_weights)
224
)
225
226
# Sample random points in the latent space.
227
random_latent_vectors = keras.random.normal(
228
shape=(batch_size, self.latent_dim), seed=self.seed_generator
229
)
230
random_vector_labels = ops.concatenate(
231
[random_latent_vectors, one_hot_labels], axis=1
232
)
233
234
# Assemble labels that say "all real images".
235
misleading_labels = ops.zeros((batch_size, 1))
236
237
# Train the generator (note that we should *not* update the weights
238
# of the discriminator)!
239
with tf.GradientTape() as tape:
240
fake_images = self.generator(random_vector_labels)
241
fake_image_and_labels = ops.concatenate(
242
[fake_images, image_one_hot_labels], -1
243
)
244
predictions = self.discriminator(fake_image_and_labels)
245
g_loss = self.loss_fn(misleading_labels, predictions)
246
grads = tape.gradient(g_loss, self.generator.trainable_weights)
247
self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))
248
249
# Monitor loss.
250
self.gen_loss_tracker.update_state(g_loss)
251
self.disc_loss_tracker.update_state(d_loss)
252
return {
253
"g_loss": self.gen_loss_tracker.result(),
254
"d_loss": self.disc_loss_tracker.result(),
255
}
256
257
258
"""
259
## Training the Conditional GAN
260
"""
261
262
cond_gan = ConditionalGAN(
263
discriminator=discriminator, generator=generator, latent_dim=latent_dim
264
)
265
cond_gan.compile(
266
d_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
267
g_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
268
loss_fn=keras.losses.BinaryCrossentropy(from_logits=True),
269
)
270
271
cond_gan.fit(dataset, epochs=20)
272
273
"""
274
## Interpolating between classes with the trained generator
275
"""
276
277
# We first extract the trained generator from our Conditional GAN.
278
trained_gen = cond_gan.generator
279
280
# Choose the number of intermediate images that would be generated in
281
# between the interpolation + 2 (start and last images).
282
num_interpolation = 9 # @param {type:"integer"}
283
284
# Sample noise for the interpolation.
285
interpolation_noise = keras.random.normal(shape=(1, latent_dim))
286
interpolation_noise = ops.repeat(interpolation_noise, repeats=num_interpolation)
287
interpolation_noise = ops.reshape(interpolation_noise, (num_interpolation, latent_dim))
288
289
290
def interpolate_class(first_number, second_number):
291
# Convert the start and end labels to one-hot encoded vectors.
292
first_label = keras.utils.to_categorical([first_number], num_classes)
293
second_label = keras.utils.to_categorical([second_number], num_classes)
294
first_label = ops.cast(first_label, "float32")
295
second_label = ops.cast(second_label, "float32")
296
297
# Calculate the interpolation vector between the two labels.
298
percent_second_label = ops.linspace(0, 1, num_interpolation)[:, None]
299
percent_second_label = ops.cast(percent_second_label, "float32")
300
interpolation_labels = (
301
first_label * (1 - percent_second_label) + second_label * percent_second_label
302
)
303
304
# Combine the noise and the labels and run inference with the generator.
305
noise_and_labels = ops.concatenate([interpolation_noise, interpolation_labels], 1)
306
fake = trained_gen.predict(noise_and_labels)
307
return fake
308
309
310
start_class = 2 # @param {type:"slider", min:0, max:9, step:1}
311
end_class = 6 # @param {type:"slider", min:0, max:9, step:1}
312
313
fake_images = interpolate_class(start_class, end_class)
314
315
"""
316
Here, we first sample noise from a normal distribution and then we repeat that for
317
`num_interpolation` times and reshape the result accordingly.
318
We then distribute it uniformly for `num_interpolation`
319
with the label identities being present in some proportion.
320
"""
321
322
fake_images *= 255.0
323
converted_images = fake_images.astype(np.uint8)
324
converted_images = ops.image.resize(converted_images, (96, 96)).numpy().astype(np.uint8)
325
imageio.mimsave("animation.gif", converted_images[:, :, :, 0], fps=1)
326
embed.embed_file("animation.gif")
327
328
"""
329
We can further improve the performance of this model with recipes like
330
[WGAN-GP](https://keras.io/examples/generative/wgan_gp).
331
Conditional generation is also widely used in many modern image generation architectures like
332
[VQ-GANs](https://arxiv.org/abs/2012.09841), [DALL-E](https://openai.com/blog/dall-e/),
333
etc.
334
335
You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/conditional-gan) and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/conditional-GAN).
336
"""
337
338