GauGAN for conditional image generation
Author: Soumik Rakshit, Sayak Paul
Date created: 2021/12/26
Last modified: 2022/01/03
Description: Implementing a GauGAN for conditional image generation.
Introduction
In this example, we present an implementation of the GauGAN architecture proposed in Semantic Image Synthesis with Spatially-Adaptive Normalization. Briefly, GauGAN uses a Generative Adversarial Network (GAN) to generate realistic images that are conditioned on cue images and segmentation maps, as shown below (image source):
The main components of a GauGAN are:
SPADE (aka spatially-adaptive normalization) : The authors of GauGAN argue that the more conventional normalization layers (such as Batch Normalization) destroy the semantic information obtained from segmentation maps that are provided as inputs. To address this problem, the authors introduce SPADE, a normalization layer particularly suitable for learning affine parameters (scale and bias) that are spatially adaptive. This is done by learning different sets of scaling and bias parameters for each semantic label.
Variational encoder: Inspired by Variational Autoencoders, GauGAN uses a variational formulation wherein an encoder learns the mean and variance of a normal (Gaussian) distribution from the cue images. This is where GauGAN gets its name from. The generator of GauGAN takes as inputs the latents sampled from the Gaussian distribution as well as the one-hot encoded semantic segmentation label maps. The cue images act as style images that guide the generator to stylistic generation. This variational formulation helps GauGAN achieve image diversity as well as fidelity.
Multi-scale patch discriminator : Inspired by the PatchGAN model, GauGAN uses a discriminator that assesses a given image on a patch basis and produces an averaged score.
As we proceed with the example, we will discuss each of the different components in further detail.
For a thorough review of GauGAN, please refer to this article. We also encourage you to check out the official GauGAN website, which has many creative applications of GauGAN. This example assumes that the reader is already familiar with the fundamental concepts of GANs. If you need a refresher, the following resources might be useful:
Chapter on GANs from the Deep Learning with Python book by François Chollet.
GAN implementations on keras.io:
Imports
Data splitting
Now, let's visualize a few samples from the training set.
Next, we implement the downsampling block for the encoder.
The GauGAN encoder consists of a few downsampling blocks. It outputs the mean and variance of a distribution.
Next, we implement the generator, which consists of the modified residual blocks and upsampling blocks. It takes latent vectors and one-hot encoded segmentation labels, and produces new images.
With SPADE, there is no need to feed the segmentation map to the first layer of the generator, since the latent inputs have enough structural information about the style we want the generator to emulate. We also discard the encoder part of the generator, which is commonly used in prior architectures. This results in a more lightweight generator network, which can also take a random vector as input, enabling a simple and natural path to multi-modal synthesis.
The discriminator takes a segmentation map and an image and concatenates them. It then predicts if patches of the concatenated image are real or fake.
Loss functions
GauGAN uses the following loss functions:
Generator:
Subclassed GauGAN model
Finally, we put everything together inside a subclassed model (from tf.keras.Model
) overriding its train_step()
method.
GauGAN training
/home/sineeli/anaconda3/envs/kerasv3/lib/python3.10/site-packages/keras/src/optimizers/base_optimizer.py:472: UserWarning: Gradients do not exist for variables ['kernel', 'kernel', 'gamma', 'beta', 'kernel', 'gamma', 'beta', 'kernel', 'gamma', 'beta', 'kernel', 'gamma', 'beta', 'kernel', 'bias', 'kernel', 'bias'] when minimizing the loss. If using model.compile()
, did you forget to provide a loss
argument? warnings.warn( WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1705013303.976306 30381 device_compiler.h:186] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process. W0000 00:00:1705013304.021899 30381 graph_launch.cc:671] Fallback to op-by-op mode because memset node breaks graph update
75/75 ━━━━━━━━━━━━━━━━━━━━ 0s 176ms/step - disc_loss: 1.3079 - feat_loss: 11.2902 - gen_loss: 113.0583 - kl_loss: 83.1424 - vgg_loss: 18.4966
W0000 00:00:1705013326.657730 30384 graph_launch.cc:671] Fallback to op-by-op mode because memset node breaks graph update
1/1 ━━━━━━━━━━━━━━━━━━━━ 3s 3s/step
75/75 ━━━━━━━━━━━━━━━━━━━━ 114s 426ms/step - disc_loss: 1.3051 - feat_loss: 11.2902 - gen_loss: 113.0590 - kl_loss: 83.1493 - vgg_loss: 18.4890 - val_disc_loss: 1.0374 - val_feat_loss: 9.2344 - val_gen_loss: 110.1001 - val_kl_loss: 83.8935 - val_vgg_loss: 16.6412 Epoch 2/15 75/75 ━━━━━━━━━━━━━━━━━━━━ 14s 193ms/step - disc_loss: 0.8257 - feat_loss: 12.6603 - gen_loss: 115.9798 - kl_loss: 84.4545 - vgg_loss: 18.2973 - val_disc_loss: 0.9296 - val_feat_loss: 10.4162 - val_gen_loss: 110.6182 - val_kl_loss: 83.4473 - val_vgg_loss: 16.5499 Epoch 3/15 75/75 ━━━━━━━━━━━━━━━━━━━━ 15s 194ms/step - disc_loss: 0.9126 - feat_loss: 10.4992 - gen_loss: 111.6962 - kl_loss: 83.8692 - vgg_loss: 17.0433 - val_disc_loss: 0.8875 - val_feat_loss: 9.9899 - val_gen_loss: 111.4879 - val_kl_loss: 84.6905 - val_vgg_loss: 16.4510 Epoch 4/15 75/75 ━━━━━━━━━━━━━━━━━━━━ 15s 194ms/step - disc_loss: 0.8975 - feat_loss: 9.9081 - gen_loss: 111.2489 - kl_loss: 84.3098 - vgg_loss: 16.7369 - val_disc_loss: 0.9266 - val_feat_loss: 8.8318 - val_gen_loss: 107.9712 - val_kl_loss: 82.1354 - val_vgg_loss: 16.2676 Epoch 5/15 75/75 ━━━━━━━━━━━━━━━━━━━━ 15s 194ms/step - disc_loss: 0.9378 - feat_loss: 9.1914 - gen_loss: 110.5359 - kl_loss: 84.7988 - vgg_loss: 16.3160 - val_disc_loss: 1.0073 - val_feat_loss: 8.9351 - val_gen_loss: 109.2667 - val_kl_loss: 84.4920 - val_vgg_loss: 16.3844 Epoch 6/15 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 35ms/step
75/75 ━━━━━━━━━━━━━━━━━━━━ 19s 258ms/step - disc_loss: 0.8982 - feat_loss: 9.2486 - gen_loss: 109.9399 - kl_loss: 83.8095 - vgg_loss: 16.5587 - val_disc_loss: 0.8061 - val_feat_loss: 8.5935 - val_gen_loss: 109.5937 - val_kl_loss: 84.5844 - val_vgg_loss: 15.8794 Epoch 7/15 75/75 ━━━━━━━━━━━━━━━━━━━━ 15s 194ms/step - disc_loss: 0.9048 - feat_loss: 9.1064 - gen_loss: 109.3803 - kl_loss: 83.8245 - vgg_loss: 16.0975 - val_disc_loss: 1.0096 - val_feat_loss: 7.6335 - val_gen_loss: 108.2900 - val_kl_loss: 84.8679 - val_vgg_loss: 15.9580 Epoch 8/15 75/75 ━━━━━━━━━━━━━━━━━━━━ 15s 193ms/step - disc_loss: 0.9075 - feat_loss: 8.0537 - gen_loss: 108.1771 - kl_loss: 83.6673 - vgg_loss: 16.1545 - val_disc_loss: 1.0090 - val_feat_loss: 8.7077 - val_gen_loss: 109.2079 - val_kl_loss: 84.5022 - val_vgg_loss: 16.3814 Epoch 9/15 75/75 ━━━━━━━━━━━━━━━━━━━━ 15s 194ms/step - disc_loss: 0.9053 - feat_loss: 7.7949 - gen_loss: 107.9268 - kl_loss: 83.6504 - vgg_loss: 16.1193 - val_disc_loss: 1.0663 - val_feat_loss: 8.2042 - val_gen_loss: 108.4819 - val_kl_loss: 84.5961 - val_vgg_loss: 16.0834 Epoch 10/15 75/75 ━━━━━━━━━━━━━━━━━━━━ 15s 194ms/step - disc_loss: 0.8905 - feat_loss: 7.7652 - gen_loss: 108.3079 - kl_loss: 83.8574 - vgg_loss: 16.2992 - val_disc_loss: 0.8362 - val_feat_loss: 7.7127 - val_gen_loss: 108.9906 - val_kl_loss: 84.4822 - val_vgg_loss: 16.0521 Epoch 11/15 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 30ms/step
75/75 ━━━━━━━━━━━━━━━━━━━━ 20s 263ms/step - disc_loss: 0.9047 - feat_loss: 7.5019 - gen_loss: 107.6317 - kl_loss: 83.6812 - vgg_loss: 16.1292 - val_disc_loss: 0.8788 - val_feat_loss: 7.7651 - val_gen_loss: 109.1731 - val_kl_loss: 84.3094 - val_vgg_loss: 16.0356 Epoch 12/15 75/75 ━━━━━━━━━━━━━━━━━━━━ 15s 194ms/step - disc_loss: 0.8899 - feat_loss: 7.5799 - gen_loss: 108.2313 - kl_loss: 84.4031 - vgg_loss: 15.9665 - val_disc_loss: 0.8358 - val_feat_loss: 7.5676 - val_gen_loss: 109.5789 - val_kl_loss: 85.7282 - val_vgg_loss: 16.0442 Epoch 13/15 75/75 ━━━━━━━━━━━━━━━━━━━━ 15s 194ms/step - disc_loss: 0.8542 - feat_loss: 7.3362 - gen_loss: 107.4649 - kl_loss: 83.6942 - vgg_loss: 16.0675 - val_disc_loss: 1.0853 - val_feat_loss: 7.9020 - val_gen_loss: 106.9958 - val_kl_loss: 84.2610 - val_vgg_loss: 15.8510 Epoch 14/15 75/75 ━━━━━━━━━━━━━━━━━━━━ 15s 194ms/step - disc_loss: 0.8631 - feat_loss: 7.6403 - gen_loss: 108.6401 - kl_loss: 84.5304 - vgg_loss: 16.0426 - val_disc_loss: 0.9516 - val_feat_loss: 8.8795 - val_gen_loss: 108.5215 - val_kl_loss: 83.1849 - val_vgg_loss: 16.3289 Epoch 15/15 75/75 ━━━━━━━━━━━━━━━━━━━━ 15s 194ms/step - disc_loss: 0.8939 - feat_loss: 7.5489 - gen_loss: 108.8330 - kl_loss: 85.0358 - vgg_loss: 15.9147 - val_disc_loss: 0.9616 - val_feat_loss: 8.0080 - val_gen_loss: 108.1650 - val_kl_loss: 84.7754 - val_vgg_loss: 15.9561
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 25ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 25ms/step