Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/generative/fine_tune_via_textual_inversion.py
3507 views
1
"""
2
Title: Teach StableDiffusion new concepts via Textual Inversion
3
Authors: Ian Stenbit, [lukewood](https://lukewood.xyz)
4
Date created: 2022/12/09
5
Last modified: 2022/12/09
6
Description: Learning new visual concepts with KerasCV's StableDiffusion implementation.
7
"""
8
9
"""
10
## Textual Inversion
11
12
Since its release, StableDiffusion has quickly become a favorite amongst
13
the generative machine learning community.
14
The high volume of traffic has led to open source contributed improvements,
15
heavy prompt engineering, and even the invention of novel algorithms.
16
17
Perhaps the most impressive new algorithm being used is
18
[Textual Inversion](https://github.com/rinongal/textual_inversion), presented in
19
[_An Image is Worth One Word: Personalizing Text-to-Image Generation using Textual Inversion_](https://textual-inversion.github.io/).
20
21
Textual Inversion is the process of teaching an image generator a specific visual concept
22
through the use of fine-tuning. In the diagram below, you can see an
23
example of this process where the authors teach the model new concepts, calling them
24
"S_*".
25
26
![https://i.imgur.com/KqEeBsM.jpg](https://i.imgur.com/KqEeBsM.jpg)
27
28
Conceptually, textual inversion works by learning a token embedding for a new text
29
token, keeping the remaining components of StableDiffusion frozen.
30
31
This guide shows you how to fine-tune the StableDiffusion model shipped in KerasCV
32
using the Textual-Inversion algorithm. By the end of the guide, you will be able to
33
write the "Gandalf the Gray as a <my-funny-cat-token>".
34
35
![https://i.imgur.com/rcb1Yfx.png](https://i.imgur.com/rcb1Yfx.png)
36
37
38
First, let's import the packages we need, and create a
39
StableDiffusion instance so we can use some of its subcomponents for fine-tuning.
40
"""
41
42
"""shell
43
pip install -q git+https://github.com/keras-team/keras-cv.git
44
pip install -q tensorflow==2.11.0
45
"""
46
47
import math
48
49
import keras_cv
50
import numpy as np
51
import tensorflow as tf
52
from keras_cv import layers as cv_layers
53
from keras_cv.models.stable_diffusion import NoiseScheduler
54
from tensorflow import keras
55
import matplotlib.pyplot as plt
56
57
stable_diffusion = keras_cv.models.StableDiffusion()
58
59
"""
60
Next, let's define a visualization utility to show off the generated images:
61
"""
62
63
64
def plot_images(images):
65
plt.figure(figsize=(20, 20))
66
for i in range(len(images)):
67
ax = plt.subplot(1, len(images), i + 1)
68
plt.imshow(images[i])
69
plt.axis("off")
70
71
72
"""
73
## Assembling a text-image pair dataset
74
75
In order to train the embedding of our new token, we first must assemble a dataset
76
consisting of text-image pairs.
77
Each sample from the dataset must contain an image of the concept we are teaching
78
StableDiffusion, as well as a caption accurately representing the content of the image.
79
In this tutorial, we will teach StableDiffusion the concept of Luke and Ian's GitHub
80
avatars:
81
82
![gh-avatars](https://i.imgur.com/WyEHDIR.jpg)
83
84
First, let's construct an image dataset of cat dolls:
85
"""
86
87
88
def assemble_image_dataset(urls):
89
# Fetch all remote files
90
files = [tf.keras.utils.get_file(origin=url) for url in urls]
91
92
# Resize images
93
resize = keras.layers.Resizing(height=512, width=512, crop_to_aspect_ratio=True)
94
images = [keras.utils.load_img(img) for img in files]
95
images = [keras.utils.img_to_array(img) for img in images]
96
images = np.array([resize(img) for img in images])
97
98
# The StableDiffusion image encoder requires images to be normalized to the
99
# [-1, 1] pixel value range
100
images = images / 127.5 - 1
101
102
# Create the tf.data.Dataset
103
image_dataset = tf.data.Dataset.from_tensor_slices(images)
104
105
# Shuffle and introduce random noise
106
image_dataset = image_dataset.shuffle(50, reshuffle_each_iteration=True)
107
image_dataset = image_dataset.map(
108
cv_layers.RandomCropAndResize(
109
target_size=(512, 512),
110
crop_area_factor=(0.8, 1.0),
111
aspect_ratio_factor=(1.0, 1.0),
112
),
113
num_parallel_calls=tf.data.AUTOTUNE,
114
)
115
image_dataset = image_dataset.map(
116
cv_layers.RandomFlip(mode="horizontal"),
117
num_parallel_calls=tf.data.AUTOTUNE,
118
)
119
return image_dataset
120
121
122
"""
123
Next, we assemble a text dataset:
124
"""
125
126
MAX_PROMPT_LENGTH = 77
127
placeholder_token = "<my-funny-cat-token>"
128
129
130
def pad_embedding(embedding):
131
return embedding + (
132
[stable_diffusion.tokenizer.end_of_text] * (MAX_PROMPT_LENGTH - len(embedding))
133
)
134
135
136
stable_diffusion.tokenizer.add_tokens(placeholder_token)
137
138
139
def assemble_text_dataset(prompts):
140
prompts = [prompt.format(placeholder_token) for prompt in prompts]
141
embeddings = [stable_diffusion.tokenizer.encode(prompt) for prompt in prompts]
142
embeddings = [np.array(pad_embedding(embedding)) for embedding in embeddings]
143
text_dataset = tf.data.Dataset.from_tensor_slices(embeddings)
144
text_dataset = text_dataset.shuffle(100, reshuffle_each_iteration=True)
145
return text_dataset
146
147
148
"""
149
Finally, we zip our datasets together to produce a text-image pair dataset.
150
"""
151
152
153
def assemble_dataset(urls, prompts):
154
image_dataset = assemble_image_dataset(urls)
155
text_dataset = assemble_text_dataset(prompts)
156
# the image dataset is quite short, so we repeat it to match the length of the
157
# text prompt dataset
158
image_dataset = image_dataset.repeat()
159
# we use the text prompt dataset to determine the length of the dataset. Due to
160
# the fact that there are relatively few prompts we repeat the dataset 5 times.
161
# we have found that this anecdotally improves results.
162
text_dataset = text_dataset.repeat(5)
163
return tf.data.Dataset.zip((image_dataset, text_dataset))
164
165
166
"""
167
In order to ensure our prompts are descriptive, we use extremely generic prompts.
168
169
Let's try this out with some sample images and prompts.
170
"""
171
172
train_ds = assemble_dataset(
173
urls=[
174
"https://i.imgur.com/VIedH1X.jpg",
175
"https://i.imgur.com/eBw13hE.png",
176
"https://i.imgur.com/oJ3rSg7.png",
177
"https://i.imgur.com/5mCL6Df.jpg",
178
"https://i.imgur.com/4Q6WWyI.jpg",
179
],
180
prompts=[
181
"a photo of a {}",
182
"a rendering of a {}",
183
"a cropped photo of the {}",
184
"the photo of a {}",
185
"a photo of a clean {}",
186
"a dark photo of the {}",
187
"a photo of my {}",
188
"a photo of the cool {}",
189
"a close-up photo of a {}",
190
"a bright photo of the {}",
191
"a cropped photo of a {}",
192
"a photo of the {}",
193
"a good photo of the {}",
194
"a photo of one {}",
195
"a close-up photo of the {}",
196
"a rendition of the {}",
197
"a photo of the clean {}",
198
"a rendition of a {}",
199
"a photo of a nice {}",
200
"a good photo of a {}",
201
"a photo of the nice {}",
202
"a photo of the small {}",
203
"a photo of the weird {}",
204
"a photo of the large {}",
205
"a photo of a cool {}",
206
"a photo of a small {}",
207
],
208
)
209
210
"""
211
## On the importance of prompt accuracy
212
213
During our first attempt at writing this guide we included images of groups of these cat
214
dolls in our dataset but continued to use the generic prompts listed above.
215
Our results were anecdotally poor. For example, here's cat doll gandalf using this method:
216
217
![mediocre-wizard](https://i.imgur.com/Thq7XOu.jpg)
218
219
It's conceptually close, but it isn't as great as it can be.
220
221
In order to remedy this, we began experimenting with splitting our images into images of
222
singular cat dolls and groups of cat dolls.
223
Following this split, we came up with new prompts for the group shots.
224
225
Training on text-image pairs that accurately represent the content boosted the quality
226
of our results *substantially*. This speaks to the importance of prompt accuracy.
227
228
In addition to separating the images into singular and group images, we also remove some
229
inaccurate prompts; such as "a dark photo of the {}"
230
231
Keeping this in mind, we assemble our final training dataset below:
232
"""
233
234
single_ds = assemble_dataset(
235
urls=[
236
"https://i.imgur.com/VIedH1X.jpg",
237
"https://i.imgur.com/eBw13hE.png",
238
"https://i.imgur.com/oJ3rSg7.png",
239
"https://i.imgur.com/5mCL6Df.jpg",
240
"https://i.imgur.com/4Q6WWyI.jpg",
241
],
242
prompts=[
243
"a photo of a {}",
244
"a rendering of a {}",
245
"a cropped photo of the {}",
246
"the photo of a {}",
247
"a photo of a clean {}",
248
"a photo of my {}",
249
"a photo of the cool {}",
250
"a close-up photo of a {}",
251
"a bright photo of the {}",
252
"a cropped photo of a {}",
253
"a photo of the {}",
254
"a good photo of the {}",
255
"a photo of one {}",
256
"a close-up photo of the {}",
257
"a rendition of the {}",
258
"a photo of the clean {}",
259
"a rendition of a {}",
260
"a photo of a nice {}",
261
"a good photo of a {}",
262
"a photo of the nice {}",
263
"a photo of the small {}",
264
"a photo of the weird {}",
265
"a photo of the large {}",
266
"a photo of a cool {}",
267
"a photo of a small {}",
268
],
269
)
270
271
"""
272
![https://i.imgur.com/gQCRjK6.png](https://i.imgur.com/gQCRjK6.png)
273
274
Looks great!
275
276
Next, we assemble a dataset of groups of our GitHub avatars:
277
"""
278
279
group_ds = assemble_dataset(
280
urls=[
281
"https://i.imgur.com/yVmZ2Qa.jpg",
282
"https://i.imgur.com/JbyFbZJ.jpg",
283
"https://i.imgur.com/CCubd3q.jpg",
284
],
285
prompts=[
286
"a photo of a group of {}",
287
"a rendering of a group of {}",
288
"a cropped photo of the group of {}",
289
"the photo of a group of {}",
290
"a photo of a clean group of {}",
291
"a photo of my group of {}",
292
"a photo of a cool group of {}",
293
"a close-up photo of a group of {}",
294
"a bright photo of the group of {}",
295
"a cropped photo of a group of {}",
296
"a photo of the group of {}",
297
"a good photo of the group of {}",
298
"a photo of one group of {}",
299
"a close-up photo of the group of {}",
300
"a rendition of the group of {}",
301
"a photo of the clean group of {}",
302
"a rendition of a group of {}",
303
"a photo of a nice group of {}",
304
"a good photo of a group of {}",
305
"a photo of the nice group of {}",
306
"a photo of the small group of {}",
307
"a photo of the weird group of {}",
308
"a photo of the large group of {}",
309
"a photo of a cool group of {}",
310
"a photo of a small group of {}",
311
],
312
)
313
314
"""
315
![https://i.imgur.com/GY9Pf3D.png](https://i.imgur.com/GY9Pf3D.png)
316
317
Finally, we concatenate the two datasets:
318
"""
319
320
train_ds = single_ds.concatenate(group_ds)
321
train_ds = train_ds.batch(1).shuffle(
322
train_ds.cardinality(), reshuffle_each_iteration=True
323
)
324
325
"""
326
## Adding a new token to the text encoder
327
328
Next, we create a new text encoder for the StableDiffusion model and add our new
329
embedding for '<my-funny-cat-token>' into the model.
330
"""
331
tokenized_initializer = stable_diffusion.tokenizer.encode("cat")[1]
332
new_weights = stable_diffusion.text_encoder.layers[2].token_embedding(
333
tf.constant(tokenized_initializer)
334
)
335
336
# Get len of .vocab instead of tokenizer
337
new_vocab_size = len(stable_diffusion.tokenizer.vocab)
338
339
# The embedding layer is the 2nd layer in the text encoder
340
old_token_weights = stable_diffusion.text_encoder.layers[
341
2
342
].token_embedding.get_weights()
343
old_position_weights = stable_diffusion.text_encoder.layers[
344
2
345
].position_embedding.get_weights()
346
347
old_token_weights = old_token_weights[0]
348
new_weights = np.expand_dims(new_weights, axis=0)
349
new_weights = np.concatenate([old_token_weights, new_weights], axis=0)
350
351
352
"""
353
Let's construct a new TextEncoder and prepare it.
354
"""
355
356
# Have to set download_weights False so we can init (otherwise tries to load weights)
357
new_encoder = keras_cv.models.stable_diffusion.TextEncoder(
358
keras_cv.models.stable_diffusion.stable_diffusion.MAX_PROMPT_LENGTH,
359
vocab_size=new_vocab_size,
360
download_weights=False,
361
)
362
for index, layer in enumerate(stable_diffusion.text_encoder.layers):
363
# Layer 2 is the embedding layer, so we omit it from our weight-copying
364
if index == 2:
365
continue
366
new_encoder.layers[index].set_weights(layer.get_weights())
367
368
369
new_encoder.layers[2].token_embedding.set_weights([new_weights])
370
new_encoder.layers[2].position_embedding.set_weights(old_position_weights)
371
372
stable_diffusion._text_encoder = new_encoder
373
stable_diffusion._text_encoder.compile(jit_compile=True)
374
375
"""
376
## Training
377
378
Now we can move on to the exciting part: training!
379
380
In TextualInversion, the only piece of the model that is trained is the embedding vector.
381
Let's freeze the rest of the model.
382
"""
383
384
385
stable_diffusion.diffusion_model.trainable = False
386
stable_diffusion.decoder.trainable = False
387
stable_diffusion.text_encoder.trainable = True
388
389
stable_diffusion.text_encoder.layers[2].trainable = True
390
391
392
def traverse_layers(layer):
393
if hasattr(layer, "layers"):
394
for layer in layer.layers:
395
yield layer
396
if hasattr(layer, "token_embedding"):
397
yield layer.token_embedding
398
if hasattr(layer, "position_embedding"):
399
yield layer.position_embedding
400
401
402
for layer in traverse_layers(stable_diffusion.text_encoder):
403
if isinstance(layer, keras.layers.Embedding) or "clip_embedding" in layer.name:
404
layer.trainable = True
405
else:
406
layer.trainable = False
407
408
new_encoder.layers[2].position_embedding.trainable = False
409
410
"""
411
Let's confirm the proper weights are set to trainable.
412
"""
413
414
all_models = [
415
stable_diffusion.text_encoder,
416
stable_diffusion.diffusion_model,
417
stable_diffusion.decoder,
418
]
419
print([[w.shape for w in model.trainable_weights] for model in all_models])
420
421
"""
422
## Training the new embedding
423
424
In order to train the embedding, we need a couple of utilities.
425
We import a NoiseScheduler from KerasCV, and define the following utilities below:
426
427
- `sample_from_encoder_outputs` is a wrapper around the base StableDiffusion image
428
encoder which samples from the statistical distribution produced by the image
429
encoder, rather than taking just the mean (like many other SD applications)
430
- `get_timestep_embedding` produces an embedding for a specified timestep for the
431
diffusion model
432
- `get_position_ids` produces a tensor of position IDs for the text encoder (which is just a
433
series from `[1, MAX_PROMPT_LENGTH]`)
434
"""
435
436
437
# Remove the top layer from the encoder, which cuts off the variance and only returns
438
# the mean
439
training_image_encoder = keras.Model(
440
stable_diffusion.image_encoder.input,
441
stable_diffusion.image_encoder.layers[-2].output,
442
)
443
444
445
def sample_from_encoder_outputs(outputs):
446
mean, logvar = tf.split(outputs, 2, axis=-1)
447
logvar = tf.clip_by_value(logvar, -30.0, 20.0)
448
std = tf.exp(0.5 * logvar)
449
sample = tf.random.normal(tf.shape(mean))
450
return mean + std * sample
451
452
453
def get_timestep_embedding(timestep, dim=320, max_period=10000):
454
half = dim // 2
455
freqs = tf.math.exp(
456
-math.log(max_period) * tf.range(0, half, dtype=tf.float32) / half
457
)
458
args = tf.convert_to_tensor([timestep], dtype=tf.float32) * freqs
459
embedding = tf.concat([tf.math.cos(args), tf.math.sin(args)], 0)
460
return embedding
461
462
463
def get_position_ids():
464
return tf.convert_to_tensor([list(range(MAX_PROMPT_LENGTH))], dtype=tf.int32)
465
466
467
"""
468
Next, we implement a `StableDiffusionFineTuner`, which is a subclass of `keras.Model`
469
that overrides `train_step` to train the token embeddings of our text encoder.
470
This is the core of the Textual Inversion algorithm.
471
472
Abstractly speaking, the train step takes a sample from the output of the frozen SD
473
image encoder's latent distribution for a training image, adds noise to that sample, and
474
then passes that noisy sample to the frozen diffusion model.
475
The hidden state of the diffusion model is the output of the text encoder for the prompt
476
corresponding to the image.
477
478
Our final goal state is that the diffusion model is able to separate the noise from the
479
sample using the text encoding as hidden state, so our loss is the mean-squared error of
480
the noise and the output of the diffusion model (which has, ideally, removed the image
481
latents from the noise).
482
483
We compute gradients for only the token embeddings of the text encoder, and in the
484
train step we zero-out the gradients for all tokens other than the token that we're
485
learning.
486
487
See in-line code comments for more details about the train step.
488
"""
489
490
491
class StableDiffusionFineTuner(keras.Model):
492
def __init__(self, stable_diffusion, noise_scheduler, **kwargs):
493
super().__init__(**kwargs)
494
self.stable_diffusion = stable_diffusion
495
self.noise_scheduler = noise_scheduler
496
497
def train_step(self, data):
498
images, embeddings = data
499
500
with tf.GradientTape() as tape:
501
# Sample from the predicted distribution for the training image
502
latents = sample_from_encoder_outputs(training_image_encoder(images))
503
# The latents must be downsampled to match the scale of the latents used
504
# in the training of StableDiffusion. This number is truly just a "magic"
505
# constant that they chose when training the model.
506
latents = latents * 0.18215
507
508
# Produce random noise in the same shape as the latent sample
509
noise = tf.random.normal(tf.shape(latents))
510
batch_dim = tf.shape(latents)[0]
511
512
# Pick a random timestep for each sample in the batch
513
timesteps = tf.random.uniform(
514
(batch_dim,),
515
minval=0,
516
maxval=noise_scheduler.train_timesteps,
517
dtype=tf.int64,
518
)
519
520
# Add noise to the latents based on the timestep for each sample
521
noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)
522
523
# Encode the text in the training samples to use as hidden state in the
524
# diffusion model
525
encoder_hidden_state = self.stable_diffusion.text_encoder(
526
[embeddings, get_position_ids()]
527
)
528
529
# Compute timestep embeddings for the randomly-selected timesteps for each
530
# sample in the batch
531
timestep_embeddings = tf.map_fn(
532
fn=get_timestep_embedding,
533
elems=timesteps,
534
fn_output_signature=tf.float32,
535
)
536
537
# Call the diffusion model
538
noise_pred = self.stable_diffusion.diffusion_model(
539
[noisy_latents, timestep_embeddings, encoder_hidden_state]
540
)
541
542
# Compute the mean-squared error loss and reduce it.
543
loss = self.compiled_loss(noise_pred, noise)
544
loss = tf.reduce_mean(loss, axis=2)
545
loss = tf.reduce_mean(loss, axis=1)
546
loss = tf.reduce_mean(loss)
547
548
# Load the trainable weights and compute the gradients for them
549
trainable_weights = self.stable_diffusion.text_encoder.trainable_weights
550
grads = tape.gradient(loss, trainable_weights)
551
552
# Gradients are stored in indexed slices, so we have to find the index
553
# of the slice(s) which contain the placeholder token.
554
index_of_placeholder_token = tf.reshape(tf.where(grads[0].indices == 49408), ())
555
condition = grads[0].indices == 49408
556
condition = tf.expand_dims(condition, axis=-1)
557
558
# Override the gradients, zeroing out the gradients for all slices that
559
# aren't for the placeholder token, effectively freezing the weights for
560
# all other tokens.
561
grads[0] = tf.IndexedSlices(
562
values=tf.where(condition, grads[0].values, 0),
563
indices=grads[0].indices,
564
dense_shape=grads[0].dense_shape,
565
)
566
567
self.optimizer.apply_gradients(zip(grads, trainable_weights))
568
return {"loss": loss}
569
570
571
"""
572
Before we start training, let's take a look at what StableDiffusion produces for our
573
token.
574
"""
575
576
generated = stable_diffusion.text_to_image(
577
f"an oil painting of {placeholder_token}", seed=1337, batch_size=3
578
)
579
plot_images(generated)
580
581
"""
582
As you can see, the model still thinks of our token as a cat, as this was the seed token
583
we used to initialize our custom token.
584
585
Now, to get started with training, we can just `compile()` our model like any other
586
Keras model. Before doing so, we also instantiate a noise scheduler for training and
587
configure our training parameters such as learning rate and optimizer.
588
"""
589
590
noise_scheduler = NoiseScheduler(
591
beta_start=0.00085,
592
beta_end=0.012,
593
beta_schedule="scaled_linear",
594
train_timesteps=1000,
595
)
596
trainer = StableDiffusionFineTuner(stable_diffusion, noise_scheduler, name="trainer")
597
EPOCHS = 50
598
learning_rate = keras.optimizers.schedules.CosineDecay(
599
initial_learning_rate=1e-4, decay_steps=train_ds.cardinality() * EPOCHS
600
)
601
optimizer = keras.optimizers.Adam(
602
weight_decay=0.004, learning_rate=learning_rate, epsilon=1e-8, global_clipnorm=10
603
)
604
605
trainer.compile(
606
optimizer=optimizer,
607
# We are performing reduction manually in our train step, so none is required here.
608
loss=keras.losses.MeanSquaredError(reduction="none"),
609
)
610
611
"""
612
To monitor training, we can produce a `keras.callbacks.Callback` to produce a few images
613
every epoch using our custom token.
614
615
We create three callbacks with different prompts so that we can see how they progress
616
over the course of training. We use a fixed seed so that we can easily see the
617
progression of the learned token.
618
"""
619
620
621
class GenerateImages(keras.callbacks.Callback):
622
def __init__(
623
self, stable_diffusion, prompt, steps=50, frequency=10, seed=None, **kwargs
624
):
625
super().__init__(**kwargs)
626
self.stable_diffusion = stable_diffusion
627
self.prompt = prompt
628
self.seed = seed
629
self.frequency = frequency
630
self.steps = steps
631
632
def on_epoch_end(self, epoch, logs):
633
if epoch % self.frequency == 0:
634
images = self.stable_diffusion.text_to_image(
635
self.prompt, batch_size=3, num_steps=self.steps, seed=self.seed
636
)
637
plot_images(
638
images,
639
)
640
641
642
cbs = [
643
GenerateImages(
644
stable_diffusion, prompt=f"an oil painting of {placeholder_token}", seed=1337
645
),
646
GenerateImages(
647
stable_diffusion, prompt=f"gandalf the gray as a {placeholder_token}", seed=1337
648
),
649
GenerateImages(
650
stable_diffusion,
651
prompt=f"two {placeholder_token} getting married, photorealistic, high quality",
652
seed=1337,
653
),
654
]
655
656
"""
657
Now, all that is left to do is to call `model.fit()`!
658
"""
659
660
trainer.fit(
661
train_ds,
662
epochs=EPOCHS,
663
callbacks=cbs,
664
)
665
666
"""
667
It's pretty fun to see how the model learns our new token over time. Play around with it
668
and see how you can tune training parameters and your training dataset to produce the
669
best images!
670
"""
671
672
"""
673
## Taking the Fine Tuned Model for a Spin
674
675
Now for the really fun part. We've learned a token embedding for our custom token, so
676
now we can generate images with StableDiffusion the same way we would for any other
677
token!
678
679
Here are some fun example prompts to get you started, with sample outputs from our cat
680
doll token!
681
"""
682
683
generated = stable_diffusion.text_to_image(
684
f"Gandalf as a {placeholder_token} fantasy art drawn by disney concept artists, "
685
"golden colour, high quality, highly detailed, elegant, sharp focus, concept art, "
686
"character concepts, digital painting, mystery, adventure",
687
batch_size=3,
688
)
689
plot_images(generated)
690
691
"""
692
"""
693
694
generated = stable_diffusion.text_to_image(
695
f"A masterpiece of a {placeholder_token} crying out to the heavens. "
696
f"Behind the {placeholder_token}, an dark, evil shade looms over it - sucking the "
697
"life right out of it.",
698
batch_size=3,
699
)
700
plot_images(generated)
701
702
"""
703
"""
704
705
generated = stable_diffusion.text_to_image(
706
f"An evil {placeholder_token}.", batch_size=3
707
)
708
plot_images(generated)
709
710
"""
711
"""
712
713
generated = stable_diffusion.text_to_image(
714
f"A mysterious {placeholder_token} approaches the great pyramids of egypt.",
715
batch_size=3,
716
)
717
plot_images(generated)
718
719
"""
720
## Conclusions
721
722
Using the Textual Inversion algorithm you can teach StableDiffusion new concepts!
723
724
Some possible next steps to follow:
725
726
- Try out your own prompts
727
- Teach the model a style
728
- Gather a dataset of your favorite pet cat or dog and teach the model about it
729
"""
730
731