Path: blob/master/examples/generative/random_walks_with_stable_diffusion_3.py
3507 views
"""1Title: A walk through latent space with Stable Diffusion 32Authors: [Hongyu Chiu](https://github.com/james77777778), Ian Stenbit, [fchollet](https://twitter.com/fchollet), [lukewood](https://twitter.com/luke_wood_ml)3Date created: 2024/11/114Last modified: 2024/11/115Description: Explore the latent manifold of Stable Diffusion 3.6Accelerator: GPU7"""89"""10## Overview1112Generative image models learn a "latent manifold" of the visual world: a13low-dimensional vector space where each point maps to an image. Going from such14a point on the manifold back to a displayable image is called "decoding" -- in15the Stable Diffusion model, this is handled by the "decoder" model.16171819This latent manifold of images is continuous and interpolative, meaning that:20211. Moving a little on the manifold only changes the corresponding image a22little (continuity).232. For any two points A and B on the manifold (i.e. any two images), it is24possible to move from A to B via a path where each intermediate point is also on25the manifold (i.e. is also a valid image). Intermediate points would be called26"interpolations" between the two starting images.2728Stable Diffusion isn't just an image model, though, it's also a natural language29model. It has two latent spaces: the image representation space learned by the30encoder used during training, and the prompt latent space which is learned using31a combination of pretraining and training-time fine-tuning.3233_Latent space walking_, or _latent space exploration_, is the process of34sampling a point in latent space and incrementally changing the latent35representation. Its most common application is generating animations where each36sampled point is fed to the decoder and is stored as a frame in the final37animation.38For high-quality latent representations, this produces coherent-looking39animations. These animations can provide insight into the feature map of the40latent space, and can ultimately lead to improvements in the training process.41One such GIF is displayed below:42434445In this guide, we will show how to take advantage of the TextToImage API in46KerasHub to perform prompt interpolation and circular walks through Stable47Diffusion 3's visual latent manifold, as well as through the text encoder's48latent manifold.4950This guide assumes the reader has a high-level understanding of Stable51Diffusion 3. If you haven't already, you should start by reading the52[Stable Diffusion 3 in KerasHub](53https://keras.io/guides/keras_hub/stable_diffusion_3_in_keras_hub/).5455It is also worth noting that the preset "stable_diffusion_3_medium" excludes the56T5XXL text encoder, as it requires significantly more GPU memory. The performace57degradation is negligible in most cases. The weights, including T5XXL, will be58available on KerasHub soon.59"""6061"""shell62# Use the latest version of KerasHub63!pip install -Uq git+https://github.com/keras-team/keras-hub.git64"""6566import math6768import keras69import keras_hub70import matplotlib.pyplot as plt71from keras import ops72from keras import random73from PIL import Image7475height, width = 512, 51276num_steps = 2877guidance_scale = 7.078dtype = "float16"7980# Instantiate the Stable Diffusion 3 model and the preprocessor81backbone = keras_hub.models.StableDiffusion3Backbone.from_preset(82"stable_diffusion_3_medium", image_shape=(height, width, 3), dtype=dtype83)84preprocessor = keras_hub.models.StableDiffusion3TextToImagePreprocessor.from_preset(85"stable_diffusion_3_medium"86)8788"""89Let's define some helper functions for this example.90"""919293def get_text_embeddings(prompt):94"""Get the text embeddings for a given prompt."""95token_ids = preprocessor.generate_preprocess([prompt])96negative_token_ids = preprocessor.generate_preprocess([""])97(98positive_embeddings,99negative_embeddings,100positive_pooled_embeddings,101negative_pooled_embeddings,102) = backbone.encode_text_step(token_ids, negative_token_ids)103return (104positive_embeddings,105negative_embeddings,106positive_pooled_embeddings,107negative_pooled_embeddings,108)109110111def decode_to_images(x, height, width):112"""Concatenate and normalize the images to uint8 dtype."""113x = ops.concatenate(x, axis=0)114x = ops.reshape(x, (-1, height, width, 3))115x = ops.clip(ops.divide(ops.add(x, 1.0), 2.0), 0.0, 1.0)116return ops.cast(ops.round(ops.multiply(x, 255.0)), "uint8")117118119def generate_with_latents_and_embeddings(120latents, embeddings, num_steps, guidance_scale121):122"""Generate images from latents and text embeddings."""123124def body_fun(step, latents):125return backbone.denoise_step(126latents,127embeddings,128step,129num_steps,130guidance_scale,131)132133latents = ops.fori_loop(0, num_steps, body_fun, latents)134return backbone.decode_step(latents)135136137def export_as_gif(filename, images, frames_per_second=10, no_rubber_band=False):138if not no_rubber_band:139images += images[2:-1][::-1] # Makes a rubber band: A->B->A140images[0].save(141filename,142save_all=True,143append_images=images[1:],144duration=1000 // frames_per_second,145loop=0,146)147148149"""150We are going to generate images using custom latents and embeddings, so we need151to implement the `generate_with_latents_and_embeddings` function. Additionally,152it is important to compile this function to speed up the generation process.153"""154155if keras.config.backend() == "torch":156import torch157158@torch.no_grad()159def wrapped_function(*args, **kwargs):160return generate_with_latents_and_embeddings(*args, **kwargs)161162generate_function = wrapped_function163elif keras.config.backend() == "tensorflow":164import tensorflow as tf165166generate_function = tf.function(167generate_with_latents_and_embeddings, jit_compile=True168)169elif keras.config.backend() == "jax":170import itertools171172import jax173174@jax.jit175def compiled_function(state, *args, **kwargs):176(trainable_variables, non_trainable_variables) = state177mapping = itertools.chain(178zip(backbone.trainable_variables, trainable_variables),179zip(backbone.non_trainable_variables, non_trainable_variables),180)181with keras.StatelessScope(state_mapping=mapping):182return generate_with_latents_and_embeddings(*args, **kwargs)183184def wrapped_function(*args, **kwargs):185state = (186[v.value for v in backbone.trainable_variables],187[v.value for v in backbone.non_trainable_variables],188)189return compiled_function(state, *args, **kwargs)190191generate_function = wrapped_function192193194"""195## Interpolating between text prompts196197In Stable Diffusion 3, a text prompt is encoded into multiple vectors, which are198then used to guide the diffusion process. These latent encoding vectors have199shapes of 154x4096 and 2048 for both the positive and negative prompts - quite200large! When we input a text prompt into Stable Diffusion 3, we generate images201from a single point on this latent manifold.202203To explore more of this manifold, we can interpolate between two text encodings204and generate images at those interpolated points:205"""206207prompt_1 = "A cute dog in a beautiful field of lavander colorful flowers "208prompt_1 += "everywhere, perfect lighting, leica summicron 35mm f2.0, kodak "209prompt_1 += "portra 400, film grain"210prompt_2 = prompt_1.replace("dog", "cat")211interpolation_steps = 5212213encoding_1 = get_text_embeddings(prompt_1)214encoding_2 = get_text_embeddings(prompt_2)215216217# Show the size of the latent manifold218print(f"Positive embeddings shape: {encoding_1[0].shape}")219print(f"Negative embeddings shape: {encoding_1[1].shape}")220print(f"Positive pooled embeddings shape: {encoding_1[2].shape}")221print(f"Negative pooled embeddings shape: {encoding_1[3].shape}")222223224"""225In this example, we want to use Spherical Linear Interpolation (slerp) instead226of simple linear interpolation. Slerp is commonly used in computer graphics to227animate rotations smoothly and can also be applied to interpolate between228high-dimensional data points, such as latent vectors used in generative models.229230The source is from Andrej Karpathy's gist:231[https://gist.github.com/karpathy/00103b0037c5aaea32fe1da1af553355](https://gist.github.com/karpathy/00103b0037c5aaea32fe1da1af553355).232233A more detailed explanation of this method can be found at:234[https://en.wikipedia.org/wiki/Slerp](https://en.wikipedia.org/wiki/Slerp).235"""236237238def slerp(v1, v2, num):239ori_dtype = v1.dtype240# Cast to float32 for numerical stability.241v1 = ops.cast(v1, "float32")242v2 = ops.cast(v2, "float32")243244def interpolation(t, v1, v2, dot_threshold=0.9995):245"""helper function to spherically interpolate two arrays."""246dot = ops.sum(247v1 * v2 / (ops.linalg.norm(ops.ravel(v1)) * ops.linalg.norm(ops.ravel(v2)))248)249if ops.abs(dot) > dot_threshold:250v2 = (1 - t) * v1 + t * v2251else:252theta_0 = ops.arccos(dot)253sin_theta_0 = ops.sin(theta_0)254theta_t = theta_0 * t255sin_theta_t = ops.sin(theta_t)256s0 = ops.sin(theta_0 - theta_t) / sin_theta_0257s1 = sin_theta_t / sin_theta_0258v2 = s0 * v1 + s1 * v2259return v2260261t = ops.linspace(0, 1, num)262interpolated = ops.stack([interpolation(t[i], v1, v2) for i in range(num)], axis=0)263return ops.cast(interpolated, ori_dtype)264265266interpolated_positive_embeddings = slerp(267encoding_1[0], encoding_2[0], interpolation_steps268)269interpolated_positive_pooled_embeddings = slerp(270encoding_1[2], encoding_2[2], interpolation_steps271)272# We don't use negative prompts in this example, so there’s no need to273# interpolate them.274negative_embeddings = encoding_1[1]275negative_pooled_embeddings = encoding_1[3]276277278"""279Once we've interpolated the encodings, we can generate images from each point.280Note that in order to maintain some stability between the resulting images we281keep the diffusion latents constant between images.282"""283284latents = random.normal((1, height // 8, width // 8, 16), seed=42)285286images = []287progbar = keras.utils.Progbar(interpolation_steps)288for i in range(interpolation_steps):289images.append(290generate_function(291latents,292(293interpolated_positive_embeddings[i],294negative_embeddings,295interpolated_positive_pooled_embeddings[i],296negative_pooled_embeddings,297),298ops.convert_to_tensor(num_steps),299ops.convert_to_tensor(guidance_scale),300)301)302progbar.update(i + 1, finalize=i == interpolation_steps - 1)303304"""305Now that we've generated some interpolated images, let's take a look at them!306307Throughout this tutorial, we're going to export sequences of images as gifs so308that they can be easily viewed with some temporal context. For sequences of309images where the first and last images don't match conceptually, we rubber-band310the gif.311312If you're running in Colab, you can view your own GIFs by running:313314```315from IPython.display import Image as IImage316IImage("dog_to_cat_5.gif")317```318"""319320images = ops.convert_to_numpy(decode_to_images(images, height, width))321export_as_gif(322"dog_to_cat_5.gif",323[Image.fromarray(image) for image in images],324frames_per_second=2,325)326327"""328The results may seem surprising. Generally, interpolating between prompts329produces coherent looking images, and often demonstrates a progressive concept330shift between the contents of the two prompts. This is indicative of a high331quality representation space, that closely mirrors the natural structure of the332visual world.333334To best visualize this, we should do a much more fine-grained interpolation,335using more steps.336"""337338interpolation_steps = 64339batch_size = 4340batches = interpolation_steps // batch_size341342interpolated_positive_embeddings = slerp(343encoding_1[0], encoding_2[0], interpolation_steps344)345interpolated_positive_pooled_embeddings = slerp(346encoding_1[2], encoding_2[2], interpolation_steps347)348positive_embeddings_shape = ops.shape(encoding_1[0])349positive_pooled_embeddings_shape = ops.shape(encoding_1[2])350interpolated_positive_embeddings = ops.reshape(351interpolated_positive_embeddings,352(353batches,354batch_size,355positive_embeddings_shape[-2],356positive_embeddings_shape[-1],357),358)359interpolated_positive_pooled_embeddings = ops.reshape(360interpolated_positive_pooled_embeddings,361(batches, batch_size, positive_pooled_embeddings_shape[-1]),362)363negative_embeddings = ops.tile(encoding_1[1], (batch_size, 1, 1))364negative_pooled_embeddings = ops.tile(encoding_1[3], (batch_size, 1))365366latents = random.normal((1, height // 8, width // 8, 16), seed=42)367latents = ops.tile(latents, (batch_size, 1, 1, 1))368369images = []370progbar = keras.utils.Progbar(batches)371for i in range(batches):372images.append(373generate_function(374latents,375(376interpolated_positive_embeddings[i],377negative_embeddings,378interpolated_positive_pooled_embeddings[i],379negative_pooled_embeddings,380),381ops.convert_to_tensor(num_steps),382ops.convert_to_tensor(guidance_scale),383)384)385progbar.update(i + 1, finalize=i == batches - 1)386387images = ops.convert_to_numpy(decode_to_images(images, height, width))388export_as_gif(389"dog_to_cat_64.gif",390[Image.fromarray(image) for image in images],391frames_per_second=2,392)393394"""395The resulting gif shows a much clearer and more coherent shift between the two396prompts. Try out some prompts of your own and experiment!397398We can even extend this concept for more than one image. For example, we can399interpolate between four prompts:400"""401402prompt_1 = "A watercolor painting of a Golden Retriever at the beach"403prompt_2 = "A still life DSLR photo of a bowl of fruit"404prompt_3 = "The eiffel tower in the style of starry night"405prompt_4 = "An architectural sketch of a skyscraper"406407interpolation_steps = 8408batch_size = 4409batches = (interpolation_steps**2) // batch_size410411encoding_1 = get_text_embeddings(prompt_1)412encoding_2 = get_text_embeddings(prompt_2)413encoding_3 = get_text_embeddings(prompt_3)414encoding_4 = get_text_embeddings(prompt_4)415416positive_embeddings_shape = ops.shape(encoding_1[0])417positive_pooled_embeddings_shape = ops.shape(encoding_1[2])418interpolated_positive_embeddings_12 = slerp(419encoding_1[0], encoding_2[0], interpolation_steps420)421interpolated_positive_embeddings_34 = slerp(422encoding_3[0], encoding_4[0], interpolation_steps423)424interpolated_positive_embeddings = slerp(425interpolated_positive_embeddings_12,426interpolated_positive_embeddings_34,427interpolation_steps,428)429interpolated_positive_embeddings = ops.reshape(430interpolated_positive_embeddings,431(432batches,433batch_size,434positive_embeddings_shape[-2],435positive_embeddings_shape[-1],436),437)438interpolated_positive_pooled_embeddings_12 = slerp(439encoding_1[2], encoding_2[2], interpolation_steps440)441interpolated_positive_pooled_embeddings_34 = slerp(442encoding_3[2], encoding_4[2], interpolation_steps443)444interpolated_positive_pooled_embeddings = slerp(445interpolated_positive_pooled_embeddings_12,446interpolated_positive_pooled_embeddings_34,447interpolation_steps,448)449interpolated_positive_pooled_embeddings = ops.reshape(450interpolated_positive_pooled_embeddings,451(batches, batch_size, positive_pooled_embeddings_shape[-1]),452)453negative_embeddings = ops.tile(encoding_1[1], (batch_size, 1, 1))454negative_pooled_embeddings = ops.tile(encoding_1[3], (batch_size, 1))455456latents = random.normal((1, height // 8, width // 8, 16), seed=42)457latents = ops.tile(latents, (batch_size, 1, 1, 1))458459images = []460progbar = keras.utils.Progbar(batches)461for i in range(batches):462images.append(463generate_function(464latents,465(466interpolated_positive_embeddings[i],467negative_embeddings,468interpolated_positive_pooled_embeddings[i],469negative_pooled_embeddings,470),471ops.convert_to_tensor(num_steps),472ops.convert_to_tensor(guidance_scale),473)474)475progbar.update(i + 1, finalize=i == batches - 1)476477478"""479Let's display the resulting images in a grid to make them easier to interpret.480"""481482483def plot_grid(images, path, grid_size, scale=2):484fig, axs = plt.subplots(485grid_size, grid_size, figsize=(grid_size * scale, grid_size * scale)486)487fig.tight_layout()488plt.subplots_adjust(wspace=0, hspace=0)489plt.axis("off")490for ax in axs.flat:491ax.axis("off")492493for i in range(min(grid_size * grid_size, len(images))):494ax = axs.flat[i]495ax.imshow(images[i])496ax.axis("off")497498for i in range(len(images), grid_size * grid_size):499axs.flat[i].axis("off")500axs.flat[i].remove()501502plt.savefig(503fname=path,504pad_inches=0,505bbox_inches="tight",506transparent=False,507dpi=60,508)509510511images = ops.convert_to_numpy(decode_to_images(images, height, width))512plot_grid(images, "4-way-interpolation.jpg", interpolation_steps)513514"""515We can also interpolate while allowing diffusion latents to vary by dropping516the `seed` parameter:517"""518519images = []520progbar = keras.utils.Progbar(batches)521for i in range(batches):522# Vary diffusion latents for each input.523latents = random.normal((batch_size, height // 8, width // 8, 16))524images.append(525generate_function(526latents,527(528interpolated_positive_embeddings[i],529negative_embeddings,530interpolated_positive_pooled_embeddings[i],531negative_pooled_embeddings,532),533ops.convert_to_tensor(num_steps),534ops.convert_to_tensor(guidance_scale),535)536)537progbar.update(i + 1, finalize=i == batches - 1)538539images = ops.convert_to_numpy(decode_to_images(images, height, width))540plot_grid(images, "4-way-interpolation-varying-latent.jpg", interpolation_steps)541542"""543Next up -- let's go for some walks!544545## A walk around a text prompt546547Our next experiment will be to go for a walk around the latent manifold548starting from a point produced by a particular prompt.549"""550551walk_steps = 64552batch_size = 4553batches = walk_steps // batch_size554step_size = 0.01555prompt = "The eiffel tower in the style of starry night"556encoding = get_text_embeddings(prompt)557558positive_embeddings = encoding[0]559positive_pooled_embeddings = encoding[2]560negative_embeddings = encoding[1]561negative_pooled_embeddings = encoding[3]562563# The shape of `positive_embeddings`: (1, 154, 4096)564# The shape of `positive_pooled_embeddings`: (1, 2048)565positive_embeddings_delta = ops.ones_like(positive_embeddings) * step_size566positive_pooled_embeddings_delta = ops.ones_like(positive_pooled_embeddings) * step_size567positive_embeddings_shape = ops.shape(positive_embeddings)568positive_pooled_embeddings_shape = ops.shape(positive_pooled_embeddings)569570walked_positive_embeddings = []571walked_positive_pooled_embeddings = []572for step_index in range(walk_steps):573walked_positive_embeddings.append(positive_embeddings)574walked_positive_pooled_embeddings.append(positive_pooled_embeddings)575positive_embeddings += positive_embeddings_delta576positive_pooled_embeddings += positive_pooled_embeddings_delta577walked_positive_embeddings = ops.stack(walked_positive_embeddings, axis=0)578walked_positive_pooled_embeddings = ops.stack(walked_positive_pooled_embeddings, axis=0)579walked_positive_embeddings = ops.reshape(580walked_positive_embeddings,581(582batches,583batch_size,584positive_embeddings_shape[-2],585positive_embeddings_shape[-1],586),587)588walked_positive_pooled_embeddings = ops.reshape(589walked_positive_pooled_embeddings,590(batches, batch_size, positive_pooled_embeddings_shape[-1]),591)592negative_embeddings = ops.tile(encoding_1[1], (batch_size, 1, 1))593negative_pooled_embeddings = ops.tile(encoding_1[3], (batch_size, 1))594595latents = random.normal((1, height // 8, width // 8, 16), seed=42)596latents = ops.tile(latents, (batch_size, 1, 1, 1))597598images = []599progbar = keras.utils.Progbar(batches)600for i in range(batches):601images.append(602generate_function(603latents,604(605walked_positive_embeddings[i],606negative_embeddings,607walked_positive_pooled_embeddings[i],608negative_pooled_embeddings,609),610ops.convert_to_tensor(num_steps),611ops.convert_to_tensor(guidance_scale),612)613)614progbar.update(i + 1, finalize=i == batches - 1)615616images = ops.convert_to_numpy(decode_to_images(images, height, width))617export_as_gif(618"eiffel-tower-starry-night.gif",619[Image.fromarray(image) for image in images],620frames_per_second=2,621)622623"""624Perhaps unsurprisingly, walking too far from the encoder's latent manifold625produces images that look incoherent. Try it for yourself by setting your own626prompt, and adjusting `step_size` to increase or decrease the magnitude627of the walk. Note that when the magnitude of the walk gets large, the walk often628leads into areas which produce extremely noisy images.629630## A circular walk through the diffusion latent space for a single prompt631632Our final experiment is to stick to one prompt and explore the variety of images633that the diffusion model can produce from that prompt. We do this by controlling634the noise that is used to seed the diffusion process.635636We create two noise components, `x` and `y`, and do a walk from 0 to 2π, summing637the cosine of our `x` component and the sin of our `y` component to produce638noise. Using this approach, the end of our walk arrives at the same noise inputs639where we began our walk, so we get a "loopable" result!640"""641642walk_steps = 64643batch_size = 4644batches = walk_steps // batch_size645prompt = "An oil paintings of cows in a field next to a windmill in Holland"646encoding = get_text_embeddings(prompt)647648walk_latent_x = random.normal((1, height // 8, width // 8, 16))649walk_latent_y = random.normal((1, height // 8, width // 8, 16))650walk_scale_x = ops.cos(ops.linspace(0.0, 2.0, walk_steps) * math.pi)651walk_scale_y = ops.sin(ops.linspace(0.0, 2.0, walk_steps) * math.pi)652latent_x = ops.tensordot(walk_scale_x, walk_latent_x, axes=0)653latent_y = ops.tensordot(walk_scale_y, walk_latent_y, axes=0)654latents = ops.add(latent_x, latent_y)655latents = ops.reshape(latents, (batches, batch_size, height // 8, width // 8, 16))656657images = []658progbar = keras.utils.Progbar(batches)659for i in range(batches):660images.append(661generate_function(662latents[i],663(664ops.tile(encoding[0], (batch_size, 1, 1)),665ops.tile(encoding[1], (batch_size, 1, 1)),666ops.tile(encoding[2], (batch_size, 1)),667ops.tile(encoding[3], (batch_size, 1)),668),669ops.convert_to_tensor(num_steps),670ops.convert_to_tensor(guidance_scale),671)672)673progbar.update(i + 1, finalize=i == batches - 1)674675images = ops.convert_to_numpy(decode_to_images(images, height, width))676export_as_gif(677"cows.gif",678[Image.fromarray(image) for image in images],679frames_per_second=4,680no_rubber_band=True,681)682683"""684Experiment with your own prompts and with different values of the parameters!685686## Conclusion687688Stable Diffusion 3 offers a lot more than just single text-to-image generation.689Exploring the latent manifold of the text encoder and the latent space of the690diffusion model are two fun ways to experience the power of this model, and691KerasHub makes it easy!692"""693694695