Path: blob/master/examples/generative/md/wgan_gp.md
3508 views
WGAN-GP overriding Model.train_step
Author: A_K_Nain
Date created: 2020/05/9
Last modified: 2023/08/3
Description: Implementation of Wasserstein GAN with Gradient Penalty.
View in Colab โข
GitHub source
Wasserstein GAN (WGAN) with Gradient Penalty (GP)
The original Wasserstein GAN leverages the Wasserstein distance to produce a value function that has better theoretical properties than the value function used in the original GAN paper. WGAN requires that the discriminator (aka the critic) lie within the space of 1-Lipschitz functions. The authors proposed the idea of weight clipping to achieve this constraint. Though weight clipping works, it can be a problematic way to enforce 1-Lipschitz constraint and can cause undesirable behavior, e.g. a very deep WGAN discriminator (critic) often fails to converge.
The WGAN-GP method proposes an alternative to weight clipping to ensure smooth training. Instead of clipping the weights, the authors proposed a "gradient penalty" by adding a loss term that keeps the L2 norm of the discriminator gradients close to 1.
Setup
Prepare the Fashion-MNIST data
To demonstrate how to train WGAN-GP, we will be using the Fashion-MNIST dataset. Each sample in this dataset is a 28x28 grayscale image associated with a label from 10 classes (e.g. trouser, pullover, sneaker, etc.)
Model: "discriminator"
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโณโโโโโโโโโโโโโโโโโโโโโโโโโโโโณโโโโโโโโโโโโโ โ Layer (type) โ Output Shape โ Param # โ โกโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโฉ โ input_layer (InputLayer) โ (None, 28, 28, 1) โ 0 โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโค โ zero_padding2d (ZeroPadding2D) โ (None, 32, 32, 1) โ 0 โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโค โ conv2d (Conv2D) โ (None, 16, 16, 64) โ 1,664 โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโค โ leaky_re_lu (LeakyReLU) โ (None, 16, 16, 64) โ 0 โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโค โ conv2d_1 (Conv2D) โ (None, 8, 8, 128) โ 204,928 โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโค โ leaky_re_lu_1 (LeakyReLU) โ (None, 8, 8, 128) โ 0 โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโค โ dropout (Dropout) โ (None, 8, 8, 128) โ 0 โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโค โ conv2d_2 (Conv2D) โ (None, 4, 4, 256) โ 819,456 โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโค โ leaky_re_lu_2 (LeakyReLU) โ (None, 4, 4, 256) โ 0 โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโค โ dropout_1 (Dropout) โ (None, 4, 4, 256) โ 0 โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโค โ conv2d_3 (Conv2D) โ (None, 2, 2, 512) โ 3,277,312 โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโค โ leaky_re_lu_3 (LeakyReLU) โ (None, 2, 2, 512) โ 0 โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโค โ flatten (Flatten) โ (None, 2048) โ 0 โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโค โ dropout_2 (Dropout) โ (None, 2048) โ 0 โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโค โ dense (Dense) โ (None, 1) โ 2,049 โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโดโโโโโโโโโโโโโโโโโโโโโโโโโโโโดโโโโโโโโโโโโโ
Total params: 4,305,409 (16.42 MB)
Trainable params: 4,305,409 (16.42 MB)
Non-trainable params: 0 (0.00 B)
Create the generator
Model: "generator"
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโณโโโโโโโโโโโโโโโโโโโโโโโโโโโโณโโโโโโโโโโโโโ โ Layer (type) โ Output Shape โ Param # โ โกโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโฉ โ input_layer_1 (InputLayer) โ (None, 128) โ 0 โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโค โ dense_1 (Dense) โ (None, 4096) โ 524,288 โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโค โ batch_normalization โ (None, 4096) โ 16,384 โ โ (BatchNormalization) โ โ โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโค โ leaky_re_lu_4 (LeakyReLU) โ (None, 4096) โ 0 โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโค โ reshape (Reshape) โ (None, 4, 4, 256) โ 0 โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโค โ up_sampling2d (UpSampling2D) โ (None, 8, 8, 256) โ 0 โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโค โ conv2d_4 (Conv2D) โ (None, 8, 8, 128) โ 294,912 โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโค โ batch_normalization_1 โ (None, 8, 8, 128) โ 512 โ โ (BatchNormalization) โ โ โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโค โ leaky_re_lu_5 (LeakyReLU) โ (None, 8, 8, 128) โ 0 โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโค โ up_sampling2d_1 (UpSampling2D) โ (None, 16, 16, 128) โ 0 โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโค โ conv2d_5 (Conv2D) โ (None, 16, 16, 64) โ 73,728 โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโค โ batch_normalization_2 โ (None, 16, 16, 64) โ 256 โ โ (BatchNormalization) โ โ โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโค โ leaky_re_lu_6 (LeakyReLU) โ (None, 16, 16, 64) โ 0 โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโค โ up_sampling2d_2 (UpSampling2D) โ (None, 32, 32, 64) โ 0 โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโค โ conv2d_6 (Conv2D) โ (None, 32, 32, 1) โ 576 โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโค โ batch_normalization_3 โ (None, 32, 32, 1) โ 4 โ โ (BatchNormalization) โ โ โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโค โ activation (Activation) โ (None, 32, 32, 1) โ 0 โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโค โ cropping2d (Cropping2D) โ (None, 28, 28, 1) โ 0 โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโดโโโโโโโโโโโโโโโโโโโโโโโโโโโโดโโโโโโโโโโโโโ
Total params: 910,660 (3.47 MB)
Trainable params: 902,082 (3.44 MB)
Non-trainable params: 8,578 (33.51 KB)
Create the WGAN-GP model
Now that we have defined our generator and discriminator, it's time to implement the WGAN-GP model. We will also override the train_step
for training.
Create a Keras callback that periodically saves generated images
Train the end-to-end model
<keras.src.callbacks.history.History at 0x7fc763a8e950>