Path: blob/master/examples/generative/finetune_stable_diffusion.py
3507 views
"""1Title: Fine-tuning Stable Diffusion2Author: [Sayak Paul](https://twitter.com/RisingSayak), [Chansung Park](https://twitter.com/algo_diver)3Date created: 2022/12/284Last modified: 2023/01/135Description: Fine-tuning Stable Diffusion using a custom image-caption dataset.6Accelerator: GPU7"""89"""10## Introduction1112This tutorial shows how to fine-tune a13[Stable Diffusion model](https://keras.io/guides/keras_cv/generate_images_with_stable_diffusion/)14on a custom dataset of `{image, caption}` pairs. We build on top of the fine-tuning15script provided by Hugging Face16[here](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py).1718We assume that you have a high-level understanding of the Stable Diffusion model.19The following resources can be helpful if you're looking for more information in that regard:2021* [High-performance image generation using Stable Diffusion in KerasCV](https://keras.io/guides/keras_cv/generate_images_with_stable_diffusion/)22* [Stable Diffusion with Diffusers](https://huggingface.co/blog/stable_diffusion)2324It's highly recommended that you use a GPU with at least 30GB of memory to execute25the code.2627By the end of the guide, you'll be able to generate images of interesting Pokémon:28293031The tutorial relies on KerasCV 0.4.0. Additionally, we need32at least TensorFlow 2.11 in order to use AdamW with mixed precision.33"""3435"""shell36pip install keras-cv==0.6.0 -q37pip install -U tensorflow -q38pip install keras-core -q39"""4041"""42## What are we fine-tuning?4344A Stable Diffusion model can be decomposed into several key models:4546* A text encoder that projects the input prompt to a latent space. (The caption47associated with an image is referred to as the "prompt".)48* A variational autoencoder (VAE) that projects an input image to a latent space acting49as an image vector space.50* A diffusion model that refines a latent vector and produces another latent vector, conditioned51on the encoded text prompt52* A decoder that generates images given a latent vector from the diffusion model.5354It's worth noting that during the process of generating an image from a text prompt, the55image encoder is not typically employed.5657However, during the process of fine-tuning, the workflow goes like the following:58591. An input text prompt is projected to a latent space by the text encoder.602. An input image is projected to a latent space by the image encoder portion of the VAE.613. A small amount of noise is added to the image latent vector for a given timestep.624. The diffusion model uses latent vectors from these two spaces along with a timestep embedding63to predict the noise that was added to the image latent.645. A reconstruction loss is calculated between the predicted noise and the original noise65added in step 3.666. Finally, the diffusion model parameters are optimized w.r.t this loss using67gradient descent.6869Note that only the diffusion model parameters are updated during fine-tuning, while the70(pre-trained) text and the image encoders are kept frozen.7172Don't worry if this sounds complicated. The code is much simpler than this!73"""7475"""76## Imports77"""7879from textwrap import wrap80import os8182import keras_cv83import matplotlib.pyplot as plt84import numpy as np85import pandas as pd86import tensorflow as tf87import tensorflow.experimental.numpy as tnp88from keras_cv.models.stable_diffusion.clip_tokenizer import SimpleTokenizer89from keras_cv.models.stable_diffusion.diffusion_model import DiffusionModel90from keras_cv.models.stable_diffusion.image_encoder import ImageEncoder91from keras_cv.models.stable_diffusion.noise_scheduler import NoiseScheduler92from keras_cv.models.stable_diffusion.text_encoder import TextEncoder93from tensorflow import keras9495"""96## Data loading9798We use the dataset99[Pokémon BLIP captions](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions).100However, we'll use a slightly different version which was derived from the original101dataset to fit better with `tf.data`. Refer to102[the documentation](https://huggingface.co/datasets/sayakpaul/pokemon-blip-original-version)103for more details.104"""105106data_path = tf.keras.utils.get_file(107origin="https://huggingface.co/datasets/sayakpaul/pokemon-blip-original-version/resolve/main/pokemon_dataset.tar.gz",108untar=True,109)110111data_frame = pd.read_csv(os.path.join(data_path, "data.csv"))112113data_frame["image_path"] = data_frame["image_path"].apply(114lambda x: os.path.join(data_path, x)115)116data_frame.head()117118"""119Since we have only 833 `{image, caption}` pairs, we can precompute the text embeddings from120the captions. Moreover, the text encoder will be kept frozen during the course of121fine-tuning, so we can save some compute by doing this.122123Before we use the text encoder, we need to tokenize the captions.124"""125126# The padding token and maximum prompt length are specific to the text encoder.127# If you're using a different text encoder be sure to change them accordingly.128PADDING_TOKEN = 49407129MAX_PROMPT_LENGTH = 77130131# Load the tokenizer.132tokenizer = SimpleTokenizer()133134135# Method to tokenize and pad the tokens.136def process_text(caption):137tokens = tokenizer.encode(caption)138tokens = tokens + [PADDING_TOKEN] * (MAX_PROMPT_LENGTH - len(tokens))139return np.array(tokens)140141142# Collate the tokenized captions into an array.143tokenized_texts = np.empty((len(data_frame), MAX_PROMPT_LENGTH))144145all_captions = list(data_frame["caption"].values)146for i, caption in enumerate(all_captions):147tokenized_texts[i] = process_text(caption)148149"""150## Prepare a `tf.data.Dataset`151152In this section, we'll prepare a `tf.data.Dataset` object from the input image file paths153and their corresponding caption tokens. The section will include the following:154155* Pre-computation of the text embeddings from the tokenized captions.156* Loading and augmentation of the input images.157* Shuffling and batching of the dataset.158"""159160RESOLUTION = 256161AUTO = tf.data.AUTOTUNE162POS_IDS = tf.convert_to_tensor([list(range(MAX_PROMPT_LENGTH))], dtype=tf.int32)163164augmenter = keras.Sequential(165layers=[166keras_cv.layers.CenterCrop(RESOLUTION, RESOLUTION),167keras_cv.layers.RandomFlip(),168tf.keras.layers.Rescaling(scale=1.0 / 127.5, offset=-1),169]170)171text_encoder = TextEncoder(MAX_PROMPT_LENGTH)172173174def process_image(image_path, tokenized_text):175image = tf.io.read_file(image_path)176image = tf.io.decode_png(image, 3)177image = tf.image.resize(image, (RESOLUTION, RESOLUTION))178return image, tokenized_text179180181def apply_augmentation(image_batch, token_batch):182return augmenter(image_batch), token_batch183184185def run_text_encoder(image_batch, token_batch):186return (187image_batch,188token_batch,189text_encoder([token_batch, POS_IDS], training=False),190)191192193def prepare_dict(image_batch, token_batch, encoded_text_batch):194return {195"images": image_batch,196"tokens": token_batch,197"encoded_text": encoded_text_batch,198}199200201def prepare_dataset(image_paths, tokenized_texts, batch_size=1):202dataset = tf.data.Dataset.from_tensor_slices((image_paths, tokenized_texts))203dataset = dataset.shuffle(batch_size * 10)204dataset = dataset.map(process_image, num_parallel_calls=AUTO).batch(batch_size)205dataset = dataset.map(apply_augmentation, num_parallel_calls=AUTO)206dataset = dataset.map(run_text_encoder, num_parallel_calls=AUTO)207dataset = dataset.map(prepare_dict, num_parallel_calls=AUTO)208return dataset.prefetch(AUTO)209210211"""212The baseline Stable Diffusion model was trained using images with 512x512 resolution. It's213unlikely for a model that's trained using higher-resolution images to transfer well to214lower-resolution images. However, the current model will lead to OOM if we keep the215resolution to 512x512 (without enabling mixed-precision). Therefore, in the interest of216interactive demonstrations, we kept the input resolution to 256x256.217"""218219# Prepare the dataset.220training_dataset = prepare_dataset(221np.array(data_frame["image_path"]), tokenized_texts, batch_size=4222)223224# Take a sample batch and investigate.225sample_batch = next(iter(training_dataset))226227for k in sample_batch:228print(k, sample_batch[k].shape)229230"""231We can also take a look at the training images and their corresponding captions.232"""233234plt.figure(figsize=(20, 10))235236for i in range(3):237ax = plt.subplot(1, 4, i + 1)238plt.imshow((sample_batch["images"][i] + 1) / 2)239240text = tokenizer.decode(sample_batch["tokens"][i].numpy().squeeze())241text = text.replace("<|startoftext|>", "")242text = text.replace("<|endoftext|>", "")243text = "\n".join(wrap(text, 12))244plt.title(text, fontsize=15)245246plt.axis("off")247248"""249## A trainer class for the fine-tuning loop250"""251252253class Trainer(tf.keras.Model):254# Reference:255# https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py256257def __init__(258self,259diffusion_model,260vae,261noise_scheduler,262use_mixed_precision=False,263max_grad_norm=1.0,264**kwargs265):266super().__init__(**kwargs)267268self.diffusion_model = diffusion_model269self.vae = vae270self.noise_scheduler = noise_scheduler271self.max_grad_norm = max_grad_norm272273self.use_mixed_precision = use_mixed_precision274self.vae.trainable = False275276def train_step(self, inputs):277images = inputs["images"]278encoded_text = inputs["encoded_text"]279batch_size = tf.shape(images)[0]280281with tf.GradientTape() as tape:282# Project image into the latent space and sample from it.283latents = self.sample_from_encoder_outputs(self.vae(images, training=False))284# Know more about the magic number here:285# https://keras.io/examples/generative/fine_tune_via_textual_inversion/286latents = latents * 0.18215287288# Sample noise that we'll add to the latents.289noise = tf.random.normal(tf.shape(latents))290291# Sample a random timestep for each image.292timesteps = tnp.random.randint(2930, self.noise_scheduler.train_timesteps, (batch_size,)294)295296# Add noise to the latents according to the noise magnitude at each timestep297# (this is the forward diffusion process).298noisy_latents = self.noise_scheduler.add_noise(299tf.cast(latents, noise.dtype), noise, timesteps300)301302# Get the target for loss depending on the prediction type303# just the sampled noise for now.304target = noise # noise_schedule.predict_epsilon == True305306# Predict the noise residual and compute loss.307timestep_embedding = tf.map_fn(308lambda t: self.get_timestep_embedding(t), timesteps, dtype=tf.float32309)310timestep_embedding = tf.squeeze(timestep_embedding, 1)311model_pred = self.diffusion_model(312[noisy_latents, timestep_embedding, encoded_text], training=True313)314loss = self.compiled_loss(target, model_pred)315if self.use_mixed_precision:316loss = self.optimizer.get_scaled_loss(loss)317318# Update parameters of the diffusion model.319trainable_vars = self.diffusion_model.trainable_variables320gradients = tape.gradient(loss, trainable_vars)321if self.use_mixed_precision:322gradients = self.optimizer.get_unscaled_gradients(gradients)323gradients = [tf.clip_by_norm(g, self.max_grad_norm) for g in gradients]324self.optimizer.apply_gradients(zip(gradients, trainable_vars))325326return {m.name: m.result() for m in self.metrics}327328def get_timestep_embedding(self, timestep, dim=320, max_period=10000):329half = dim // 2330log_max_period = tf.math.log(tf.cast(max_period, tf.float32))331freqs = tf.math.exp(332-log_max_period * tf.range(0, half, dtype=tf.float32) / half333)334args = tf.convert_to_tensor([timestep], dtype=tf.float32) * freqs335embedding = tf.concat([tf.math.cos(args), tf.math.sin(args)], 0)336embedding = tf.reshape(embedding, [1, -1])337return embedding338339def sample_from_encoder_outputs(self, outputs):340mean, logvar = tf.split(outputs, 2, axis=-1)341logvar = tf.clip_by_value(logvar, -30.0, 20.0)342std = tf.exp(0.5 * logvar)343sample = tf.random.normal(tf.shape(mean), dtype=mean.dtype)344return mean + std * sample345346def save_weights(self, filepath, overwrite=True, save_format=None, options=None):347# Overriding this method will allow us to use the `ModelCheckpoint`348# callback directly with this trainer class. In this case, it will349# only checkpoint the `diffusion_model` since that's what we're training350# during fine-tuning.351self.diffusion_model.save_weights(352filepath=filepath,353overwrite=overwrite,354save_format=save_format,355options=options,356)357358359"""360One important implementation detail to note here: Instead of directly taking361the latent vector produced by the image encoder (which is a VAE), we sample from the362mean and log-variance predicted by it. This way, we can achieve better sample363quality and diversity.364365It's common to add support for mixed-precision training along with exponential366moving averaging of model weights for fine-tuning these models. However, in the interest367of brevity, we discard those elements. More on this later in the tutorial.368"""369370"""371## Initialize the trainer and compile it372"""373374# Enable mixed-precision training if the underlying GPU has tensor cores.375USE_MP = True376if USE_MP:377keras.mixed_precision.set_global_policy("mixed_float16")378379image_encoder = ImageEncoder()380diffusion_ft_trainer = Trainer(381diffusion_model=DiffusionModel(RESOLUTION, RESOLUTION, MAX_PROMPT_LENGTH),382# Remove the top layer from the encoder, which cuts off the variance and only383# returns the mean.384vae=tf.keras.Model(385image_encoder.input,386image_encoder.layers[-2].output,387),388noise_scheduler=NoiseScheduler(),389use_mixed_precision=USE_MP,390)391392# These hyperparameters come from this tutorial by Hugging Face:393# https://huggingface.co/docs/diffusers/training/text2image394lr = 1e-5395beta_1, beta_2 = 0.9, 0.999396weight_decay = (1e-2,)397epsilon = 1e-08398399optimizer = tf.keras.optimizers.experimental.AdamW(400learning_rate=lr,401weight_decay=weight_decay,402beta_1=beta_1,403beta_2=beta_2,404epsilon=epsilon,405)406diffusion_ft_trainer.compile(optimizer=optimizer, loss="mse")407408"""409## Fine-tuning410411To keep the runtime of this tutorial short, we just fine-tune for an epoch.412"""413414epochs = 1415ckpt_path = "finetuned_stable_diffusion.h5"416ckpt_callback = tf.keras.callbacks.ModelCheckpoint(417ckpt_path,418save_weights_only=True,419monitor="loss",420mode="min",421)422diffusion_ft_trainer.fit(training_dataset, epochs=epochs, callbacks=[ckpt_callback])423424"""425## Inference426427We fine-tuned the model for 60 epochs on an image resolution of 512x512. To allow428training with this resolution, we incorporated mixed-precision support. You can429check out430[this repository](https://github.com/sayakpaul/stabe-diffusion-keras-ft)431for more details. It additionally provides support for exponential moving averaging of432the fine-tuned model parameters and model checkpointing.433434435For this section, we'll use the checkpoint derived after 60 epochs of fine-tuning.436"""437438weights_path = tf.keras.utils.get_file(439origin="https://huggingface.co/sayakpaul/kerascv_sd_pokemon_finetuned/resolve/main/ckpt_epochs_72_res_512_mp_True.h5"440)441442img_height = img_width = 512443pokemon_model = keras_cv.models.StableDiffusion(444img_width=img_width, img_height=img_height445)446# We just reload the weights of the fine-tuned diffusion model.447pokemon_model.diffusion_model.load_weights(weights_path)448449"""450Now, we can take this model for a test-drive.451"""452453prompts = ["Yoda", "Hello Kitty", "A pokemon with red eyes"]454images_to_generate = 3455outputs = {}456457for prompt in prompts:458generated_images = pokemon_model.text_to_image(459prompt, batch_size=images_to_generate, unconditional_guidance_scale=40460)461outputs.update({prompt: generated_images})462463"""464With 60 epochs of fine-tuning (a good number is about 70), the generated images were not465up to the mark. So, we experimented with the number of steps Stable Diffusion takes466during the inference time and the `unconditional_guidance_scale` parameter.467468We found the best results with this checkpoint with `unconditional_guidance_scale` set to46940.470"""471472473def plot_images(images, title):474plt.figure(figsize=(20, 20))475for i in range(len(images)):476ax = plt.subplot(1, len(images), i + 1)477plt.imshow(images[i])478plt.title(title, fontsize=12)479plt.axis("off")480481482for prompt in outputs:483plot_images(outputs[prompt], prompt)484485"""486We can notice that the model has started adapting to the style of our dataset. You can487check the488[accompanying repository](https://github.com/sayakpaul/stable-diffusion-keras-ft#results)489for more comparisons and commentary. If you're feeling adventurous to try out a demo,490you can check out491[this resource](https://huggingface.co/spaces/sayakpaul/pokemon-sd-kerascv).492"""493494"""495## Conclusion and acknowledgements496497We demonstrated how to fine-tune the Stable Diffusion model on a custom dataset. While498the results are far from aesthetically pleasing, we believe with more epochs of499fine-tuning, they will likely improve. To enable that, having support for gradient500accumulation and distributed training is crucial. This can be thought of as the next step501in this tutorial.502503There is another interesting way in which Stable Diffusion models can be fine-tuned,504called textual inversion. You can refer to505[this tutorial](https://keras.io/examples/generative/fine_tune_via_textual_inversion/)506to know more about it.507508We'd like to acknowledge the GCP Credit support from ML Developer Programs' team at509Google. We'd like to thank the Hugging Face team for providing the510[fine-tuning script](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py)511. It's very readable and easy to understand.512"""513514515