Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/generative/wgan_gp.py
3507 views
1
"""
2
Title: WGAN-GP overriding `Model.train_step`
3
Author: [A_K_Nain](https://twitter.com/A_K_Nain)
4
Date created: 2020/05/9
5
Last modified: 2023/08/3
6
Description: Implementation of Wasserstein GAN with Gradient Penalty.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Wasserstein GAN (WGAN) with Gradient Penalty (GP)
12
13
The original [Wasserstein GAN](https://arxiv.org/abs/1701.07875) leverages the
14
Wasserstein distance to produce a value function that has better theoretical
15
properties than the value function used in the original GAN paper. WGAN requires
16
that the discriminator (aka the critic) lie within the space of 1-Lipschitz
17
functions. The authors proposed the idea of weight clipping to achieve this
18
constraint. Though weight clipping works, it can be a problematic way to enforce
19
1-Lipschitz constraint and can cause undesirable behavior, e.g. a very deep WGAN
20
discriminator (critic) often fails to converge.
21
22
The [WGAN-GP](https://arxiv.org/abs/1704.00028) method proposes an
23
alternative to weight clipping to ensure smooth training. Instead of clipping
24
the weights, the authors proposed a "gradient penalty" by adding a loss term
25
that keeps the L2 norm of the discriminator gradients close to 1.
26
"""
27
28
"""
29
## Setup
30
"""
31
import os
32
33
os.environ["KERAS_BACKEND"] = "tensorflow"
34
35
import keras
36
import tensorflow as tf
37
from keras import layers
38
39
40
"""
41
## Prepare the Fashion-MNIST data
42
43
To demonstrate how to train WGAN-GP, we will be using the
44
[Fashion-MNIST](https://github.com/zalandoresearch/fashion-mnist) dataset. Each
45
sample in this dataset is a 28x28 grayscale image associated with a label from
46
10 classes (e.g. trouser, pullover, sneaker, etc.)
47
"""
48
49
IMG_SHAPE = (28, 28, 1)
50
BATCH_SIZE = 512
51
52
# Size of the noise vector
53
noise_dim = 128
54
55
fashion_mnist = keras.datasets.fashion_mnist
56
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
57
print(f"Number of examples: {len(train_images)}")
58
print(f"Shape of the images in the dataset: {train_images.shape[1:]}")
59
60
# Reshape each sample to (28, 28, 1) and normalize the pixel values in the [-1, 1] range
61
train_images = train_images.reshape(train_images.shape[0], *IMG_SHAPE).astype("float32")
62
train_images = (train_images - 127.5) / 127.5
63
64
"""
65
## Create the discriminator (the critic in the original WGAN)
66
67
The samples in the dataset have a (28, 28, 1) shape. Because we will be
68
using strided convolutions, this can result in a shape with odd dimensions.
69
For example,
70
`(28, 28) -> Conv_s2 -> (14, 14) -> Conv_s2 -> (7, 7) -> Conv_s2 ->(3, 3)`.
71
72
While performing upsampling in the generator part of the network, we won't get
73
the same input shape as the original images if we aren't careful. To avoid this,
74
we will do something much simpler:
75
- In the discriminator: "zero pad" the input to change the shape to `(32, 32, 1)`
76
for each sample; and
77
- Ihe generator: crop the final output to match the shape with input shape.
78
"""
79
80
81
def conv_block(
82
x,
83
filters,
84
activation,
85
kernel_size=(3, 3),
86
strides=(1, 1),
87
padding="same",
88
use_bias=True,
89
use_bn=False,
90
use_dropout=False,
91
drop_value=0.5,
92
):
93
x = layers.Conv2D(
94
filters, kernel_size, strides=strides, padding=padding, use_bias=use_bias
95
)(x)
96
if use_bn:
97
x = layers.BatchNormalization()(x)
98
x = activation(x)
99
if use_dropout:
100
x = layers.Dropout(drop_value)(x)
101
return x
102
103
104
def get_discriminator_model():
105
img_input = layers.Input(shape=IMG_SHAPE)
106
# Zero pad the input to make the input images size to (32, 32, 1).
107
x = layers.ZeroPadding2D((2, 2))(img_input)
108
x = conv_block(
109
x,
110
64,
111
kernel_size=(5, 5),
112
strides=(2, 2),
113
use_bn=False,
114
use_bias=True,
115
activation=layers.LeakyReLU(0.2),
116
use_dropout=False,
117
drop_value=0.3,
118
)
119
x = conv_block(
120
x,
121
128,
122
kernel_size=(5, 5),
123
strides=(2, 2),
124
use_bn=False,
125
activation=layers.LeakyReLU(0.2),
126
use_bias=True,
127
use_dropout=True,
128
drop_value=0.3,
129
)
130
x = conv_block(
131
x,
132
256,
133
kernel_size=(5, 5),
134
strides=(2, 2),
135
use_bn=False,
136
activation=layers.LeakyReLU(0.2),
137
use_bias=True,
138
use_dropout=True,
139
drop_value=0.3,
140
)
141
x = conv_block(
142
x,
143
512,
144
kernel_size=(5, 5),
145
strides=(2, 2),
146
use_bn=False,
147
activation=layers.LeakyReLU(0.2),
148
use_bias=True,
149
use_dropout=False,
150
drop_value=0.3,
151
)
152
153
x = layers.Flatten()(x)
154
x = layers.Dropout(0.2)(x)
155
x = layers.Dense(1)(x)
156
157
d_model = keras.models.Model(img_input, x, name="discriminator")
158
return d_model
159
160
161
d_model = get_discriminator_model()
162
d_model.summary()
163
164
"""
165
## Create the generator
166
"""
167
168
169
def upsample_block(
170
x,
171
filters,
172
activation,
173
kernel_size=(3, 3),
174
strides=(1, 1),
175
up_size=(2, 2),
176
padding="same",
177
use_bn=False,
178
use_bias=True,
179
use_dropout=False,
180
drop_value=0.3,
181
):
182
x = layers.UpSampling2D(up_size)(x)
183
x = layers.Conv2D(
184
filters, kernel_size, strides=strides, padding=padding, use_bias=use_bias
185
)(x)
186
187
if use_bn:
188
x = layers.BatchNormalization()(x)
189
190
if activation:
191
x = activation(x)
192
if use_dropout:
193
x = layers.Dropout(drop_value)(x)
194
return x
195
196
197
def get_generator_model():
198
noise = layers.Input(shape=(noise_dim,))
199
x = layers.Dense(4 * 4 * 256, use_bias=False)(noise)
200
x = layers.BatchNormalization()(x)
201
x = layers.LeakyReLU(0.2)(x)
202
203
x = layers.Reshape((4, 4, 256))(x)
204
x = upsample_block(
205
x,
206
128,
207
layers.LeakyReLU(0.2),
208
strides=(1, 1),
209
use_bias=False,
210
use_bn=True,
211
padding="same",
212
use_dropout=False,
213
)
214
x = upsample_block(
215
x,
216
64,
217
layers.LeakyReLU(0.2),
218
strides=(1, 1),
219
use_bias=False,
220
use_bn=True,
221
padding="same",
222
use_dropout=False,
223
)
224
x = upsample_block(
225
x, 1, layers.Activation("tanh"), strides=(1, 1), use_bias=False, use_bn=True
226
)
227
# At this point, we have an output which has the same shape as the input, (32, 32, 1).
228
# We will use a Cropping2D layer to make it (28, 28, 1).
229
x = layers.Cropping2D((2, 2))(x)
230
231
g_model = keras.models.Model(noise, x, name="generator")
232
return g_model
233
234
235
g_model = get_generator_model()
236
g_model.summary()
237
238
"""
239
## Create the WGAN-GP model
240
241
Now that we have defined our generator and discriminator, it's time to implement
242
the WGAN-GP model. We will also override the `train_step` for training.
243
"""
244
245
246
class WGAN(keras.Model):
247
def __init__(
248
self,
249
discriminator,
250
generator,
251
latent_dim,
252
discriminator_extra_steps=3,
253
gp_weight=10.0,
254
):
255
super().__init__()
256
self.discriminator = discriminator
257
self.generator = generator
258
self.latent_dim = latent_dim
259
self.d_steps = discriminator_extra_steps
260
self.gp_weight = gp_weight
261
262
def compile(self, d_optimizer, g_optimizer, d_loss_fn, g_loss_fn):
263
super().compile()
264
self.d_optimizer = d_optimizer
265
self.g_optimizer = g_optimizer
266
self.d_loss_fn = d_loss_fn
267
self.g_loss_fn = g_loss_fn
268
269
def gradient_penalty(self, batch_size, real_images, fake_images):
270
"""Calculates the gradient penalty.
271
272
This loss is calculated on an interpolated image
273
and added to the discriminator loss.
274
"""
275
# Get the interpolated image
276
alpha = tf.random.uniform([batch_size, 1, 1, 1], 0.0, 1.0)
277
diff = fake_images - real_images
278
interpolated = real_images + alpha * diff
279
280
with tf.GradientTape() as gp_tape:
281
gp_tape.watch(interpolated)
282
# 1. Get the discriminator output for this interpolated image.
283
pred = self.discriminator(interpolated, training=True)
284
285
# 2. Calculate the gradients w.r.t to this interpolated image.
286
grads = gp_tape.gradient(pred, [interpolated])[0]
287
# 3. Calculate the norm of the gradients.
288
norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))
289
gp = tf.reduce_mean((norm - 1.0) ** 2)
290
return gp
291
292
def train_step(self, real_images):
293
if isinstance(real_images, tuple):
294
real_images = real_images[0]
295
296
# Get the batch size
297
batch_size = tf.shape(real_images)[0]
298
299
# For each batch, we are going to perform the
300
# following steps as laid out in the original paper:
301
# 1. Train the generator and get the generator loss
302
# 2. Train the discriminator and get the discriminator loss
303
# 3. Calculate the gradient penalty
304
# 4. Multiply this gradient penalty with a constant weight factor
305
# 5. Add the gradient penalty to the discriminator loss
306
# 6. Return the generator and discriminator losses as a loss dictionary
307
308
# Train the discriminator first. The original paper recommends training
309
# the discriminator for `x` more steps (typically 5) as compared to
310
# one step of the generator. Here we will train it for 3 extra steps
311
# as compared to 5 to reduce the training time.
312
for i in range(self.d_steps):
313
# Get the latent vector
314
random_latent_vectors = tf.random.normal(
315
shape=(batch_size, self.latent_dim)
316
)
317
with tf.GradientTape() as tape:
318
# Generate fake images from the latent vector
319
fake_images = self.generator(random_latent_vectors, training=True)
320
# Get the logits for the fake images
321
fake_logits = self.discriminator(fake_images, training=True)
322
# Get the logits for the real images
323
real_logits = self.discriminator(real_images, training=True)
324
325
# Calculate the discriminator loss using the fake and real image logits
326
d_cost = self.d_loss_fn(real_img=real_logits, fake_img=fake_logits)
327
# Calculate the gradient penalty
328
gp = self.gradient_penalty(batch_size, real_images, fake_images)
329
# Add the gradient penalty to the original discriminator loss
330
d_loss = d_cost + gp * self.gp_weight
331
332
# Get the gradients w.r.t the discriminator loss
333
d_gradient = tape.gradient(d_loss, self.discriminator.trainable_variables)
334
# Update the weights of the discriminator using the discriminator optimizer
335
self.d_optimizer.apply_gradients(
336
zip(d_gradient, self.discriminator.trainable_variables)
337
)
338
339
# Train the generator
340
# Get the latent vector
341
random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
342
with tf.GradientTape() as tape:
343
# Generate fake images using the generator
344
generated_images = self.generator(random_latent_vectors, training=True)
345
# Get the discriminator logits for fake images
346
gen_img_logits = self.discriminator(generated_images, training=True)
347
# Calculate the generator loss
348
g_loss = self.g_loss_fn(gen_img_logits)
349
350
# Get the gradients w.r.t the generator loss
351
gen_gradient = tape.gradient(g_loss, self.generator.trainable_variables)
352
# Update the weights of the generator using the generator optimizer
353
self.g_optimizer.apply_gradients(
354
zip(gen_gradient, self.generator.trainable_variables)
355
)
356
return {"d_loss": d_loss, "g_loss": g_loss}
357
358
359
"""
360
## Create a Keras callback that periodically saves generated images
361
"""
362
363
364
class GANMonitor(keras.callbacks.Callback):
365
def __init__(self, num_img=6, latent_dim=128):
366
self.num_img = num_img
367
self.latent_dim = latent_dim
368
369
def on_epoch_end(self, epoch, logs=None):
370
random_latent_vectors = tf.random.normal(shape=(self.num_img, self.latent_dim))
371
generated_images = self.model.generator(random_latent_vectors)
372
generated_images = (generated_images * 127.5) + 127.5
373
374
for i in range(self.num_img):
375
img = generated_images[i].numpy()
376
img = keras.utils.array_to_img(img)
377
img.save("generated_img_{i}_{epoch}.png".format(i=i, epoch=epoch))
378
379
380
"""
381
## Train the end-to-end model
382
"""
383
384
# Instantiate the optimizer for both networks
385
# (learning_rate=0.0002, beta_1=0.5 are recommended)
386
generator_optimizer = keras.optimizers.Adam(
387
learning_rate=0.0002, beta_1=0.5, beta_2=0.9
388
)
389
discriminator_optimizer = keras.optimizers.Adam(
390
learning_rate=0.0002, beta_1=0.5, beta_2=0.9
391
)
392
393
394
# Define the loss functions for the discriminator,
395
# which should be (fake_loss - real_loss).
396
# We will add the gradient penalty later to this loss function.
397
def discriminator_loss(real_img, fake_img):
398
real_loss = tf.reduce_mean(real_img)
399
fake_loss = tf.reduce_mean(fake_img)
400
return fake_loss - real_loss
401
402
403
# Define the loss functions for the generator.
404
def generator_loss(fake_img):
405
return -tf.reduce_mean(fake_img)
406
407
408
# Set the number of epochs for training.
409
epochs = 20
410
411
# Instantiate the customer `GANMonitor` Keras callback.
412
cbk = GANMonitor(num_img=3, latent_dim=noise_dim)
413
414
# Get the wgan model
415
wgan = WGAN(
416
discriminator=d_model,
417
generator=g_model,
418
latent_dim=noise_dim,
419
discriminator_extra_steps=3,
420
)
421
422
# Compile the wgan model
423
wgan.compile(
424
d_optimizer=discriminator_optimizer,
425
g_optimizer=generator_optimizer,
426
g_loss_fn=generator_loss,
427
d_loss_fn=discriminator_loss,
428
)
429
430
# Start training
431
wgan.fit(train_images, batch_size=BATCH_SIZE, epochs=epochs, callbacks=[cbk])
432
433
"""
434
Display the last generated images:
435
"""
436
437
from IPython.display import Image, display
438
439
display(Image("generated_img_0_19.png"))
440
display(Image("generated_img_1_19.png"))
441
display(Image("generated_img_2_19.png"))
442
443