Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/generative/finetune_stable_diffusion.py
3507 views
1
"""
2
Title: Fine-tuning Stable Diffusion
3
Author: [Sayak Paul](https://twitter.com/RisingSayak), [Chansung Park](https://twitter.com/algo_diver)
4
Date created: 2022/12/28
5
Last modified: 2023/01/13
6
Description: Fine-tuning Stable Diffusion using a custom image-caption dataset.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
This tutorial shows how to fine-tune a
14
[Stable Diffusion model](https://keras.io/guides/keras_cv/generate_images_with_stable_diffusion/)
15
on a custom dataset of `{image, caption}` pairs. We build on top of the fine-tuning
16
script provided by Hugging Face
17
[here](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py).
18
19
We assume that you have a high-level understanding of the Stable Diffusion model.
20
The following resources can be helpful if you're looking for more information in that regard:
21
22
* [High-performance image generation using Stable Diffusion in KerasCV](https://keras.io/guides/keras_cv/generate_images_with_stable_diffusion/)
23
* [Stable Diffusion with Diffusers](https://huggingface.co/blog/stable_diffusion)
24
25
It's highly recommended that you use a GPU with at least 30GB of memory to execute
26
the code.
27
28
By the end of the guide, you'll be able to generate images of interesting Pokémon:
29
30
![custom-pokemons](https://i.imgur.com/X4m614M.png)
31
32
The tutorial relies on KerasCV 0.4.0. Additionally, we need
33
at least TensorFlow 2.11 in order to use AdamW with mixed precision.
34
"""
35
36
"""shell
37
pip install keras-cv==0.6.0 -q
38
pip install -U tensorflow -q
39
pip install keras-core -q
40
"""
41
42
"""
43
## What are we fine-tuning?
44
45
A Stable Diffusion model can be decomposed into several key models:
46
47
* A text encoder that projects the input prompt to a latent space. (The caption
48
associated with an image is referred to as the "prompt".)
49
* A variational autoencoder (VAE) that projects an input image to a latent space acting
50
as an image vector space.
51
* A diffusion model that refines a latent vector and produces another latent vector, conditioned
52
on the encoded text prompt
53
* A decoder that generates images given a latent vector from the diffusion model.
54
55
It's worth noting that during the process of generating an image from a text prompt, the
56
image encoder is not typically employed.
57
58
However, during the process of fine-tuning, the workflow goes like the following:
59
60
1. An input text prompt is projected to a latent space by the text encoder.
61
2. An input image is projected to a latent space by the image encoder portion of the VAE.
62
3. A small amount of noise is added to the image latent vector for a given timestep.
63
4. The diffusion model uses latent vectors from these two spaces along with a timestep embedding
64
to predict the noise that was added to the image latent.
65
5. A reconstruction loss is calculated between the predicted noise and the original noise
66
added in step 3.
67
6. Finally, the diffusion model parameters are optimized w.r.t this loss using
68
gradient descent.
69
70
Note that only the diffusion model parameters are updated during fine-tuning, while the
71
(pre-trained) text and the image encoders are kept frozen.
72
73
Don't worry if this sounds complicated. The code is much simpler than this!
74
"""
75
76
"""
77
## Imports
78
"""
79
80
from textwrap import wrap
81
import os
82
83
import keras_cv
84
import matplotlib.pyplot as plt
85
import numpy as np
86
import pandas as pd
87
import tensorflow as tf
88
import tensorflow.experimental.numpy as tnp
89
from keras_cv.models.stable_diffusion.clip_tokenizer import SimpleTokenizer
90
from keras_cv.models.stable_diffusion.diffusion_model import DiffusionModel
91
from keras_cv.models.stable_diffusion.image_encoder import ImageEncoder
92
from keras_cv.models.stable_diffusion.noise_scheduler import NoiseScheduler
93
from keras_cv.models.stable_diffusion.text_encoder import TextEncoder
94
from tensorflow import keras
95
96
"""
97
## Data loading
98
99
We use the dataset
100
[Pokémon BLIP captions](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions).
101
However, we'll use a slightly different version which was derived from the original
102
dataset to fit better with `tf.data`. Refer to
103
[the documentation](https://huggingface.co/datasets/sayakpaul/pokemon-blip-original-version)
104
for more details.
105
"""
106
107
data_path = tf.keras.utils.get_file(
108
origin="https://huggingface.co/datasets/sayakpaul/pokemon-blip-original-version/resolve/main/pokemon_dataset.tar.gz",
109
untar=True,
110
)
111
112
data_frame = pd.read_csv(os.path.join(data_path, "data.csv"))
113
114
data_frame["image_path"] = data_frame["image_path"].apply(
115
lambda x: os.path.join(data_path, x)
116
)
117
data_frame.head()
118
119
"""
120
Since we have only 833 `{image, caption}` pairs, we can precompute the text embeddings from
121
the captions. Moreover, the text encoder will be kept frozen during the course of
122
fine-tuning, so we can save some compute by doing this.
123
124
Before we use the text encoder, we need to tokenize the captions.
125
"""
126
127
# The padding token and maximum prompt length are specific to the text encoder.
128
# If you're using a different text encoder be sure to change them accordingly.
129
PADDING_TOKEN = 49407
130
MAX_PROMPT_LENGTH = 77
131
132
# Load the tokenizer.
133
tokenizer = SimpleTokenizer()
134
135
136
# Method to tokenize and pad the tokens.
137
def process_text(caption):
138
tokens = tokenizer.encode(caption)
139
tokens = tokens + [PADDING_TOKEN] * (MAX_PROMPT_LENGTH - len(tokens))
140
return np.array(tokens)
141
142
143
# Collate the tokenized captions into an array.
144
tokenized_texts = np.empty((len(data_frame), MAX_PROMPT_LENGTH))
145
146
all_captions = list(data_frame["caption"].values)
147
for i, caption in enumerate(all_captions):
148
tokenized_texts[i] = process_text(caption)
149
150
"""
151
## Prepare a `tf.data.Dataset`
152
153
In this section, we'll prepare a `tf.data.Dataset` object from the input image file paths
154
and their corresponding caption tokens. The section will include the following:
155
156
* Pre-computation of the text embeddings from the tokenized captions.
157
* Loading and augmentation of the input images.
158
* Shuffling and batching of the dataset.
159
"""
160
161
RESOLUTION = 256
162
AUTO = tf.data.AUTOTUNE
163
POS_IDS = tf.convert_to_tensor([list(range(MAX_PROMPT_LENGTH))], dtype=tf.int32)
164
165
augmenter = keras.Sequential(
166
layers=[
167
keras_cv.layers.CenterCrop(RESOLUTION, RESOLUTION),
168
keras_cv.layers.RandomFlip(),
169
tf.keras.layers.Rescaling(scale=1.0 / 127.5, offset=-1),
170
]
171
)
172
text_encoder = TextEncoder(MAX_PROMPT_LENGTH)
173
174
175
def process_image(image_path, tokenized_text):
176
image = tf.io.read_file(image_path)
177
image = tf.io.decode_png(image, 3)
178
image = tf.image.resize(image, (RESOLUTION, RESOLUTION))
179
return image, tokenized_text
180
181
182
def apply_augmentation(image_batch, token_batch):
183
return augmenter(image_batch), token_batch
184
185
186
def run_text_encoder(image_batch, token_batch):
187
return (
188
image_batch,
189
token_batch,
190
text_encoder([token_batch, POS_IDS], training=False),
191
)
192
193
194
def prepare_dict(image_batch, token_batch, encoded_text_batch):
195
return {
196
"images": image_batch,
197
"tokens": token_batch,
198
"encoded_text": encoded_text_batch,
199
}
200
201
202
def prepare_dataset(image_paths, tokenized_texts, batch_size=1):
203
dataset = tf.data.Dataset.from_tensor_slices((image_paths, tokenized_texts))
204
dataset = dataset.shuffle(batch_size * 10)
205
dataset = dataset.map(process_image, num_parallel_calls=AUTO).batch(batch_size)
206
dataset = dataset.map(apply_augmentation, num_parallel_calls=AUTO)
207
dataset = dataset.map(run_text_encoder, num_parallel_calls=AUTO)
208
dataset = dataset.map(prepare_dict, num_parallel_calls=AUTO)
209
return dataset.prefetch(AUTO)
210
211
212
"""
213
The baseline Stable Diffusion model was trained using images with 512x512 resolution. It's
214
unlikely for a model that's trained using higher-resolution images to transfer well to
215
lower-resolution images. However, the current model will lead to OOM if we keep the
216
resolution to 512x512 (without enabling mixed-precision). Therefore, in the interest of
217
interactive demonstrations, we kept the input resolution to 256x256.
218
"""
219
220
# Prepare the dataset.
221
training_dataset = prepare_dataset(
222
np.array(data_frame["image_path"]), tokenized_texts, batch_size=4
223
)
224
225
# Take a sample batch and investigate.
226
sample_batch = next(iter(training_dataset))
227
228
for k in sample_batch:
229
print(k, sample_batch[k].shape)
230
231
"""
232
We can also take a look at the training images and their corresponding captions.
233
"""
234
235
plt.figure(figsize=(20, 10))
236
237
for i in range(3):
238
ax = plt.subplot(1, 4, i + 1)
239
plt.imshow((sample_batch["images"][i] + 1) / 2)
240
241
text = tokenizer.decode(sample_batch["tokens"][i].numpy().squeeze())
242
text = text.replace("<|startoftext|>", "")
243
text = text.replace("<|endoftext|>", "")
244
text = "\n".join(wrap(text, 12))
245
plt.title(text, fontsize=15)
246
247
plt.axis("off")
248
249
"""
250
## A trainer class for the fine-tuning loop
251
"""
252
253
254
class Trainer(tf.keras.Model):
255
# Reference:
256
# https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py
257
258
def __init__(
259
self,
260
diffusion_model,
261
vae,
262
noise_scheduler,
263
use_mixed_precision=False,
264
max_grad_norm=1.0,
265
**kwargs
266
):
267
super().__init__(**kwargs)
268
269
self.diffusion_model = diffusion_model
270
self.vae = vae
271
self.noise_scheduler = noise_scheduler
272
self.max_grad_norm = max_grad_norm
273
274
self.use_mixed_precision = use_mixed_precision
275
self.vae.trainable = False
276
277
def train_step(self, inputs):
278
images = inputs["images"]
279
encoded_text = inputs["encoded_text"]
280
batch_size = tf.shape(images)[0]
281
282
with tf.GradientTape() as tape:
283
# Project image into the latent space and sample from it.
284
latents = self.sample_from_encoder_outputs(self.vae(images, training=False))
285
# Know more about the magic number here:
286
# https://keras.io/examples/generative/fine_tune_via_textual_inversion/
287
latents = latents * 0.18215
288
289
# Sample noise that we'll add to the latents.
290
noise = tf.random.normal(tf.shape(latents))
291
292
# Sample a random timestep for each image.
293
timesteps = tnp.random.randint(
294
0, self.noise_scheduler.train_timesteps, (batch_size,)
295
)
296
297
# Add noise to the latents according to the noise magnitude at each timestep
298
# (this is the forward diffusion process).
299
noisy_latents = self.noise_scheduler.add_noise(
300
tf.cast(latents, noise.dtype), noise, timesteps
301
)
302
303
# Get the target for loss depending on the prediction type
304
# just the sampled noise for now.
305
target = noise # noise_schedule.predict_epsilon == True
306
307
# Predict the noise residual and compute loss.
308
timestep_embedding = tf.map_fn(
309
lambda t: self.get_timestep_embedding(t), timesteps, dtype=tf.float32
310
)
311
timestep_embedding = tf.squeeze(timestep_embedding, 1)
312
model_pred = self.diffusion_model(
313
[noisy_latents, timestep_embedding, encoded_text], training=True
314
)
315
loss = self.compiled_loss(target, model_pred)
316
if self.use_mixed_precision:
317
loss = self.optimizer.get_scaled_loss(loss)
318
319
# Update parameters of the diffusion model.
320
trainable_vars = self.diffusion_model.trainable_variables
321
gradients = tape.gradient(loss, trainable_vars)
322
if self.use_mixed_precision:
323
gradients = self.optimizer.get_unscaled_gradients(gradients)
324
gradients = [tf.clip_by_norm(g, self.max_grad_norm) for g in gradients]
325
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
326
327
return {m.name: m.result() for m in self.metrics}
328
329
def get_timestep_embedding(self, timestep, dim=320, max_period=10000):
330
half = dim // 2
331
log_max_period = tf.math.log(tf.cast(max_period, tf.float32))
332
freqs = tf.math.exp(
333
-log_max_period * tf.range(0, half, dtype=tf.float32) / half
334
)
335
args = tf.convert_to_tensor([timestep], dtype=tf.float32) * freqs
336
embedding = tf.concat([tf.math.cos(args), tf.math.sin(args)], 0)
337
embedding = tf.reshape(embedding, [1, -1])
338
return embedding
339
340
def sample_from_encoder_outputs(self, outputs):
341
mean, logvar = tf.split(outputs, 2, axis=-1)
342
logvar = tf.clip_by_value(logvar, -30.0, 20.0)
343
std = tf.exp(0.5 * logvar)
344
sample = tf.random.normal(tf.shape(mean), dtype=mean.dtype)
345
return mean + std * sample
346
347
def save_weights(self, filepath, overwrite=True, save_format=None, options=None):
348
# Overriding this method will allow us to use the `ModelCheckpoint`
349
# callback directly with this trainer class. In this case, it will
350
# only checkpoint the `diffusion_model` since that's what we're training
351
# during fine-tuning.
352
self.diffusion_model.save_weights(
353
filepath=filepath,
354
overwrite=overwrite,
355
save_format=save_format,
356
options=options,
357
)
358
359
360
"""
361
One important implementation detail to note here: Instead of directly taking
362
the latent vector produced by the image encoder (which is a VAE), we sample from the
363
mean and log-variance predicted by it. This way, we can achieve better sample
364
quality and diversity.
365
366
It's common to add support for mixed-precision training along with exponential
367
moving averaging of model weights for fine-tuning these models. However, in the interest
368
of brevity, we discard those elements. More on this later in the tutorial.
369
"""
370
371
"""
372
## Initialize the trainer and compile it
373
"""
374
375
# Enable mixed-precision training if the underlying GPU has tensor cores.
376
USE_MP = True
377
if USE_MP:
378
keras.mixed_precision.set_global_policy("mixed_float16")
379
380
image_encoder = ImageEncoder()
381
diffusion_ft_trainer = Trainer(
382
diffusion_model=DiffusionModel(RESOLUTION, RESOLUTION, MAX_PROMPT_LENGTH),
383
# Remove the top layer from the encoder, which cuts off the variance and only
384
# returns the mean.
385
vae=tf.keras.Model(
386
image_encoder.input,
387
image_encoder.layers[-2].output,
388
),
389
noise_scheduler=NoiseScheduler(),
390
use_mixed_precision=USE_MP,
391
)
392
393
# These hyperparameters come from this tutorial by Hugging Face:
394
# https://huggingface.co/docs/diffusers/training/text2image
395
lr = 1e-5
396
beta_1, beta_2 = 0.9, 0.999
397
weight_decay = (1e-2,)
398
epsilon = 1e-08
399
400
optimizer = tf.keras.optimizers.experimental.AdamW(
401
learning_rate=lr,
402
weight_decay=weight_decay,
403
beta_1=beta_1,
404
beta_2=beta_2,
405
epsilon=epsilon,
406
)
407
diffusion_ft_trainer.compile(optimizer=optimizer, loss="mse")
408
409
"""
410
## Fine-tuning
411
412
To keep the runtime of this tutorial short, we just fine-tune for an epoch.
413
"""
414
415
epochs = 1
416
ckpt_path = "finetuned_stable_diffusion.h5"
417
ckpt_callback = tf.keras.callbacks.ModelCheckpoint(
418
ckpt_path,
419
save_weights_only=True,
420
monitor="loss",
421
mode="min",
422
)
423
diffusion_ft_trainer.fit(training_dataset, epochs=epochs, callbacks=[ckpt_callback])
424
425
"""
426
## Inference
427
428
We fine-tuned the model for 60 epochs on an image resolution of 512x512. To allow
429
training with this resolution, we incorporated mixed-precision support. You can
430
check out
431
[this repository](https://github.com/sayakpaul/stabe-diffusion-keras-ft)
432
for more details. It additionally provides support for exponential moving averaging of
433
the fine-tuned model parameters and model checkpointing.
434
435
436
For this section, we'll use the checkpoint derived after 60 epochs of fine-tuning.
437
"""
438
439
weights_path = tf.keras.utils.get_file(
440
origin="https://huggingface.co/sayakpaul/kerascv_sd_pokemon_finetuned/resolve/main/ckpt_epochs_72_res_512_mp_True.h5"
441
)
442
443
img_height = img_width = 512
444
pokemon_model = keras_cv.models.StableDiffusion(
445
img_width=img_width, img_height=img_height
446
)
447
# We just reload the weights of the fine-tuned diffusion model.
448
pokemon_model.diffusion_model.load_weights(weights_path)
449
450
"""
451
Now, we can take this model for a test-drive.
452
"""
453
454
prompts = ["Yoda", "Hello Kitty", "A pokemon with red eyes"]
455
images_to_generate = 3
456
outputs = {}
457
458
for prompt in prompts:
459
generated_images = pokemon_model.text_to_image(
460
prompt, batch_size=images_to_generate, unconditional_guidance_scale=40
461
)
462
outputs.update({prompt: generated_images})
463
464
"""
465
With 60 epochs of fine-tuning (a good number is about 70), the generated images were not
466
up to the mark. So, we experimented with the number of steps Stable Diffusion takes
467
during the inference time and the `unconditional_guidance_scale` parameter.
468
469
We found the best results with this checkpoint with `unconditional_guidance_scale` set to
470
40.
471
"""
472
473
474
def plot_images(images, title):
475
plt.figure(figsize=(20, 20))
476
for i in range(len(images)):
477
ax = plt.subplot(1, len(images), i + 1)
478
plt.imshow(images[i])
479
plt.title(title, fontsize=12)
480
plt.axis("off")
481
482
483
for prompt in outputs:
484
plot_images(outputs[prompt], prompt)
485
486
"""
487
We can notice that the model has started adapting to the style of our dataset. You can
488
check the
489
[accompanying repository](https://github.com/sayakpaul/stable-diffusion-keras-ft#results)
490
for more comparisons and commentary. If you're feeling adventurous to try out a demo,
491
you can check out
492
[this resource](https://huggingface.co/spaces/sayakpaul/pokemon-sd-kerascv).
493
"""
494
495
"""
496
## Conclusion and acknowledgements
497
498
We demonstrated how to fine-tune the Stable Diffusion model on a custom dataset. While
499
the results are far from aesthetically pleasing, we believe with more epochs of
500
fine-tuning, they will likely improve. To enable that, having support for gradient
501
accumulation and distributed training is crucial. This can be thought of as the next step
502
in this tutorial.
503
504
There is another interesting way in which Stable Diffusion models can be fine-tuned,
505
called textual inversion. You can refer to
506
[this tutorial](https://keras.io/examples/generative/fine_tune_via_textual_inversion/)
507
to know more about it.
508
509
We'd like to acknowledge the GCP Credit support from ML Developer Programs' team at
510
Google. We'd like to thank the Hugging Face team for providing the
511
[fine-tuning script](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py)
512
. It's very readable and easy to understand.
513
"""
514
515