Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/generative/gaugan.py
3507 views
1
"""
2
Title: GauGAN for conditional image generation
3
Author: [Soumik Rakshit](https://github.com/soumik12345), [Sayak Paul](https://twitter.com/RisingSayak)
4
Date created: 2021/12/26
5
Last modified: 2022/01/03
6
Description: Implementing a GauGAN for conditional image generation.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
In this example, we present an implementation of the GauGAN architecture proposed in
14
[Semantic Image Synthesis with Spatially-Adaptive Normalization](https://arxiv.org/abs/1903.07291).
15
Briefly, GauGAN uses a Generative Adversarial Network (GAN) to generate realistic images
16
that are conditioned on cue images and segmentation maps, as shown below
17
([image source](https://nvlabs.github.io/SPADE/)):
18
19
![](https://i.ibb.co/p305dzv/image.png)
20
21
The main components of a GauGAN are:
22
23
- **SPADE (aka spatially-adaptive normalization)** : The authors of GauGAN argue that the
24
more conventional normalization layers (such as
25
[Batch Normalization](https://arxiv.org/abs/1502.03167))
26
destroy the semantic information obtained from segmentation maps that
27
are provided as inputs. To address this problem, the authors introduce SPADE, a
28
normalization layer particularly suitable for learning affine parameters (scale and bias)
29
that are spatially adaptive. This is done by learning different sets of scaling and
30
bias parameters for each semantic label.
31
- **Variational encoder**: Inspired by
32
[Variational Autoencoders](https://arxiv.org/abs/1312.6114), GauGAN uses a
33
variational formulation wherein an encoder learns the mean and variance of a
34
normal (Gaussian) distribution from the cue images. This is where GauGAN gets its name
35
from. The generator of GauGAN takes as inputs the latents sampled from the Gaussian
36
distribution as well as the one-hot encoded semantic segmentation label maps. The cue
37
images act as style images that guide the generator to stylistic generation. This
38
variational formulation helps GauGAN achieve image diversity as well as fidelity.
39
- **Multi-scale patch discriminator** : Inspired by the
40
[PatchGAN](https://paperswithcode.com/method/patchgan) model,
41
GauGAN uses a discriminator that assesses a given image on a patch basis
42
and produces an averaged score.
43
44
As we proceed with the example, we will discuss each of the different
45
components in further detail.
46
47
For a thorough review of GauGAN, please refer to
48
[this article](https://blog.paperspace.com/nvidia-gaugan-introduction/).
49
We also encourage you to check out
50
[the official GauGAN website](https://nvlabs.github.io/SPADE/), which
51
has many creative applications of GauGAN. This example assumes that the reader is already
52
familiar with the fundamental concepts of GANs. If you need a refresher, the following
53
resources might be useful:
54
55
* [Chapter on GANs](https://livebook.manning.com/book/deep-learning-with-python/chapter-8)
56
from the Deep Learning with Python book by François Chollet.
57
* GAN implementations on keras.io:
58
59
* [Data efficient GANs](https://keras.io/examples/generative/gan_ada)
60
* [CycleGAN](https://keras.io/examples/generative/cyclegan)
61
* [Conditional GAN](https://keras.io/examples/generative/conditional_gan)
62
"""
63
64
"""
65
## Data collection
66
67
We will be using the
68
[Facades dataset](https://cmp.felk.cvut.cz/~tylecr1/facade/)
69
for training our GauGAN model. Let's first download it.
70
"""
71
72
"""shell
73
wget https://drive.google.com/uc?id=1q4FEjQg1YSb4mPx2VdxL7LXKYu3voTMj -O facades_data.zip
74
unzip -q facades_data.zip
75
"""
76
77
"""
78
## Imports
79
"""
80
import os
81
82
os.environ["KERAS_BACKEND"] = "tensorflow"
83
84
85
import numpy as np
86
import matplotlib.pyplot as plt
87
88
import tensorflow as tf
89
import keras
90
from keras import ops
91
from keras import layers
92
93
from glob import glob
94
95
"""
96
## Data splitting
97
"""
98
99
PATH = "./facades_data/"
100
SPLIT = 0.2
101
102
files = glob(PATH + "*.jpg")
103
np.random.shuffle(files)
104
105
split_index = int(len(files) * (1 - SPLIT))
106
train_files = files[:split_index]
107
val_files = files[split_index:]
108
109
print(f"Total samples: {len(files)}.")
110
print(f"Total training samples: {len(train_files)}.")
111
print(f"Total validation samples: {len(val_files)}.")
112
113
"""
114
## Data loader
115
"""
116
117
BATCH_SIZE = 4
118
IMG_HEIGHT = IMG_WIDTH = 256
119
NUM_CLASSES = 12
120
AUTOTUNE = tf.data.AUTOTUNE
121
122
123
def load(image_files, batch_size, is_train=True):
124
def _random_crop(
125
segmentation_map,
126
image,
127
labels,
128
crop_size=(IMG_HEIGHT, IMG_WIDTH),
129
):
130
crop_size = tf.convert_to_tensor(crop_size)
131
image_shape = tf.shape(image)[:2]
132
margins = image_shape - crop_size
133
y1 = tf.random.uniform(shape=(), maxval=margins[0], dtype=tf.int32)
134
x1 = tf.random.uniform(shape=(), maxval=margins[1], dtype=tf.int32)
135
y2 = y1 + crop_size[0]
136
x2 = x1 + crop_size[1]
137
138
cropped_images = []
139
images = [segmentation_map, image, labels]
140
for img in images:
141
cropped_images.append(img[y1:y2, x1:x2])
142
return cropped_images
143
144
def _load_data_tf(image_file, segmentation_map_file, label_file):
145
image = tf.image.decode_png(tf.io.read_file(image_file), channels=3)
146
segmentation_map = tf.image.decode_png(
147
tf.io.read_file(segmentation_map_file), channels=3
148
)
149
labels = tf.image.decode_bmp(tf.io.read_file(label_file), channels=0)
150
labels = tf.squeeze(labels)
151
152
image = tf.cast(image, tf.float32) / 127.5 - 1
153
segmentation_map = tf.cast(segmentation_map, tf.float32) / 127.5 - 1
154
return segmentation_map, image, labels
155
156
def _one_hot(segmentation_maps, real_images, labels):
157
labels = tf.one_hot(labels, NUM_CLASSES)
158
labels.set_shape((None, None, NUM_CLASSES))
159
return segmentation_maps, real_images, labels
160
161
segmentation_map_files = [
162
image_file.replace("images", "segmentation_map").replace("jpg", "png")
163
for image_file in image_files
164
]
165
label_files = [
166
image_file.replace("images", "segmentation_labels").replace("jpg", "bmp")
167
for image_file in image_files
168
]
169
dataset = tf.data.Dataset.from_tensor_slices(
170
(image_files, segmentation_map_files, label_files)
171
)
172
173
dataset = dataset.shuffle(batch_size * 10) if is_train else dataset
174
dataset = dataset.map(_load_data_tf, num_parallel_calls=AUTOTUNE)
175
dataset = dataset.map(_random_crop, num_parallel_calls=AUTOTUNE)
176
dataset = dataset.map(_one_hot, num_parallel_calls=AUTOTUNE)
177
dataset = dataset.batch(batch_size, drop_remainder=True)
178
return dataset
179
180
181
train_dataset = load(train_files, batch_size=BATCH_SIZE, is_train=True)
182
val_dataset = load(val_files, batch_size=BATCH_SIZE, is_train=False)
183
184
"""
185
Now, let's visualize a few samples from the training set.
186
"""
187
188
sample_train_batch = next(iter(train_dataset))
189
print(f"Segmentation map batch shape: {sample_train_batch[0].shape}.")
190
print(f"Image batch shape: {sample_train_batch[1].shape}.")
191
print(f"One-hot encoded label map shape: {sample_train_batch[2].shape}.")
192
193
# Plot a view samples from the training set.
194
for segmentation_map, real_image in zip(sample_train_batch[0], sample_train_batch[1]):
195
fig = plt.figure(figsize=(10, 10))
196
fig.add_subplot(1, 2, 1).set_title("Segmentation Map")
197
plt.imshow((segmentation_map + 1) / 2)
198
fig.add_subplot(1, 2, 2).set_title("Real Image")
199
plt.imshow((real_image + 1) / 2)
200
plt.show()
201
202
"""
203
Note that in the rest of this example, we use a couple of figures from the
204
[original GauGAN paper](https://arxiv.org/abs/1903.07291) for convenience.
205
"""
206
207
"""
208
## Custom layers
209
210
In the following section, we implement the following layers:
211
212
* SPADE
213
* Residual block including SPADE
214
* Gaussian sampler
215
"""
216
217
"""
218
### Some more notes on SPADE
219
220
![](https://i.imgur.com/DgMWrrs.png)
221
222
**SPatially-Adaptive (DE) normalization** or **SPADE** is a simple but effective layer
223
for synthesizing photorealistic images given an input semantic layout. Previous methods
224
for conditional image generation from semantic input such as
225
Pix2Pix ([Isola et al.](https://arxiv.org/abs/1611.07004))
226
or Pix2PixHD ([Wang et al.](https://arxiv.org/abs/1711.11585))
227
directly feed the semantic layout as input to the deep network, which is then processed
228
through stacks of convolution, normalization, and nonlinearity layers. This is often
229
suboptimal as the normalization layers have a tendency to wash away semantic information.
230
231
In SPADE, the segmentation mask is first projected onto an embedding space, and then
232
convolved to produce the modulation parameters `γ` and `β`. Unlike prior conditional
233
normalization methods, `γ` and `β` are not vectors, but tensors with spatial dimensions.
234
The produced `γ` and `β` are multiplied and added to the normalized activation
235
element-wise. As the modulation parameters are adaptive to the input segmentation mask,
236
SPADE is better suited for semantic image synthesis.
237
"""
238
239
240
class SPADE(layers.Layer):
241
def __init__(self, filters, epsilon=1e-5, **kwargs):
242
super().__init__(**kwargs)
243
self.epsilon = epsilon
244
self.conv = layers.Conv2D(128, 3, padding="same", activation="relu")
245
self.conv_gamma = layers.Conv2D(filters, 3, padding="same")
246
self.conv_beta = layers.Conv2D(filters, 3, padding="same")
247
248
def build(self, input_shape):
249
self.resize_shape = input_shape[1:3]
250
251
def call(self, input_tensor, raw_mask):
252
mask = ops.image.resize(raw_mask, self.resize_shape, interpolation="nearest")
253
x = self.conv(mask)
254
gamma = self.conv_gamma(x)
255
beta = self.conv_beta(x)
256
mean, var = ops.moments(input_tensor, axes=(0, 1, 2), keepdims=True)
257
std = ops.sqrt(var + self.epsilon)
258
normalized = (input_tensor - mean) / std
259
output = gamma * normalized + beta
260
return output
261
262
263
class ResBlock(layers.Layer):
264
def __init__(self, filters, **kwargs):
265
super().__init__(**kwargs)
266
self.filters = filters
267
268
def build(self, input_shape):
269
input_filter = input_shape[-1]
270
self.spade_1 = SPADE(input_filter)
271
self.spade_2 = SPADE(self.filters)
272
self.conv_1 = layers.Conv2D(self.filters, 3, padding="same")
273
self.conv_2 = layers.Conv2D(self.filters, 3, padding="same")
274
self.learned_skip = False
275
276
if self.filters != input_filter:
277
self.learned_skip = True
278
self.spade_3 = SPADE(input_filter)
279
self.conv_3 = layers.Conv2D(self.filters, 3, padding="same")
280
281
def call(self, input_tensor, mask):
282
x = self.spade_1(input_tensor, mask)
283
x = self.conv_1(keras.activations.leaky_relu(x, 0.2))
284
x = self.spade_2(x, mask)
285
x = self.conv_2(keras.activations.leaky_relu(x, 0.2))
286
skip = (
287
self.conv_3(
288
keras.activations.leaky_relu(self.spade_3(input_tensor, mask), 0.2)
289
)
290
if self.learned_skip
291
else input_tensor
292
)
293
output = skip + x
294
return output
295
296
297
class GaussianSampler(layers.Layer):
298
def __init__(self, batch_size, latent_dim, **kwargs):
299
super().__init__(**kwargs)
300
self.batch_size = batch_size
301
self.latent_dim = latent_dim
302
self.seed_generator = keras.random.SeedGenerator(1337)
303
304
def call(self, inputs):
305
means, variance = inputs
306
epsilon = keras.random.normal(
307
shape=(self.batch_size, self.latent_dim),
308
mean=0.0,
309
stddev=1.0,
310
seed=self.seed_generator,
311
)
312
samples = means + ops.exp(0.5 * variance) * epsilon
313
return samples
314
315
316
"""
317
Next, we implement the downsampling block for the encoder.
318
"""
319
320
321
def downsample(
322
channels,
323
kernels,
324
strides=2,
325
apply_norm=True,
326
apply_activation=True,
327
apply_dropout=False,
328
):
329
block = keras.Sequential()
330
block.add(
331
layers.Conv2D(
332
channels,
333
kernels,
334
strides=strides,
335
padding="same",
336
use_bias=False,
337
kernel_initializer=keras.initializers.GlorotNormal(),
338
)
339
)
340
if apply_norm:
341
block.add(layers.GroupNormalization(groups=-1))
342
if apply_activation:
343
block.add(layers.LeakyReLU(0.2))
344
if apply_dropout:
345
block.add(layers.Dropout(0.5))
346
return block
347
348
349
"""
350
The GauGAN encoder consists of a few downsampling blocks. It outputs the mean and
351
variance of a distribution.
352
353
![](https://i.imgur.com/JgAv1EW.png)
354
355
"""
356
357
358
def build_encoder(image_shape, encoder_downsample_factor=64, latent_dim=256):
359
input_image = keras.Input(shape=image_shape)
360
x = downsample(encoder_downsample_factor, 3, apply_norm=False)(input_image)
361
x = downsample(2 * encoder_downsample_factor, 3)(x)
362
x = downsample(4 * encoder_downsample_factor, 3)(x)
363
x = downsample(8 * encoder_downsample_factor, 3)(x)
364
x = downsample(8 * encoder_downsample_factor, 3)(x)
365
x = layers.Flatten()(x)
366
mean = layers.Dense(latent_dim, name="mean")(x)
367
variance = layers.Dense(latent_dim, name="variance")(x)
368
return keras.Model(input_image, [mean, variance], name="encoder")
369
370
371
"""
372
Next, we implement the generator, which consists of the modified residual blocks and
373
upsampling blocks. It takes latent vectors and one-hot encoded segmentation labels, and
374
produces new images.
375
376
![](https://i.imgur.com/9iP1TsB.png)
377
378
With SPADE, there is no need to feed the segmentation map to the first layer of the
379
generator, since the latent inputs have enough structural information about the style we
380
want the generator to emulate. We also discard the encoder part of the generator, which is
381
commonly used in prior architectures. This results in a more lightweight
382
generator network, which can also take a random vector as input, enabling a simple and
383
natural path to multi-modal synthesis.
384
"""
385
386
387
def build_generator(mask_shape, latent_dim=256):
388
latent = keras.Input(shape=(latent_dim,))
389
mask = keras.Input(shape=mask_shape)
390
x = layers.Dense(16384)(latent)
391
x = layers.Reshape((4, 4, 1024))(x)
392
x = ResBlock(filters=1024)(x, mask)
393
x = layers.UpSampling2D((2, 2))(x)
394
x = ResBlock(filters=1024)(x, mask)
395
x = layers.UpSampling2D((2, 2))(x)
396
x = ResBlock(filters=1024)(x, mask)
397
x = layers.UpSampling2D((2, 2))(x)
398
x = ResBlock(filters=512)(x, mask)
399
x = layers.UpSampling2D((2, 2))(x)
400
x = ResBlock(filters=256)(x, mask)
401
x = layers.UpSampling2D((2, 2))(x)
402
x = ResBlock(filters=128)(x, mask)
403
x = layers.UpSampling2D((2, 2))(x)
404
x = keras.activations.leaky_relu(x, 0.2)
405
output_image = keras.activations.tanh(layers.Conv2D(3, 4, padding="same")(x))
406
return keras.Model([latent, mask], output_image, name="generator")
407
408
409
"""
410
The discriminator takes a segmentation map and an image and concatenates them. It
411
then predicts if patches of the concatenated image are real or fake.
412
413
![](https://i.imgur.com/rn71PlM.png)
414
"""
415
416
417
def build_discriminator(image_shape, downsample_factor=64):
418
input_image_A = keras.Input(shape=image_shape, name="discriminator_image_A")
419
input_image_B = keras.Input(shape=image_shape, name="discriminator_image_B")
420
x = layers.Concatenate()([input_image_A, input_image_B])
421
x1 = downsample(downsample_factor, 4, apply_norm=False)(x)
422
x2 = downsample(2 * downsample_factor, 4)(x1)
423
x3 = downsample(4 * downsample_factor, 4)(x2)
424
x4 = downsample(8 * downsample_factor, 4, strides=1)(x3)
425
x5 = layers.Conv2D(1, 4)(x4)
426
outputs = [x1, x2, x3, x4, x5]
427
return keras.Model([input_image_A, input_image_B], outputs)
428
429
430
"""
431
## Loss functions
432
433
GauGAN uses the following loss functions:
434
435
* Generator:
436
437
* Expectation over the discriminator predictions.
438
* [KL divergence](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence)
439
for learning the mean and variance predicted by the encoder.
440
* Minimization between the discriminator predictions on original and generated
441
images to align the feature space of the generator.
442
* [Perceptual loss](https://arxiv.org/abs/1603.08155) for encouraging the generated
443
images to have perceptual quality.
444
445
* Discriminator:
446
447
* [Hinge loss](https://en.wikipedia.org/wiki/Hinge_loss).
448
"""
449
450
451
def generator_loss(y):
452
return -ops.mean(y)
453
454
455
def kl_divergence_loss(mean, variance):
456
return -0.5 * ops.sum(1 + variance - ops.square(mean) - ops.exp(variance))
457
458
459
class FeatureMatchingLoss(keras.losses.Loss):
460
def __init__(self, **kwargs):
461
super().__init__(**kwargs)
462
self.mae = keras.losses.MeanAbsoluteError()
463
464
def call(self, y_true, y_pred):
465
loss = 0
466
for i in range(len(y_true) - 1):
467
loss += self.mae(y_true[i], y_pred[i])
468
return loss
469
470
471
class VGGFeatureMatchingLoss(keras.losses.Loss):
472
def __init__(self, **kwargs):
473
super().__init__(**kwargs)
474
self.encoder_layers = [
475
"block1_conv1",
476
"block2_conv1",
477
"block3_conv1",
478
"block4_conv1",
479
"block5_conv1",
480
]
481
self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]
482
vgg = keras.applications.VGG19(include_top=False, weights="imagenet")
483
layer_outputs = [vgg.get_layer(x).output for x in self.encoder_layers]
484
self.vgg_model = keras.Model(vgg.input, layer_outputs, name="VGG")
485
self.mae = keras.losses.MeanAbsoluteError()
486
487
def call(self, y_true, y_pred):
488
y_true = keras.applications.vgg19.preprocess_input(127.5 * (y_true + 1))
489
y_pred = keras.applications.vgg19.preprocess_input(127.5 * (y_pred + 1))
490
real_features = self.vgg_model(y_true)
491
fake_features = self.vgg_model(y_pred)
492
loss = 0
493
for i in range(len(real_features)):
494
loss += self.weights[i] * self.mae(real_features[i], fake_features[i])
495
return loss
496
497
498
class DiscriminatorLoss(keras.losses.Loss):
499
def __init__(self, **kwargs):
500
super().__init__(**kwargs)
501
self.hinge_loss = keras.losses.Hinge()
502
503
def call(self, y, is_real):
504
return self.hinge_loss(is_real, y)
505
506
507
"""
508
## GAN monitor callback
509
510
Next, we implement a callback to monitor the GauGAN results while it is training.
511
"""
512
513
514
class GanMonitor(keras.callbacks.Callback):
515
def __init__(self, val_dataset, n_samples, epoch_interval=5):
516
self.val_images = next(iter(val_dataset))
517
self.n_samples = n_samples
518
self.epoch_interval = epoch_interval
519
self.seed_generator = keras.random.SeedGenerator(42)
520
521
def infer(self):
522
latent_vector = keras.random.normal(
523
shape=(self.model.batch_size, self.model.latent_dim),
524
mean=0.0,
525
stddev=2.0,
526
seed=self.seed_generator,
527
)
528
return self.model.predict([latent_vector, self.val_images[2]])
529
530
def on_epoch_end(self, epoch, logs=None):
531
if epoch % self.epoch_interval == 0:
532
generated_images = self.infer()
533
for _ in range(self.n_samples):
534
grid_row = min(generated_images.shape[0], 3)
535
f, axarr = plt.subplots(grid_row, 3, figsize=(18, grid_row * 6))
536
for row in range(grid_row):
537
ax = axarr if grid_row == 1 else axarr[row]
538
ax[0].imshow((self.val_images[0][row] + 1) / 2)
539
ax[0].axis("off")
540
ax[0].set_title("Mask", fontsize=20)
541
ax[1].imshow((self.val_images[1][row] + 1) / 2)
542
ax[1].axis("off")
543
ax[1].set_title("Ground Truth", fontsize=20)
544
ax[2].imshow((generated_images[row] + 1) / 2)
545
ax[2].axis("off")
546
ax[2].set_title("Generated", fontsize=20)
547
plt.show()
548
549
550
"""
551
## Subclassed GauGAN model
552
553
Finally, we put everything together inside a subclassed model (from `tf.keras.Model`)
554
overriding its `train_step()` method.
555
"""
556
557
558
class GauGAN(keras.Model):
559
def __init__(
560
self,
561
image_size,
562
num_classes,
563
batch_size,
564
latent_dim,
565
feature_loss_coeff=10,
566
vgg_feature_loss_coeff=0.1,
567
kl_divergence_loss_coeff=0.1,
568
**kwargs,
569
):
570
super().__init__(**kwargs)
571
572
self.image_size = image_size
573
self.latent_dim = latent_dim
574
self.batch_size = batch_size
575
self.num_classes = num_classes
576
self.image_shape = (image_size, image_size, 3)
577
self.mask_shape = (image_size, image_size, num_classes)
578
self.feature_loss_coeff = feature_loss_coeff
579
self.vgg_feature_loss_coeff = vgg_feature_loss_coeff
580
self.kl_divergence_loss_coeff = kl_divergence_loss_coeff
581
582
self.discriminator = build_discriminator(self.image_shape)
583
self.generator = build_generator(self.mask_shape)
584
self.encoder = build_encoder(self.image_shape)
585
self.sampler = GaussianSampler(batch_size, latent_dim)
586
self.patch_size, self.combined_model = self.build_combined_generator()
587
588
self.disc_loss_tracker = keras.metrics.Mean(name="disc_loss")
589
self.gen_loss_tracker = keras.metrics.Mean(name="gen_loss")
590
self.feat_loss_tracker = keras.metrics.Mean(name="feat_loss")
591
self.vgg_loss_tracker = keras.metrics.Mean(name="vgg_loss")
592
self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")
593
594
@property
595
def metrics(self):
596
return [
597
self.disc_loss_tracker,
598
self.gen_loss_tracker,
599
self.feat_loss_tracker,
600
self.vgg_loss_tracker,
601
self.kl_loss_tracker,
602
]
603
604
def build_combined_generator(self):
605
# This method builds a model that takes as inputs the following:
606
# latent vector, one-hot encoded segmentation label map, and
607
# a segmentation map. It then (i) generates an image with the generator,
608
# (ii) passes the generated images and segmentation map to the discriminator.
609
# Finally, the model produces the following outputs: (a) discriminator outputs,
610
# (b) generated image.
611
# We will be using this model to simplify the implementation.
612
self.discriminator.trainable = False
613
mask_input = keras.Input(shape=self.mask_shape, name="mask")
614
image_input = keras.Input(shape=self.image_shape, name="image")
615
latent_input = keras.Input(shape=(self.latent_dim,), name="latent")
616
generated_image = self.generator([latent_input, mask_input])
617
discriminator_output = self.discriminator([image_input, generated_image])
618
combined_outputs = discriminator_output + [generated_image]
619
patch_size = discriminator_output[-1].shape[1]
620
combined_model = keras.Model(
621
[latent_input, mask_input, image_input], combined_outputs
622
)
623
return patch_size, combined_model
624
625
def compile(self, gen_lr=1e-4, disc_lr=4e-4, **kwargs):
626
super().compile(**kwargs)
627
self.generator_optimizer = keras.optimizers.Adam(
628
gen_lr, beta_1=0.0, beta_2=0.999
629
)
630
self.discriminator_optimizer = keras.optimizers.Adam(
631
disc_lr, beta_1=0.0, beta_2=0.999
632
)
633
self.discriminator_loss = DiscriminatorLoss()
634
self.feature_matching_loss = FeatureMatchingLoss()
635
self.vgg_loss = VGGFeatureMatchingLoss()
636
637
def train_discriminator(self, latent_vector, segmentation_map, real_image, labels):
638
fake_images = self.generator([latent_vector, labels])
639
with tf.GradientTape() as gradient_tape:
640
pred_fake = self.discriminator([segmentation_map, fake_images])[-1]
641
pred_real = self.discriminator([segmentation_map, real_image])[-1]
642
loss_fake = self.discriminator_loss(pred_fake, -1.0)
643
loss_real = self.discriminator_loss(pred_real, 1.0)
644
total_loss = 0.5 * (loss_fake + loss_real)
645
646
self.discriminator.trainable = True
647
gradients = gradient_tape.gradient(
648
total_loss, self.discriminator.trainable_variables
649
)
650
self.discriminator_optimizer.apply_gradients(
651
zip(gradients, self.discriminator.trainable_variables)
652
)
653
return total_loss
654
655
def train_generator(
656
self, latent_vector, segmentation_map, labels, image, mean, variance
657
):
658
# Generator learns through the signal provided by the discriminator. During
659
# backpropagation, we only update the generator parameters.
660
self.discriminator.trainable = False
661
with tf.GradientTape() as tape:
662
real_d_output = self.discriminator([segmentation_map, image])
663
combined_outputs = self.combined_model(
664
[latent_vector, labels, segmentation_map]
665
)
666
fake_d_output, fake_image = combined_outputs[:-1], combined_outputs[-1]
667
pred = fake_d_output[-1]
668
669
# Compute generator losses.
670
g_loss = generator_loss(pred)
671
kl_loss = self.kl_divergence_loss_coeff * kl_divergence_loss(mean, variance)
672
vgg_loss = self.vgg_feature_loss_coeff * self.vgg_loss(image, fake_image)
673
feature_loss = self.feature_loss_coeff * self.feature_matching_loss(
674
real_d_output, fake_d_output
675
)
676
total_loss = g_loss + kl_loss + vgg_loss + feature_loss
677
678
all_trainable_variables = (
679
self.combined_model.trainable_variables + self.encoder.trainable_variables
680
)
681
682
gradients = tape.gradient(total_loss, all_trainable_variables)
683
self.generator_optimizer.apply_gradients(
684
zip(gradients, all_trainable_variables)
685
)
686
return total_loss, feature_loss, vgg_loss, kl_loss
687
688
def train_step(self, data):
689
segmentation_map, image, labels = data
690
mean, variance = self.encoder(image)
691
latent_vector = self.sampler([mean, variance])
692
discriminator_loss = self.train_discriminator(
693
latent_vector, segmentation_map, image, labels
694
)
695
(generator_loss, feature_loss, vgg_loss, kl_loss) = self.train_generator(
696
latent_vector, segmentation_map, labels, image, mean, variance
697
)
698
699
# Report progress.
700
self.disc_loss_tracker.update_state(discriminator_loss)
701
self.gen_loss_tracker.update_state(generator_loss)
702
self.feat_loss_tracker.update_state(feature_loss)
703
self.vgg_loss_tracker.update_state(vgg_loss)
704
self.kl_loss_tracker.update_state(kl_loss)
705
results = {m.name: m.result() for m in self.metrics}
706
return results
707
708
def test_step(self, data):
709
segmentation_map, image, labels = data
710
# Obtain the learned moments of the real image distribution.
711
mean, variance = self.encoder(image)
712
713
# Sample a latent from the distribution defined by the learned moments.
714
latent_vector = self.sampler([mean, variance])
715
716
# Generate the fake images.
717
fake_images = self.generator([latent_vector, labels])
718
719
# Calculate the losses.
720
pred_fake = self.discriminator([segmentation_map, fake_images])[-1]
721
pred_real = self.discriminator([segmentation_map, image])[-1]
722
loss_fake = self.discriminator_loss(pred_fake, -1.0)
723
loss_real = self.discriminator_loss(pred_real, 1.0)
724
total_discriminator_loss = 0.5 * (loss_fake + loss_real)
725
real_d_output = self.discriminator([segmentation_map, image])
726
combined_outputs = self.combined_model(
727
[latent_vector, labels, segmentation_map]
728
)
729
fake_d_output, fake_image = combined_outputs[:-1], combined_outputs[-1]
730
pred = fake_d_output[-1]
731
g_loss = generator_loss(pred)
732
kl_loss = self.kl_divergence_loss_coeff * kl_divergence_loss(mean, variance)
733
vgg_loss = self.vgg_feature_loss_coeff * self.vgg_loss(image, fake_image)
734
feature_loss = self.feature_loss_coeff * self.feature_matching_loss(
735
real_d_output, fake_d_output
736
)
737
total_generator_loss = g_loss + kl_loss + vgg_loss + feature_loss
738
739
# Report progress.
740
self.disc_loss_tracker.update_state(total_discriminator_loss)
741
self.gen_loss_tracker.update_state(total_generator_loss)
742
self.feat_loss_tracker.update_state(feature_loss)
743
self.vgg_loss_tracker.update_state(vgg_loss)
744
self.kl_loss_tracker.update_state(kl_loss)
745
results = {m.name: m.result() for m in self.metrics}
746
return results
747
748
def call(self, inputs):
749
latent_vectors, labels = inputs
750
return self.generator([latent_vectors, labels])
751
752
753
"""
754
## GauGAN training
755
"""
756
757
gaugan = GauGAN(IMG_HEIGHT, NUM_CLASSES, BATCH_SIZE, latent_dim=256)
758
gaugan.compile()
759
history = gaugan.fit(
760
train_dataset,
761
validation_data=val_dataset,
762
epochs=15,
763
callbacks=[GanMonitor(val_dataset, BATCH_SIZE)],
764
)
765
766
767
def plot_history(item):
768
plt.plot(history.history[item], label=item)
769
plt.plot(history.history["val_" + item], label="val_" + item)
770
plt.xlabel("Epochs")
771
plt.ylabel(item)
772
plt.title("Train and Validation {} Over Epochs".format(item), fontsize=14)
773
plt.legend()
774
plt.grid()
775
plt.show()
776
777
778
plot_history("disc_loss")
779
plot_history("gen_loss")
780
plot_history("feat_loss")
781
plot_history("vgg_loss")
782
plot_history("kl_loss")
783
784
"""
785
## Inference
786
"""
787
788
val_iterator = iter(val_dataset)
789
790
for _ in range(5):
791
val_images = next(val_iterator)
792
# Sample latent from a normal distribution.
793
latent_vector = keras.random.normal(
794
shape=(gaugan.batch_size, gaugan.latent_dim), mean=0.0, stddev=2.0
795
)
796
# Generate fake images.
797
fake_images = gaugan.predict([latent_vector, val_images[2]])
798
799
real_images = val_images
800
grid_row = min(fake_images.shape[0], 3)
801
grid_col = 3
802
f, axarr = plt.subplots(grid_row, grid_col, figsize=(grid_col * 6, grid_row * 6))
803
for row in range(grid_row):
804
ax = axarr if grid_row == 1 else axarr[row]
805
ax[0].imshow((real_images[0][row] + 1) / 2)
806
ax[0].axis("off")
807
ax[0].set_title("Mask", fontsize=20)
808
ax[1].imshow((real_images[1][row] + 1) / 2)
809
ax[1].axis("off")
810
ax[1].set_title("Ground Truth", fontsize=20)
811
ax[2].imshow((fake_images[row] + 1) / 2)
812
ax[2].axis("off")
813
ax[2].set_title("Generated", fontsize=20)
814
plt.show()
815
816
"""
817
## Final words
818
819
* The dataset we used in this example is a small one. For obtaining even better results
820
we recommend to use a bigger dataset. GauGAN results were demonstrated with the
821
[COCO-Stuff](https://github.com/nightrome/cocostuff) and
822
[CityScapes](https://www.cityscapes-dataset.com/) datasets.
823
* This example was inspired the Chapter 6 of
824
[Hands-On Image Generation with TensorFlow](https://www.packtpub.com/product/hands-on-image-generation-with-tensorflow/9781838826789)
825
by [Soon-Yau Cheong](https://www.linkedin.com/in/soonyau/) and
826
[Implementing SPADE using fastai](https://towardsdatascience.com/implementing-spade-using-fastai-6ad86b94030a) by
827
[Divyansh Jha](https://medium.com/@divyanshj.16).
828
* If you found this example interesting and exciting, you might want to check out
829
[our repository](https://github.com/soumik12345/tf2_gans) which we are
830
currently building. It will include reimplementations of popular GANs and pretrained
831
models. Our focus will be on readability and making the code as accessible as possible.
832
Our plain is to first train our implementation of GauGAN (following the code of
833
this example) on a bigger dataset and then make the repository public. We welcome
834
contributions!
835
* Recently GauGAN2 was also released. You can check it out
836
[here](https://blogs.nvidia.com/blog/2021/11/22/gaugan2-ai-art-demo/).
837
838
"""
839
"""
840
Example available on HuggingFace.
841
842
| Trained Model | Demo |
843
| :--: | :--: |
844
| [![Generic badge](https://img.shields.io/badge/%F0%9F%A4%97%20Model-GauGAN%20Image%20Generation-black.svg)](https://huggingface.co/keras-io/GauGAN-Image-generation) | [![Generic badge](https://img.shields.io/badge/%F0%9F%A4%97%20Spaces-GauGAN%20Image%20Generation-black.svg)](https://huggingface.co/spaces/keras-io/GauGAN_Conditional_Image_Generation) |
845
"""
846
847