Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/generative/random_walks_with_stable_diffusion.py
3507 views
1
"""
2
Title: A walk through latent space with Stable Diffusion
3
Authors: Ian Stenbit, [fchollet](https://twitter.com/fchollet), [lukewood](https://twitter.com/luke_wood_ml)
4
Date created: 2022/09/28
5
Last modified: 2022/09/28
6
Description: Explore the latent manifold of Stable Diffusion.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Overview
12
13
Generative image models learn a "latent manifold" of the visual world:
14
a low-dimensional vector space where each point maps to an image.
15
Going from such a point on the manifold back to a displayable image
16
is called "decoding" -- in the Stable Diffusion model, this is handled by
17
the "decoder" model.
18
19
![The Stable Diffusion architecture](https://i.imgur.com/2uC8rYJ.png)
20
21
This latent manifold of images is continuous and interpolative, meaning that:
22
23
1. Moving a little on the manifold only changes the corresponding image a little (continuity).
24
2. For any two points A and B on the manifold (i.e. any two images), it is possible
25
to move from A to B via a path where each intermediate point is also on the manifold (i.e.
26
is also a valid image). Intermediate points would be called "interpolations" between
27
the two starting images.
28
29
Stable Diffusion isn't just an image model, though, it's also a natural language model.
30
It has two latent spaces: the image representation space learned by the
31
encoder used during training, and the prompt latent space
32
which is learned using a combination of pretraining and training-time
33
fine-tuning.
34
35
_Latent space walking_, or _latent space exploration_, is the process of
36
sampling a point in latent space and incrementally changing the latent
37
representation. Its most common application is generating animations
38
where each sampled point is fed to the decoder and is stored as a
39
frame in the final animation.
40
For high-quality latent representations, this produces coherent-looking
41
animations. These animations can provide insight into the feature map of the
42
latent space, and can ultimately lead to improvements in the training
43
process. One such GIF is displayed below:
44
45
![Panda to Plane](/img/examples/generative/random_walks_with_stable_diffusion/panda2plane.gif)
46
47
In this guide, we will show how to take advantage of the Stable Diffusion API
48
in KerasCV to perform prompt interpolation and circular walks through
49
Stable Diffusion's visual latent manifold, as well as through
50
the text encoder's latent manifold.
51
52
This guide assumes the reader has a
53
high-level understanding of Stable Diffusion.
54
If you haven't already, you should start
55
by reading the [Stable Diffusion Tutorial](https://keras.io/guides/keras_cv/generate_images_with_stable_diffusion/).
56
57
To start, we import KerasCV and load up a Stable Diffusion model using the
58
optimizations discussed in the tutorial
59
[Generate images with Stable Diffusion](https://keras.io/guides/keras_cv/generate_images_with_stable_diffusion/).
60
Note that if you are running with a M1 Mac GPU you should not enable mixed precision.
61
"""
62
63
"""shell
64
pip install keras-cv --upgrade --quiet
65
"""
66
67
import keras_cv
68
import keras
69
import matplotlib.pyplot as plt
70
from keras import ops
71
import numpy as np
72
import math
73
from PIL import Image
74
75
# Enable mixed precision
76
# (only do this if you have a recent NVIDIA GPU)
77
keras.mixed_precision.set_global_policy("mixed_float16")
78
79
# Instantiate the Stable Diffusion model
80
model = keras_cv.models.StableDiffusion(jit_compile=True)
81
82
"""
83
## Interpolating between text prompts
84
85
In Stable Diffusion, a text prompt is first encoded into a vector,
86
and that encoding is used to guide the diffusion process.
87
The latent encoding vector has shape
88
77x768 (that's huge!), and when we give Stable Diffusion a text prompt, we're
89
generating images from just one such point on the latent manifold.
90
91
To explore more of this manifold, we can interpolate between two text encodings
92
and generate images at those interpolated points:
93
"""
94
95
prompt_1 = "A watercolor painting of a Golden Retriever at the beach"
96
prompt_2 = "A still life DSLR photo of a bowl of fruit"
97
interpolation_steps = 5
98
99
encoding_1 = ops.squeeze(model.encode_text(prompt_1))
100
encoding_2 = ops.squeeze(model.encode_text(prompt_2))
101
102
interpolated_encodings = ops.linspace(encoding_1, encoding_2, interpolation_steps)
103
104
# Show the size of the latent manifold
105
print(f"Encoding shape: {encoding_1.shape}")
106
107
"""
108
Once we've interpolated the encodings, we can generate images from each point.
109
Note that in order to maintain some stability between the resulting images we
110
keep the diffusion noise constant between images.
111
"""
112
113
seed = 12345
114
noise = keras.random.normal((512 // 8, 512 // 8, 4), seed=seed)
115
116
images = model.generate_image(
117
interpolated_encodings,
118
batch_size=interpolation_steps,
119
diffusion_noise=noise,
120
)
121
122
"""
123
Now that we've generated some interpolated images, let's take a look at them!
124
125
Throughout this tutorial, we're going to export sequences of images as gifs so
126
that they can be easily viewed with some temporal context. For sequences of
127
images where the first and last images don't match conceptually, we rubber-band
128
the gif.
129
130
If you're running in Colab, you can view your own GIFs by running:
131
132
```
133
from IPython.display import Image as IImage
134
IImage("doggo-and-fruit-5.gif")
135
```
136
"""
137
138
139
def export_as_gif(filename, images, frames_per_second=10, rubber_band=False):
140
if rubber_band:
141
images += images[2:-1][::-1]
142
images[0].save(
143
filename,
144
save_all=True,
145
append_images=images[1:],
146
duration=1000 // frames_per_second,
147
loop=0,
148
)
149
150
151
export_as_gif(
152
"doggo-and-fruit-5.gif",
153
[Image.fromarray(img) for img in images],
154
frames_per_second=2,
155
rubber_band=True,
156
)
157
158
"""
159
![Dog to Fruit 5](https://i.imgur.com/4ZCxZY4.gif)
160
161
The results may seem surprising. Generally, interpolating between prompts
162
produces coherent looking images, and often demonstrates a progressive concept
163
shift between the contents of the two prompts. This is indicative of a high
164
quality representation space, that closely mirrors the natural structure
165
of the visual world.
166
167
To best visualize this, we should do a much more fine-grained interpolation,
168
using hundreds of steps. In order to keep batch size small (so that we don't
169
OOM our GPU), this requires manually batching our interpolated
170
encodings.
171
"""
172
173
interpolation_steps = 150
174
batch_size = 3
175
batches = interpolation_steps // batch_size
176
177
interpolated_encodings = ops.linspace(encoding_1, encoding_2, interpolation_steps)
178
batched_encodings = ops.split(interpolated_encodings, batches)
179
180
images = []
181
for batch in range(batches):
182
images += [
183
Image.fromarray(img)
184
for img in model.generate_image(
185
batched_encodings[batch],
186
batch_size=batch_size,
187
num_steps=25,
188
diffusion_noise=noise,
189
)
190
]
191
192
export_as_gif("doggo-and-fruit-150.gif", images, rubber_band=True)
193
194
"""
195
![Dog to Fruit 150](/img/examples/generative/random_walks_with_stable_diffusion/dog2fruit150.gif)
196
197
The resulting gif shows a much clearer and more coherent shift between the two
198
prompts. Try out some prompts of your own and experiment!
199
200
We can even extend this concept for more than one image. For example, we can
201
interpolate between four prompts:
202
"""
203
204
prompt_1 = "A watercolor painting of a Golden Retriever at the beach"
205
prompt_2 = "A still life DSLR photo of a bowl of fruit"
206
prompt_3 = "The eiffel tower in the style of starry night"
207
prompt_4 = "An architectural sketch of a skyscraper"
208
209
interpolation_steps = 6
210
batch_size = 3
211
batches = (interpolation_steps**2) // batch_size
212
213
encoding_1 = ops.squeeze(model.encode_text(prompt_1))
214
encoding_2 = ops.squeeze(model.encode_text(prompt_2))
215
encoding_3 = ops.squeeze(model.encode_text(prompt_3))
216
encoding_4 = ops.squeeze(model.encode_text(prompt_4))
217
218
interpolated_encodings = ops.linspace(
219
ops.linspace(encoding_1, encoding_2, interpolation_steps),
220
ops.linspace(encoding_3, encoding_4, interpolation_steps),
221
interpolation_steps,
222
)
223
interpolated_encodings = ops.reshape(
224
interpolated_encodings, (interpolation_steps**2, 77, 768)
225
)
226
batched_encodings = ops.split(interpolated_encodings, batches)
227
228
images = []
229
for batch in range(batches):
230
images.append(
231
model.generate_image(
232
batched_encodings[batch],
233
batch_size=batch_size,
234
diffusion_noise=noise,
235
)
236
)
237
238
239
def plot_grid(images, path, grid_size, scale=2):
240
fig, axs = plt.subplots(
241
grid_size, grid_size, figsize=(grid_size * scale, grid_size * scale)
242
)
243
fig.tight_layout()
244
plt.subplots_adjust(wspace=0, hspace=0)
245
plt.axis("off")
246
for ax in axs.flat:
247
ax.axis("off")
248
249
images = images.astype(int)
250
for i in range(min(grid_size * grid_size, len(images))):
251
ax = axs.flat[i]
252
ax.imshow(images[i].astype("uint8"))
253
ax.axis("off")
254
255
for i in range(len(images), grid_size * grid_size):
256
axs.flat[i].axis("off")
257
axs.flat[i].remove()
258
259
plt.savefig(
260
fname=path,
261
pad_inches=0,
262
bbox_inches="tight",
263
transparent=False,
264
dpi=60,
265
)
266
267
268
images = np.concatenate(images)
269
plot_grid(images, "4-way-interpolation.jpg", interpolation_steps)
270
271
"""
272
We can also interpolate while allowing diffusion noise to vary by dropping
273
the `diffusion_noise` parameter:
274
"""
275
276
images = []
277
for batch in range(batches):
278
images.append(model.generate_image(batched_encodings[batch], batch_size=batch_size))
279
280
images = np.concatenate(images)
281
plot_grid(images, "4-way-interpolation-varying-noise.jpg", interpolation_steps)
282
283
"""
284
Next up -- let's go for some walks!
285
286
## A walk around a text prompt
287
288
Our next experiment will be to go for a walk around the latent manifold
289
starting from a point produced by a particular prompt.
290
"""
291
292
walk_steps = 150
293
batch_size = 3
294
batches = walk_steps // batch_size
295
step_size = 0.005
296
297
encoding = ops.squeeze(
298
model.encode_text("The Eiffel Tower in the style of starry night")
299
)
300
# Note that (77, 768) is the shape of the text encoding.
301
delta = ops.ones_like(encoding) * step_size
302
303
walked_encodings = []
304
for step_index in range(walk_steps):
305
walked_encodings.append(encoding)
306
encoding += delta
307
walked_encodings = ops.stack(walked_encodings)
308
batched_encodings = ops.split(walked_encodings, batches)
309
310
images = []
311
for batch in range(batches):
312
images += [
313
Image.fromarray(img)
314
for img in model.generate_image(
315
batched_encodings[batch],
316
batch_size=batch_size,
317
num_steps=25,
318
diffusion_noise=noise,
319
)
320
]
321
322
export_as_gif("eiffel-tower-starry-night.gif", images, rubber_band=True)
323
324
"""
325
![Eiffel tower walk gif](https://i.imgur.com/9MMYtal.gif)
326
327
Perhaps unsurprisingly, walking too far from the encoder's latent manifold
328
produces images that look incoherent. Try it for yourself by setting
329
your own prompt, and adjusting `step_size` to increase or decrease the magnitude
330
of the walk. Note that when the magnitude of the walk gets large, the walk often
331
leads into areas which produce extremely noisy images.
332
333
## A circular walk through the diffusion noise space for a single prompt
334
335
Our final experiment is to stick to one prompt and explore the variety of images
336
that the diffusion model can produce from that prompt. We do this by controlling
337
the noise that is used to seed the diffusion process.
338
339
We create two noise components, `x` and `y`, and do a walk from 0 to 2π, summing
340
the cosine of our `x` component and the sin of our `y` component to produce noise.
341
Using this approach, the end of our walk arrives at the same noise inputs where
342
we began our walk, so we get a "loopable" result!
343
"""
344
345
prompt = "An oil paintings of cows in a field next to a windmill in Holland"
346
encoding = ops.squeeze(model.encode_text(prompt))
347
walk_steps = 150
348
batch_size = 3
349
batches = walk_steps // batch_size
350
351
walk_noise_x = keras.random.normal(noise.shape, dtype="float64")
352
walk_noise_y = keras.random.normal(noise.shape, dtype="float64")
353
354
walk_scale_x = ops.cos(ops.linspace(0, 2, walk_steps) * math.pi)
355
walk_scale_y = ops.sin(ops.linspace(0, 2, walk_steps) * math.pi)
356
noise_x = ops.tensordot(walk_scale_x, walk_noise_x, axes=0)
357
noise_y = ops.tensordot(walk_scale_y, walk_noise_y, axes=0)
358
noise = ops.add(noise_x, noise_y)
359
batched_noise = ops.split(noise, batches)
360
361
images = []
362
for batch in range(batches):
363
images += [
364
Image.fromarray(img)
365
for img in model.generate_image(
366
encoding,
367
batch_size=batch_size,
368
num_steps=25,
369
diffusion_noise=batched_noise[batch],
370
)
371
]
372
373
export_as_gif("cows.gif", images)
374
375
"""
376
![Happy Cows](/img/examples/generative/random_walks_with_stable_diffusion/happycows.gif)
377
378
Experiment with your own prompts and with different values of
379
`unconditional_guidance_scale`!
380
381
## Conclusion
382
383
Stable Diffusion offers a lot more than just single text-to-image generation.
384
Exploring the latent manifold of the text encoder and the noise space of the
385
diffusion model are two fun ways to experience the power of this model, and
386
KerasCV makes it easy!
387
"""
388
389