Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/generative/random_walks_with_stable_diffusion_3.py
3507 views
1
"""
2
Title: A walk through latent space with Stable Diffusion 3
3
Authors: [Hongyu Chiu](https://github.com/james77777778), Ian Stenbit, [fchollet](https://twitter.com/fchollet), [lukewood](https://twitter.com/luke_wood_ml)
4
Date created: 2024/11/11
5
Last modified: 2024/11/11
6
Description: Explore the latent manifold of Stable Diffusion 3.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Overview
12
13
Generative image models learn a "latent manifold" of the visual world: a
14
low-dimensional vector space where each point maps to an image. Going from such
15
a point on the manifold back to a displayable image is called "decoding" -- in
16
the Stable Diffusion model, this is handled by the "decoder" model.
17
18
![Stable Diffusion 3 Medium Architecture](/img/examples/generative/random_walks_with_stable_diffusion_3/mmdit.png)
19
20
This latent manifold of images is continuous and interpolative, meaning that:
21
22
1. Moving a little on the manifold only changes the corresponding image a
23
little (continuity).
24
2. For any two points A and B on the manifold (i.e. any two images), it is
25
possible to move from A to B via a path where each intermediate point is also on
26
the manifold (i.e. is also a valid image). Intermediate points would be called
27
"interpolations" between the two starting images.
28
29
Stable Diffusion isn't just an image model, though, it's also a natural language
30
model. It has two latent spaces: the image representation space learned by the
31
encoder used during training, and the prompt latent space which is learned using
32
a combination of pretraining and training-time fine-tuning.
33
34
_Latent space walking_, or _latent space exploration_, is the process of
35
sampling a point in latent space and incrementally changing the latent
36
representation. Its most common application is generating animations where each
37
sampled point is fed to the decoder and is stored as a frame in the final
38
animation.
39
For high-quality latent representations, this produces coherent-looking
40
animations. These animations can provide insight into the feature map of the
41
latent space, and can ultimately lead to improvements in the training process.
42
One such GIF is displayed below:
43
44
![dog_to_cat_64.gif](/img/examples/generative/random_walks_with_stable_diffusion_3/dog_to_cat_64.gif)
45
46
In this guide, we will show how to take advantage of the TextToImage API in
47
KerasHub to perform prompt interpolation and circular walks through Stable
48
Diffusion 3's visual latent manifold, as well as through the text encoder's
49
latent manifold.
50
51
This guide assumes the reader has a high-level understanding of Stable
52
Diffusion 3. If you haven't already, you should start by reading the
53
[Stable Diffusion 3 in KerasHub](
54
https://keras.io/guides/keras_hub/stable_diffusion_3_in_keras_hub/).
55
56
It is also worth noting that the preset "stable_diffusion_3_medium" excludes the
57
T5XXL text encoder, as it requires significantly more GPU memory. The performace
58
degradation is negligible in most cases. The weights, including T5XXL, will be
59
available on KerasHub soon.
60
"""
61
62
"""shell
63
# Use the latest version of KerasHub
64
!pip install -Uq git+https://github.com/keras-team/keras-hub.git
65
"""
66
67
import math
68
69
import keras
70
import keras_hub
71
import matplotlib.pyplot as plt
72
from keras import ops
73
from keras import random
74
from PIL import Image
75
76
height, width = 512, 512
77
num_steps = 28
78
guidance_scale = 7.0
79
dtype = "float16"
80
81
# Instantiate the Stable Diffusion 3 model and the preprocessor
82
backbone = keras_hub.models.StableDiffusion3Backbone.from_preset(
83
"stable_diffusion_3_medium", image_shape=(height, width, 3), dtype=dtype
84
)
85
preprocessor = keras_hub.models.StableDiffusion3TextToImagePreprocessor.from_preset(
86
"stable_diffusion_3_medium"
87
)
88
89
"""
90
Let's define some helper functions for this example.
91
"""
92
93
94
def get_text_embeddings(prompt):
95
"""Get the text embeddings for a given prompt."""
96
token_ids = preprocessor.generate_preprocess([prompt])
97
negative_token_ids = preprocessor.generate_preprocess([""])
98
(
99
positive_embeddings,
100
negative_embeddings,
101
positive_pooled_embeddings,
102
negative_pooled_embeddings,
103
) = backbone.encode_text_step(token_ids, negative_token_ids)
104
return (
105
positive_embeddings,
106
negative_embeddings,
107
positive_pooled_embeddings,
108
negative_pooled_embeddings,
109
)
110
111
112
def decode_to_images(x, height, width):
113
"""Concatenate and normalize the images to uint8 dtype."""
114
x = ops.concatenate(x, axis=0)
115
x = ops.reshape(x, (-1, height, width, 3))
116
x = ops.clip(ops.divide(ops.add(x, 1.0), 2.0), 0.0, 1.0)
117
return ops.cast(ops.round(ops.multiply(x, 255.0)), "uint8")
118
119
120
def generate_with_latents_and_embeddings(
121
latents, embeddings, num_steps, guidance_scale
122
):
123
"""Generate images from latents and text embeddings."""
124
125
def body_fun(step, latents):
126
return backbone.denoise_step(
127
latents,
128
embeddings,
129
step,
130
num_steps,
131
guidance_scale,
132
)
133
134
latents = ops.fori_loop(0, num_steps, body_fun, latents)
135
return backbone.decode_step(latents)
136
137
138
def export_as_gif(filename, images, frames_per_second=10, no_rubber_band=False):
139
if not no_rubber_band:
140
images += images[2:-1][::-1] # Makes a rubber band: A->B->A
141
images[0].save(
142
filename,
143
save_all=True,
144
append_images=images[1:],
145
duration=1000 // frames_per_second,
146
loop=0,
147
)
148
149
150
"""
151
We are going to generate images using custom latents and embeddings, so we need
152
to implement the `generate_with_latents_and_embeddings` function. Additionally,
153
it is important to compile this function to speed up the generation process.
154
"""
155
156
if keras.config.backend() == "torch":
157
import torch
158
159
@torch.no_grad()
160
def wrapped_function(*args, **kwargs):
161
return generate_with_latents_and_embeddings(*args, **kwargs)
162
163
generate_function = wrapped_function
164
elif keras.config.backend() == "tensorflow":
165
import tensorflow as tf
166
167
generate_function = tf.function(
168
generate_with_latents_and_embeddings, jit_compile=True
169
)
170
elif keras.config.backend() == "jax":
171
import itertools
172
173
import jax
174
175
@jax.jit
176
def compiled_function(state, *args, **kwargs):
177
(trainable_variables, non_trainable_variables) = state
178
mapping = itertools.chain(
179
zip(backbone.trainable_variables, trainable_variables),
180
zip(backbone.non_trainable_variables, non_trainable_variables),
181
)
182
with keras.StatelessScope(state_mapping=mapping):
183
return generate_with_latents_and_embeddings(*args, **kwargs)
184
185
def wrapped_function(*args, **kwargs):
186
state = (
187
[v.value for v in backbone.trainable_variables],
188
[v.value for v in backbone.non_trainable_variables],
189
)
190
return compiled_function(state, *args, **kwargs)
191
192
generate_function = wrapped_function
193
194
195
"""
196
## Interpolating between text prompts
197
198
In Stable Diffusion 3, a text prompt is encoded into multiple vectors, which are
199
then used to guide the diffusion process. These latent encoding vectors have
200
shapes of 154x4096 and 2048 for both the positive and negative prompts - quite
201
large! When we input a text prompt into Stable Diffusion 3, we generate images
202
from a single point on this latent manifold.
203
204
To explore more of this manifold, we can interpolate between two text encodings
205
and generate images at those interpolated points:
206
"""
207
208
prompt_1 = "A cute dog in a beautiful field of lavander colorful flowers "
209
prompt_1 += "everywhere, perfect lighting, leica summicron 35mm f2.0, kodak "
210
prompt_1 += "portra 400, film grain"
211
prompt_2 = prompt_1.replace("dog", "cat")
212
interpolation_steps = 5
213
214
encoding_1 = get_text_embeddings(prompt_1)
215
encoding_2 = get_text_embeddings(prompt_2)
216
217
218
# Show the size of the latent manifold
219
print(f"Positive embeddings shape: {encoding_1[0].shape}")
220
print(f"Negative embeddings shape: {encoding_1[1].shape}")
221
print(f"Positive pooled embeddings shape: {encoding_1[2].shape}")
222
print(f"Negative pooled embeddings shape: {encoding_1[3].shape}")
223
224
225
"""
226
In this example, we want to use Spherical Linear Interpolation (slerp) instead
227
of simple linear interpolation. Slerp is commonly used in computer graphics to
228
animate rotations smoothly and can also be applied to interpolate between
229
high-dimensional data points, such as latent vectors used in generative models.
230
231
The source is from Andrej Karpathy's gist:
232
[https://gist.github.com/karpathy/00103b0037c5aaea32fe1da1af553355](https://gist.github.com/karpathy/00103b0037c5aaea32fe1da1af553355).
233
234
A more detailed explanation of this method can be found at:
235
[https://en.wikipedia.org/wiki/Slerp](https://en.wikipedia.org/wiki/Slerp).
236
"""
237
238
239
def slerp(v1, v2, num):
240
ori_dtype = v1.dtype
241
# Cast to float32 for numerical stability.
242
v1 = ops.cast(v1, "float32")
243
v2 = ops.cast(v2, "float32")
244
245
def interpolation(t, v1, v2, dot_threshold=0.9995):
246
"""helper function to spherically interpolate two arrays."""
247
dot = ops.sum(
248
v1 * v2 / (ops.linalg.norm(ops.ravel(v1)) * ops.linalg.norm(ops.ravel(v2)))
249
)
250
if ops.abs(dot) > dot_threshold:
251
v2 = (1 - t) * v1 + t * v2
252
else:
253
theta_0 = ops.arccos(dot)
254
sin_theta_0 = ops.sin(theta_0)
255
theta_t = theta_0 * t
256
sin_theta_t = ops.sin(theta_t)
257
s0 = ops.sin(theta_0 - theta_t) / sin_theta_0
258
s1 = sin_theta_t / sin_theta_0
259
v2 = s0 * v1 + s1 * v2
260
return v2
261
262
t = ops.linspace(0, 1, num)
263
interpolated = ops.stack([interpolation(t[i], v1, v2) for i in range(num)], axis=0)
264
return ops.cast(interpolated, ori_dtype)
265
266
267
interpolated_positive_embeddings = slerp(
268
encoding_1[0], encoding_2[0], interpolation_steps
269
)
270
interpolated_positive_pooled_embeddings = slerp(
271
encoding_1[2], encoding_2[2], interpolation_steps
272
)
273
# We don't use negative prompts in this example, so there’s no need to
274
# interpolate them.
275
negative_embeddings = encoding_1[1]
276
negative_pooled_embeddings = encoding_1[3]
277
278
279
"""
280
Once we've interpolated the encodings, we can generate images from each point.
281
Note that in order to maintain some stability between the resulting images we
282
keep the diffusion latents constant between images.
283
"""
284
285
latents = random.normal((1, height // 8, width // 8, 16), seed=42)
286
287
images = []
288
progbar = keras.utils.Progbar(interpolation_steps)
289
for i in range(interpolation_steps):
290
images.append(
291
generate_function(
292
latents,
293
(
294
interpolated_positive_embeddings[i],
295
negative_embeddings,
296
interpolated_positive_pooled_embeddings[i],
297
negative_pooled_embeddings,
298
),
299
ops.convert_to_tensor(num_steps),
300
ops.convert_to_tensor(guidance_scale),
301
)
302
)
303
progbar.update(i + 1, finalize=i == interpolation_steps - 1)
304
305
"""
306
Now that we've generated some interpolated images, let's take a look at them!
307
308
Throughout this tutorial, we're going to export sequences of images as gifs so
309
that they can be easily viewed with some temporal context. For sequences of
310
images where the first and last images don't match conceptually, we rubber-band
311
the gif.
312
313
If you're running in Colab, you can view your own GIFs by running:
314
315
```
316
from IPython.display import Image as IImage
317
IImage("dog_to_cat_5.gif")
318
```
319
"""
320
321
images = ops.convert_to_numpy(decode_to_images(images, height, width))
322
export_as_gif(
323
"dog_to_cat_5.gif",
324
[Image.fromarray(image) for image in images],
325
frames_per_second=2,
326
)
327
328
"""
329
The results may seem surprising. Generally, interpolating between prompts
330
produces coherent looking images, and often demonstrates a progressive concept
331
shift between the contents of the two prompts. This is indicative of a high
332
quality representation space, that closely mirrors the natural structure of the
333
visual world.
334
335
To best visualize this, we should do a much more fine-grained interpolation,
336
using more steps.
337
"""
338
339
interpolation_steps = 64
340
batch_size = 4
341
batches = interpolation_steps // batch_size
342
343
interpolated_positive_embeddings = slerp(
344
encoding_1[0], encoding_2[0], interpolation_steps
345
)
346
interpolated_positive_pooled_embeddings = slerp(
347
encoding_1[2], encoding_2[2], interpolation_steps
348
)
349
positive_embeddings_shape = ops.shape(encoding_1[0])
350
positive_pooled_embeddings_shape = ops.shape(encoding_1[2])
351
interpolated_positive_embeddings = ops.reshape(
352
interpolated_positive_embeddings,
353
(
354
batches,
355
batch_size,
356
positive_embeddings_shape[-2],
357
positive_embeddings_shape[-1],
358
),
359
)
360
interpolated_positive_pooled_embeddings = ops.reshape(
361
interpolated_positive_pooled_embeddings,
362
(batches, batch_size, positive_pooled_embeddings_shape[-1]),
363
)
364
negative_embeddings = ops.tile(encoding_1[1], (batch_size, 1, 1))
365
negative_pooled_embeddings = ops.tile(encoding_1[3], (batch_size, 1))
366
367
latents = random.normal((1, height // 8, width // 8, 16), seed=42)
368
latents = ops.tile(latents, (batch_size, 1, 1, 1))
369
370
images = []
371
progbar = keras.utils.Progbar(batches)
372
for i in range(batches):
373
images.append(
374
generate_function(
375
latents,
376
(
377
interpolated_positive_embeddings[i],
378
negative_embeddings,
379
interpolated_positive_pooled_embeddings[i],
380
negative_pooled_embeddings,
381
),
382
ops.convert_to_tensor(num_steps),
383
ops.convert_to_tensor(guidance_scale),
384
)
385
)
386
progbar.update(i + 1, finalize=i == batches - 1)
387
388
images = ops.convert_to_numpy(decode_to_images(images, height, width))
389
export_as_gif(
390
"dog_to_cat_64.gif",
391
[Image.fromarray(image) for image in images],
392
frames_per_second=2,
393
)
394
395
"""
396
The resulting gif shows a much clearer and more coherent shift between the two
397
prompts. Try out some prompts of your own and experiment!
398
399
We can even extend this concept for more than one image. For example, we can
400
interpolate between four prompts:
401
"""
402
403
prompt_1 = "A watercolor painting of a Golden Retriever at the beach"
404
prompt_2 = "A still life DSLR photo of a bowl of fruit"
405
prompt_3 = "The eiffel tower in the style of starry night"
406
prompt_4 = "An architectural sketch of a skyscraper"
407
408
interpolation_steps = 8
409
batch_size = 4
410
batches = (interpolation_steps**2) // batch_size
411
412
encoding_1 = get_text_embeddings(prompt_1)
413
encoding_2 = get_text_embeddings(prompt_2)
414
encoding_3 = get_text_embeddings(prompt_3)
415
encoding_4 = get_text_embeddings(prompt_4)
416
417
positive_embeddings_shape = ops.shape(encoding_1[0])
418
positive_pooled_embeddings_shape = ops.shape(encoding_1[2])
419
interpolated_positive_embeddings_12 = slerp(
420
encoding_1[0], encoding_2[0], interpolation_steps
421
)
422
interpolated_positive_embeddings_34 = slerp(
423
encoding_3[0], encoding_4[0], interpolation_steps
424
)
425
interpolated_positive_embeddings = slerp(
426
interpolated_positive_embeddings_12,
427
interpolated_positive_embeddings_34,
428
interpolation_steps,
429
)
430
interpolated_positive_embeddings = ops.reshape(
431
interpolated_positive_embeddings,
432
(
433
batches,
434
batch_size,
435
positive_embeddings_shape[-2],
436
positive_embeddings_shape[-1],
437
),
438
)
439
interpolated_positive_pooled_embeddings_12 = slerp(
440
encoding_1[2], encoding_2[2], interpolation_steps
441
)
442
interpolated_positive_pooled_embeddings_34 = slerp(
443
encoding_3[2], encoding_4[2], interpolation_steps
444
)
445
interpolated_positive_pooled_embeddings = slerp(
446
interpolated_positive_pooled_embeddings_12,
447
interpolated_positive_pooled_embeddings_34,
448
interpolation_steps,
449
)
450
interpolated_positive_pooled_embeddings = ops.reshape(
451
interpolated_positive_pooled_embeddings,
452
(batches, batch_size, positive_pooled_embeddings_shape[-1]),
453
)
454
negative_embeddings = ops.tile(encoding_1[1], (batch_size, 1, 1))
455
negative_pooled_embeddings = ops.tile(encoding_1[3], (batch_size, 1))
456
457
latents = random.normal((1, height // 8, width // 8, 16), seed=42)
458
latents = ops.tile(latents, (batch_size, 1, 1, 1))
459
460
images = []
461
progbar = keras.utils.Progbar(batches)
462
for i in range(batches):
463
images.append(
464
generate_function(
465
latents,
466
(
467
interpolated_positive_embeddings[i],
468
negative_embeddings,
469
interpolated_positive_pooled_embeddings[i],
470
negative_pooled_embeddings,
471
),
472
ops.convert_to_tensor(num_steps),
473
ops.convert_to_tensor(guidance_scale),
474
)
475
)
476
progbar.update(i + 1, finalize=i == batches - 1)
477
478
479
"""
480
Let's display the resulting images in a grid to make them easier to interpret.
481
"""
482
483
484
def plot_grid(images, path, grid_size, scale=2):
485
fig, axs = plt.subplots(
486
grid_size, grid_size, figsize=(grid_size * scale, grid_size * scale)
487
)
488
fig.tight_layout()
489
plt.subplots_adjust(wspace=0, hspace=0)
490
plt.axis("off")
491
for ax in axs.flat:
492
ax.axis("off")
493
494
for i in range(min(grid_size * grid_size, len(images))):
495
ax = axs.flat[i]
496
ax.imshow(images[i])
497
ax.axis("off")
498
499
for i in range(len(images), grid_size * grid_size):
500
axs.flat[i].axis("off")
501
axs.flat[i].remove()
502
503
plt.savefig(
504
fname=path,
505
pad_inches=0,
506
bbox_inches="tight",
507
transparent=False,
508
dpi=60,
509
)
510
511
512
images = ops.convert_to_numpy(decode_to_images(images, height, width))
513
plot_grid(images, "4-way-interpolation.jpg", interpolation_steps)
514
515
"""
516
We can also interpolate while allowing diffusion latents to vary by dropping
517
the `seed` parameter:
518
"""
519
520
images = []
521
progbar = keras.utils.Progbar(batches)
522
for i in range(batches):
523
# Vary diffusion latents for each input.
524
latents = random.normal((batch_size, height // 8, width // 8, 16))
525
images.append(
526
generate_function(
527
latents,
528
(
529
interpolated_positive_embeddings[i],
530
negative_embeddings,
531
interpolated_positive_pooled_embeddings[i],
532
negative_pooled_embeddings,
533
),
534
ops.convert_to_tensor(num_steps),
535
ops.convert_to_tensor(guidance_scale),
536
)
537
)
538
progbar.update(i + 1, finalize=i == batches - 1)
539
540
images = ops.convert_to_numpy(decode_to_images(images, height, width))
541
plot_grid(images, "4-way-interpolation-varying-latent.jpg", interpolation_steps)
542
543
"""
544
Next up -- let's go for some walks!
545
546
## A walk around a text prompt
547
548
Our next experiment will be to go for a walk around the latent manifold
549
starting from a point produced by a particular prompt.
550
"""
551
552
walk_steps = 64
553
batch_size = 4
554
batches = walk_steps // batch_size
555
step_size = 0.01
556
prompt = "The eiffel tower in the style of starry night"
557
encoding = get_text_embeddings(prompt)
558
559
positive_embeddings = encoding[0]
560
positive_pooled_embeddings = encoding[2]
561
negative_embeddings = encoding[1]
562
negative_pooled_embeddings = encoding[3]
563
564
# The shape of `positive_embeddings`: (1, 154, 4096)
565
# The shape of `positive_pooled_embeddings`: (1, 2048)
566
positive_embeddings_delta = ops.ones_like(positive_embeddings) * step_size
567
positive_pooled_embeddings_delta = ops.ones_like(positive_pooled_embeddings) * step_size
568
positive_embeddings_shape = ops.shape(positive_embeddings)
569
positive_pooled_embeddings_shape = ops.shape(positive_pooled_embeddings)
570
571
walked_positive_embeddings = []
572
walked_positive_pooled_embeddings = []
573
for step_index in range(walk_steps):
574
walked_positive_embeddings.append(positive_embeddings)
575
walked_positive_pooled_embeddings.append(positive_pooled_embeddings)
576
positive_embeddings += positive_embeddings_delta
577
positive_pooled_embeddings += positive_pooled_embeddings_delta
578
walked_positive_embeddings = ops.stack(walked_positive_embeddings, axis=0)
579
walked_positive_pooled_embeddings = ops.stack(walked_positive_pooled_embeddings, axis=0)
580
walked_positive_embeddings = ops.reshape(
581
walked_positive_embeddings,
582
(
583
batches,
584
batch_size,
585
positive_embeddings_shape[-2],
586
positive_embeddings_shape[-1],
587
),
588
)
589
walked_positive_pooled_embeddings = ops.reshape(
590
walked_positive_pooled_embeddings,
591
(batches, batch_size, positive_pooled_embeddings_shape[-1]),
592
)
593
negative_embeddings = ops.tile(encoding_1[1], (batch_size, 1, 1))
594
negative_pooled_embeddings = ops.tile(encoding_1[3], (batch_size, 1))
595
596
latents = random.normal((1, height // 8, width // 8, 16), seed=42)
597
latents = ops.tile(latents, (batch_size, 1, 1, 1))
598
599
images = []
600
progbar = keras.utils.Progbar(batches)
601
for i in range(batches):
602
images.append(
603
generate_function(
604
latents,
605
(
606
walked_positive_embeddings[i],
607
negative_embeddings,
608
walked_positive_pooled_embeddings[i],
609
negative_pooled_embeddings,
610
),
611
ops.convert_to_tensor(num_steps),
612
ops.convert_to_tensor(guidance_scale),
613
)
614
)
615
progbar.update(i + 1, finalize=i == batches - 1)
616
617
images = ops.convert_to_numpy(decode_to_images(images, height, width))
618
export_as_gif(
619
"eiffel-tower-starry-night.gif",
620
[Image.fromarray(image) for image in images],
621
frames_per_second=2,
622
)
623
624
"""
625
Perhaps unsurprisingly, walking too far from the encoder's latent manifold
626
produces images that look incoherent. Try it for yourself by setting your own
627
prompt, and adjusting `step_size` to increase or decrease the magnitude
628
of the walk. Note that when the magnitude of the walk gets large, the walk often
629
leads into areas which produce extremely noisy images.
630
631
## A circular walk through the diffusion latent space for a single prompt
632
633
Our final experiment is to stick to one prompt and explore the variety of images
634
that the diffusion model can produce from that prompt. We do this by controlling
635
the noise that is used to seed the diffusion process.
636
637
We create two noise components, `x` and `y`, and do a walk from 0 to 2π, summing
638
the cosine of our `x` component and the sin of our `y` component to produce
639
noise. Using this approach, the end of our walk arrives at the same noise inputs
640
where we began our walk, so we get a "loopable" result!
641
"""
642
643
walk_steps = 64
644
batch_size = 4
645
batches = walk_steps // batch_size
646
prompt = "An oil paintings of cows in a field next to a windmill in Holland"
647
encoding = get_text_embeddings(prompt)
648
649
walk_latent_x = random.normal((1, height // 8, width // 8, 16))
650
walk_latent_y = random.normal((1, height // 8, width // 8, 16))
651
walk_scale_x = ops.cos(ops.linspace(0.0, 2.0, walk_steps) * math.pi)
652
walk_scale_y = ops.sin(ops.linspace(0.0, 2.0, walk_steps) * math.pi)
653
latent_x = ops.tensordot(walk_scale_x, walk_latent_x, axes=0)
654
latent_y = ops.tensordot(walk_scale_y, walk_latent_y, axes=0)
655
latents = ops.add(latent_x, latent_y)
656
latents = ops.reshape(latents, (batches, batch_size, height // 8, width // 8, 16))
657
658
images = []
659
progbar = keras.utils.Progbar(batches)
660
for i in range(batches):
661
images.append(
662
generate_function(
663
latents[i],
664
(
665
ops.tile(encoding[0], (batch_size, 1, 1)),
666
ops.tile(encoding[1], (batch_size, 1, 1)),
667
ops.tile(encoding[2], (batch_size, 1)),
668
ops.tile(encoding[3], (batch_size, 1)),
669
),
670
ops.convert_to_tensor(num_steps),
671
ops.convert_to_tensor(guidance_scale),
672
)
673
)
674
progbar.update(i + 1, finalize=i == batches - 1)
675
676
images = ops.convert_to_numpy(decode_to_images(images, height, width))
677
export_as_gif(
678
"cows.gif",
679
[Image.fromarray(image) for image in images],
680
frames_per_second=4,
681
no_rubber_band=True,
682
)
683
684
"""
685
Experiment with your own prompts and with different values of the parameters!
686
687
## Conclusion
688
689
Stable Diffusion 3 offers a lot more than just single text-to-image generation.
690
Exploring the latent manifold of the text encoder and the latent space of the
691
diffusion model are two fun ways to experience the power of this model, and
692
KerasHub makes it easy!
693
"""
694
695