CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!

GitHub Repository: pytorch/tutorials
Path: blob/main/advanced_source/neural_style_tutorial.py
Views: 494
1
"""
2
Neural Transfer Using PyTorch
3
=============================
4
5
6
**Author**: `Alexis Jacq <https://alexis-jacq.github.io>`_
7
8
**Edited by**: `Winston Herring <https://github.com/winston6>`_
9
10
Introduction
11
------------
12
13
This tutorial explains how to implement the `Neural-Style algorithm <https://arxiv.org/abs/1508.06576>`__
14
developed by Leon A. Gatys, Alexander S. Ecker and Matthias Bethge.
15
Neural-Style, or Neural-Transfer, allows you to take an image and
16
reproduce it with a new artistic style. The algorithm takes three images,
17
an input image, a content-image, and a style-image, and changes the input
18
to resemble the content of the content-image and the artistic style of the style-image.
19
20
21
.. figure:: /_static/img/neural-style/neuralstyle.png
22
:alt: content1
23
"""
24
25
######################################################################
26
# Underlying Principle
27
# --------------------
28
#
29
# The principle is simple: we define two distances, one for the content
30
# (:math:`D_C`) and one for the style (:math:`D_S`). :math:`D_C` measures how different the content
31
# is between two images while :math:`D_S` measures how different the style is
32
# between two images. Then, we take a third image, the input, and
33
# transform it to minimize both its content-distance with the
34
# content-image and its style-distance with the style-image. Now we can
35
# import the necessary packages and begin the neural transfer.
36
#
37
# Importing Packages and Selecting a Device
38
# -----------------------------------------
39
# Below is a list of the packages needed to implement the neural transfer.
40
#
41
# - ``torch``, ``torch.nn``, ``numpy`` (indispensables packages for
42
# neural networks with PyTorch)
43
# - ``torch.optim`` (efficient gradient descents)
44
# - ``PIL``, ``PIL.Image``, ``matplotlib.pyplot`` (load and display
45
# images)
46
# - ``torchvision.transforms`` (transform PIL images into tensors)
47
# - ``torchvision.models`` (train or load pretrained models)
48
# - ``copy`` (to deep copy the models; system package)
49
50
import torch
51
import torch.nn as nn
52
import torch.nn.functional as F
53
import torch.optim as optim
54
55
from PIL import Image
56
import matplotlib.pyplot as plt
57
58
import torchvision.transforms as transforms
59
from torchvision.models import vgg19, VGG19_Weights
60
61
import copy
62
63
64
######################################################################
65
# Next, we need to choose which device to run the network on and import the
66
# content and style images. Running the neural transfer algorithm on large
67
# images takes longer and will go much faster when running on a GPU. We can
68
# use ``torch.cuda.is_available()`` to detect if there is a GPU available.
69
# Next, we set the ``torch.device`` for use throughout the tutorial. Also the ``.to(device)``
70
# method is used to move tensors or modules to a desired device.
71
72
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
73
torch.set_default_device(device)
74
75
######################################################################
76
# Loading the Images
77
# ------------------
78
#
79
# Now we will import the style and content images. The original PIL images have values between 0 and 255, but when
80
# transformed into torch tensors, their values are converted to be between
81
# 0 and 1. The images also need to be resized to have the same dimensions.
82
# An important detail to note is that neural networks from the
83
# torch library are trained with tensor values ranging from 0 to 1. If you
84
# try to feed the networks with 0 to 255 tensor images, then the activated
85
# feature maps will be unable to sense the intended content and style.
86
# However, pretrained networks from the Caffe library are trained with 0
87
# to 255 tensor images.
88
#
89
#
90
# .. note::
91
# Here are links to download the images required to run the tutorial:
92
# `picasso.jpg <https://pytorch.org/tutorials/_static/img/neural-style/picasso.jpg>`__ and
93
# `dancing.jpg <https://pytorch.org/tutorials/_static/img/neural-style/dancing.jpg>`__.
94
# Download these two images and add them to a directory
95
# with name ``images`` in your current working directory.
96
97
# desired size of the output image
98
imsize = 512 if torch.cuda.is_available() else 128 # use small size if no GPU
99
100
loader = transforms.Compose([
101
transforms.Resize(imsize), # scale imported image
102
transforms.ToTensor()]) # transform it into a torch tensor
103
104
105
def image_loader(image_name):
106
image = Image.open(image_name)
107
# fake batch dimension required to fit network's input dimensions
108
image = loader(image).unsqueeze(0)
109
return image.to(device, torch.float)
110
111
112
style_img = image_loader("./data/images/neural-style/picasso.jpg")
113
content_img = image_loader("./data/images/neural-style/dancing.jpg")
114
115
assert style_img.size() == content_img.size(), \
116
"we need to import style and content images of the same size"
117
118
119
######################################################################
120
# Now, let's create a function that displays an image by reconverting a
121
# copy of it to PIL format and displaying the copy using
122
# ``plt.imshow``. We will try displaying the content and style images
123
# to ensure they were imported correctly.
124
125
unloader = transforms.ToPILImage() # reconvert into PIL image
126
127
plt.ion()
128
129
def imshow(tensor, title=None):
130
image = tensor.cpu().clone() # we clone the tensor to not do changes on it
131
image = image.squeeze(0) # remove the fake batch dimension
132
image = unloader(image)
133
plt.imshow(image)
134
if title is not None:
135
plt.title(title)
136
plt.pause(0.001) # pause a bit so that plots are updated
137
138
139
plt.figure()
140
imshow(style_img, title='Style Image')
141
142
plt.figure()
143
imshow(content_img, title='Content Image')
144
145
######################################################################
146
# Loss Functions
147
# --------------
148
# Content Loss
149
# ~~~~~~~~~~~~
150
#
151
# The content loss is a function that represents a weighted version of the
152
# content distance for an individual layer. The function takes the feature
153
# maps :math:`F_{XL}` of a layer :math:`L` in a network processing input :math:`X` and returns the
154
# weighted content distance :math:`w_{CL}.D_C^L(X,C)` between the image :math:`X` and the
155
# content image :math:`C`. The feature maps of the content image(:math:`F_{CL}`) must be
156
# known by the function in order to calculate the content distance. We
157
# implement this function as a torch module with a constructor that takes
158
# :math:`F_{CL}` as an input. The distance :math:`\|F_{XL} - F_{CL}\|^2` is the mean square error
159
# between the two sets of feature maps, and can be computed using ``nn.MSELoss``.
160
#
161
# We will add this content loss module directly after the convolution
162
# layer(s) that are being used to compute the content distance. This way
163
# each time the network is fed an input image the content losses will be
164
# computed at the desired layers and because of auto grad, all the
165
# gradients will be computed. Now, in order to make the content loss layer
166
# transparent we must define a ``forward`` method that computes the content
167
# loss and then returns the layer’s input. The computed loss is saved as a
168
# parameter of the module.
169
#
170
171
class ContentLoss(nn.Module):
172
173
def __init__(self, target,):
174
super(ContentLoss, self).__init__()
175
# we 'detach' the target content from the tree used
176
# to dynamically compute the gradient: this is a stated value,
177
# not a variable. Otherwise the forward method of the criterion
178
# will throw an error.
179
self.target = target.detach()
180
181
def forward(self, input):
182
self.loss = F.mse_loss(input, self.target)
183
return input
184
185
######################################################################
186
# .. note::
187
# **Important detail**: although this module is named ``ContentLoss``, it
188
# is not a true PyTorch Loss function. If you want to define your content
189
# loss as a PyTorch Loss function, you have to create a PyTorch autograd function
190
# to recompute/implement the gradient manually in the ``backward``
191
# method.
192
193
######################################################################
194
# Style Loss
195
# ~~~~~~~~~~
196
#
197
# The style loss module is implemented similarly to the content loss
198
# module. It will act as a transparent layer in a
199
# network that computes the style loss of that layer. In order to
200
# calculate the style loss, we need to compute the gram matrix :math:`G_{XL}`. A gram
201
# matrix is the result of multiplying a given matrix by its transposed
202
# matrix. In this application the given matrix is a reshaped version of
203
# the feature maps :math:`F_{XL}` of a layer :math:`L`. :math:`F_{XL}` is reshaped to form :math:`\hat{F}_{XL}`, a :math:`K`\ x\ :math:`N`
204
# matrix, where :math:`K` is the number of feature maps at layer :math:`L` and :math:`N` is the
205
# length of any vectorized feature map :math:`F_{XL}^k`. For example, the first line
206
# of :math:`\hat{F}_{XL}` corresponds to the first vectorized feature map :math:`F_{XL}^1`.
207
#
208
# Finally, the gram matrix must be normalized by dividing each element by
209
# the total number of elements in the matrix. This normalization is to
210
# counteract the fact that :math:`\hat{F}_{XL}` matrices with a large :math:`N` dimension yield
211
# larger values in the Gram matrix. These larger values will cause the
212
# first layers (before pooling layers) to have a larger impact during the
213
# gradient descent. Style features tend to be in the deeper layers of the
214
# network so this normalization step is crucial.
215
#
216
217
def gram_matrix(input):
218
a, b, c, d = input.size() # a=batch size(=1)
219
# b=number of feature maps
220
# (c,d)=dimensions of a f. map (N=c*d)
221
222
features = input.view(a * b, c * d) # resize F_XL into \hat F_XL
223
224
G = torch.mm(features, features.t()) # compute the gram product
225
226
# we 'normalize' the values of the gram matrix
227
# by dividing by the number of element in each feature maps.
228
return G.div(a * b * c * d)
229
230
231
######################################################################
232
# Now the style loss module looks almost exactly like the content loss
233
# module. The style distance is also computed using the mean square
234
# error between :math:`G_{XL}` and :math:`G_{SL}`.
235
#
236
237
class StyleLoss(nn.Module):
238
239
def __init__(self, target_feature):
240
super(StyleLoss, self).__init__()
241
self.target = gram_matrix(target_feature).detach()
242
243
def forward(self, input):
244
G = gram_matrix(input)
245
self.loss = F.mse_loss(G, self.target)
246
return input
247
248
249
######################################################################
250
# Importing the Model
251
# -------------------
252
#
253
# Now we need to import a pretrained neural network. We will use a 19
254
# layer VGG network like the one used in the paper.
255
#
256
# PyTorch’s implementation of VGG is a module divided into two child
257
# ``Sequential`` modules: ``features`` (containing convolution and pooling layers),
258
# and ``classifier`` (containing fully connected layers). We will use the
259
# ``features`` module because we need the output of the individual
260
# convolution layers to measure content and style loss. Some layers have
261
# different behavior during training than evaluation, so we must set the
262
# network to evaluation mode using ``.eval()``.
263
#
264
265
cnn = vgg19(weights=VGG19_Weights.DEFAULT).features.eval()
266
267
268
269
######################################################################
270
# Additionally, VGG networks are trained on images with each channel
271
# normalized by mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225].
272
# We will use them to normalize the image before sending it into the network.
273
#
274
275
cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406])
276
cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225])
277
278
# create a module to normalize input image so we can easily put it in a
279
# ``nn.Sequential``
280
class Normalization(nn.Module):
281
def __init__(self, mean, std):
282
super(Normalization, self).__init__()
283
# .view the mean and std to make them [C x 1 x 1] so that they can
284
# directly work with image Tensor of shape [B x C x H x W].
285
# B is batch size. C is number of channels. H is height and W is width.
286
self.mean = torch.tensor(mean).view(-1, 1, 1)
287
self.std = torch.tensor(std).view(-1, 1, 1)
288
289
def forward(self, img):
290
# normalize ``img``
291
return (img - self.mean) / self.std
292
293
294
######################################################################
295
# A ``Sequential`` module contains an ordered list of child modules. For
296
# instance, ``vgg19.features`` contains a sequence (``Conv2d``, ``ReLU``, ``MaxPool2d``,
297
# ``Conv2d``, ``ReLU``…) aligned in the right order of depth. We need to add our
298
# content loss and style loss layers immediately after the convolution
299
# layer they are detecting. To do this we must create a new ``Sequential``
300
# module that has content loss and style loss modules correctly inserted.
301
#
302
303
# desired depth layers to compute style/content losses :
304
content_layers_default = ['conv_4']
305
style_layers_default = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
306
307
def get_style_model_and_losses(cnn, normalization_mean, normalization_std,
308
style_img, content_img,
309
content_layers=content_layers_default,
310
style_layers=style_layers_default):
311
# normalization module
312
normalization = Normalization(normalization_mean, normalization_std)
313
314
# just in order to have an iterable access to or list of content/style
315
# losses
316
content_losses = []
317
style_losses = []
318
319
# assuming that ``cnn`` is a ``nn.Sequential``, so we make a new ``nn.Sequential``
320
# to put in modules that are supposed to be activated sequentially
321
model = nn.Sequential(normalization)
322
323
i = 0 # increment every time we see a conv
324
for layer in cnn.children():
325
if isinstance(layer, nn.Conv2d):
326
i += 1
327
name = 'conv_{}'.format(i)
328
elif isinstance(layer, nn.ReLU):
329
name = 'relu_{}'.format(i)
330
# The in-place version doesn't play very nicely with the ``ContentLoss``
331
# and ``StyleLoss`` we insert below. So we replace with out-of-place
332
# ones here.
333
layer = nn.ReLU(inplace=False)
334
elif isinstance(layer, nn.MaxPool2d):
335
name = 'pool_{}'.format(i)
336
elif isinstance(layer, nn.BatchNorm2d):
337
name = 'bn_{}'.format(i)
338
else:
339
raise RuntimeError('Unrecognized layer: {}'.format(layer.__class__.__name__))
340
341
model.add_module(name, layer)
342
343
if name in content_layers:
344
# add content loss:
345
target = model(content_img).detach()
346
content_loss = ContentLoss(target)
347
model.add_module("content_loss_{}".format(i), content_loss)
348
content_losses.append(content_loss)
349
350
if name in style_layers:
351
# add style loss:
352
target_feature = model(style_img).detach()
353
style_loss = StyleLoss(target_feature)
354
model.add_module("style_loss_{}".format(i), style_loss)
355
style_losses.append(style_loss)
356
357
# now we trim off the layers after the last content and style losses
358
for i in range(len(model) - 1, -1, -1):
359
if isinstance(model[i], ContentLoss) or isinstance(model[i], StyleLoss):
360
break
361
362
model = model[:(i + 1)]
363
364
return model, style_losses, content_losses
365
366
367
######################################################################
368
# Next, we select the input image. You can use a copy of the content image
369
# or white noise.
370
#
371
372
input_img = content_img.clone()
373
# if you want to use white noise by using the following code:
374
#
375
# .. code-block:: python
376
#
377
# input_img = torch.randn(content_img.data.size())
378
379
# add the original input image to the figure:
380
plt.figure()
381
imshow(input_img, title='Input Image')
382
383
384
######################################################################
385
# Gradient Descent
386
# ----------------
387
#
388
# As Leon Gatys, the author of the algorithm, suggested `here <https://discuss.pytorch.org/t/pytorch-tutorial-for-neural-transfert-of-artistic-style/336/20?u=alexis-jacq>`__, we will use
389
# L-BFGS algorithm to run our gradient descent. Unlike training a network,
390
# we want to train the input image in order to minimize the content/style
391
# losses. We will create a PyTorch L-BFGS optimizer ``optim.LBFGS`` and pass
392
# our image to it as the tensor to optimize.
393
#
394
395
def get_input_optimizer(input_img):
396
# this line to show that input is a parameter that requires a gradient
397
optimizer = optim.LBFGS([input_img])
398
return optimizer
399
400
401
######################################################################
402
# Finally, we must define a function that performs the neural transfer. For
403
# each iteration of the networks, it is fed an updated input and computes
404
# new losses. We will run the ``backward`` methods of each loss module to
405
# dynamically compute their gradients. The optimizer requires a “closure”
406
# function, which reevaluates the module and returns the loss.
407
#
408
# We still have one final constraint to address. The network may try to
409
# optimize the input with values that exceed the 0 to 1 tensor range for
410
# the image. We can address this by correcting the input values to be
411
# between 0 to 1 each time the network is run.
412
#
413
414
def run_style_transfer(cnn, normalization_mean, normalization_std,
415
content_img, style_img, input_img, num_steps=300,
416
style_weight=1000000, content_weight=1):
417
"""Run the style transfer."""
418
print('Building the style transfer model..')
419
model, style_losses, content_losses = get_style_model_and_losses(cnn,
420
normalization_mean, normalization_std, style_img, content_img)
421
422
# We want to optimize the input and not the model parameters so we
423
# update all the requires_grad fields accordingly
424
input_img.requires_grad_(True)
425
# We also put the model in evaluation mode, so that specific layers
426
# such as dropout or batch normalization layers behave correctly.
427
model.eval()
428
model.requires_grad_(False)
429
430
optimizer = get_input_optimizer(input_img)
431
432
print('Optimizing..')
433
run = [0]
434
while run[0] <= num_steps:
435
436
def closure():
437
# correct the values of updated input image
438
with torch.no_grad():
439
input_img.clamp_(0, 1)
440
441
optimizer.zero_grad()
442
model(input_img)
443
style_score = 0
444
content_score = 0
445
446
for sl in style_losses:
447
style_score += sl.loss
448
for cl in content_losses:
449
content_score += cl.loss
450
451
style_score *= style_weight
452
content_score *= content_weight
453
454
loss = style_score + content_score
455
loss.backward()
456
457
run[0] += 1
458
if run[0] % 50 == 0:
459
print("run {}:".format(run))
460
print('Style Loss : {:4f} Content Loss: {:4f}'.format(
461
style_score.item(), content_score.item()))
462
print()
463
464
return style_score + content_score
465
466
optimizer.step(closure)
467
468
# a last correction...
469
with torch.no_grad():
470
input_img.clamp_(0, 1)
471
472
return input_img
473
474
475
######################################################################
476
# Finally, we can run the algorithm.
477
#
478
479
output = run_style_transfer(cnn, cnn_normalization_mean, cnn_normalization_std,
480
content_img, style_img, input_img)
481
482
plt.figure()
483
imshow(output, title='Output Image')
484
485
# sphinx_gallery_thumbnail_number = 4
486
plt.ioff()
487
plt.show()
488
489
490