Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/generative/neural_style_transfer.py
3507 views
1
"""
2
Title: Neural style transfer
3
Author: [fchollet](https://twitter.com/fchollet)
4
Date created: 2016/01/11
5
Last modified: 2020/05/02
6
Description: Transferring the style of a reference image to target image using gradient descent.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
Style transfer consists in generating an image
14
with the same "content" as a base image, but with the
15
"style" of a different picture (typically artistic).
16
This is achieved through the optimization of a loss function
17
that has 3 components: "style loss", "content loss",
18
and "total variation loss":
19
20
- The total variation loss imposes local spatial continuity between
21
the pixels of the combination image, giving it visual coherence.
22
- The style loss is where the deep learning keeps in --that one is defined
23
using a deep convolutional neural network. Precisely, it consists in a sum of
24
L2 distances between the Gram matrices of the representations of
25
the base image and the style reference image, extracted from
26
different layers of a convnet (trained on ImageNet). The general idea
27
is to capture color/texture information at different spatial
28
scales (fairly large scales --defined by the depth of the layer considered).
29
- The content loss is a L2 distance between the features of the base
30
image (extracted from a deep layer) and the features of the combination image,
31
keeping the generated image close enough to the original one.
32
33
**Reference:** [A Neural Algorithm of Artistic Style](
34
http://arxiv.org/abs/1508.06576)
35
"""
36
37
"""
38
## Setup
39
"""
40
import os
41
42
os.environ["KERAS_BACKEND"] = "tensorflow"
43
44
import keras
45
import numpy as np
46
import tensorflow as tf
47
from keras.applications import vgg19
48
49
base_image_path = keras.utils.get_file("paris.jpg", "https://i.imgur.com/F28w3Ac.jpg")
50
style_reference_image_path = keras.utils.get_file(
51
"starry_night.jpg", "https://i.imgur.com/9ooB60I.jpg"
52
)
53
result_prefix = "paris_generated"
54
55
# Weights of the different loss components
56
total_variation_weight = 1e-6
57
style_weight = 1e-6
58
content_weight = 2.5e-8
59
60
# Dimensions of the generated picture.
61
width, height = keras.utils.load_img(base_image_path).size
62
img_nrows = 400
63
img_ncols = int(width * img_nrows / height)
64
65
"""
66
## Let's take a look at our base (content) image and our style reference image
67
"""
68
69
from IPython.display import Image, display
70
71
display(Image(base_image_path))
72
display(Image(style_reference_image_path))
73
74
"""
75
## Image preprocessing / deprocessing utilities
76
"""
77
78
79
def preprocess_image(image_path):
80
# Util function to open, resize and format pictures into appropriate tensors
81
img = keras.utils.load_img(image_path, target_size=(img_nrows, img_ncols))
82
img = keras.utils.img_to_array(img)
83
img = np.expand_dims(img, axis=0)
84
img = vgg19.preprocess_input(img)
85
return tf.convert_to_tensor(img)
86
87
88
def deprocess_image(x):
89
# Util function to convert a tensor into a valid image
90
x = x.reshape((img_nrows, img_ncols, 3))
91
# Remove zero-center by mean pixel
92
x[:, :, 0] += 103.939
93
x[:, :, 1] += 116.779
94
x[:, :, 2] += 123.68
95
# 'BGR'->'RGB'
96
x = x[:, :, ::-1]
97
x = np.clip(x, 0, 255).astype("uint8")
98
return x
99
100
101
"""
102
## Compute the style transfer loss
103
104
First, we need to define 4 utility functions:
105
106
- `gram_matrix` (used to compute the style loss)
107
- The `style_loss` function, which keeps the generated image close to the local textures
108
of the style reference image
109
- The `content_loss` function, which keeps the high-level representation of the
110
generated image close to that of the base image
111
- The `total_variation_loss` function, a regularization loss which keeps the generated
112
image locally-coherent
113
"""
114
115
# The gram matrix of an image tensor (feature-wise outer product)
116
117
118
def gram_matrix(x):
119
x = tf.transpose(x, (2, 0, 1))
120
features = tf.reshape(x, (tf.shape(x)[0], -1))
121
gram = tf.matmul(features, tf.transpose(features))
122
return gram
123
124
125
# The "style loss" is designed to maintain
126
# the style of the reference image in the generated image.
127
# It is based on the gram matrices (which capture style) of
128
# feature maps from the style reference image
129
# and from the generated image
130
131
132
def style_loss(style, combination):
133
S = gram_matrix(style)
134
C = gram_matrix(combination)
135
channels = 3
136
size = img_nrows * img_ncols
137
return tf.reduce_sum(tf.square(S - C)) / (4.0 * (channels**2) * (size**2))
138
139
140
# An auxiliary loss function
141
# designed to maintain the "content" of the
142
# base image in the generated image
143
144
145
def content_loss(base, combination):
146
return tf.reduce_sum(tf.square(combination - base))
147
148
149
# The 3rd loss function, total variation loss,
150
# designed to keep the generated image locally coherent
151
152
153
def total_variation_loss(x):
154
a = tf.square(
155
x[:, : img_nrows - 1, : img_ncols - 1, :] - x[:, 1:, : img_ncols - 1, :]
156
)
157
b = tf.square(
158
x[:, : img_nrows - 1, : img_ncols - 1, :] - x[:, : img_nrows - 1, 1:, :]
159
)
160
return tf.reduce_sum(tf.pow(a + b, 1.25))
161
162
163
"""
164
Next, let's create a feature extraction model that retrieves the intermediate activations
165
of VGG19 (as a dict, by name).
166
"""
167
168
# Build a VGG19 model loaded with pre-trained ImageNet weights
169
model = vgg19.VGG19(weights="imagenet", include_top=False)
170
171
# Get the symbolic outputs of each "key" layer (we gave them unique names).
172
outputs_dict = dict([(layer.name, layer.output) for layer in model.layers])
173
174
# Set up a model that returns the activation values for every layer in
175
# VGG19 (as a dict).
176
feature_extractor = keras.Model(inputs=model.inputs, outputs=outputs_dict)
177
178
"""
179
Finally, here's the code that computes the style transfer loss.
180
"""
181
182
# List of layers to use for the style loss.
183
style_layer_names = [
184
"block1_conv1",
185
"block2_conv1",
186
"block3_conv1",
187
"block4_conv1",
188
"block5_conv1",
189
]
190
# The layer to use for the content loss.
191
content_layer_name = "block5_conv2"
192
193
194
def compute_loss(combination_image, base_image, style_reference_image):
195
input_tensor = tf.concat(
196
[base_image, style_reference_image, combination_image], axis=0
197
)
198
features = feature_extractor(input_tensor)
199
200
# Initialize the loss
201
loss = tf.zeros(shape=())
202
203
# Add content loss
204
layer_features = features[content_layer_name]
205
base_image_features = layer_features[0, :, :, :]
206
combination_features = layer_features[2, :, :, :]
207
loss = loss + content_weight * content_loss(
208
base_image_features, combination_features
209
)
210
# Add style loss
211
for layer_name in style_layer_names:
212
layer_features = features[layer_name]
213
style_reference_features = layer_features[1, :, :, :]
214
combination_features = layer_features[2, :, :, :]
215
sl = style_loss(style_reference_features, combination_features)
216
loss += (style_weight / len(style_layer_names)) * sl
217
218
# Add total variation loss
219
loss += total_variation_weight * total_variation_loss(combination_image)
220
return loss
221
222
223
"""
224
## Add a tf.function decorator to loss & gradient computation
225
226
To compile it, and thus make it fast.
227
"""
228
229
230
@tf.function
231
def compute_loss_and_grads(combination_image, base_image, style_reference_image):
232
with tf.GradientTape() as tape:
233
loss = compute_loss(combination_image, base_image, style_reference_image)
234
grads = tape.gradient(loss, combination_image)
235
return loss, grads
236
237
238
"""
239
## The training loop
240
241
Repeatedly run vanilla gradient descent steps to minimize the loss, and save the
242
resulting image every 100 iterations.
243
244
We decay the learning rate by 0.96 every 100 steps.
245
"""
246
247
optimizer = keras.optimizers.SGD(
248
keras.optimizers.schedules.ExponentialDecay(
249
initial_learning_rate=100.0, decay_steps=100, decay_rate=0.96
250
)
251
)
252
253
base_image = preprocess_image(base_image_path)
254
style_reference_image = preprocess_image(style_reference_image_path)
255
combination_image = tf.Variable(preprocess_image(base_image_path))
256
257
iterations = 4000
258
for i in range(1, iterations + 1):
259
loss, grads = compute_loss_and_grads(
260
combination_image, base_image, style_reference_image
261
)
262
optimizer.apply_gradients([(grads, combination_image)])
263
if i % 100 == 0:
264
print("Iteration %d: loss=%.2f" % (i, loss))
265
img = deprocess_image(combination_image.numpy())
266
fname = result_prefix + "_at_iteration_%d.png" % i
267
keras.utils.save_img(fname, img)
268
269
"""
270
After 4000 iterations, you get the following result:
271
"""
272
273
display(Image(result_prefix + "_at_iteration_4000.png"))
274
275