Path: blob/master/examples/generative/ipynb/fine_tune_via_textual_inversion.ipynb
3508 views
Teach StableDiffusion new concepts via Textual Inversion
Authors: Ian Stenbit, lukewood
Date created: 2022/12/09
Last modified: 2022/12/09
Description: Learning new visual concepts with KerasCV's StableDiffusion implementation.
Textual Inversion
Since its release, StableDiffusion has quickly become a favorite amongst the generative machine learning community. The high volume of traffic has led to open source contributed improvements, heavy prompt engineering, and even the invention of novel algorithms.
Perhaps the most impressive new algorithm being used is Textual Inversion, presented in An Image is Worth One Word: Personalizing Text-to-Image Generation using Textual Inversion.
Textual Inversion is the process of teaching an image generator a specific visual concept through the use of fine-tuning. In the diagram below, you can see an example of this process where the authors teach the model new concepts, calling them "S_*".
Conceptually, textual inversion works by learning a token embedding for a new text token, keeping the remaining components of StableDiffusion frozen.
This guide shows you how to fine-tune the StableDiffusion model shipped in KerasCV using the Textual-Inversion algorithm. By the end of the guide, you will be able to write the "Gandalf the Gray as a <my-funny-cat-token>".
First, let's import the packages we need, and create a StableDiffusion instance so we can use some of its subcomponents for fine-tuning.
Next, let's define a visualization utility to show off the generated images:
Assembling a text-image pair dataset
In order to train the embedding of our new token, we first must assemble a dataset consisting of text-image pairs. Each sample from the dataset must contain an image of the concept we are teaching StableDiffusion, as well as a caption accurately representing the content of the image. In this tutorial, we will teach StableDiffusion the concept of Luke and Ian's GitHub avatars:
First, let's construct an image dataset of cat dolls:
Next, we assemble a text dataset:
Finally, we zip our datasets together to produce a text-image pair dataset.
In order to ensure our prompts are descriptive, we use extremely generic prompts.
Let's try this out with some sample images and prompts.
On the importance of prompt accuracy
During our first attempt at writing this guide we included images of groups of these cat dolls in our dataset but continued to use the generic prompts listed above. Our results were anecdotally poor. For example, here's cat doll gandalf using this method:
It's conceptually close, but it isn't as great as it can be.
In order to remedy this, we began experimenting with splitting our images into images of singular cat dolls and groups of cat dolls. Following this split, we came up with new prompts for the group shots.
Training on text-image pairs that accurately represent the content boosted the quality of our results substantially. This speaks to the importance of prompt accuracy.
In addition to separating the images into singular and group images, we also remove some inaccurate prompts; such as "a dark photo of the {}"
Keeping this in mind, we assemble our final training dataset below:
Looks great!
Next, we assemble a dataset of groups of our GitHub avatars:
Finally, we concatenate the two datasets:
Adding a new token to the text encoder
Next, we create a new text encoder for the StableDiffusion model and add our new embedding for '' into the model.
Let's construct a new TextEncoder and prepare it.
Training
Now we can move on to the exciting part: training!
In TextualInversion, the only piece of the model that is trained is the embedding vector. Let's freeze the rest of the model.
Let's confirm the proper weights are set to trainable.
Training the new embedding
In order to train the embedding, we need a couple of utilities. We import a NoiseScheduler from KerasCV, and define the following utilities below:
sample_from_encoder_outputs
is a wrapper around the base StableDiffusion image encoder which samples from the statistical distribution produced by the image encoder, rather than taking just the mean (like many other SD applications)get_timestep_embedding
produces an embedding for a specified timestep for the diffusion modelget_position_ids
produces a tensor of position IDs for the text encoder (which is just a series from[1, MAX_PROMPT_LENGTH]
)
Next, we implement a StableDiffusionFineTuner
, which is a subclass of keras.Model
that overrides train_step
to train the token embeddings of our text encoder. This is the core of the Textual Inversion algorithm.
Abstractly speaking, the train step takes a sample from the output of the frozen SD image encoder's latent distribution for a training image, adds noise to that sample, and then passes that noisy sample to the frozen diffusion model. The hidden state of the diffusion model is the output of the text encoder for the prompt corresponding to the image.
Our final goal state is that the diffusion model is able to separate the noise from the sample using the text encoding as hidden state, so our loss is the mean-squared error of the noise and the output of the diffusion model (which has, ideally, removed the image latents from the noise).
We compute gradients for only the token embeddings of the text encoder, and in the train step we zero-out the gradients for all tokens other than the token that we're learning.
See in-line code comments for more details about the train step.
Before we start training, let's take a look at what StableDiffusion produces for our token.
As you can see, the model still thinks of our token as a cat, as this was the seed token we used to initialize our custom token.
Now, to get started with training, we can just compile()
our model like any other Keras model. Before doing so, we also instantiate a noise scheduler for training and configure our training parameters such as learning rate and optimizer.
To monitor training, we can produce a keras.callbacks.Callback
to produce a few images every epoch using our custom token.
We create three callbacks with different prompts so that we can see how they progress over the course of training. We use a fixed seed so that we can easily see the progression of the learned token.
Now, all that is left to do is to call model.fit()
!
It's pretty fun to see how the model learns our new token over time. Play around with it and see how you can tune training parameters and your training dataset to produce the best images!
Taking the Fine Tuned Model for a Spin
Now for the really fun part. We've learned a token embedding for our custom token, so now we can generate images with StableDiffusion the same way we would for any other token!
Here are some fun example prompts to get you started, with sample outputs from our cat doll token!
Conclusions
Using the Textual Inversion algorithm you can teach StableDiffusion new concepts!
Some possible next steps to follow:
Try out your own prompts
Teach the model a style
Gather a dataset of your favorite pet cat or dog and teach the model about it