Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/generative/dreambooth.py
3507 views
1
"""
2
Title: DreamBooth
3
Author: [Sayak Paul](https://twitter.com/RisingSayak), [Chansung Park](https://twitter.com/algo_diver)
4
Date created: 2023/02/01
5
Last modified: 2023/02/05
6
Description: Implementing DreamBooth.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
In this example, we implement DreamBooth, a fine-tuning technique to teach new visual
14
concepts to text-conditioned Diffusion models with just 3 - 5 images. DreamBooth was
15
proposed in
16
[DreamBooth: Fine Tuning Text-to-Image Diffusion Models for Subject-Driven Generation](https://arxiv.org/abs/2208.12242)
17
by Ruiz et al.
18
19
DreamBooth, in a sense, is similar to the
20
[traditional way of fine-tuning a text-conditioned Diffusion model except](https://keras.io/examples/generative/finetune_stable_diffusion/)
21
for a few gotchas. This example assumes that you have basic familiarity with
22
Diffusion models and how to fine-tune them. Here are some reference examples that might
23
help you to get familiarized quickly:
24
25
* [High-performance image generation using Stable Diffusion in KerasCV](https://keras.io/guides/keras_cv/generate_images_with_stable_diffusion/)
26
* [Teach StableDiffusion new concepts via Textual Inversion](https://keras.io/examples/generative/fine_tune_via_textual_inversion/)
27
* [Fine-tuning Stable Diffusion](https://keras.io/examples/generative/finetune_stable_diffusion/)
28
29
First, let's install the latest versions of KerasCV and TensorFlow.
30
31
"""
32
33
"""shell
34
pip install -q -U keras_cv==0.6.0
35
pip install -q -U tensorflow
36
"""
37
38
"""
39
If you're running the code, please ensure you're using a GPU with at least 24 GBs of
40
VRAM.
41
"""
42
43
"""
44
## Initial imports
45
"""
46
47
import math
48
49
import keras_cv
50
import matplotlib.pyplot as plt
51
import numpy as np
52
import tensorflow as tf
53
from imutils import paths
54
from tensorflow import keras
55
56
"""
57
## Usage of DreamBooth
58
59
... is very versatile. By teaching Stable Diffusion about your favorite visual
60
concepts, you can
61
62
* Recontextualize objects in interesting ways:
63
64
![](https://i.imgur.com/4Da9ozw.png)
65
66
* Generate artistic renderings of the underlying visual concept:
67
68
![](https://i.imgur.com/nI2N8bI.png)
69
70
71
And many other applications. We welcome you to check out the original
72
[DreamBooth paper](https://arxiv.org/abs/2208.12242) in this regard.
73
"""
74
75
"""
76
## Download the instance and class images
77
78
DreamBooth uses a technique called "prior preservation" to meaningfully guide the
79
training procedure such that the fine-tuned models can still preserve some of the prior
80
semantics of the visual concept you're introducing. To know more about the idea of "prior
81
preservation" refer to [this document](https://dreambooth.github.io/).
82
83
Here, we need to introduce a few key terms specific to DreamBooth:
84
85
* **Unique class**: Examples include "dog", "person", etc. In this example, we use "dog".
86
* **Unique identifier**: A unique identifier that is prepended to the unique class while
87
forming the "instance prompts". In this example, we use "sks" as this unique identifier.
88
* **Instance prompt**: Denotes a prompt that best describes the "instance images". An
89
example prompt could be - "f"a photo of {unique_id} {unique_class}". So, for our example,
90
this becomes - "a photo of sks dog".
91
* **Class prompt**: Denotes a prompt without the unique identifier. This prompt is used
92
for generating "class images" for prior preservation. For our example, this prompt is -
93
"a photo of dog".
94
* **Instance images**: Denote the images that represent the visual concept you're trying
95
to teach aka the "instance prompt". This number is typically just 3 - 5. We typically
96
gather these images ourselves.
97
* **Class images**: Denote the images generated using the "class prompt" for using prior
98
preservation in DreamBooth training. We leverage the pre-trained model before fine-tuning
99
it to generate these class images. Typically, 200 - 300 class images are enough.
100
101
In code, this generation process looks quite simply:
102
103
```py
104
from tqdm import tqdm
105
import numpy as np
106
import hashlib
107
import keras_cv
108
import PIL
109
import os
110
111
class_images_dir = "class-images"
112
os.makedirs(class_images_dir, exist_ok=True)
113
114
model = keras_cv.models.StableDiffusion(img_width=512, img_height=512, jit_compile=True)
115
116
class_prompt = "a photo of dog"
117
num_imgs_to_generate = 200
118
for i in tqdm(range(num_imgs_to_generate)):
119
images = model.text_to_image(
120
class_prompt,
121
batch_size=3,
122
)
123
idx = np.random.choice(len(images))
124
selected_image = PIL.Image.fromarray(images[idx])
125
hash_image = hashlib.sha1(selected_image.tobytes()).hexdigest()
126
image_filename = os.path.join(class_images_dir, f"{hash_image}.jpg")
127
selected_image.save(image_filename)
128
```
129
130
To keep the runtime of this example short, the authors of this example have gone ahead
131
and generated some class images using
132
[this notebook](https://colab.research.google.com/gist/sayakpaul/6b5de345d29cf5860f84b6d04d958692/generate_class_priors.ipynb).
133
134
**Note** that prior preservation is an optional technique used in DreamBooth, but it
135
almost always helps in improving the quality of the generated images.
136
"""
137
138
instance_images_root = tf.keras.utils.get_file(
139
origin="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/instance-images.tar.gz",
140
untar=True,
141
)
142
class_images_root = tf.keras.utils.get_file(
143
origin="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/class-images.tar.gz",
144
untar=True,
145
)
146
147
"""
148
## Visualize images
149
150
First, let's load the image paths.
151
"""
152
instance_image_paths = list(paths.list_images(instance_images_root))
153
class_image_paths = list(paths.list_images(class_images_root))
154
155
"""
156
Then we load the images from the paths.
157
"""
158
159
160
def load_images(image_paths):
161
images = [np.array(keras.utils.load_img(path)) for path in image_paths]
162
return images
163
164
165
"""
166
And then we make use a utility function to plot the loaded images.
167
"""
168
169
170
def plot_images(images, title=None):
171
plt.figure(figsize=(20, 20))
172
for i in range(len(images)):
173
ax = plt.subplot(1, len(images), i + 1)
174
if title is not None:
175
plt.title(title)
176
plt.imshow(images[i])
177
plt.axis("off")
178
179
180
"""
181
**Instance images**:
182
"""
183
184
plot_images(load_images(instance_image_paths[:5]))
185
186
"""
187
**Class images**:
188
"""
189
190
plot_images(load_images(class_image_paths[:5]))
191
192
"""
193
## Prepare datasets
194
195
Dataset preparation includes two stages: (1): preparing the captions, (2) processing the
196
images.
197
"""
198
199
"""
200
### Prepare the captions
201
"""
202
203
# Since we're using prior preservation, we need to match the number
204
# of instance images we're using. We just repeat the instance image paths
205
# to do so.
206
new_instance_image_paths = []
207
for index in range(len(class_image_paths)):
208
instance_image = instance_image_paths[index % len(instance_image_paths)]
209
new_instance_image_paths.append(instance_image)
210
211
# We just repeat the prompts / captions per images.
212
unique_id = "sks"
213
class_label = "dog"
214
215
instance_prompt = f"a photo of {unique_id} {class_label}"
216
instance_prompts = [instance_prompt] * len(new_instance_image_paths)
217
218
class_prompt = f"a photo of {class_label}"
219
class_prompts = [class_prompt] * len(class_image_paths)
220
221
"""
222
Next, we embed the prompts to save some compute.
223
"""
224
225
import itertools
226
227
# The padding token and maximum prompt length are specific to the text encoder.
228
# If you're using a different text encoder be sure to change them accordingly.
229
padding_token = 49407
230
max_prompt_length = 77
231
232
# Load the tokenizer.
233
tokenizer = keras_cv.models.stable_diffusion.SimpleTokenizer()
234
235
236
# Method to tokenize and pad the tokens.
237
def process_text(caption):
238
tokens = tokenizer.encode(caption)
239
tokens = tokens + [padding_token] * (max_prompt_length - len(tokens))
240
return np.array(tokens)
241
242
243
# Collate the tokenized captions into an array.
244
tokenized_texts = np.empty(
245
(len(instance_prompts) + len(class_prompts), max_prompt_length)
246
)
247
248
for i, caption in enumerate(itertools.chain(instance_prompts, class_prompts)):
249
tokenized_texts[i] = process_text(caption)
250
251
252
# We also pre-compute the text embeddings to save some memory during training.
253
POS_IDS = tf.convert_to_tensor([list(range(max_prompt_length))], dtype=tf.int32)
254
text_encoder = keras_cv.models.stable_diffusion.TextEncoder(max_prompt_length)
255
256
gpus = tf.config.list_logical_devices("GPU")
257
258
# Ensure the computation takes place on a GPU.
259
# Note that it's done automatically when there's a GPU present.
260
# This example just attempts at showing how you can do it
261
# more explicitly.
262
with tf.device(gpus[0].name):
263
embedded_text = text_encoder(
264
[tf.convert_to_tensor(tokenized_texts), POS_IDS], training=False
265
).numpy()
266
267
# To ensure text_encoder doesn't occupy any GPU space.
268
del text_encoder
269
270
"""
271
## Prepare the images
272
"""
273
274
resolution = 512
275
auto = tf.data.AUTOTUNE
276
277
augmenter = keras.Sequential(
278
layers=[
279
keras_cv.layers.CenterCrop(resolution, resolution),
280
keras_cv.layers.RandomFlip(),
281
keras.layers.Rescaling(scale=1.0 / 127.5, offset=-1),
282
]
283
)
284
285
286
def process_image(image_path, tokenized_text):
287
image = tf.io.read_file(image_path)
288
image = tf.io.decode_png(image, 3)
289
image = tf.image.resize(image, (resolution, resolution))
290
return image, tokenized_text
291
292
293
def apply_augmentation(image_batch, embedded_tokens):
294
return augmenter(image_batch), embedded_tokens
295
296
297
def prepare_dict(instance_only=True):
298
def fn(image_batch, embedded_tokens):
299
if instance_only:
300
batch_dict = {
301
"instance_images": image_batch,
302
"instance_embedded_texts": embedded_tokens,
303
}
304
return batch_dict
305
else:
306
batch_dict = {
307
"class_images": image_batch,
308
"class_embedded_texts": embedded_tokens,
309
}
310
return batch_dict
311
312
return fn
313
314
315
def assemble_dataset(image_paths, embedded_texts, instance_only=True, batch_size=1):
316
dataset = tf.data.Dataset.from_tensor_slices((image_paths, embedded_texts))
317
dataset = dataset.map(process_image, num_parallel_calls=auto)
318
dataset = dataset.shuffle(5, reshuffle_each_iteration=True)
319
dataset = dataset.batch(batch_size)
320
dataset = dataset.map(apply_augmentation, num_parallel_calls=auto)
321
322
prepare_dict_fn = prepare_dict(instance_only=instance_only)
323
dataset = dataset.map(prepare_dict_fn, num_parallel_calls=auto)
324
return dataset
325
326
327
"""
328
## Assemble dataset
329
"""
330
instance_dataset = assemble_dataset(
331
new_instance_image_paths,
332
embedded_text[: len(new_instance_image_paths)],
333
)
334
class_dataset = assemble_dataset(
335
class_image_paths,
336
embedded_text[len(new_instance_image_paths) :],
337
instance_only=False,
338
)
339
train_dataset = tf.data.Dataset.zip((instance_dataset, class_dataset))
340
"""
341
## Check shapes
342
343
Now that the dataset has been prepared, let's quickly check what's inside it.
344
"""
345
346
sample_batch = next(iter(train_dataset))
347
print(sample_batch[0].keys(), sample_batch[1].keys())
348
349
for k in sample_batch[0]:
350
print(k, sample_batch[0][k].shape)
351
352
for k in sample_batch[1]:
353
print(k, sample_batch[1][k].shape)
354
355
"""
356
During training, we make use of these keys to gather the images and text embeddings and
357
concat them accordingly.
358
"""
359
360
"""
361
## DreamBooth training loop
362
363
Our DreamBooth training loop is very much inspired by
364
[this script](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py)
365
provided by the Diffusers team at Hugging Face. However, there is an important
366
difference to note. We only fine-tune the UNet (the model responsible for predicting
367
noise) and don't fine-tune the text encoder in this example. If you're looking for an
368
implementation that also performs the additional fine-tuning of the text encoder, refer
369
to [this repository](https://github.com/sayakpaul/dreambooth-keras/).
370
"""
371
372
import tensorflow.experimental.numpy as tnp
373
374
375
class DreamBoothTrainer(tf.keras.Model):
376
# Reference:
377
# https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py
378
379
def __init__(
380
self,
381
diffusion_model,
382
vae,
383
noise_scheduler,
384
use_mixed_precision=False,
385
prior_loss_weight=1.0,
386
max_grad_norm=1.0,
387
**kwargs,
388
):
389
super().__init__(**kwargs)
390
391
self.diffusion_model = diffusion_model
392
self.vae = vae
393
self.noise_scheduler = noise_scheduler
394
self.prior_loss_weight = prior_loss_weight
395
self.max_grad_norm = max_grad_norm
396
397
self.use_mixed_precision = use_mixed_precision
398
self.vae.trainable = False
399
400
def train_step(self, inputs):
401
instance_batch = inputs[0]
402
class_batch = inputs[1]
403
404
instance_images = instance_batch["instance_images"]
405
instance_embedded_text = instance_batch["instance_embedded_texts"]
406
class_images = class_batch["class_images"]
407
class_embedded_text = class_batch["class_embedded_texts"]
408
409
images = tf.concat([instance_images, class_images], 0)
410
embedded_texts = tf.concat([instance_embedded_text, class_embedded_text], 0)
411
batch_size = tf.shape(images)[0]
412
413
with tf.GradientTape() as tape:
414
# Project image into the latent space and sample from it.
415
latents = self.sample_from_encoder_outputs(self.vae(images, training=False))
416
# Know more about the magic number here:
417
# https://keras.io/examples/generative/fine_tune_via_textual_inversion/
418
latents = latents * 0.18215
419
420
# Sample noise that we'll add to the latents.
421
noise = tf.random.normal(tf.shape(latents))
422
423
# Sample a random timestep for each image.
424
timesteps = tnp.random.randint(
425
0, self.noise_scheduler.train_timesteps, (batch_size,)
426
)
427
428
# Add noise to the latents according to the noise magnitude at each timestep
429
# (this is the forward diffusion process).
430
noisy_latents = self.noise_scheduler.add_noise(
431
tf.cast(latents, noise.dtype), noise, timesteps
432
)
433
434
# Get the target for loss depending on the prediction type
435
# just the sampled noise for now.
436
target = noise # noise_schedule.predict_epsilon == True
437
438
# Predict the noise residual and compute loss.
439
timestep_embedding = tf.map_fn(
440
lambda t: self.get_timestep_embedding(t), timesteps, dtype=tf.float32
441
)
442
model_pred = self.diffusion_model(
443
[noisy_latents, timestep_embedding, embedded_texts], training=True
444
)
445
loss = self.compute_loss(target, model_pred)
446
if self.use_mixed_precision:
447
loss = self.optimizer.get_scaled_loss(loss)
448
449
# Update parameters of the diffusion model.
450
trainable_vars = self.diffusion_model.trainable_variables
451
gradients = tape.gradient(loss, trainable_vars)
452
if self.use_mixed_precision:
453
gradients = self.optimizer.get_unscaled_gradients(gradients)
454
gradients = [tf.clip_by_norm(g, self.max_grad_norm) for g in gradients]
455
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
456
457
return {m.name: m.result() for m in self.metrics}
458
459
def get_timestep_embedding(self, timestep, dim=320, max_period=10000):
460
half = dim // 2
461
log_max_period = tf.math.log(tf.cast(max_period, tf.float32))
462
freqs = tf.math.exp(
463
-log_max_period * tf.range(0, half, dtype=tf.float32) / half
464
)
465
args = tf.convert_to_tensor([timestep], dtype=tf.float32) * freqs
466
embedding = tf.concat([tf.math.cos(args), tf.math.sin(args)], 0)
467
return embedding
468
469
def sample_from_encoder_outputs(self, outputs):
470
mean, logvar = tf.split(outputs, 2, axis=-1)
471
logvar = tf.clip_by_value(logvar, -30.0, 20.0)
472
std = tf.exp(0.5 * logvar)
473
sample = tf.random.normal(tf.shape(mean), dtype=mean.dtype)
474
return mean + std * sample
475
476
def compute_loss(self, target, model_pred):
477
# Chunk the noise and model_pred into two parts and compute the loss
478
# on each part separately.
479
# Since the first half of the inputs has instance samples and the second half
480
# has class samples, we do the chunking accordingly.
481
model_pred, model_pred_prior = tf.split(
482
model_pred, num_or_size_splits=2, axis=0
483
)
484
target, target_prior = tf.split(target, num_or_size_splits=2, axis=0)
485
486
# Compute instance loss.
487
loss = self.compiled_loss(target, model_pred)
488
489
# Compute prior loss.
490
prior_loss = self.compiled_loss(target_prior, model_pred_prior)
491
492
# Add the prior loss to the instance loss.
493
loss = loss + self.prior_loss_weight * prior_loss
494
return loss
495
496
def save_weights(self, filepath, overwrite=True, save_format=None, options=None):
497
# Overriding this method will allow us to use the `ModelCheckpoint`
498
# callback directly with this trainer class. In this case, it will
499
# only checkpoint the `diffusion_model` since that's what we're training
500
# during fine-tuning.
501
self.diffusion_model.save_weights(
502
filepath=filepath,
503
overwrite=overwrite,
504
save_format=save_format,
505
options=options,
506
)
507
508
def load_weights(self, filepath, by_name=False, skip_mismatch=False, options=None):
509
# Similarly override `load_weights()` so that we can directly call it on
510
# the trainer class object.
511
self.diffusion_model.load_weights(
512
filepath=filepath,
513
by_name=by_name,
514
skip_mismatch=skip_mismatch,
515
options=options,
516
)
517
518
519
"""
520
## Trainer initialization
521
"""
522
523
# Comment it if you are not using a GPU having tensor cores.
524
tf.keras.mixed_precision.set_global_policy("mixed_float16")
525
526
use_mp = True # Set it to False if you're not using a GPU with tensor cores.
527
528
image_encoder = keras_cv.models.stable_diffusion.ImageEncoder()
529
dreambooth_trainer = DreamBoothTrainer(
530
diffusion_model=keras_cv.models.stable_diffusion.DiffusionModel(
531
resolution, resolution, max_prompt_length
532
),
533
# Remove the top layer from the encoder, which cuts off the variance and only
534
# returns the mean.
535
vae=tf.keras.Model(
536
image_encoder.input,
537
image_encoder.layers[-2].output,
538
),
539
noise_scheduler=keras_cv.models.stable_diffusion.NoiseScheduler(),
540
use_mixed_precision=use_mp,
541
)
542
543
# These hyperparameters come from this tutorial by Hugging Face:
544
# https://github.com/huggingface/diffusers/tree/main/examples/dreambooth
545
learning_rate = 5e-6
546
beta_1, beta_2 = 0.9, 0.999
547
weight_decay = (1e-2,)
548
epsilon = 1e-08
549
550
optimizer = tf.keras.optimizers.experimental.AdamW(
551
learning_rate=learning_rate,
552
weight_decay=weight_decay,
553
beta_1=beta_1,
554
beta_2=beta_2,
555
epsilon=epsilon,
556
)
557
dreambooth_trainer.compile(optimizer=optimizer, loss="mse")
558
559
"""
560
## Train!
561
562
We first calculate the number of epochs, we need to train for.
563
"""
564
565
num_update_steps_per_epoch = train_dataset.cardinality()
566
max_train_steps = 800
567
epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)
568
print(f"Training for {epochs} epochs.")
569
570
"""
571
And then we start training!
572
"""
573
574
ckpt_path = "dreambooth-unet.h5"
575
ckpt_callback = tf.keras.callbacks.ModelCheckpoint(
576
ckpt_path,
577
save_weights_only=True,
578
monitor="loss",
579
mode="min",
580
)
581
dreambooth_trainer.fit(train_dataset, epochs=epochs, callbacks=[ckpt_callback])
582
583
"""
584
## Experiments and inference
585
586
We ran various experiments with a slightly modified version of this example. Our
587
experiments are based on
588
[this repository](https://github.com/sayakpaul/dreambooth-keras/) and are inspired by
589
[this blog post](https://huggingface.co/blog/dreambooth) from Hugging Face.
590
591
First, let's see how we can use the fine-tuned checkpoint for running inference.
592
"""
593
594
# Initialize a new Stable Diffusion model.
595
dreambooth_model = keras_cv.models.StableDiffusion(
596
img_width=resolution, img_height=resolution, jit_compile=True
597
)
598
dreambooth_model.diffusion_model.load_weights(ckpt_path)
599
600
# Note how the unique identifier and the class have been used in the prompt.
601
prompt = f"A photo of {unique_id} {class_label} in a bucket"
602
num_imgs_to_gen = 3
603
604
images_dreamboothed = dreambooth_model.text_to_image(prompt, batch_size=num_imgs_to_gen)
605
plot_images(images_dreamboothed, prompt)
606
607
"""
608
Now, let's load checkpoints from a different experiment we conducted where we also
609
fine-tuned the text encoder along with the UNet:
610
"""
611
612
unet_weights = tf.keras.utils.get_file(
613
origin="https://huggingface.co/chansung/dreambooth-dog/resolve/main/lr%409e-06-max_train_steps%40200-train_text_encoder%40True-unet.h5"
614
)
615
text_encoder_weights = tf.keras.utils.get_file(
616
origin="https://huggingface.co/chansung/dreambooth-dog/resolve/main/lr%409e-06-max_train_steps%40200-train_text_encoder%40True-text_encoder.h5"
617
)
618
619
dreambooth_model.diffusion_model.load_weights(unet_weights)
620
dreambooth_model.text_encoder.load_weights(text_encoder_weights)
621
622
images_dreamboothed = dreambooth_model.text_to_image(prompt, batch_size=num_imgs_to_gen)
623
plot_images(images_dreamboothed, prompt)
624
625
"""
626
The default number of steps for generating an image in `text_to_image()`
627
[is 50](https://github.com/keras-team/keras-cv/blob/3575bc3b944564fe15b46b917e6555aa6a9d7be0/keras_cv/models/stable_diffusion/stable_diffusion.py#L73).
628
Let's increase it to 100.
629
"""
630
631
images_dreamboothed = dreambooth_model.text_to_image(
632
prompt, batch_size=num_imgs_to_gen, num_steps=100
633
)
634
plot_images(images_dreamboothed, prompt)
635
636
"""
637
Feel free to experiment with different prompts (don't forget to add the unique identifier
638
and the class label!) to see how the results change. We welcome you to check out our
639
codebase and more experimental results
640
[here](https://github.com/sayakpaul/dreambooth-keras#results). You can also read
641
[this blog post](https://huggingface.co/blog/dreambooth) to get more ideas.
642
"""
643
644
"""
645
## Acknowledgements
646
647
* Thanks to the
648
[DreamBooth example script](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py)
649
provided by Hugging Face which helped us a lot in getting the initial implementation
650
ready quickly.
651
* Getting DreamBooth to work on human faces can be challenging. We have compiled some
652
general recommendations
653
[here](https://github.com/sayakpaul/dreambooth-keras#notes-on-preparing-data-for-dreambooth-training-of-faces).
654
Thanks to
655
[Abhishek Thakur](https://no.linkedin.com/in/abhi1thakur)
656
for helping with these.
657
"""
658
659