Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/generative/adain.py
3507 views
1
"""
2
Title: Neural Style Transfer with AdaIN
3
Author: [Aritra Roy Gosthipaty](https://twitter.com/arig23498), [Ritwik Raha](https://twitter.com/ritwik_raha)
4
Date created: 2021/11/08
5
Last modified: 2021/11/08
6
Description: Neural Style Transfer with Adaptive Instance Normalization.
7
Accelerator: GPU
8
"""
9
10
"""
11
# Introduction
12
13
[Neural Style Transfer](https://www.tensorflow.org/tutorials/generative/style_transfer)
14
is the process of transferring the style of one image onto the content
15
of another. This was first introduced in the seminal paper
16
["A Neural Algorithm of Artistic Style"](https://arxiv.org/abs/1508.06576)
17
by Gatys et al. A major limitation of the technique proposed in this
18
work is in its runtime, as the algorithm uses a slow iterative
19
optimization process.
20
21
Follow-up papers that introduced
22
[Batch Normalization](https://arxiv.org/abs/1502.03167),
23
[Instance Normalization](https://arxiv.org/abs/1701.02096) and
24
[Conditional Instance Normalization](https://arxiv.org/abs/1610.07629)
25
allowed Style Transfer to be performed in new ways, no longer
26
requiring a slow iterative process.
27
28
Following these papers, the authors Xun Huang and Serge
29
Belongie propose
30
[Adaptive Instance Normalization](https://arxiv.org/abs/1703.06868) (AdaIN),
31
which allows arbitrary style transfer in real time.
32
33
In this example we implement Adaptive Instance Normalization
34
for Neural Style Transfer. We show in the below figure the output
35
of our AdaIN model trained for
36
only **30 epochs**.
37
38
![Style transfer sample gallery](https://i.imgur.com/zDjDuea.png)
39
40
You can also try out the model with your own images with this
41
[Hugging Face demo](https://huggingface.co/spaces/ariG23498/nst).
42
"""
43
44
"""
45
# Setup
46
47
We begin with importing the necessary packages. We also set the
48
seed for reproducibility. The global variables are hyperparameters
49
which we can change as we like.
50
"""
51
52
import os
53
import numpy as np
54
import tensorflow as tf
55
from tensorflow import keras
56
import matplotlib.pyplot as plt
57
import tensorflow_datasets as tfds
58
from tensorflow.keras import layers
59
60
# Defining the global variables.
61
IMAGE_SIZE = (224, 224)
62
BATCH_SIZE = 64
63
# Training for single epoch for time constraint.
64
# Please use atleast 30 epochs to see good results.
65
EPOCHS = 1
66
AUTOTUNE = tf.data.AUTOTUNE
67
68
"""
69
## Style transfer sample gallery
70
71
For Neural Style Transfer we need style images and content images. In
72
this example we will use the
73
[Best Artworks of All Time](https://www.kaggle.com/ikarus777/best-artworks-of-all-time)
74
as our style dataset and
75
[Pascal VOC](https://www.tensorflow.org/datasets/catalog/voc)
76
as our content dataset.
77
78
This is a deviation from the original paper implementation by the
79
authors, where they use
80
[WIKI-Art](https://paperswithcode.com/dataset/wikiart) as style and
81
[MSCOCO](https://cocodataset.org/#home) as content datasets
82
respectively. We do this to create a minimal yet reproducible example.
83
84
## Downloading the dataset from Kaggle
85
86
The [Best Artworks of All Time](https://www.kaggle.com/ikarus777/best-artworks-of-all-time)
87
dataset is hosted on Kaggle and one can easily download it in Colab by
88
following these steps:
89
90
- Follow the instructions [here](https://github.com/Kaggle/kaggle-api)
91
in order to obtain your Kaggle API keys in case you don't have them.
92
- Use the following command to upload the Kaggle API keys.
93
94
```python
95
from google.colab import files
96
files.upload()
97
```
98
99
- Use the following commands to move the API keys to the proper
100
directory and download the dataset.
101
102
```shell
103
$ mkdir ~/.kaggle
104
$ cp kaggle.json ~/.kaggle/
105
$ chmod 600 ~/.kaggle/kaggle.json
106
$ kaggle datasets download ikarus777/best-artworks-of-all-time
107
$ unzip -qq best-artworks-of-all-time.zip
108
$ rm -rf images
109
$ mv resized artwork
110
$ rm best-artworks-of-all-time.zip artists.csv
111
```
112
"""
113
114
"""
115
## `tf.data` pipeline
116
117
In this section, we will build the `tf.data` pipeline for the project.
118
For the style dataset, we decode, convert and resize the images from
119
the folder. For the content images we are already presented with a
120
`tf.data` dataset as we use the `tfds` module.
121
122
After we have our style and content data pipeline ready, we zip the
123
two together to obtain the data pipeline that our model will consume.
124
"""
125
126
127
def decode_and_resize(image_path):
128
"""Decodes and resizes an image from the image file path.
129
130
Args:
131
image_path: The image file path.
132
133
Returns:
134
A resized image.
135
"""
136
image = tf.io.read_file(image_path)
137
image = tf.image.decode_jpeg(image, channels=3)
138
image = tf.image.convert_image_dtype(image, dtype="float32")
139
image = tf.image.resize(image, IMAGE_SIZE)
140
return image
141
142
143
def extract_image_from_voc(element):
144
"""Extracts image from the PascalVOC dataset.
145
146
Args:
147
element: A dictionary of data.
148
149
Returns:
150
A resized image.
151
"""
152
image = element["image"]
153
image = tf.image.convert_image_dtype(image, dtype="float32")
154
image = tf.image.resize(image, IMAGE_SIZE)
155
return image
156
157
158
# Get the image file paths for the style images.
159
style_images = os.listdir("artwork/resized")
160
style_images = [os.path.join("artwork/resized", path) for path in style_images]
161
162
# split the style images in train, val and test
163
total_style_images = len(style_images)
164
train_style = style_images[: int(0.8 * total_style_images)]
165
val_style = style_images[int(0.8 * total_style_images) : int(0.9 * total_style_images)]
166
test_style = style_images[int(0.9 * total_style_images) :]
167
168
# Build the style and content tf.data datasets.
169
train_style_ds = (
170
tf.data.Dataset.from_tensor_slices(train_style)
171
.map(decode_and_resize, num_parallel_calls=AUTOTUNE)
172
.repeat()
173
)
174
train_content_ds = tfds.load("voc", split="train").map(extract_image_from_voc).repeat()
175
176
val_style_ds = (
177
tf.data.Dataset.from_tensor_slices(val_style)
178
.map(decode_and_resize, num_parallel_calls=AUTOTUNE)
179
.repeat()
180
)
181
val_content_ds = (
182
tfds.load("voc", split="validation").map(extract_image_from_voc).repeat()
183
)
184
185
test_style_ds = (
186
tf.data.Dataset.from_tensor_slices(test_style)
187
.map(decode_and_resize, num_parallel_calls=AUTOTUNE)
188
.repeat()
189
)
190
test_content_ds = (
191
tfds.load("voc", split="test")
192
.map(extract_image_from_voc, num_parallel_calls=AUTOTUNE)
193
.repeat()
194
)
195
196
# Zipping the style and content datasets.
197
train_ds = (
198
tf.data.Dataset.zip((train_style_ds, train_content_ds))
199
.shuffle(BATCH_SIZE * 2)
200
.batch(BATCH_SIZE)
201
.prefetch(AUTOTUNE)
202
)
203
204
val_ds = (
205
tf.data.Dataset.zip((val_style_ds, val_content_ds))
206
.shuffle(BATCH_SIZE * 2)
207
.batch(BATCH_SIZE)
208
.prefetch(AUTOTUNE)
209
)
210
211
test_ds = (
212
tf.data.Dataset.zip((test_style_ds, test_content_ds))
213
.shuffle(BATCH_SIZE * 2)
214
.batch(BATCH_SIZE)
215
.prefetch(AUTOTUNE)
216
)
217
218
"""
219
## Visualizing the data
220
221
It is always better to visualize the data before training. To ensure
222
the correctness of our preprocessing pipeline, we visualize 10 samples
223
from our dataset.
224
"""
225
226
style, content = next(iter(train_ds))
227
fig, axes = plt.subplots(nrows=10, ncols=2, figsize=(5, 30))
228
[ax.axis("off") for ax in np.ravel(axes)]
229
230
for axis, style_image, content_image in zip(axes, style[0:10], content[0:10]):
231
(ax_style, ax_content) = axis
232
ax_style.imshow(style_image)
233
ax_style.set_title("Style Image")
234
235
ax_content.imshow(content_image)
236
ax_content.set_title("Content Image")
237
238
"""
239
## Architecture
240
241
The style transfer network takes a content image and a style image as
242
inputs and outputs the style transferred image. The authors of AdaIN
243
propose a simple encoder-decoder structure for achieving this.
244
245
![AdaIN architecture](https://i.imgur.com/JbIfoyE.png)
246
247
The content image (`C`) and the style image (`S`) are both fed to the
248
encoder networks. The output from these encoder networks (feature maps)
249
are then fed to the AdaIN layer. The AdaIN layer computes a combined
250
feature map. This feature map is then fed into a randomly initialized
251
decoder network that serves as the generator for the neural style
252
transferred image.
253
254
![AdaIn equation](https://i.imgur.com/hqhcBQS.png)
255
256
The style feature map (`fs`) and the content feature map (`fc`) are
257
fed to the AdaIN layer. This layer produced the combined feature map
258
`t`. The function `g` represents the decoder (generator) network.
259
"""
260
261
"""
262
### Encoder
263
264
The encoder is a part of the pretrained (pretrained on
265
[imagenet](https://www.image-net.org/)) VGG19 model. We slice the
266
model from the `block4-conv1` layer. The output layer is as suggested
267
by the authors in their paper.
268
"""
269
270
271
def get_encoder():
272
vgg19 = keras.applications.VGG19(
273
include_top=False,
274
weights="imagenet",
275
input_shape=(*IMAGE_SIZE, 3),
276
)
277
vgg19.trainable = False
278
mini_vgg19 = keras.Model(vgg19.input, vgg19.get_layer("block4_conv1").output)
279
280
inputs = layers.Input([*IMAGE_SIZE, 3])
281
mini_vgg19_out = mini_vgg19(inputs)
282
return keras.Model(inputs, mini_vgg19_out, name="mini_vgg19")
283
284
285
"""
286
### Adaptive Instance Normalization
287
288
The AdaIN layer takes in the features
289
of the content and style image. The layer can be defined via the
290
following equation:
291
292
![AdaIn formula](https://i.imgur.com/tWq3VKP.png)
293
294
where `sigma` is the standard deviation and `mu` is the mean for the
295
concerned variable. In the above equation the mean and variance of the
296
content feature map `fc` is aligned with the mean and variance of the
297
style feature maps `fs`.
298
299
It is important to note that the AdaIN layer proposed by the authors
300
uses no other parameters apart from mean and variance. The layer also
301
does not have any trainable parameters. This is why we use a
302
*Python function* instead of using a *Keras layer*. The function takes
303
style and content feature maps, computes the mean and standard deviation
304
of the images and returns the adaptive instance normalized feature map.
305
"""
306
307
308
def get_mean_std(x, epsilon=1e-5):
309
axes = [1, 2]
310
311
# Compute the mean and standard deviation of a tensor.
312
mean, variance = tf.nn.moments(x, axes=axes, keepdims=True)
313
standard_deviation = tf.sqrt(variance + epsilon)
314
return mean, standard_deviation
315
316
317
def ada_in(style, content):
318
"""Computes the AdaIn feature map.
319
320
Args:
321
style: The style feature map.
322
content: The content feature map.
323
324
Returns:
325
The AdaIN feature map.
326
"""
327
content_mean, content_std = get_mean_std(content)
328
style_mean, style_std = get_mean_std(style)
329
t = style_std * (content - content_mean) / content_std + style_mean
330
return t
331
332
333
"""
334
### Decoder
335
336
The authors specify that the decoder network must mirror the encoder
337
network. We have symmetrically inverted the encoder to build our
338
decoder. We have used `UpSampling2D` layers to increase the spatial
339
resolution of the feature maps.
340
341
Note that the authors warn against using any normalization layer
342
in the decoder network, and do indeed go on to show that including
343
batch normalization or instance normalization hurts the performance
344
of the overall network.
345
346
This is the only portion of the entire architecture that is trainable.
347
"""
348
349
350
def get_decoder():
351
config = {"kernel_size": 3, "strides": 1, "padding": "same", "activation": "relu"}
352
decoder = keras.Sequential(
353
[
354
layers.InputLayer((None, None, 512)),
355
layers.Conv2D(filters=512, **config),
356
layers.UpSampling2D(),
357
layers.Conv2D(filters=256, **config),
358
layers.Conv2D(filters=256, **config),
359
layers.Conv2D(filters=256, **config),
360
layers.Conv2D(filters=256, **config),
361
layers.UpSampling2D(),
362
layers.Conv2D(filters=128, **config),
363
layers.Conv2D(filters=128, **config),
364
layers.UpSampling2D(),
365
layers.Conv2D(filters=64, **config),
366
layers.Conv2D(
367
filters=3,
368
kernel_size=3,
369
strides=1,
370
padding="same",
371
activation="sigmoid",
372
),
373
]
374
)
375
return decoder
376
377
378
"""
379
### Loss functions
380
381
Here we build the loss functions for the neural style transfer model.
382
The authors propose to use a pretrained VGG-19 to compute the loss
383
function of the network. It is important to keep in mind that this
384
will be used for training only the decoder network. The total
385
loss (`Lt`) is a weighted combination of content loss (`Lc`) and style
386
loss (`Ls`). The `lambda` term is used to vary the amount of style
387
transferred.
388
389
![The total loss](https://i.imgur.com/Q5y1jUM.png)
390
391
### Content Loss
392
393
This is the Euclidean distance between the content image features
394
and the features of the neural style transferred image.
395
396
![The content loss](https://i.imgur.com/dZ0uD0N.png)
397
398
Here the authors propose to use the output from the AdaIn layer `t` as
399
the content target rather than using features of the original image as
400
target. This is done to speed up convergence.
401
402
### Style Loss
403
404
Rather than using the more commonly used
405
[Gram Matrix](https://mathworld.wolfram.com/GramMatrix.html),
406
the authors propose to compute the difference between the statistical features
407
(mean and variance) which makes it conceptually cleaner. This can be
408
easily visualized via the following equation:
409
410
![The style loss](https://i.imgur.com/Ctclhn3.png)
411
412
where `theta` denotes the layers in VGG-19 used to compute the loss.
413
In this case this corresponds to:
414
415
- `block1_conv1`
416
- `block1_conv2`
417
- `block1_conv3`
418
- `block1_conv4`
419
420
"""
421
422
423
def get_loss_net():
424
vgg19 = keras.applications.VGG19(
425
include_top=False, weights="imagenet", input_shape=(*IMAGE_SIZE, 3)
426
)
427
vgg19.trainable = False
428
layer_names = ["block1_conv1", "block2_conv1", "block3_conv1", "block4_conv1"]
429
outputs = [vgg19.get_layer(name).output for name in layer_names]
430
mini_vgg19 = keras.Model(vgg19.input, outputs)
431
432
inputs = layers.Input([*IMAGE_SIZE, 3])
433
mini_vgg19_out = mini_vgg19(inputs)
434
return keras.Model(inputs, mini_vgg19_out, name="loss_net")
435
436
437
"""
438
## Neural Style Transfer
439
440
This is the trainer module. We wrap the encoder and decoder inside
441
a `tf.keras.Model` subclass. This allows us to customize what happens
442
in the `model.fit()` loop.
443
"""
444
445
446
class NeuralStyleTransfer(tf.keras.Model):
447
def __init__(self, encoder, decoder, loss_net, style_weight, **kwargs):
448
super().__init__(**kwargs)
449
self.encoder = encoder
450
self.decoder = decoder
451
self.loss_net = loss_net
452
self.style_weight = style_weight
453
454
def compile(self, optimizer, loss_fn):
455
super().compile()
456
self.optimizer = optimizer
457
self.loss_fn = loss_fn
458
self.style_loss_tracker = keras.metrics.Mean(name="style_loss")
459
self.content_loss_tracker = keras.metrics.Mean(name="content_loss")
460
self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
461
462
def train_step(self, inputs):
463
style, content = inputs
464
465
# Initialize the content and style loss.
466
loss_content = 0.0
467
loss_style = 0.0
468
469
with tf.GradientTape() as tape:
470
# Encode the style and content image.
471
style_encoded = self.encoder(style)
472
content_encoded = self.encoder(content)
473
474
# Compute the AdaIN target feature maps.
475
t = ada_in(style=style_encoded, content=content_encoded)
476
477
# Generate the neural style transferred image.
478
reconstructed_image = self.decoder(t)
479
480
# Compute the losses.
481
reconstructed_vgg_features = self.loss_net(reconstructed_image)
482
style_vgg_features = self.loss_net(style)
483
loss_content = self.loss_fn(t, reconstructed_vgg_features[-1])
484
for inp, out in zip(style_vgg_features, reconstructed_vgg_features):
485
mean_inp, std_inp = get_mean_std(inp)
486
mean_out, std_out = get_mean_std(out)
487
loss_style += self.loss_fn(mean_inp, mean_out) + self.loss_fn(
488
std_inp, std_out
489
)
490
loss_style = self.style_weight * loss_style
491
total_loss = loss_content + loss_style
492
493
# Compute gradients and optimize the decoder.
494
trainable_vars = self.decoder.trainable_variables
495
gradients = tape.gradient(total_loss, trainable_vars)
496
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
497
498
# Update the trackers.
499
self.style_loss_tracker.update_state(loss_style)
500
self.content_loss_tracker.update_state(loss_content)
501
self.total_loss_tracker.update_state(total_loss)
502
return {
503
"style_loss": self.style_loss_tracker.result(),
504
"content_loss": self.content_loss_tracker.result(),
505
"total_loss": self.total_loss_tracker.result(),
506
}
507
508
def test_step(self, inputs):
509
style, content = inputs
510
511
# Initialize the content and style loss.
512
loss_content = 0.0
513
loss_style = 0.0
514
515
# Encode the style and content image.
516
style_encoded = self.encoder(style)
517
content_encoded = self.encoder(content)
518
519
# Compute the AdaIN target feature maps.
520
t = ada_in(style=style_encoded, content=content_encoded)
521
522
# Generate the neural style transferred image.
523
reconstructed_image = self.decoder(t)
524
525
# Compute the losses.
526
recons_vgg_features = self.loss_net(reconstructed_image)
527
style_vgg_features = self.loss_net(style)
528
loss_content = self.loss_fn(t, recons_vgg_features[-1])
529
for inp, out in zip(style_vgg_features, recons_vgg_features):
530
mean_inp, std_inp = get_mean_std(inp)
531
mean_out, std_out = get_mean_std(out)
532
loss_style += self.loss_fn(mean_inp, mean_out) + self.loss_fn(
533
std_inp, std_out
534
)
535
loss_style = self.style_weight * loss_style
536
total_loss = loss_content + loss_style
537
538
# Update the trackers.
539
self.style_loss_tracker.update_state(loss_style)
540
self.content_loss_tracker.update_state(loss_content)
541
self.total_loss_tracker.update_state(total_loss)
542
return {
543
"style_loss": self.style_loss_tracker.result(),
544
"content_loss": self.content_loss_tracker.result(),
545
"total_loss": self.total_loss_tracker.result(),
546
}
547
548
@property
549
def metrics(self):
550
return [
551
self.style_loss_tracker,
552
self.content_loss_tracker,
553
self.total_loss_tracker,
554
]
555
556
557
"""
558
## Train Monitor callback
559
560
This callback is used to visualize the style transfer output of
561
the model at the end of each epoch. The objective of style transfer cannot be
562
quantified properly, and is to be subjectively evaluated by an audience.
563
For this reason, visualization is a key aspect of evaluating the model.
564
"""
565
566
test_style, test_content = next(iter(test_ds))
567
568
569
class TrainMonitor(tf.keras.callbacks.Callback):
570
def on_epoch_end(self, epoch, logs=None):
571
# Encode the style and content image.
572
test_style_encoded = self.model.encoder(test_style)
573
test_content_encoded = self.model.encoder(test_content)
574
575
# Compute the AdaIN features.
576
test_t = ada_in(style=test_style_encoded, content=test_content_encoded)
577
test_reconstructed_image = self.model.decoder(test_t)
578
579
# Plot the Style, Content and the NST image.
580
fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(20, 5))
581
ax[0].imshow(tf.keras.utils.array_to_img(test_style[0]))
582
ax[0].set_title(f"Style: {epoch:03d}")
583
584
ax[1].imshow(tf.keras.utils.array_to_img(test_content[0]))
585
ax[1].set_title(f"Content: {epoch:03d}")
586
587
ax[2].imshow(tf.keras.utils.array_to_img(test_reconstructed_image[0]))
588
ax[2].set_title(f"NST: {epoch:03d}")
589
590
plt.show()
591
plt.close()
592
593
594
"""
595
## Train the model
596
597
In this section, we define the optimizer, the loss function, and the
598
trainer module. We compile the trainer module with the optimizer and
599
the loss function and then train it.
600
601
*Note*: We train the model for a single epoch for time constraints,
602
but we will need to train is for atleast 30 epochs to see good results.
603
"""
604
605
optimizer = keras.optimizers.Adam(learning_rate=1e-5)
606
loss_fn = keras.losses.MeanSquaredError()
607
608
encoder = get_encoder()
609
loss_net = get_loss_net()
610
decoder = get_decoder()
611
612
model = NeuralStyleTransfer(
613
encoder=encoder, decoder=decoder, loss_net=loss_net, style_weight=4.0
614
)
615
616
model.compile(optimizer=optimizer, loss_fn=loss_fn)
617
618
history = model.fit(
619
train_ds,
620
epochs=EPOCHS,
621
steps_per_epoch=50,
622
validation_data=val_ds,
623
validation_steps=50,
624
callbacks=[TrainMonitor()],
625
)
626
627
"""
628
## Inference
629
630
After we train the model, we now need to run inference with it. We will
631
pass arbitrary content and style images from the test dataset and take a look at
632
the output images.
633
634
*NOTE*: To try out the model on your own images, you can use this
635
[Hugging Face demo](https://huggingface.co/spaces/ariG23498/nst).
636
"""
637
638
for style, content in test_ds.take(1):
639
style_encoded = model.encoder(style)
640
content_encoded = model.encoder(content)
641
t = ada_in(style=style_encoded, content=content_encoded)
642
reconstructed_image = model.decoder(t)
643
fig, axes = plt.subplots(nrows=10, ncols=3, figsize=(10, 30))
644
[ax.axis("off") for ax in np.ravel(axes)]
645
646
for axis, style_image, content_image, reconstructed_image in zip(
647
axes, style[0:10], content[0:10], reconstructed_image[0:10]
648
):
649
(ax_style, ax_content, ax_reconstructed) = axis
650
ax_style.imshow(style_image)
651
ax_style.set_title("Style Image")
652
ax_content.imshow(content_image)
653
ax_content.set_title("Content Image")
654
ax_reconstructed.imshow(reconstructed_image)
655
ax_reconstructed.set_title("NST Image")
656
657
"""
658
## Conclusion
659
660
Adaptive Instance Normalization allows arbitrary style transfer in
661
real time. It is also important to note that the novel proposition of
662
the authors is to achieve this only by aligning the statistical
663
features (mean and standard deviation) of the style and the content
664
images.
665
666
*Note*: AdaIN also serves as the base for
667
[Style-GANs](https://arxiv.org/abs/1812.04948).
668
669
## Reference
670
671
- [TF implementation](https://github.com/ftokarev/tf-adain)
672
673
## Acknowledgement
674
675
We thank [Luke Wood](https://lukewood.xyz) for his
676
detailed review.
677
"""
678
679