Path: blob/master/examples/generative/deep_dream.py
3507 views
"""1Title: Deep Dream2Author: [fchollet](https://twitter.com/fchollet)3Date created: 2016/01/134Last modified: 2020/05/025Description: Generating Deep Dreams with Keras.6Accelerator: GPU7"""89"""10## Introduction1112"Deep dream" is an image-filtering technique which consists of taking an image13classification model, and running gradient ascent over an input image to14try to maximize the activations of specific layers (and sometimes, specific units in15specific layers) for this input. It produces hallucination-like visuals.1617It was first introduced by Alexander Mordvintsev from Google in July 2015.1819Process:2021- Load the original image.22- Define a number of processing scales ("octaves"),23from smallest to largest.24- Resize the original image to the smallest scale.25- For every scale, starting with the smallest (i.e. current one):26- Run gradient ascent27- Upscale image to the next scale28- Reinject the detail that was lost at upscaling time29- Stop when we are back to the original size.30To obtain the detail lost during upscaling, we simply31take the original image, shrink it down, upscale it,32and compare the result to the (resized) original image.33"""3435"""36## Setup37"""38import os3940os.environ["KERAS_BACKEND"] = "tensorflow"4142import numpy as np43import tensorflow as tf44import keras45from keras.applications import inception_v34647base_image_path = keras.utils.get_file("sky.jpg", "https://i.imgur.com/aGBdQyK.jpg")48result_prefix = "sky_dream"4950# These are the names of the layers51# for which we try to maximize activation,52# as well as their weight in the final loss53# we try to maximize.54# You can tweak these setting to obtain new visual effects.55layer_settings = {56"mixed4": 1.0,57"mixed5": 1.5,58"mixed6": 2.0,59"mixed7": 2.5,60}6162# Playing with these hyperparameters will also allow you to achieve new effects63step = 0.01 # Gradient ascent step size64num_octave = 3 # Number of scales at which to run gradient ascent65octave_scale = 1.4 # Size ratio between scales66iterations = 20 # Number of ascent steps per scale67max_loss = 15.06869"""70This is our base image:71"""7273from IPython.display import Image, display7475display(Image(base_image_path))7677"""78Let's set up some image preprocessing/deprocessing utilities:79"""808182def preprocess_image(image_path):83# Util function to open, resize and format pictures84# into appropriate arrays.85img = keras.utils.load_img(image_path)86img = keras.utils.img_to_array(img)87img = np.expand_dims(img, axis=0)88img = inception_v3.preprocess_input(img)89return img909192def deprocess_image(x):93# Util function to convert a NumPy array into a valid image.94x = x.reshape((x.shape[1], x.shape[2], 3))95# Undo inception v3 preprocessing96x /= 2.097x += 0.598x *= 255.099# Convert to uint8 and clip to the valid range [0, 255]100x = np.clip(x, 0, 255).astype("uint8")101return x102103104"""105## Compute the Deep Dream loss106107First, build a feature extraction model to retrieve the activations of our target layers108given an input image.109"""110111# Build an InceptionV3 model loaded with pre-trained ImageNet weights112model = inception_v3.InceptionV3(weights="imagenet", include_top=False)113114# Get the symbolic outputs of each "key" layer (we gave them unique names).115outputs_dict = dict(116[117(layer.name, layer.output)118for layer in [model.get_layer(name) for name in layer_settings.keys()]119]120)121122# Set up a model that returns the activation values for every target layer123# (as a dict)124feature_extractor = keras.Model(inputs=model.inputs, outputs=outputs_dict)125126"""127The actual loss computation is very simple:128"""129130131def compute_loss(input_image):132features = feature_extractor(input_image)133# Initialize the loss134loss = tf.zeros(shape=())135for name in features.keys():136coeff = layer_settings[name]137activation = features[name]138# We avoid border artifacts by only involving non-border pixels in the loss.139scaling = tf.reduce_prod(tf.cast(tf.shape(activation), "float32"))140loss += coeff * tf.reduce_sum(tf.square(activation[:, 2:-2, 2:-2, :])) / scaling141return loss142143144"""145## Set up the gradient ascent loop for one octave146"""147148149@tf.function150def gradient_ascent_step(img, learning_rate):151with tf.GradientTape() as tape:152tape.watch(img)153loss = compute_loss(img)154# Compute gradients.155grads = tape.gradient(loss, img)156# Normalize gradients.157grads /= tf.maximum(tf.reduce_mean(tf.abs(grads)), 1e-6)158img += learning_rate * grads159return loss, img160161162def gradient_ascent_loop(img, iterations, learning_rate, max_loss=None):163for i in range(iterations):164loss, img = gradient_ascent_step(img, learning_rate)165if max_loss is not None and loss > max_loss:166break167print("... Loss value at step %d: %.2f" % (i, loss))168return img169170171"""172## Run the training loop, iterating over different octaves173"""174175original_img = preprocess_image(base_image_path)176original_shape = original_img.shape[1:3]177178successive_shapes = [original_shape]179for i in range(1, num_octave):180shape = tuple([int(dim / (octave_scale**i)) for dim in original_shape])181successive_shapes.append(shape)182successive_shapes = successive_shapes[::-1]183shrunk_original_img = tf.image.resize(original_img, successive_shapes[0])184185img = tf.identity(original_img) # Make a copy186for i, shape in enumerate(successive_shapes):187print("Processing octave %d with shape %s" % (i, shape))188img = tf.image.resize(img, shape)189img = gradient_ascent_loop(190img, iterations=iterations, learning_rate=step, max_loss=max_loss191)192upscaled_shrunk_original_img = tf.image.resize(shrunk_original_img, shape)193same_size_original = tf.image.resize(original_img, shape)194lost_detail = same_size_original - upscaled_shrunk_original_img195196img += lost_detail197shrunk_original_img = tf.image.resize(original_img, shape)198199keras.utils.save_img(result_prefix + ".png", deprocess_image(img.numpy()))200201"""202Display the result.203204You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/deep-dream)205and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/deep-dream).206"""207208display(Image(result_prefix + ".png"))209210211