Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/generative/cyclegan.py
3507 views
1
"""
2
Title: CycleGAN
3
Author: [A_K_Nain](https://twitter.com/A_K_Nain)
4
Date created: 2020/08/12
5
Last modified: 2024/09/30
6
Description: Implementation of CycleGAN.
7
Accelerator: GPU
8
"""
9
10
"""
11
## CycleGAN
12
13
CycleGAN is a model that aims to solve the image-to-image translation
14
problem. The goal of the image-to-image translation problem is to learn the
15
mapping between an input image and an output image using a training set of
16
aligned image pairs. However, obtaining paired examples isn't always feasible.
17
CycleGAN tries to learn this mapping without requiring paired input-output images,
18
using cycle-consistent adversarial networks.
19
20
- [Paper](https://arxiv.org/abs/1703.10593)
21
- [Original implementation](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix)
22
"""
23
24
"""
25
## Setup
26
"""
27
28
import os
29
import numpy as np
30
import matplotlib.pyplot as plt
31
import tensorflow as tf
32
import keras
33
from keras import layers, ops
34
import tensorflow_datasets as tfds
35
36
tfds.disable_progress_bar()
37
autotune = tf.data.AUTOTUNE
38
39
os.environ["KERAS_BACKEND"] = "tensorflow"
40
41
"""
42
## Prepare the dataset
43
44
In this example, we will be using the
45
[horse to zebra](https://www.tensorflow.org/datasets/catalog/cycle_gan#cycle_ganhorse2zebra)
46
dataset.
47
"""
48
49
# Load the horse-zebra dataset using tensorflow-datasets.
50
dataset, _ = tfds.load(name="cycle_gan/horse2zebra", with_info=True, as_supervised=True)
51
train_horses, train_zebras = dataset["trainA"], dataset["trainB"]
52
test_horses, test_zebras = dataset["testA"], dataset["testB"]
53
54
# Define the standard image size.
55
orig_img_size = (286, 286)
56
# Size of the random crops to be used during training.
57
input_img_size = (256, 256, 3)
58
# Weights initializer for the layers.
59
kernel_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)
60
# Gamma initializer for instance normalization.
61
gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)
62
63
buffer_size = 256
64
batch_size = 1
65
66
67
def normalize_img(img):
68
img = ops.cast(img, dtype=tf.float32)
69
# Map values in the range [-1, 1]
70
return (img / 127.5) - 1.0
71
72
73
def preprocess_train_image(img, label):
74
# Random flip
75
img = tf.image.random_flip_left_right(img)
76
# Resize to the original size first
77
img = ops.image.resize(img, [*orig_img_size])
78
# Random crop to 256X256
79
img = tf.image.random_crop(img, size=[*input_img_size])
80
# Normalize the pixel values in the range [-1, 1]
81
img = normalize_img(img)
82
return img
83
84
85
def preprocess_test_image(img, label):
86
# Only resizing and normalization for the test images.
87
img = ops.image.resize(img, [input_img_size[0], input_img_size[1]])
88
img = normalize_img(img)
89
return img
90
91
92
"""
93
## Create `Dataset` objects
94
"""
95
96
97
# Apply the preprocessing operations to the training data
98
train_horses = (
99
train_horses.map(preprocess_train_image, num_parallel_calls=autotune)
100
.cache()
101
.shuffle(buffer_size)
102
.batch(batch_size)
103
)
104
train_zebras = (
105
train_zebras.map(preprocess_train_image, num_parallel_calls=autotune)
106
.cache()
107
.shuffle(buffer_size)
108
.batch(batch_size)
109
)
110
111
# Apply the preprocessing operations to the test data
112
test_horses = (
113
test_horses.map(preprocess_test_image, num_parallel_calls=autotune)
114
.cache()
115
.shuffle(buffer_size)
116
.batch(batch_size)
117
)
118
test_zebras = (
119
test_zebras.map(preprocess_test_image, num_parallel_calls=autotune)
120
.cache()
121
.shuffle(buffer_size)
122
.batch(batch_size)
123
)
124
125
126
"""
127
## Visualize some samples
128
"""
129
130
131
_, ax = plt.subplots(4, 2, figsize=(10, 15))
132
for i, samples in enumerate(zip(train_horses.take(4), train_zebras.take(4))):
133
horse = (((samples[0][0] * 127.5) + 127.5).numpy()).astype(np.uint8)
134
zebra = (((samples[1][0] * 127.5) + 127.5).numpy()).astype(np.uint8)
135
ax[i, 0].imshow(horse)
136
ax[i, 1].imshow(zebra)
137
plt.show()
138
139
140
"""
141
## Building blocks used in the CycleGAN generators and discriminators
142
"""
143
144
145
class ReflectionPadding2D(layers.Layer):
146
"""Implements Reflection Padding as a layer.
147
148
Args:
149
padding(tuple): Amount of padding for the
150
spatial dimensions.
151
152
Returns:
153
A padded tensor with the same type as the input tensor.
154
"""
155
156
def __init__(self, padding=(1, 1), **kwargs):
157
self.padding = tuple(padding)
158
super().__init__(**kwargs)
159
160
def call(self, input_tensor, mask=None):
161
padding_width, padding_height = self.padding
162
padding_tensor = [
163
[0, 0],
164
[padding_height, padding_height],
165
[padding_width, padding_width],
166
[0, 0],
167
]
168
return ops.pad(input_tensor, padding_tensor, mode="REFLECT")
169
170
171
def residual_block(
172
x,
173
activation,
174
kernel_initializer=kernel_init,
175
kernel_size=(3, 3),
176
strides=(1, 1),
177
padding="valid",
178
gamma_initializer=gamma_init,
179
use_bias=False,
180
):
181
dim = x.shape[-1]
182
input_tensor = x
183
184
x = ReflectionPadding2D()(input_tensor)
185
x = layers.Conv2D(
186
dim,
187
kernel_size,
188
strides=strides,
189
kernel_initializer=kernel_initializer,
190
padding=padding,
191
use_bias=use_bias,
192
)(x)
193
x = keras.layers.GroupNormalization(groups=1, gamma_initializer=gamma_initializer)(
194
x
195
)
196
x = activation(x)
197
198
x = ReflectionPadding2D()(x)
199
x = layers.Conv2D(
200
dim,
201
kernel_size,
202
strides=strides,
203
kernel_initializer=kernel_initializer,
204
padding=padding,
205
use_bias=use_bias,
206
)(x)
207
x = keras.layers.GroupNormalization(groups=1, gamma_initializer=gamma_initializer)(
208
x
209
)
210
x = layers.add([input_tensor, x])
211
return x
212
213
214
def downsample(
215
x,
216
filters,
217
activation,
218
kernel_initializer=kernel_init,
219
kernel_size=(3, 3),
220
strides=(2, 2),
221
padding="same",
222
gamma_initializer=gamma_init,
223
use_bias=False,
224
):
225
x = layers.Conv2D(
226
filters,
227
kernel_size,
228
strides=strides,
229
kernel_initializer=kernel_initializer,
230
padding=padding,
231
use_bias=use_bias,
232
)(x)
233
x = keras.layers.GroupNormalization(groups=1, gamma_initializer=gamma_initializer)(
234
x
235
)
236
if activation:
237
x = activation(x)
238
return x
239
240
241
def upsample(
242
x,
243
filters,
244
activation,
245
kernel_size=(3, 3),
246
strides=(2, 2),
247
padding="same",
248
kernel_initializer=kernel_init,
249
gamma_initializer=gamma_init,
250
use_bias=False,
251
):
252
x = layers.Conv2DTranspose(
253
filters,
254
kernel_size,
255
strides=strides,
256
padding=padding,
257
kernel_initializer=kernel_initializer,
258
use_bias=use_bias,
259
)(x)
260
x = keras.layers.GroupNormalization(groups=1, gamma_initializer=gamma_initializer)(
261
x
262
)
263
if activation:
264
x = activation(x)
265
return x
266
267
268
"""
269
## Build the generators
270
271
The generator consists of downsampling blocks: nine residual blocks
272
and upsampling blocks. The structure of the generator is the following:
273
274
```
275
c7s1-64 ==> Conv block with `relu` activation, filter size of 7
276
d128 ====|
277
|-> 2 downsampling blocks
278
d256 ====|
279
R256 ====|
280
R256 |
281
R256 |
282
R256 |
283
R256 |-> 9 residual blocks
284
R256 |
285
R256 |
286
R256 |
287
R256 ====|
288
u128 ====|
289
|-> 2 upsampling blocks
290
u64 ====|
291
c7s1-3 => Last conv block with `tanh` activation, filter size of 7.
292
```
293
"""
294
295
296
def get_resnet_generator(
297
filters=64,
298
num_downsampling_blocks=2,
299
num_residual_blocks=9,
300
num_upsample_blocks=2,
301
gamma_initializer=gamma_init,
302
name=None,
303
):
304
img_input = layers.Input(shape=input_img_size, name=name + "_img_input")
305
x = ReflectionPadding2D(padding=(3, 3))(img_input)
306
x = layers.Conv2D(filters, (7, 7), kernel_initializer=kernel_init, use_bias=False)(
307
x
308
)
309
x = keras.layers.GroupNormalization(groups=1, gamma_initializer=gamma_initializer)(
310
x
311
)
312
x = layers.Activation("relu")(x)
313
314
# Downsampling
315
for _ in range(num_downsampling_blocks):
316
filters *= 2
317
x = downsample(x, filters=filters, activation=layers.Activation("relu"))
318
319
# Residual blocks
320
for _ in range(num_residual_blocks):
321
x = residual_block(x, activation=layers.Activation("relu"))
322
323
# Upsampling
324
for _ in range(num_upsample_blocks):
325
filters //= 2
326
x = upsample(x, filters, activation=layers.Activation("relu"))
327
328
# Final block
329
x = ReflectionPadding2D(padding=(3, 3))(x)
330
x = layers.Conv2D(3, (7, 7), padding="valid")(x)
331
x = layers.Activation("tanh")(x)
332
333
model = keras.models.Model(img_input, x, name=name)
334
return model
335
336
337
"""
338
## Build the discriminators
339
340
The discriminators implement the following architecture:
341
`C64->C128->C256->C512`
342
"""
343
344
345
def get_discriminator(
346
filters=64, kernel_initializer=kernel_init, num_downsampling=3, name=None
347
):
348
img_input = layers.Input(shape=input_img_size, name=name + "_img_input")
349
x = layers.Conv2D(
350
filters,
351
(4, 4),
352
strides=(2, 2),
353
padding="same",
354
kernel_initializer=kernel_initializer,
355
)(img_input)
356
x = layers.LeakyReLU(0.2)(x)
357
358
num_filters = filters
359
for num_downsample_block in range(3):
360
num_filters *= 2
361
if num_downsample_block < 2:
362
x = downsample(
363
x,
364
filters=num_filters,
365
activation=layers.LeakyReLU(0.2),
366
kernel_size=(4, 4),
367
strides=(2, 2),
368
)
369
else:
370
x = downsample(
371
x,
372
filters=num_filters,
373
activation=layers.LeakyReLU(0.2),
374
kernel_size=(4, 4),
375
strides=(1, 1),
376
)
377
378
x = layers.Conv2D(
379
1, (4, 4), strides=(1, 1), padding="same", kernel_initializer=kernel_initializer
380
)(x)
381
382
model = keras.models.Model(inputs=img_input, outputs=x, name=name)
383
return model
384
385
386
# Get the generators
387
gen_G = get_resnet_generator(name="generator_G")
388
gen_F = get_resnet_generator(name="generator_F")
389
390
# Get the discriminators
391
disc_X = get_discriminator(name="discriminator_X")
392
disc_Y = get_discriminator(name="discriminator_Y")
393
394
395
"""
396
## Build the CycleGAN model
397
398
We will override the `train_step()` method of the `Model` class
399
for training via `fit()`.
400
"""
401
402
403
class CycleGan(keras.Model):
404
def __init__(
405
self,
406
generator_G,
407
generator_F,
408
discriminator_X,
409
discriminator_Y,
410
lambda_cycle=10.0,
411
lambda_identity=0.5,
412
):
413
super().__init__()
414
self.gen_G = generator_G
415
self.gen_F = generator_F
416
self.disc_X = discriminator_X
417
self.disc_Y = discriminator_Y
418
self.lambda_cycle = lambda_cycle
419
self.lambda_identity = lambda_identity
420
421
def call(self, inputs):
422
return (
423
self.disc_X(inputs),
424
self.disc_Y(inputs),
425
self.gen_G(inputs),
426
self.gen_F(inputs),
427
)
428
429
def compile(
430
self,
431
gen_G_optimizer,
432
gen_F_optimizer,
433
disc_X_optimizer,
434
disc_Y_optimizer,
435
gen_loss_fn,
436
disc_loss_fn,
437
):
438
super().compile()
439
self.gen_G_optimizer = gen_G_optimizer
440
self.gen_F_optimizer = gen_F_optimizer
441
self.disc_X_optimizer = disc_X_optimizer
442
self.disc_Y_optimizer = disc_Y_optimizer
443
self.generator_loss_fn = gen_loss_fn
444
self.discriminator_loss_fn = disc_loss_fn
445
self.cycle_loss_fn = keras.losses.MeanAbsoluteError()
446
self.identity_loss_fn = keras.losses.MeanAbsoluteError()
447
448
def train_step(self, batch_data):
449
# x is Horse and y is zebra
450
real_x, real_y = batch_data
451
452
# For CycleGAN, we need to calculate different
453
# kinds of losses for the generators and discriminators.
454
# We will perform the following steps here:
455
#
456
# 1. Pass real images through the generators and get the generated images
457
# 2. Pass the generated images back to the generators to check if we
458
# can predict the original image from the generated image.
459
# 3. Do an identity mapping of the real images using the generators.
460
# 4. Pass the generated images in 1) to the corresponding discriminators.
461
# 5. Calculate the generators total loss (adversarial + cycle + identity)
462
# 6. Calculate the discriminators loss
463
# 7. Update the weights of the generators
464
# 8. Update the weights of the discriminators
465
# 9. Return the losses in a dictionary
466
467
with tf.GradientTape(persistent=True) as tape:
468
# Horse to fake zebra
469
fake_y = self.gen_G(real_x, training=True)
470
# Zebra to fake horse -> y2x
471
fake_x = self.gen_F(real_y, training=True)
472
473
# Cycle (Horse to fake zebra to fake horse): x -> y -> x
474
cycled_x = self.gen_F(fake_y, training=True)
475
# Cycle (Zebra to fake horse to fake zebra) y -> x -> y
476
cycled_y = self.gen_G(fake_x, training=True)
477
478
# Identity mapping
479
same_x = self.gen_F(real_x, training=True)
480
same_y = self.gen_G(real_y, training=True)
481
482
# Discriminator output
483
disc_real_x = self.disc_X(real_x, training=True)
484
disc_fake_x = self.disc_X(fake_x, training=True)
485
486
disc_real_y = self.disc_Y(real_y, training=True)
487
disc_fake_y = self.disc_Y(fake_y, training=True)
488
489
# Generator adversarial loss
490
gen_G_loss = self.generator_loss_fn(disc_fake_y)
491
gen_F_loss = self.generator_loss_fn(disc_fake_x)
492
493
# Generator cycle loss
494
cycle_loss_G = self.cycle_loss_fn(real_y, cycled_y) * self.lambda_cycle
495
cycle_loss_F = self.cycle_loss_fn(real_x, cycled_x) * self.lambda_cycle
496
497
# Generator identity loss
498
id_loss_G = (
499
self.identity_loss_fn(real_y, same_y)
500
* self.lambda_cycle
501
* self.lambda_identity
502
)
503
id_loss_F = (
504
self.identity_loss_fn(real_x, same_x)
505
* self.lambda_cycle
506
* self.lambda_identity
507
)
508
509
# Total generator loss
510
total_loss_G = gen_G_loss + cycle_loss_G + id_loss_G
511
total_loss_F = gen_F_loss + cycle_loss_F + id_loss_F
512
513
# Discriminator loss
514
disc_X_loss = self.discriminator_loss_fn(disc_real_x, disc_fake_x)
515
disc_Y_loss = self.discriminator_loss_fn(disc_real_y, disc_fake_y)
516
517
# Get the gradients for the generators
518
grads_G = tape.gradient(total_loss_G, self.gen_G.trainable_variables)
519
grads_F = tape.gradient(total_loss_F, self.gen_F.trainable_variables)
520
521
# Get the gradients for the discriminators
522
disc_X_grads = tape.gradient(disc_X_loss, self.disc_X.trainable_variables)
523
disc_Y_grads = tape.gradient(disc_Y_loss, self.disc_Y.trainable_variables)
524
525
# Update the weights of the generators
526
self.gen_G_optimizer.apply_gradients(
527
zip(grads_G, self.gen_G.trainable_variables)
528
)
529
self.gen_F_optimizer.apply_gradients(
530
zip(grads_F, self.gen_F.trainable_variables)
531
)
532
533
# Update the weights of the discriminators
534
self.disc_X_optimizer.apply_gradients(
535
zip(disc_X_grads, self.disc_X.trainable_variables)
536
)
537
self.disc_Y_optimizer.apply_gradients(
538
zip(disc_Y_grads, self.disc_Y.trainable_variables)
539
)
540
541
return {
542
"G_loss": total_loss_G,
543
"F_loss": total_loss_F,
544
"D_X_loss": disc_X_loss,
545
"D_Y_loss": disc_Y_loss,
546
}
547
548
549
"""
550
## Create a callback that periodically saves generated images
551
"""
552
553
554
class GANMonitor(keras.callbacks.Callback):
555
"""A callback to generate and save images after each epoch"""
556
557
def __init__(self, num_img=4):
558
self.num_img = num_img
559
560
def on_epoch_end(self, epoch, logs=None):
561
_, ax = plt.subplots(4, 2, figsize=(12, 12))
562
for i, img in enumerate(test_horses.take(self.num_img)):
563
prediction = self.model.gen_G(img)[0].numpy()
564
prediction = (prediction * 127.5 + 127.5).astype(np.uint8)
565
img = (img[0] * 127.5 + 127.5).numpy().astype(np.uint8)
566
567
ax[i, 0].imshow(img)
568
ax[i, 1].imshow(prediction)
569
ax[i, 0].set_title("Input image")
570
ax[i, 1].set_title("Translated image")
571
ax[i, 0].axis("off")
572
ax[i, 1].axis("off")
573
574
prediction = keras.utils.array_to_img(prediction)
575
prediction.save(
576
"generated_img_{i}_{epoch}.png".format(i=i, epoch=epoch + 1)
577
)
578
plt.show()
579
plt.close()
580
581
582
"""
583
## Train the end-to-end model
584
"""
585
586
587
# Loss function for evaluating adversarial loss
588
adv_loss_fn = keras.losses.MeanSquaredError()
589
590
# Define the loss function for the generators
591
592
593
def generator_loss_fn(fake):
594
fake_loss = adv_loss_fn(ops.ones_like(fake), fake)
595
return fake_loss
596
597
598
# Define the loss function for the discriminators
599
def discriminator_loss_fn(real, fake):
600
real_loss = adv_loss_fn(ops.ones_like(real), real)
601
fake_loss = adv_loss_fn(ops.zeros_like(fake), fake)
602
return (real_loss + fake_loss) * 0.5
603
604
605
# Create cycle gan model
606
cycle_gan_model = CycleGan(
607
generator_G=gen_G, generator_F=gen_F, discriminator_X=disc_X, discriminator_Y=disc_Y
608
)
609
610
# Compile the model
611
cycle_gan_model.compile(
612
gen_G_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5),
613
gen_F_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5),
614
disc_X_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5),
615
disc_Y_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5),
616
gen_loss_fn=generator_loss_fn,
617
disc_loss_fn=discriminator_loss_fn,
618
)
619
# Callbacks
620
plotter = GANMonitor()
621
checkpoint_filepath = "./model_checkpoints/cyclegan_checkpoints.weights.h5"
622
model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
623
filepath=checkpoint_filepath, save_weights_only=True
624
)
625
626
# Here we will train the model for just one epoch as each epoch takes around
627
# 7 minutes on a single P100 backed machine.
628
cycle_gan_model.fit(
629
tf.data.Dataset.zip((train_horses, train_zebras)),
630
epochs=90,
631
callbacks=[plotter, model_checkpoint_callback],
632
)
633
634
"""
635
Test the performance of the model.
636
"""
637
638
639
# Once the weights are loaded, we will take a few samples from the test data and check the model's performance.
640
641
642
# Load the checkpoints
643
cycle_gan_model.load_weights(checkpoint_filepath)
644
print("Weights loaded successfully")
645
646
_, ax = plt.subplots(4, 2, figsize=(10, 15))
647
for i, img in enumerate(test_horses.take(4)):
648
prediction = cycle_gan_model.gen_G(img, training=False)[0].numpy()
649
prediction = (prediction * 127.5 + 127.5).astype(np.uint8)
650
img = (img[0] * 127.5 + 127.5).numpy().astype(np.uint8)
651
652
ax[i, 0].imshow(img)
653
ax[i, 1].imshow(prediction)
654
ax[i, 0].set_title("Input image")
655
ax[i, 0].set_title("Input image")
656
ax[i, 1].set_title("Translated image")
657
ax[i, 0].axis("off")
658
ax[i, 1].axis("off")
659
660
prediction = keras.utils.array_to_img(prediction)
661
prediction.save("predicted_img_{i}.png".format(i=i))
662
plt.tight_layout()
663
plt.show()
664
665