Path: blob/master/examples/generative/ipynb/finetune_stable_diffusion.ipynb
3508 views
Fine-tuning Stable Diffusion
Author: Sayak Paul, Chansung Park
Date created: 2022/12/28
Last modified: 2023/01/13
Description: Fine-tuning Stable Diffusion using a custom image-caption dataset.
Introduction
This tutorial shows how to fine-tune a Stable Diffusion model on a custom dataset of {image, caption}
pairs. We build on top of the fine-tuning script provided by Hugging Face here.
We assume that you have a high-level understanding of the Stable Diffusion model. The following resources can be helpful if you're looking for more information in that regard:
It's highly recommended that you use a GPU with at least 30GB of memory to execute the code.
By the end of the guide, you'll be able to generate images of interesting Pokémon:
The tutorial relies on KerasCV 0.4.0. Additionally, we need at least TensorFlow 2.11 in order to use AdamW with mixed precision.
What are we fine-tuning?
A Stable Diffusion model can be decomposed into several key models:
A text encoder that projects the input prompt to a latent space. (The caption associated with an image is referred to as the "prompt".)
A variational autoencoder (VAE) that projects an input image to a latent space acting as an image vector space.
A diffusion model that refines a latent vector and produces another latent vector, conditioned on the encoded text prompt
A decoder that generates images given a latent vector from the diffusion model.
It's worth noting that during the process of generating an image from a text prompt, the image encoder is not typically employed.
However, during the process of fine-tuning, the workflow goes like the following:
An input text prompt is projected to a latent space by the text encoder.
An input image is projected to a latent space by the image encoder portion of the VAE.
A small amount of noise is added to the image latent vector for a given timestep.
The diffusion model uses latent vectors from these two spaces along with a timestep embedding to predict the noise that was added to the image latent.
A reconstruction loss is calculated between the predicted noise and the original noise added in step 3.
Finally, the diffusion model parameters are optimized w.r.t this loss using gradient descent.
Note that only the diffusion model parameters are updated during fine-tuning, while the (pre-trained) text and the image encoders are kept frozen.
Don't worry if this sounds complicated. The code is much simpler than this!
Imports
Data loading
We use the dataset Pokémon BLIP captions. However, we'll use a slightly different version which was derived from the original dataset to fit better with tf.data
. Refer to the documentation for more details.
Since we have only 833 {image, caption}
pairs, we can precompute the text embeddings from the captions. Moreover, the text encoder will be kept frozen during the course of fine-tuning, so we can save some compute by doing this.
Before we use the text encoder, we need to tokenize the captions.
Prepare a tf.data.Dataset
In this section, we'll prepare a tf.data.Dataset
object from the input image file paths and their corresponding caption tokens. The section will include the following:
Pre-computation of the text embeddings from the tokenized captions.
Loading and augmentation of the input images.
Shuffling and batching of the dataset.
The baseline Stable Diffusion model was trained using images with 512x512 resolution. It's unlikely for a model that's trained using higher-resolution images to transfer well to lower-resolution images. However, the current model will lead to OOM if we keep the resolution to 512x512 (without enabling mixed-precision). Therefore, in the interest of interactive demonstrations, we kept the input resolution to 256x256.
We can also take a look at the training images and their corresponding captions.
A trainer class for the fine-tuning loop
One important implementation detail to note here: Instead of directly taking the latent vector produced by the image encoder (which is a VAE), we sample from the mean and log-variance predicted by it. This way, we can achieve better sample quality and diversity.
It's common to add support for mixed-precision training along with exponential moving averaging of model weights for fine-tuning these models. However, in the interest of brevity, we discard those elements. More on this later in the tutorial.
Initialize the trainer and compile it
Fine-tuning
To keep the runtime of this tutorial short, we just fine-tune for an epoch.
Inference
We fine-tuned the model for 60 epochs on an image resolution of 512x512. To allow training with this resolution, we incorporated mixed-precision support. You can check out this repository for more details. It additionally provides support for exponential moving averaging of the fine-tuned model parameters and model checkpointing.
For this section, we'll use the checkpoint derived after 60 epochs of fine-tuning.
Now, we can take this model for a test-drive.
With 60 epochs of fine-tuning (a good number is about 70), the generated images were not up to the mark. So, we experimented with the number of steps Stable Diffusion takes during the inference time and the unconditional_guidance_scale
parameter.
We found the best results with this checkpoint with unconditional_guidance_scale
set to 40.
We can notice that the model has started adapting to the style of our dataset. You can check the accompanying repository for more comparisons and commentary. If you're feeling adventurous to try out a demo, you can check out this resource.
Conclusion and acknowledgements
We demonstrated how to fine-tune the Stable Diffusion model on a custom dataset. While the results are far from aesthetically pleasing, we believe with more epochs of fine-tuning, they will likely improve. To enable that, having support for gradient accumulation and distributed training is crucial. This can be thought of as the next step in this tutorial.
There is another interesting way in which Stable Diffusion models can be fine-tuned, called textual inversion. You can refer to this tutorial to know more about it.
We'd like to acknowledge the GCP Credit support from ML Developer Programs' team at Google. We'd like to thank the Hugging Face team for providing the fine-tuning script . It's very readable and easy to understand.