Path: blob/master/examples/generative/neural_style_transfer.py
3507 views
"""1Title: Neural style transfer2Author: [fchollet](https://twitter.com/fchollet)3Date created: 2016/01/114Last modified: 2020/05/025Description: Transferring the style of a reference image to target image using gradient descent.6Accelerator: GPU7"""89"""10## Introduction1112Style transfer consists in generating an image13with the same "content" as a base image, but with the14"style" of a different picture (typically artistic).15This is achieved through the optimization of a loss function16that has 3 components: "style loss", "content loss",17and "total variation loss":1819- The total variation loss imposes local spatial continuity between20the pixels of the combination image, giving it visual coherence.21- The style loss is where the deep learning keeps in --that one is defined22using a deep convolutional neural network. Precisely, it consists in a sum of23L2 distances between the Gram matrices of the representations of24the base image and the style reference image, extracted from25different layers of a convnet (trained on ImageNet). The general idea26is to capture color/texture information at different spatial27scales (fairly large scales --defined by the depth of the layer considered).28- The content loss is a L2 distance between the features of the base29image (extracted from a deep layer) and the features of the combination image,30keeping the generated image close enough to the original one.3132**Reference:** [A Neural Algorithm of Artistic Style](33http://arxiv.org/abs/1508.06576)34"""3536"""37## Setup38"""39import os4041os.environ["KERAS_BACKEND"] = "tensorflow"4243import keras44import numpy as np45import tensorflow as tf46from keras.applications import vgg194748base_image_path = keras.utils.get_file("paris.jpg", "https://i.imgur.com/F28w3Ac.jpg")49style_reference_image_path = keras.utils.get_file(50"starry_night.jpg", "https://i.imgur.com/9ooB60I.jpg"51)52result_prefix = "paris_generated"5354# Weights of the different loss components55total_variation_weight = 1e-656style_weight = 1e-657content_weight = 2.5e-85859# Dimensions of the generated picture.60width, height = keras.utils.load_img(base_image_path).size61img_nrows = 40062img_ncols = int(width * img_nrows / height)6364"""65## Let's take a look at our base (content) image and our style reference image66"""6768from IPython.display import Image, display6970display(Image(base_image_path))71display(Image(style_reference_image_path))7273"""74## Image preprocessing / deprocessing utilities75"""767778def preprocess_image(image_path):79# Util function to open, resize and format pictures into appropriate tensors80img = keras.utils.load_img(image_path, target_size=(img_nrows, img_ncols))81img = keras.utils.img_to_array(img)82img = np.expand_dims(img, axis=0)83img = vgg19.preprocess_input(img)84return tf.convert_to_tensor(img)858687def deprocess_image(x):88# Util function to convert a tensor into a valid image89x = x.reshape((img_nrows, img_ncols, 3))90# Remove zero-center by mean pixel91x[:, :, 0] += 103.93992x[:, :, 1] += 116.77993x[:, :, 2] += 123.6894# 'BGR'->'RGB'95x = x[:, :, ::-1]96x = np.clip(x, 0, 255).astype("uint8")97return x9899100"""101## Compute the style transfer loss102103First, we need to define 4 utility functions:104105- `gram_matrix` (used to compute the style loss)106- The `style_loss` function, which keeps the generated image close to the local textures107of the style reference image108- The `content_loss` function, which keeps the high-level representation of the109generated image close to that of the base image110- The `total_variation_loss` function, a regularization loss which keeps the generated111image locally-coherent112"""113114# The gram matrix of an image tensor (feature-wise outer product)115116117def gram_matrix(x):118x = tf.transpose(x, (2, 0, 1))119features = tf.reshape(x, (tf.shape(x)[0], -1))120gram = tf.matmul(features, tf.transpose(features))121return gram122123124# The "style loss" is designed to maintain125# the style of the reference image in the generated image.126# It is based on the gram matrices (which capture style) of127# feature maps from the style reference image128# and from the generated image129130131def style_loss(style, combination):132S = gram_matrix(style)133C = gram_matrix(combination)134channels = 3135size = img_nrows * img_ncols136return tf.reduce_sum(tf.square(S - C)) / (4.0 * (channels**2) * (size**2))137138139# An auxiliary loss function140# designed to maintain the "content" of the141# base image in the generated image142143144def content_loss(base, combination):145return tf.reduce_sum(tf.square(combination - base))146147148# The 3rd loss function, total variation loss,149# designed to keep the generated image locally coherent150151152def total_variation_loss(x):153a = tf.square(154x[:, : img_nrows - 1, : img_ncols - 1, :] - x[:, 1:, : img_ncols - 1, :]155)156b = tf.square(157x[:, : img_nrows - 1, : img_ncols - 1, :] - x[:, : img_nrows - 1, 1:, :]158)159return tf.reduce_sum(tf.pow(a + b, 1.25))160161162"""163Next, let's create a feature extraction model that retrieves the intermediate activations164of VGG19 (as a dict, by name).165"""166167# Build a VGG19 model loaded with pre-trained ImageNet weights168model = vgg19.VGG19(weights="imagenet", include_top=False)169170# Get the symbolic outputs of each "key" layer (we gave them unique names).171outputs_dict = dict([(layer.name, layer.output) for layer in model.layers])172173# Set up a model that returns the activation values for every layer in174# VGG19 (as a dict).175feature_extractor = keras.Model(inputs=model.inputs, outputs=outputs_dict)176177"""178Finally, here's the code that computes the style transfer loss.179"""180181# List of layers to use for the style loss.182style_layer_names = [183"block1_conv1",184"block2_conv1",185"block3_conv1",186"block4_conv1",187"block5_conv1",188]189# The layer to use for the content loss.190content_layer_name = "block5_conv2"191192193def compute_loss(combination_image, base_image, style_reference_image):194input_tensor = tf.concat(195[base_image, style_reference_image, combination_image], axis=0196)197features = feature_extractor(input_tensor)198199# Initialize the loss200loss = tf.zeros(shape=())201202# Add content loss203layer_features = features[content_layer_name]204base_image_features = layer_features[0, :, :, :]205combination_features = layer_features[2, :, :, :]206loss = loss + content_weight * content_loss(207base_image_features, combination_features208)209# Add style loss210for layer_name in style_layer_names:211layer_features = features[layer_name]212style_reference_features = layer_features[1, :, :, :]213combination_features = layer_features[2, :, :, :]214sl = style_loss(style_reference_features, combination_features)215loss += (style_weight / len(style_layer_names)) * sl216217# Add total variation loss218loss += total_variation_weight * total_variation_loss(combination_image)219return loss220221222"""223## Add a tf.function decorator to loss & gradient computation224225To compile it, and thus make it fast.226"""227228229@tf.function230def compute_loss_and_grads(combination_image, base_image, style_reference_image):231with tf.GradientTape() as tape:232loss = compute_loss(combination_image, base_image, style_reference_image)233grads = tape.gradient(loss, combination_image)234return loss, grads235236237"""238## The training loop239240Repeatedly run vanilla gradient descent steps to minimize the loss, and save the241resulting image every 100 iterations.242243We decay the learning rate by 0.96 every 100 steps.244"""245246optimizer = keras.optimizers.SGD(247keras.optimizers.schedules.ExponentialDecay(248initial_learning_rate=100.0, decay_steps=100, decay_rate=0.96249)250)251252base_image = preprocess_image(base_image_path)253style_reference_image = preprocess_image(style_reference_image_path)254combination_image = tf.Variable(preprocess_image(base_image_path))255256iterations = 4000257for i in range(1, iterations + 1):258loss, grads = compute_loss_and_grads(259combination_image, base_image, style_reference_image260)261optimizer.apply_gradients([(grads, combination_image)])262if i % 100 == 0:263print("Iteration %d: loss=%.2f" % (i, loss))264img = deprocess_image(combination_image.numpy())265fname = result_prefix + "_at_iteration_%d.png" % i266keras.utils.save_img(fname, img)267268"""269After 4000 iterations, you get the following result:270"""271272display(Image(result_prefix + "_at_iteration_4000.png"))273274275