Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/generative/gan_ada.py
3507 views
1
"""
2
Title: Data-efficient GANs with Adaptive Discriminator Augmentation
3
Author: [András Béres](https://www.linkedin.com/in/andras-beres-789190210)
4
Date created: 2021/10/28
5
Last modified: 2025/01/23
6
Description: Generating images from limited data using the Caltech Birds dataset.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
### GANs
14
15
[Generative Adversarial Networks (GANs)](https://arxiv.org/abs/1406.2661) are a popular
16
class of generative deep learning models, commonly used for image generation. They
17
consist of a pair of dueling neural networks, called the discriminator and the generator.
18
The discriminator's task is to distinguish real images from generated (fake) ones, while
19
the generator network tries to fool the discriminator by generating more and more
20
realistic images. If the generator is however too easy or too hard to fool, it might fail
21
to provide useful learning signal for the generator, therefore training GANs is usually
22
considered a difficult task.
23
24
### Data augmentation for GANS
25
26
Data augmentation, a popular technique in deep learning, is the process of randomly
27
applying semantics-preserving transformations to the input data to generate multiple
28
realistic versions of it, thereby effectively multiplying the amount of training data
29
available. The simplest example is left-right flipping an image, which preserves its
30
contents while generating a second unique training sample. Data augmentation is commonly
31
used in supervised learning to prevent overfitting and enhance generalization.
32
33
The authors of [StyleGAN2-ADA](https://arxiv.org/abs/2006.06676) show that discriminator
34
overfitting can be an issue in GANs, especially when only low amounts of training data is
35
available. They propose Adaptive Discriminator Augmentation to mitigate this issue.
36
37
Applying data augmentation to GANs however is not straightforward. Since the generator is
38
updated using the discriminator's gradients, if the generated images are augmented, the
39
augmentation pipeline has to be differentiable and also has to be GPU-compatible for
40
computational efficiency. Luckily, the
41
[Keras image augmentation layers](https://keras.io/api/layers/preprocessing_layers/image_augmentation/)
42
fulfill both these requirements, and are therefore very well suited for this task.
43
44
### Invertible data augmentation
45
46
A possible difficulty when using data augmentation in generative models is the issue of
47
["leaky augmentations" (section 2.2)](https://arxiv.org/abs/2006.06676), namely when the
48
model generates images that are already augmented. This would mean that it was not able
49
to separate the augmentation from the underlying data distribution, which can be caused
50
by using non-invertible data transformations. For example, if either 0, 90, 180 or 270
51
degree rotations are performed with equal probability, the original orientation of the
52
images is impossible to infer, and this information is destroyed.
53
54
A simple trick to make data augmentations invertible is to only apply them with some
55
probability. That way the original version of the images will be more common, and the
56
data distribution can be inferred. By properly choosing this probability, one can
57
effectively regularize the discriminator without making the augmentations leaky.
58
59
"""
60
61
"""
62
## Setup
63
"""
64
65
import os
66
67
os.environ["KERAS_BACKEND"] = "tensorflow"
68
69
import matplotlib.pyplot as plt
70
import tensorflow as tf
71
import tensorflow_datasets as tfds
72
73
import keras
74
from keras import ops
75
from keras import layers
76
77
"""
78
## Hyperparameterers
79
"""
80
81
# data
82
num_epochs = 10 # train for 400 epochs for good results
83
image_size = 64
84
# resolution of Kernel Inception Distance measurement, see related section
85
kid_image_size = 75
86
padding = 0.25
87
dataset_name = "caltech_birds2011"
88
89
# adaptive discriminator augmentation
90
max_translation = 0.125
91
max_rotation = 0.125
92
max_zoom = 0.25
93
target_accuracy = 0.85
94
integration_steps = 1000
95
96
# architecture
97
noise_size = 64
98
depth = 4
99
width = 128
100
leaky_relu_slope = 0.2
101
dropout_rate = 0.4
102
103
# optimization
104
batch_size = 128
105
learning_rate = 2e-4
106
beta_1 = 0.5 # not using the default value of 0.9 is important
107
ema = 0.99
108
109
"""
110
## Data pipeline
111
112
In this example, we will use the
113
[Caltech Birds (2011)](https://www.tensorflow.org/datasets/catalog/caltech_birds2011) dataset for
114
generating images of birds, which is a diverse natural dataset containing less then 6000
115
images for training. When working with such low amounts of data, one has to take extra
116
care to retain as high data quality as possible. In this example, we use the provided
117
bounding boxes of the birds to cut them out with square crops while preserving their
118
aspect ratios when possible.
119
"""
120
121
122
def round_to_int(float_value):
123
return ops.cast(ops.round(float_value), "int32")
124
125
126
def preprocess_image(data):
127
# unnormalize bounding box coordinates
128
height = ops.cast(ops.shape(data["image"])[0], "float32")
129
width = ops.cast(ops.shape(data["image"])[1], "float32")
130
bounding_box = data["bbox"] * ops.stack([height, width, height, width])
131
132
# calculate center and length of longer side, add padding
133
target_center_y = 0.5 * (bounding_box[0] + bounding_box[2])
134
target_center_x = 0.5 * (bounding_box[1] + bounding_box[3])
135
target_size = ops.maximum(
136
(1.0 + padding) * (bounding_box[2] - bounding_box[0]),
137
(1.0 + padding) * (bounding_box[3] - bounding_box[1]),
138
)
139
140
# modify crop size to fit into image
141
target_height = ops.min(
142
[target_size, 2.0 * target_center_y, 2.0 * (height - target_center_y)]
143
)
144
target_width = ops.min(
145
[target_size, 2.0 * target_center_x, 2.0 * (width - target_center_x)]
146
)
147
148
# crop image, `ops.image.crop_images` only works with non-tensor croppings
149
image = ops.slice(
150
data["image"],
151
start_indices=(
152
round_to_int(target_center_y - 0.5 * target_height),
153
round_to_int(target_center_x - 0.5 * target_width),
154
0,
155
),
156
shape=(round_to_int(target_height), round_to_int(target_width), 3),
157
)
158
159
# resize and clip
160
image = ops.cast(image, "float32")
161
image = ops.image.resize(image, [image_size, image_size])
162
163
return ops.clip(image / 255.0, 0.0, 1.0)
164
165
166
def prepare_dataset(split):
167
# the validation dataset is shuffled as well, because data order matters
168
# for the KID calculation
169
return (
170
tfds.load(dataset_name, split=split, shuffle_files=True)
171
.map(preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)
172
.cache()
173
.shuffle(10 * batch_size)
174
.batch(batch_size, drop_remainder=True)
175
.prefetch(buffer_size=tf.data.AUTOTUNE)
176
)
177
178
179
train_dataset = prepare_dataset("train")
180
val_dataset = prepare_dataset("test")
181
182
"""
183
After preprocessing the training images look like the following:
184
![birds dataset](https://i.imgur.com/Ru5HgBM.png)
185
"""
186
187
"""
188
## Kernel inception distance
189
190
[Kernel Inception Distance (KID)](https://arxiv.org/abs/1801.01401) was proposed as a
191
replacement for the popular
192
[Frechet Inception Distance (FID)](https://arxiv.org/abs/1706.08500)
193
metric for measuring image generation quality.
194
Both metrics measure the difference in the generated and training distributions in the
195
representation space of an [InceptionV3](https://keras.io/api/applications/inceptionv3/)
196
network pretrained on
197
[ImageNet](https://www.tensorflow.org/datasets/catalog/imagenet2012).
198
199
According to the paper, KID was proposed because FID has no unbiased estimator, its
200
expected value is higher when it is measured on fewer images. KID is more suitable for
201
small datasets because its expected value does not depend on the number of samples it is
202
measured on. In my experience it is also computationally lighter, numerically more
203
stable, and simpler to implement because it can be estimated in a per-batch manner.
204
205
In this example, the images are evaluated at the minimal possible resolution of the
206
Inception network (75x75 instead of 299x299), and the metric is only measured on the
207
validation set for computational efficiency.
208
209
210
"""
211
212
213
class KID(keras.metrics.Metric):
214
def __init__(self, name="kid", **kwargs):
215
super().__init__(name=name, **kwargs)
216
217
# KID is estimated per batch and is averaged across batches
218
self.kid_tracker = keras.metrics.Mean()
219
220
# a pretrained InceptionV3 is used without its classification layer
221
# transform the pixel values to the 0-255 range, then use the same
222
# preprocessing as during pretraining
223
self.encoder = keras.Sequential(
224
[
225
layers.InputLayer(input_shape=(image_size, image_size, 3)),
226
layers.Rescaling(255.0),
227
layers.Resizing(height=kid_image_size, width=kid_image_size),
228
layers.Lambda(keras.applications.inception_v3.preprocess_input),
229
keras.applications.InceptionV3(
230
include_top=False,
231
input_shape=(kid_image_size, kid_image_size, 3),
232
weights="imagenet",
233
),
234
layers.GlobalAveragePooling2D(),
235
],
236
name="inception_encoder",
237
)
238
239
def polynomial_kernel(self, features_1, features_2):
240
feature_dimensions = ops.cast(ops.shape(features_1)[1], "float32")
241
return (
242
features_1 @ ops.transpose(features_2) / feature_dimensions + 1.0
243
) ** 3.0
244
245
def update_state(self, real_images, generated_images, sample_weight=None):
246
real_features = self.encoder(real_images, training=False)
247
generated_features = self.encoder(generated_images, training=False)
248
249
# compute polynomial kernels using the two sets of features
250
kernel_real = self.polynomial_kernel(real_features, real_features)
251
kernel_generated = self.polynomial_kernel(
252
generated_features, generated_features
253
)
254
kernel_cross = self.polynomial_kernel(real_features, generated_features)
255
256
# estimate the squared maximum mean discrepancy using the average kernel values
257
batch_size = ops.shape(real_features)[0]
258
batch_size_f = ops.cast(batch_size, "float32")
259
mean_kernel_real = ops.sum(kernel_real * (1.0 - ops.eye(batch_size))) / (
260
batch_size_f * (batch_size_f - 1.0)
261
)
262
mean_kernel_generated = ops.sum(
263
kernel_generated * (1.0 - ops.eye(batch_size))
264
) / (batch_size_f * (batch_size_f - 1.0))
265
mean_kernel_cross = ops.mean(kernel_cross)
266
kid = mean_kernel_real + mean_kernel_generated - 2.0 * mean_kernel_cross
267
268
# update the average KID estimate
269
self.kid_tracker.update_state(kid)
270
271
def result(self):
272
return self.kid_tracker.result()
273
274
def reset_state(self):
275
self.kid_tracker.reset_state()
276
277
278
"""
279
280
## Adaptive discriminator augmentation
281
282
The authors of [StyleGAN2-ADA](https://arxiv.org/abs/2006.06676) propose to change the
283
augmentation probability adaptively during training. Though it is explained differently
284
in the paper, they use [integral control](https://en.wikipedia.org/wiki/PID_controller#Integral) on the augmentation
285
probability to keep the discriminator's accuracy on real images close to a target value.
286
Note, that their controlled variable is actually the average sign of the discriminator
287
logits (r_t in the paper), which corresponds to 2 * accuracy - 1.
288
289
This method requires two hyperparameters:
290
291
1. `target_accuracy`: the target value for the discriminator's accuracy on real images. I
292
recommend selecting its value from the 80-90% range.
293
2. [`integration_steps`](https://en.wikipedia.org/wiki/PID_controller#Mathematical_form):
294
the number of update steps required for an accuracy error of 100% to transform into an
295
augmentation probability increase of 100%. To give an intuition, this defines how slowly
296
the augmentation probability is changed. I recommend setting this to a relatively high
297
value (1000 in this case) so that the augmentation strength is only adjusted slowly.
298
299
The main motivation for this procedure is that the optimal value of the target accuracy
300
is similar across different dataset sizes (see [figure 4 and 5 in the paper](https://arxiv.org/abs/2006.06676)),
301
so it does not have to be re-tuned, because the
302
process automatically applies stronger data augmentation when it is needed.
303
304
"""
305
306
307
# "hard sigmoid", useful for binary accuracy calculation from logits
308
def step(values):
309
# negative values -> 0.0, positive values -> 1.0
310
return 0.5 * (1.0 + ops.sign(values))
311
312
313
# augments images with a probability that is dynamically updated during training
314
class AdaptiveAugmenter(keras.Model):
315
def __init__(self):
316
super().__init__()
317
318
# stores the current probability of an image being augmented
319
self.probability = keras.Variable(0.0)
320
self.seed_generator = keras.random.SeedGenerator(42)
321
322
# the corresponding augmentation names from the paper are shown above each layer
323
# the authors show (see figure 4), that the blitting and geometric augmentations
324
# are the most helpful in the low-data regime
325
self.augmenter = keras.Sequential(
326
[
327
layers.InputLayer(input_shape=(image_size, image_size, 3)),
328
# blitting/x-flip:
329
layers.RandomFlip("horizontal"),
330
# blitting/integer translation:
331
layers.RandomTranslation(
332
height_factor=max_translation,
333
width_factor=max_translation,
334
interpolation="nearest",
335
),
336
# geometric/rotation:
337
layers.RandomRotation(factor=max_rotation),
338
# geometric/isotropic and anisotropic scaling:
339
layers.RandomZoom(
340
height_factor=(-max_zoom, 0.0), width_factor=(-max_zoom, 0.0)
341
),
342
],
343
name="adaptive_augmenter",
344
)
345
346
def call(self, images, training):
347
if training:
348
augmented_images = self.augmenter(images, training=training)
349
350
# during training either the original or the augmented images are selected
351
# based on self.probability
352
augmentation_values = keras.random.uniform(
353
shape=(batch_size, 1, 1, 1), seed=self.seed_generator
354
)
355
augmentation_bools = ops.less(augmentation_values, self.probability)
356
357
images = ops.where(augmentation_bools, augmented_images, images)
358
return images
359
360
def update(self, real_logits):
361
current_accuracy = ops.mean(step(real_logits))
362
363
# the augmentation probability is updated based on the discriminator's
364
# accuracy on real images
365
accuracy_error = current_accuracy - target_accuracy
366
self.probability.assign(
367
ops.clip(self.probability + accuracy_error / integration_steps, 0.0, 1.0)
368
)
369
370
371
"""
372
## Network architecture
373
374
Here we specify the architecture of the two networks:
375
376
* generator: maps a random vector to an image, which should be as realistic as possible
377
* discriminator: maps an image to a scalar score, which should be high for real and low
378
for generated images
379
380
GANs tend to be sensitive to the network architecture, I implemented a DCGAN architecture
381
in this example, because it is relatively stable during training while being simple to
382
implement. We use a constant number of filters throughout the network, use a sigmoid
383
instead of tanh in the last layer of the generator, and use default initialization
384
instead of random normal as further simplifications.
385
386
As a good practice, we disable the learnable scale parameter in the batch normalization
387
layers, because on one hand the following relu + convolutional layers make it redundant
388
(as noted in the
389
[documentation](https://keras.io/api/layers/normalization_layers/batch_normalization/)).
390
But also because it should be disabled based on theory when using [spectral normalization
391
(section 4.1)](https://arxiv.org/abs/1802.05957), which is not used here, but is common
392
in GANs. We also disable the bias in the fully connected and convolutional layers, because
393
the following batch normalization makes it redundant.
394
"""
395
396
397
# DCGAN generator
398
def get_generator():
399
noise_input = keras.Input(shape=(noise_size,))
400
x = layers.Dense(4 * 4 * width, use_bias=False)(noise_input)
401
x = layers.BatchNormalization(scale=False)(x)
402
x = layers.ReLU()(x)
403
x = layers.Reshape(target_shape=(4, 4, width))(x)
404
for _ in range(depth - 1):
405
x = layers.Conv2DTranspose(
406
width,
407
kernel_size=4,
408
strides=2,
409
padding="same",
410
use_bias=False,
411
)(x)
412
x = layers.BatchNormalization(scale=False)(x)
413
x = layers.ReLU()(x)
414
image_output = layers.Conv2DTranspose(
415
3,
416
kernel_size=4,
417
strides=2,
418
padding="same",
419
activation="sigmoid",
420
)(x)
421
422
return keras.Model(noise_input, image_output, name="generator")
423
424
425
# DCGAN discriminator
426
def get_discriminator():
427
image_input = keras.Input(shape=(image_size, image_size, 3))
428
x = image_input
429
for _ in range(depth):
430
x = layers.Conv2D(
431
width,
432
kernel_size=4,
433
strides=2,
434
padding="same",
435
use_bias=False,
436
)(x)
437
x = layers.BatchNormalization(scale=False)(x)
438
x = layers.LeakyReLU(alpha=leaky_relu_slope)(x)
439
x = layers.Flatten()(x)
440
x = layers.Dropout(dropout_rate)(x)
441
output_score = layers.Dense(1)(x)
442
443
return keras.Model(image_input, output_score, name="discriminator")
444
445
446
"""
447
## GAN model
448
"""
449
450
451
class GAN_ADA(keras.Model):
452
def __init__(self):
453
super().__init__()
454
455
self.seed_generator = keras.random.SeedGenerator(seed=42)
456
self.augmenter = AdaptiveAugmenter()
457
self.generator = get_generator()
458
self.ema_generator = keras.models.clone_model(self.generator)
459
self.discriminator = get_discriminator()
460
461
self.generator.summary()
462
self.discriminator.summary()
463
# we have created all layers at this point, so we can mark the model
464
# as having been built
465
self.built = True
466
467
def compile(self, generator_optimizer, discriminator_optimizer, **kwargs):
468
super().compile(**kwargs)
469
470
# separate optimizers for the two networks
471
self.generator_optimizer = generator_optimizer
472
self.discriminator_optimizer = discriminator_optimizer
473
474
self.generator_loss_tracker = keras.metrics.Mean(name="g_loss")
475
self.discriminator_loss_tracker = keras.metrics.Mean(name="d_loss")
476
self.real_accuracy = keras.metrics.BinaryAccuracy(name="real_acc")
477
self.generated_accuracy = keras.metrics.BinaryAccuracy(name="gen_acc")
478
self.augmentation_probability_tracker = keras.metrics.Mean(name="aug_p")
479
self.kid = KID()
480
481
@property
482
def metrics(self):
483
return [
484
self.generator_loss_tracker,
485
self.discriminator_loss_tracker,
486
self.real_accuracy,
487
self.generated_accuracy,
488
self.augmentation_probability_tracker,
489
self.kid,
490
]
491
492
def generate(self, batch_size, training):
493
latent_samples = keras.random.normal(
494
shape=(batch_size, noise_size), seed=self.seed_generator
495
)
496
# use ema_generator during inference
497
if training:
498
generated_images = self.generator(latent_samples, training=training)
499
else:
500
generated_images = self.ema_generator(latent_samples, training=training)
501
return generated_images
502
503
def adversarial_loss(self, real_logits, generated_logits):
504
# this is usually called the non-saturating GAN loss
505
506
real_labels = ops.ones(shape=(batch_size, 1))
507
generated_labels = ops.zeros(shape=(batch_size, 1))
508
509
# the generator tries to produce images that the discriminator considers as real
510
generator_loss = keras.losses.binary_crossentropy(
511
real_labels, generated_logits, from_logits=True
512
)
513
# the discriminator tries to determine if images are real or generated
514
discriminator_loss = keras.losses.binary_crossentropy(
515
ops.concatenate([real_labels, generated_labels], axis=0),
516
ops.concatenate([real_logits, generated_logits], axis=0),
517
from_logits=True,
518
)
519
520
return ops.mean(generator_loss), ops.mean(discriminator_loss)
521
522
def train_step(self, real_images):
523
real_images = self.augmenter(real_images, training=True)
524
525
# use persistent gradient tape because gradients will be calculated twice
526
with tf.GradientTape(persistent=True) as tape:
527
generated_images = self.generate(batch_size, training=True)
528
# gradient is calculated through the image augmentation
529
generated_images = self.augmenter(generated_images, training=True)
530
531
# separate forward passes for the real and generated images, meaning
532
# that batch normalization is applied separately
533
real_logits = self.discriminator(real_images, training=True)
534
generated_logits = self.discriminator(generated_images, training=True)
535
536
generator_loss, discriminator_loss = self.adversarial_loss(
537
real_logits, generated_logits
538
)
539
540
# calculate gradients and update weights
541
generator_gradients = tape.gradient(
542
generator_loss, self.generator.trainable_weights
543
)
544
discriminator_gradients = tape.gradient(
545
discriminator_loss, self.discriminator.trainable_weights
546
)
547
self.generator_optimizer.apply_gradients(
548
zip(generator_gradients, self.generator.trainable_weights)
549
)
550
self.discriminator_optimizer.apply_gradients(
551
zip(discriminator_gradients, self.discriminator.trainable_weights)
552
)
553
554
# update the augmentation probability based on the discriminator's performance
555
self.augmenter.update(real_logits)
556
557
self.generator_loss_tracker.update_state(generator_loss)
558
self.discriminator_loss_tracker.update_state(discriminator_loss)
559
self.real_accuracy.update_state(1.0, step(real_logits))
560
self.generated_accuracy.update_state(0.0, step(generated_logits))
561
self.augmentation_probability_tracker.update_state(self.augmenter.probability)
562
563
# track the exponential moving average of the generator's weights to decrease
564
# variance in the generation quality
565
for weight, ema_weight in zip(
566
self.generator.weights, self.ema_generator.weights
567
):
568
ema_weight.assign(ema * ema_weight + (1 - ema) * weight)
569
570
# KID is not measured during the training phase for computational efficiency
571
return {m.name: m.result() for m in self.metrics[:-1]}
572
573
def test_step(self, real_images):
574
generated_images = self.generate(batch_size, training=False)
575
576
self.kid.update_state(real_images, generated_images)
577
578
# only KID is measured during the evaluation phase for computational efficiency
579
return {self.kid.name: self.kid.result()}
580
581
def plot_images(self, epoch=None, logs=None, num_rows=3, num_cols=6, interval=5):
582
# plot random generated images for visual evaluation of generation quality
583
if epoch is None or (epoch + 1) % interval == 0:
584
num_images = num_rows * num_cols
585
generated_images = self.generate(num_images, training=False)
586
587
plt.figure(figsize=(num_cols * 2.0, num_rows * 2.0))
588
for row in range(num_rows):
589
for col in range(num_cols):
590
index = row * num_cols + col
591
plt.subplot(num_rows, num_cols, index + 1)
592
plt.imshow(generated_images[index])
593
plt.axis("off")
594
plt.tight_layout()
595
plt.show()
596
plt.close()
597
598
599
"""
600
## Training
601
602
One can should see from the metrics during training, that if the real accuracy
603
(discriminator's accuracy on real images) is below the target accuracy, the augmentation
604
probability is increased, and vice versa. In my experience, during a healthy GAN
605
training, the discriminator accuracy should stay in the 80-95% range. Below that, the
606
discriminator is too weak, above that it is too strong.
607
608
Note that we track the exponential moving average of the generator's weights, and use that
609
for image generation and KID evaluation.
610
"""
611
612
# create and compile the model
613
model = GAN_ADA()
614
model.compile(
615
generator_optimizer=keras.optimizers.Adam(learning_rate, beta_1),
616
discriminator_optimizer=keras.optimizers.Adam(learning_rate, beta_1),
617
)
618
619
# save the best model based on the validation KID metric
620
checkpoint_path = "gan_model.weights.h5"
621
checkpoint_callback = keras.callbacks.ModelCheckpoint(
622
filepath=checkpoint_path,
623
save_weights_only=True,
624
monitor="val_kid",
625
mode="min",
626
save_best_only=True,
627
)
628
629
# run training and plot generated images periodically
630
model.fit(
631
train_dataset,
632
epochs=num_epochs,
633
validation_data=val_dataset,
634
callbacks=[
635
keras.callbacks.LambdaCallback(on_epoch_end=model.plot_images),
636
checkpoint_callback,
637
],
638
)
639
640
"""
641
## Inference
642
"""
643
644
# load the best model and generate images
645
model.load_weights(checkpoint_path)
646
model.plot_images()
647
648
"""
649
## Results
650
651
By running the training for 400 epochs (which takes 2-3 hours in a Colab notebook), one
652
can get high quality image generations using this code example.
653
654
The evolution of a random batch of images over a 400 epoch training (ema=0.999 for
655
animation smoothness):
656
![birds evolution gif](https://i.imgur.com/ecGuCcz.gif)
657
658
Latent-space interpolation between a batch of selected images:
659
![birds interpolation gif](https://i.imgur.com/nGvzlsC.gif)
660
661
I also recommend trying out training on other datasets, such as
662
[CelebA](https://www.tensorflow.org/datasets/catalog/celeb_a) for example. In my
663
experience good results can be achieved without changing any hyperparameters (though
664
discriminator augmentation might not be necessary).
665
"""
666
667
"""
668
## GAN tips and tricks
669
670
My goal with this example was to find a good tradeoff between ease of implementation and
671
generation quality for GANs. During preparation, I have run numerous ablations using
672
[this repository](https://github.com/beresandras/gan-flavours-keras).
673
674
In this section I list the lessons learned and my recommendations in my subjective order
675
of importance.
676
677
I recommend checking out the [DCGAN paper](https://arxiv.org/abs/1511.06434), this
678
[NeurIPS talk](https://www.youtube.com/watch?v=myGAju4L7O8), and this
679
[large scale GAN study](https://arxiv.org/abs/1711.10337) for others' takes on this subject.
680
681
### Architectural tips
682
683
* **resolution**: Training GANs at higher resolutions tends to get more difficult, I
684
recommend experimenting at 32x32 or 64x64 resolutions initially.
685
* **initialization**: If you see strong colorful patterns early on in the training, the
686
initialization might be the issue. Set the kernel_initializer parameters of layers to
687
[random normal](https://keras.io/api/layers/initializers/#randomnormal-class), and
688
decrease the standard deviation (recommended value: 0.02, following DCGAN) until the
689
issue disappears.
690
* **upsampling**: There are two main methods for upsampling in the generator.
691
[Transposed convolution](https://keras.io/api/layers/convolution_layers/convolution2d_transpose/)
692
is faster, but can lead to
693
[checkerboard artifacts](https://distill.pub/2016/deconv-checkerboard/), which can be reduced by using
694
a kernel size that is divisible with the stride (recommended kernel size is 4 for a stride of 2).
695
[Upsampling](https://keras.io/api/layers/reshaping_layers/up_sampling2d/) +
696
[standard convolution](https://keras.io/api/layers/convolution_layers/convolution2d/) can have slightly
697
lower quality, but checkerboard artifacts are not an issue. I recommend using nearest-neighbor
698
interpolation over bilinear for it.
699
* **batch normalization in discriminator**: Sometimes has a high impact, I recommend
700
trying out both ways.
701
* **[spectral normalization](https://www.tensorflow.org/addons/api_docs/python/tfa/layers/SpectralNormalization)**:
702
A popular technique for training GANs, can help with stability. I recommend
703
disabling batch normalization's learnable scale parameters along with it.
704
* **[residual connections](https://keras.io/guides/functional_api/#a-toy-resnet-model)**:
705
While residual discriminators behave similarly, residual generators are more difficult to
706
train in my experience. They are however necessary for training large and deep
707
architectures. I recommend starting with non-residual architectures.
708
* **dropout**: Using dropout before the last layer of the discriminator improves
709
generation quality in my experience. Recommended dropout rate is below 0.5.
710
* **[leaky ReLU](https://keras.io/api/layers/activation_layers/leaky_relu/)**: Use leaky
711
ReLU activations in the discriminator to make its gradients less sparse. Recommended
712
slope/alpha is 0.2 following DCGAN.
713
714
### Algorithmic tips
715
716
* **loss functions**: Numerous losses have been proposed over the years for training
717
GANs, promising improved performance and stability. I have implemented 5 of them in
718
[this repository](https://github.com/beresandras/gan-flavours-keras), and my experience is in
719
line with [this GAN study](https://arxiv.org/abs/1711.10337): no loss seems to
720
consistently outperform the default non-saturating GAN loss. I recommend using that as a
721
default.
722
* **Adam's beta_1 parameter**: The beta_1 parameter in Adam can be interpreted as the
723
momentum of mean gradient estimation. Using 0.5 or even 0.0 instead of the default 0.9
724
value was proposed in DCGAN and is important. This example would not work using its
725
default value.
726
* **separate batch normalization for generated and real images**: The forward pass of the
727
discriminator should be separate for the generated and real images. Doing otherwise can
728
lead to artifacts (45 degree stripes in my case) and decreased performance.
729
* **exponential moving average of generator's weights**: This helps to reduce the
730
variance of the KID measurement, and helps in averaging out the rapid color palette
731
changes during training.
732
* **[different learning rate for generator and discriminator](https://arxiv.org/abs/1706.08500)**:
733
If one has the resources, it can help
734
to tune the learning rates of the two networks separately. A similar idea is to update
735
either network's (usually the discriminator's) weights multiple times for each of the
736
other network's updates. I recommend using the same learning rate of 2e-4 (Adam),
737
following DCGAN for both networks, and only updating both of them once as a default.
738
* **label noise**: [One-sided label smoothing](https://arxiv.org/abs/1606.03498) (using
739
less than 1.0 for real labels), or adding noise to the labels can regularize the
740
discriminator not to get overconfident, however in my case they did not improve
741
performance.
742
* **adaptive data augmentation**: Since it adds another dynamic component to the training
743
process, disable it as a default, and only enable it when the other components already
744
work well.
745
"""
746
747
"""
748
## Related works
749
750
Other GAN-related Keras code examples:
751
752
* [DCGAN + CelebA](https://keras.io/examples/generative/dcgan_overriding_train_step/)
753
* [WGAN + FashionMNIST](https://keras.io/examples/generative/wgan_gp/)
754
* [WGAN + Molecules](https://keras.io/examples/generative/wgan-graphs/)
755
* [ConditionalGAN + MNIST](https://keras.io/examples/generative/conditional_gan/)
756
* [CycleGAN + Horse2Zebra](https://keras.io/examples/generative/cyclegan/)
757
* [StyleGAN](https://keras.io/examples/generative/stylegan/)
758
759
Modern GAN architecture-lines:
760
761
* [SAGAN](https://arxiv.org/abs/1805.08318), [BigGAN](https://arxiv.org/abs/1809.11096)
762
* [ProgressiveGAN](https://arxiv.org/abs/1710.10196),
763
[StyleGAN](https://arxiv.org/abs/1812.04948),
764
[StyleGAN2](https://arxiv.org/abs/1912.04958),
765
[StyleGAN2-ADA](https://arxiv.org/abs/2006.06676),
766
[AliasFreeGAN](https://arxiv.org/abs/2106.12423)
767
768
Concurrent papers on discriminator data augmentation:
769
[1](https://arxiv.org/abs/2006.02595), [2](https://arxiv.org/abs/2006.05338), [3](https://arxiv.org/abs/2006.10738)
770
771
Recent literature overview on GANs: [talk](https://www.youtube.com/watch?v=3ktD752xq5k)
772
"""
773
774