Path: blob/master/examples/generative/fine_tune_via_textual_inversion.py
3507 views
"""1Title: Teach StableDiffusion new concepts via Textual Inversion2Authors: Ian Stenbit, [lukewood](https://lukewood.xyz)3Date created: 2022/12/094Last modified: 2022/12/095Description: Learning new visual concepts with KerasCV's StableDiffusion implementation.6"""78"""9## Textual Inversion1011Since its release, StableDiffusion has quickly become a favorite amongst12the generative machine learning community.13The high volume of traffic has led to open source contributed improvements,14heavy prompt engineering, and even the invention of novel algorithms.1516Perhaps the most impressive new algorithm being used is17[Textual Inversion](https://github.com/rinongal/textual_inversion), presented in18[_An Image is Worth One Word: Personalizing Text-to-Image Generation using Textual Inversion_](https://textual-inversion.github.io/).1920Textual Inversion is the process of teaching an image generator a specific visual concept21through the use of fine-tuning. In the diagram below, you can see an22example of this process where the authors teach the model new concepts, calling them23"S_*".24252627Conceptually, textual inversion works by learning a token embedding for a new text28token, keeping the remaining components of StableDiffusion frozen.2930This guide shows you how to fine-tune the StableDiffusion model shipped in KerasCV31using the Textual-Inversion algorithm. By the end of the guide, you will be able to32write the "Gandalf the Gray as a <my-funny-cat-token>".3334353637First, let's import the packages we need, and create a38StableDiffusion instance so we can use some of its subcomponents for fine-tuning.39"""4041"""shell42pip install -q git+https://github.com/keras-team/keras-cv.git43pip install -q tensorflow==2.11.044"""4546import math4748import keras_cv49import numpy as np50import tensorflow as tf51from keras_cv import layers as cv_layers52from keras_cv.models.stable_diffusion import NoiseScheduler53from tensorflow import keras54import matplotlib.pyplot as plt5556stable_diffusion = keras_cv.models.StableDiffusion()5758"""59Next, let's define a visualization utility to show off the generated images:60"""616263def plot_images(images):64plt.figure(figsize=(20, 20))65for i in range(len(images)):66ax = plt.subplot(1, len(images), i + 1)67plt.imshow(images[i])68plt.axis("off")697071"""72## Assembling a text-image pair dataset7374In order to train the embedding of our new token, we first must assemble a dataset75consisting of text-image pairs.76Each sample from the dataset must contain an image of the concept we are teaching77StableDiffusion, as well as a caption accurately representing the content of the image.78In this tutorial, we will teach StableDiffusion the concept of Luke and Ian's GitHub79avatars:80818283First, let's construct an image dataset of cat dolls:84"""858687def assemble_image_dataset(urls):88# Fetch all remote files89files = [tf.keras.utils.get_file(origin=url) for url in urls]9091# Resize images92resize = keras.layers.Resizing(height=512, width=512, crop_to_aspect_ratio=True)93images = [keras.utils.load_img(img) for img in files]94images = [keras.utils.img_to_array(img) for img in images]95images = np.array([resize(img) for img in images])9697# The StableDiffusion image encoder requires images to be normalized to the98# [-1, 1] pixel value range99images = images / 127.5 - 1100101# Create the tf.data.Dataset102image_dataset = tf.data.Dataset.from_tensor_slices(images)103104# Shuffle and introduce random noise105image_dataset = image_dataset.shuffle(50, reshuffle_each_iteration=True)106image_dataset = image_dataset.map(107cv_layers.RandomCropAndResize(108target_size=(512, 512),109crop_area_factor=(0.8, 1.0),110aspect_ratio_factor=(1.0, 1.0),111),112num_parallel_calls=tf.data.AUTOTUNE,113)114image_dataset = image_dataset.map(115cv_layers.RandomFlip(mode="horizontal"),116num_parallel_calls=tf.data.AUTOTUNE,117)118return image_dataset119120121"""122Next, we assemble a text dataset:123"""124125MAX_PROMPT_LENGTH = 77126placeholder_token = "<my-funny-cat-token>"127128129def pad_embedding(embedding):130return embedding + (131[stable_diffusion.tokenizer.end_of_text] * (MAX_PROMPT_LENGTH - len(embedding))132)133134135stable_diffusion.tokenizer.add_tokens(placeholder_token)136137138def assemble_text_dataset(prompts):139prompts = [prompt.format(placeholder_token) for prompt in prompts]140embeddings = [stable_diffusion.tokenizer.encode(prompt) for prompt in prompts]141embeddings = [np.array(pad_embedding(embedding)) for embedding in embeddings]142text_dataset = tf.data.Dataset.from_tensor_slices(embeddings)143text_dataset = text_dataset.shuffle(100, reshuffle_each_iteration=True)144return text_dataset145146147"""148Finally, we zip our datasets together to produce a text-image pair dataset.149"""150151152def assemble_dataset(urls, prompts):153image_dataset = assemble_image_dataset(urls)154text_dataset = assemble_text_dataset(prompts)155# the image dataset is quite short, so we repeat it to match the length of the156# text prompt dataset157image_dataset = image_dataset.repeat()158# we use the text prompt dataset to determine the length of the dataset. Due to159# the fact that there are relatively few prompts we repeat the dataset 5 times.160# we have found that this anecdotally improves results.161text_dataset = text_dataset.repeat(5)162return tf.data.Dataset.zip((image_dataset, text_dataset))163164165"""166In order to ensure our prompts are descriptive, we use extremely generic prompts.167168Let's try this out with some sample images and prompts.169"""170171train_ds = assemble_dataset(172urls=[173"https://i.imgur.com/VIedH1X.jpg",174"https://i.imgur.com/eBw13hE.png",175"https://i.imgur.com/oJ3rSg7.png",176"https://i.imgur.com/5mCL6Df.jpg",177"https://i.imgur.com/4Q6WWyI.jpg",178],179prompts=[180"a photo of a {}",181"a rendering of a {}",182"a cropped photo of the {}",183"the photo of a {}",184"a photo of a clean {}",185"a dark photo of the {}",186"a photo of my {}",187"a photo of the cool {}",188"a close-up photo of a {}",189"a bright photo of the {}",190"a cropped photo of a {}",191"a photo of the {}",192"a good photo of the {}",193"a photo of one {}",194"a close-up photo of the {}",195"a rendition of the {}",196"a photo of the clean {}",197"a rendition of a {}",198"a photo of a nice {}",199"a good photo of a {}",200"a photo of the nice {}",201"a photo of the small {}",202"a photo of the weird {}",203"a photo of the large {}",204"a photo of a cool {}",205"a photo of a small {}",206],207)208209"""210## On the importance of prompt accuracy211212During our first attempt at writing this guide we included images of groups of these cat213dolls in our dataset but continued to use the generic prompts listed above.214Our results were anecdotally poor. For example, here's cat doll gandalf using this method:215216217218It's conceptually close, but it isn't as great as it can be.219220In order to remedy this, we began experimenting with splitting our images into images of221singular cat dolls and groups of cat dolls.222Following this split, we came up with new prompts for the group shots.223224Training on text-image pairs that accurately represent the content boosted the quality225of our results *substantially*. This speaks to the importance of prompt accuracy.226227In addition to separating the images into singular and group images, we also remove some228inaccurate prompts; such as "a dark photo of the {}"229230Keeping this in mind, we assemble our final training dataset below:231"""232233single_ds = assemble_dataset(234urls=[235"https://i.imgur.com/VIedH1X.jpg",236"https://i.imgur.com/eBw13hE.png",237"https://i.imgur.com/oJ3rSg7.png",238"https://i.imgur.com/5mCL6Df.jpg",239"https://i.imgur.com/4Q6WWyI.jpg",240],241prompts=[242"a photo of a {}",243"a rendering of a {}",244"a cropped photo of the {}",245"the photo of a {}",246"a photo of a clean {}",247"a photo of my {}",248"a photo of the cool {}",249"a close-up photo of a {}",250"a bright photo of the {}",251"a cropped photo of a {}",252"a photo of the {}",253"a good photo of the {}",254"a photo of one {}",255"a close-up photo of the {}",256"a rendition of the {}",257"a photo of the clean {}",258"a rendition of a {}",259"a photo of a nice {}",260"a good photo of a {}",261"a photo of the nice {}",262"a photo of the small {}",263"a photo of the weird {}",264"a photo of the large {}",265"a photo of a cool {}",266"a photo of a small {}",267],268)269270"""271272273Looks great!274275Next, we assemble a dataset of groups of our GitHub avatars:276"""277278group_ds = assemble_dataset(279urls=[280"https://i.imgur.com/yVmZ2Qa.jpg",281"https://i.imgur.com/JbyFbZJ.jpg",282"https://i.imgur.com/CCubd3q.jpg",283],284prompts=[285"a photo of a group of {}",286"a rendering of a group of {}",287"a cropped photo of the group of {}",288"the photo of a group of {}",289"a photo of a clean group of {}",290"a photo of my group of {}",291"a photo of a cool group of {}",292"a close-up photo of a group of {}",293"a bright photo of the group of {}",294"a cropped photo of a group of {}",295"a photo of the group of {}",296"a good photo of the group of {}",297"a photo of one group of {}",298"a close-up photo of the group of {}",299"a rendition of the group of {}",300"a photo of the clean group of {}",301"a rendition of a group of {}",302"a photo of a nice group of {}",303"a good photo of a group of {}",304"a photo of the nice group of {}",305"a photo of the small group of {}",306"a photo of the weird group of {}",307"a photo of the large group of {}",308"a photo of a cool group of {}",309"a photo of a small group of {}",310],311)312313"""314315316Finally, we concatenate the two datasets:317"""318319train_ds = single_ds.concatenate(group_ds)320train_ds = train_ds.batch(1).shuffle(321train_ds.cardinality(), reshuffle_each_iteration=True322)323324"""325## Adding a new token to the text encoder326327Next, we create a new text encoder for the StableDiffusion model and add our new328embedding for '<my-funny-cat-token>' into the model.329"""330tokenized_initializer = stable_diffusion.tokenizer.encode("cat")[1]331new_weights = stable_diffusion.text_encoder.layers[2].token_embedding(332tf.constant(tokenized_initializer)333)334335# Get len of .vocab instead of tokenizer336new_vocab_size = len(stable_diffusion.tokenizer.vocab)337338# The embedding layer is the 2nd layer in the text encoder339old_token_weights = stable_diffusion.text_encoder.layers[3402341].token_embedding.get_weights()342old_position_weights = stable_diffusion.text_encoder.layers[3432344].position_embedding.get_weights()345346old_token_weights = old_token_weights[0]347new_weights = np.expand_dims(new_weights, axis=0)348new_weights = np.concatenate([old_token_weights, new_weights], axis=0)349350351"""352Let's construct a new TextEncoder and prepare it.353"""354355# Have to set download_weights False so we can init (otherwise tries to load weights)356new_encoder = keras_cv.models.stable_diffusion.TextEncoder(357keras_cv.models.stable_diffusion.stable_diffusion.MAX_PROMPT_LENGTH,358vocab_size=new_vocab_size,359download_weights=False,360)361for index, layer in enumerate(stable_diffusion.text_encoder.layers):362# Layer 2 is the embedding layer, so we omit it from our weight-copying363if index == 2:364continue365new_encoder.layers[index].set_weights(layer.get_weights())366367368new_encoder.layers[2].token_embedding.set_weights([new_weights])369new_encoder.layers[2].position_embedding.set_weights(old_position_weights)370371stable_diffusion._text_encoder = new_encoder372stable_diffusion._text_encoder.compile(jit_compile=True)373374"""375## Training376377Now we can move on to the exciting part: training!378379In TextualInversion, the only piece of the model that is trained is the embedding vector.380Let's freeze the rest of the model.381"""382383384stable_diffusion.diffusion_model.trainable = False385stable_diffusion.decoder.trainable = False386stable_diffusion.text_encoder.trainable = True387388stable_diffusion.text_encoder.layers[2].trainable = True389390391def traverse_layers(layer):392if hasattr(layer, "layers"):393for layer in layer.layers:394yield layer395if hasattr(layer, "token_embedding"):396yield layer.token_embedding397if hasattr(layer, "position_embedding"):398yield layer.position_embedding399400401for layer in traverse_layers(stable_diffusion.text_encoder):402if isinstance(layer, keras.layers.Embedding) or "clip_embedding" in layer.name:403layer.trainable = True404else:405layer.trainable = False406407new_encoder.layers[2].position_embedding.trainable = False408409"""410Let's confirm the proper weights are set to trainable.411"""412413all_models = [414stable_diffusion.text_encoder,415stable_diffusion.diffusion_model,416stable_diffusion.decoder,417]418print([[w.shape for w in model.trainable_weights] for model in all_models])419420"""421## Training the new embedding422423In order to train the embedding, we need a couple of utilities.424We import a NoiseScheduler from KerasCV, and define the following utilities below:425426- `sample_from_encoder_outputs` is a wrapper around the base StableDiffusion image427encoder which samples from the statistical distribution produced by the image428encoder, rather than taking just the mean (like many other SD applications)429- `get_timestep_embedding` produces an embedding for a specified timestep for the430diffusion model431- `get_position_ids` produces a tensor of position IDs for the text encoder (which is just a432series from `[1, MAX_PROMPT_LENGTH]`)433"""434435436# Remove the top layer from the encoder, which cuts off the variance and only returns437# the mean438training_image_encoder = keras.Model(439stable_diffusion.image_encoder.input,440stable_diffusion.image_encoder.layers[-2].output,441)442443444def sample_from_encoder_outputs(outputs):445mean, logvar = tf.split(outputs, 2, axis=-1)446logvar = tf.clip_by_value(logvar, -30.0, 20.0)447std = tf.exp(0.5 * logvar)448sample = tf.random.normal(tf.shape(mean))449return mean + std * sample450451452def get_timestep_embedding(timestep, dim=320, max_period=10000):453half = dim // 2454freqs = tf.math.exp(455-math.log(max_period) * tf.range(0, half, dtype=tf.float32) / half456)457args = tf.convert_to_tensor([timestep], dtype=tf.float32) * freqs458embedding = tf.concat([tf.math.cos(args), tf.math.sin(args)], 0)459return embedding460461462def get_position_ids():463return tf.convert_to_tensor([list(range(MAX_PROMPT_LENGTH))], dtype=tf.int32)464465466"""467Next, we implement a `StableDiffusionFineTuner`, which is a subclass of `keras.Model`468that overrides `train_step` to train the token embeddings of our text encoder.469This is the core of the Textual Inversion algorithm.470471Abstractly speaking, the train step takes a sample from the output of the frozen SD472image encoder's latent distribution for a training image, adds noise to that sample, and473then passes that noisy sample to the frozen diffusion model.474The hidden state of the diffusion model is the output of the text encoder for the prompt475corresponding to the image.476477Our final goal state is that the diffusion model is able to separate the noise from the478sample using the text encoding as hidden state, so our loss is the mean-squared error of479the noise and the output of the diffusion model (which has, ideally, removed the image480latents from the noise).481482We compute gradients for only the token embeddings of the text encoder, and in the483train step we zero-out the gradients for all tokens other than the token that we're484learning.485486See in-line code comments for more details about the train step.487"""488489490class StableDiffusionFineTuner(keras.Model):491def __init__(self, stable_diffusion, noise_scheduler, **kwargs):492super().__init__(**kwargs)493self.stable_diffusion = stable_diffusion494self.noise_scheduler = noise_scheduler495496def train_step(self, data):497images, embeddings = data498499with tf.GradientTape() as tape:500# Sample from the predicted distribution for the training image501latents = sample_from_encoder_outputs(training_image_encoder(images))502# The latents must be downsampled to match the scale of the latents used503# in the training of StableDiffusion. This number is truly just a "magic"504# constant that they chose when training the model.505latents = latents * 0.18215506507# Produce random noise in the same shape as the latent sample508noise = tf.random.normal(tf.shape(latents))509batch_dim = tf.shape(latents)[0]510511# Pick a random timestep for each sample in the batch512timesteps = tf.random.uniform(513(batch_dim,),514minval=0,515maxval=noise_scheduler.train_timesteps,516dtype=tf.int64,517)518519# Add noise to the latents based on the timestep for each sample520noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)521522# Encode the text in the training samples to use as hidden state in the523# diffusion model524encoder_hidden_state = self.stable_diffusion.text_encoder(525[embeddings, get_position_ids()]526)527528# Compute timestep embeddings for the randomly-selected timesteps for each529# sample in the batch530timestep_embeddings = tf.map_fn(531fn=get_timestep_embedding,532elems=timesteps,533fn_output_signature=tf.float32,534)535536# Call the diffusion model537noise_pred = self.stable_diffusion.diffusion_model(538[noisy_latents, timestep_embeddings, encoder_hidden_state]539)540541# Compute the mean-squared error loss and reduce it.542loss = self.compiled_loss(noise_pred, noise)543loss = tf.reduce_mean(loss, axis=2)544loss = tf.reduce_mean(loss, axis=1)545loss = tf.reduce_mean(loss)546547# Load the trainable weights and compute the gradients for them548trainable_weights = self.stable_diffusion.text_encoder.trainable_weights549grads = tape.gradient(loss, trainable_weights)550551# Gradients are stored in indexed slices, so we have to find the index552# of the slice(s) which contain the placeholder token.553index_of_placeholder_token = tf.reshape(tf.where(grads[0].indices == 49408), ())554condition = grads[0].indices == 49408555condition = tf.expand_dims(condition, axis=-1)556557# Override the gradients, zeroing out the gradients for all slices that558# aren't for the placeholder token, effectively freezing the weights for559# all other tokens.560grads[0] = tf.IndexedSlices(561values=tf.where(condition, grads[0].values, 0),562indices=grads[0].indices,563dense_shape=grads[0].dense_shape,564)565566self.optimizer.apply_gradients(zip(grads, trainable_weights))567return {"loss": loss}568569570"""571Before we start training, let's take a look at what StableDiffusion produces for our572token.573"""574575generated = stable_diffusion.text_to_image(576f"an oil painting of {placeholder_token}", seed=1337, batch_size=3577)578plot_images(generated)579580"""581As you can see, the model still thinks of our token as a cat, as this was the seed token582we used to initialize our custom token.583584Now, to get started with training, we can just `compile()` our model like any other585Keras model. Before doing so, we also instantiate a noise scheduler for training and586configure our training parameters such as learning rate and optimizer.587"""588589noise_scheduler = NoiseScheduler(590beta_start=0.00085,591beta_end=0.012,592beta_schedule="scaled_linear",593train_timesteps=1000,594)595trainer = StableDiffusionFineTuner(stable_diffusion, noise_scheduler, name="trainer")596EPOCHS = 50597learning_rate = keras.optimizers.schedules.CosineDecay(598initial_learning_rate=1e-4, decay_steps=train_ds.cardinality() * EPOCHS599)600optimizer = keras.optimizers.Adam(601weight_decay=0.004, learning_rate=learning_rate, epsilon=1e-8, global_clipnorm=10602)603604trainer.compile(605optimizer=optimizer,606# We are performing reduction manually in our train step, so none is required here.607loss=keras.losses.MeanSquaredError(reduction="none"),608)609610"""611To monitor training, we can produce a `keras.callbacks.Callback` to produce a few images612every epoch using our custom token.613614We create three callbacks with different prompts so that we can see how they progress615over the course of training. We use a fixed seed so that we can easily see the616progression of the learned token.617"""618619620class GenerateImages(keras.callbacks.Callback):621def __init__(622self, stable_diffusion, prompt, steps=50, frequency=10, seed=None, **kwargs623):624super().__init__(**kwargs)625self.stable_diffusion = stable_diffusion626self.prompt = prompt627self.seed = seed628self.frequency = frequency629self.steps = steps630631def on_epoch_end(self, epoch, logs):632if epoch % self.frequency == 0:633images = self.stable_diffusion.text_to_image(634self.prompt, batch_size=3, num_steps=self.steps, seed=self.seed635)636plot_images(637images,638)639640641cbs = [642GenerateImages(643stable_diffusion, prompt=f"an oil painting of {placeholder_token}", seed=1337644),645GenerateImages(646stable_diffusion, prompt=f"gandalf the gray as a {placeholder_token}", seed=1337647),648GenerateImages(649stable_diffusion,650prompt=f"two {placeholder_token} getting married, photorealistic, high quality",651seed=1337,652),653]654655"""656Now, all that is left to do is to call `model.fit()`!657"""658659trainer.fit(660train_ds,661epochs=EPOCHS,662callbacks=cbs,663)664665"""666It's pretty fun to see how the model learns our new token over time. Play around with it667and see how you can tune training parameters and your training dataset to produce the668best images!669"""670671"""672## Taking the Fine Tuned Model for a Spin673674Now for the really fun part. We've learned a token embedding for our custom token, so675now we can generate images with StableDiffusion the same way we would for any other676token!677678Here are some fun example prompts to get you started, with sample outputs from our cat679doll token!680"""681682generated = stable_diffusion.text_to_image(683f"Gandalf as a {placeholder_token} fantasy art drawn by disney concept artists, "684"golden colour, high quality, highly detailed, elegant, sharp focus, concept art, "685"character concepts, digital painting, mystery, adventure",686batch_size=3,687)688plot_images(generated)689690"""691"""692693generated = stable_diffusion.text_to_image(694f"A masterpiece of a {placeholder_token} crying out to the heavens. "695f"Behind the {placeholder_token}, an dark, evil shade looms over it - sucking the "696"life right out of it.",697batch_size=3,698)699plot_images(generated)700701"""702"""703704generated = stable_diffusion.text_to_image(705f"An evil {placeholder_token}.", batch_size=3706)707plot_images(generated)708709"""710"""711712generated = stable_diffusion.text_to_image(713f"A mysterious {placeholder_token} approaches the great pyramids of egypt.",714batch_size=3,715)716plot_images(generated)717718"""719## Conclusions720721Using the Textual Inversion algorithm you can teach StableDiffusion new concepts!722723Some possible next steps to follow:724725- Try out your own prompts726- Teach the model a style727- Gather a dataset of your favorite pet cat or dog and teach the model about it728"""729730731