Path: blob/master/examples/generative/dreambooth.py
3507 views
"""1Title: DreamBooth2Author: [Sayak Paul](https://twitter.com/RisingSayak), [Chansung Park](https://twitter.com/algo_diver)3Date created: 2023/02/014Last modified: 2023/02/055Description: Implementing DreamBooth.6Accelerator: GPU7"""89"""10## Introduction1112In this example, we implement DreamBooth, a fine-tuning technique to teach new visual13concepts to text-conditioned Diffusion models with just 3 - 5 images. DreamBooth was14proposed in15[DreamBooth: Fine Tuning Text-to-Image Diffusion Models for Subject-Driven Generation](https://arxiv.org/abs/2208.12242)16by Ruiz et al.1718DreamBooth, in a sense, is similar to the19[traditional way of fine-tuning a text-conditioned Diffusion model except](https://keras.io/examples/generative/finetune_stable_diffusion/)20for a few gotchas. This example assumes that you have basic familiarity with21Diffusion models and how to fine-tune them. Here are some reference examples that might22help you to get familiarized quickly:2324* [High-performance image generation using Stable Diffusion in KerasCV](https://keras.io/guides/keras_cv/generate_images_with_stable_diffusion/)25* [Teach StableDiffusion new concepts via Textual Inversion](https://keras.io/examples/generative/fine_tune_via_textual_inversion/)26* [Fine-tuning Stable Diffusion](https://keras.io/examples/generative/finetune_stable_diffusion/)2728First, let's install the latest versions of KerasCV and TensorFlow.2930"""3132"""shell33pip install -q -U keras_cv==0.6.034pip install -q -U tensorflow35"""3637"""38If you're running the code, please ensure you're using a GPU with at least 24 GBs of39VRAM.40"""4142"""43## Initial imports44"""4546import math4748import keras_cv49import matplotlib.pyplot as plt50import numpy as np51import tensorflow as tf52from imutils import paths53from tensorflow import keras5455"""56## Usage of DreamBooth5758... is very versatile. By teaching Stable Diffusion about your favorite visual59concepts, you can6061* Recontextualize objects in interesting ways:62636465* Generate artistic renderings of the underlying visual concept:6667686970And many other applications. We welcome you to check out the original71[DreamBooth paper](https://arxiv.org/abs/2208.12242) in this regard.72"""7374"""75## Download the instance and class images7677DreamBooth uses a technique called "prior preservation" to meaningfully guide the78training procedure such that the fine-tuned models can still preserve some of the prior79semantics of the visual concept you're introducing. To know more about the idea of "prior80preservation" refer to [this document](https://dreambooth.github.io/).8182Here, we need to introduce a few key terms specific to DreamBooth:8384* **Unique class**: Examples include "dog", "person", etc. In this example, we use "dog".85* **Unique identifier**: A unique identifier that is prepended to the unique class while86forming the "instance prompts". In this example, we use "sks" as this unique identifier.87* **Instance prompt**: Denotes a prompt that best describes the "instance images". An88example prompt could be - "f"a photo of {unique_id} {unique_class}". So, for our example,89this becomes - "a photo of sks dog".90* **Class prompt**: Denotes a prompt without the unique identifier. This prompt is used91for generating "class images" for prior preservation. For our example, this prompt is -92"a photo of dog".93* **Instance images**: Denote the images that represent the visual concept you're trying94to teach aka the "instance prompt". This number is typically just 3 - 5. We typically95gather these images ourselves.96* **Class images**: Denote the images generated using the "class prompt" for using prior97preservation in DreamBooth training. We leverage the pre-trained model before fine-tuning98it to generate these class images. Typically, 200 - 300 class images are enough.99100In code, this generation process looks quite simply:101102```py103from tqdm import tqdm104import numpy as np105import hashlib106import keras_cv107import PIL108import os109110class_images_dir = "class-images"111os.makedirs(class_images_dir, exist_ok=True)112113model = keras_cv.models.StableDiffusion(img_width=512, img_height=512, jit_compile=True)114115class_prompt = "a photo of dog"116num_imgs_to_generate = 200117for i in tqdm(range(num_imgs_to_generate)):118images = model.text_to_image(119class_prompt,120batch_size=3,121)122idx = np.random.choice(len(images))123selected_image = PIL.Image.fromarray(images[idx])124hash_image = hashlib.sha1(selected_image.tobytes()).hexdigest()125image_filename = os.path.join(class_images_dir, f"{hash_image}.jpg")126selected_image.save(image_filename)127```128129To keep the runtime of this example short, the authors of this example have gone ahead130and generated some class images using131[this notebook](https://colab.research.google.com/gist/sayakpaul/6b5de345d29cf5860f84b6d04d958692/generate_class_priors.ipynb).132133**Note** that prior preservation is an optional technique used in DreamBooth, but it134almost always helps in improving the quality of the generated images.135"""136137instance_images_root = tf.keras.utils.get_file(138origin="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/instance-images.tar.gz",139untar=True,140)141class_images_root = tf.keras.utils.get_file(142origin="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/class-images.tar.gz",143untar=True,144)145146"""147## Visualize images148149First, let's load the image paths.150"""151instance_image_paths = list(paths.list_images(instance_images_root))152class_image_paths = list(paths.list_images(class_images_root))153154"""155Then we load the images from the paths.156"""157158159def load_images(image_paths):160images = [np.array(keras.utils.load_img(path)) for path in image_paths]161return images162163164"""165And then we make use a utility function to plot the loaded images.166"""167168169def plot_images(images, title=None):170plt.figure(figsize=(20, 20))171for i in range(len(images)):172ax = plt.subplot(1, len(images), i + 1)173if title is not None:174plt.title(title)175plt.imshow(images[i])176plt.axis("off")177178179"""180**Instance images**:181"""182183plot_images(load_images(instance_image_paths[:5]))184185"""186**Class images**:187"""188189plot_images(load_images(class_image_paths[:5]))190191"""192## Prepare datasets193194Dataset preparation includes two stages: (1): preparing the captions, (2) processing the195images.196"""197198"""199### Prepare the captions200"""201202# Since we're using prior preservation, we need to match the number203# of instance images we're using. We just repeat the instance image paths204# to do so.205new_instance_image_paths = []206for index in range(len(class_image_paths)):207instance_image = instance_image_paths[index % len(instance_image_paths)]208new_instance_image_paths.append(instance_image)209210# We just repeat the prompts / captions per images.211unique_id = "sks"212class_label = "dog"213214instance_prompt = f"a photo of {unique_id} {class_label}"215instance_prompts = [instance_prompt] * len(new_instance_image_paths)216217class_prompt = f"a photo of {class_label}"218class_prompts = [class_prompt] * len(class_image_paths)219220"""221Next, we embed the prompts to save some compute.222"""223224import itertools225226# The padding token and maximum prompt length are specific to the text encoder.227# If you're using a different text encoder be sure to change them accordingly.228padding_token = 49407229max_prompt_length = 77230231# Load the tokenizer.232tokenizer = keras_cv.models.stable_diffusion.SimpleTokenizer()233234235# Method to tokenize and pad the tokens.236def process_text(caption):237tokens = tokenizer.encode(caption)238tokens = tokens + [padding_token] * (max_prompt_length - len(tokens))239return np.array(tokens)240241242# Collate the tokenized captions into an array.243tokenized_texts = np.empty(244(len(instance_prompts) + len(class_prompts), max_prompt_length)245)246247for i, caption in enumerate(itertools.chain(instance_prompts, class_prompts)):248tokenized_texts[i] = process_text(caption)249250251# We also pre-compute the text embeddings to save some memory during training.252POS_IDS = tf.convert_to_tensor([list(range(max_prompt_length))], dtype=tf.int32)253text_encoder = keras_cv.models.stable_diffusion.TextEncoder(max_prompt_length)254255gpus = tf.config.list_logical_devices("GPU")256257# Ensure the computation takes place on a GPU.258# Note that it's done automatically when there's a GPU present.259# This example just attempts at showing how you can do it260# more explicitly.261with tf.device(gpus[0].name):262embedded_text = text_encoder(263[tf.convert_to_tensor(tokenized_texts), POS_IDS], training=False264).numpy()265266# To ensure text_encoder doesn't occupy any GPU space.267del text_encoder268269"""270## Prepare the images271"""272273resolution = 512274auto = tf.data.AUTOTUNE275276augmenter = keras.Sequential(277layers=[278keras_cv.layers.CenterCrop(resolution, resolution),279keras_cv.layers.RandomFlip(),280keras.layers.Rescaling(scale=1.0 / 127.5, offset=-1),281]282)283284285def process_image(image_path, tokenized_text):286image = tf.io.read_file(image_path)287image = tf.io.decode_png(image, 3)288image = tf.image.resize(image, (resolution, resolution))289return image, tokenized_text290291292def apply_augmentation(image_batch, embedded_tokens):293return augmenter(image_batch), embedded_tokens294295296def prepare_dict(instance_only=True):297def fn(image_batch, embedded_tokens):298if instance_only:299batch_dict = {300"instance_images": image_batch,301"instance_embedded_texts": embedded_tokens,302}303return batch_dict304else:305batch_dict = {306"class_images": image_batch,307"class_embedded_texts": embedded_tokens,308}309return batch_dict310311return fn312313314def assemble_dataset(image_paths, embedded_texts, instance_only=True, batch_size=1):315dataset = tf.data.Dataset.from_tensor_slices((image_paths, embedded_texts))316dataset = dataset.map(process_image, num_parallel_calls=auto)317dataset = dataset.shuffle(5, reshuffle_each_iteration=True)318dataset = dataset.batch(batch_size)319dataset = dataset.map(apply_augmentation, num_parallel_calls=auto)320321prepare_dict_fn = prepare_dict(instance_only=instance_only)322dataset = dataset.map(prepare_dict_fn, num_parallel_calls=auto)323return dataset324325326"""327## Assemble dataset328"""329instance_dataset = assemble_dataset(330new_instance_image_paths,331embedded_text[: len(new_instance_image_paths)],332)333class_dataset = assemble_dataset(334class_image_paths,335embedded_text[len(new_instance_image_paths) :],336instance_only=False,337)338train_dataset = tf.data.Dataset.zip((instance_dataset, class_dataset))339"""340## Check shapes341342Now that the dataset has been prepared, let's quickly check what's inside it.343"""344345sample_batch = next(iter(train_dataset))346print(sample_batch[0].keys(), sample_batch[1].keys())347348for k in sample_batch[0]:349print(k, sample_batch[0][k].shape)350351for k in sample_batch[1]:352print(k, sample_batch[1][k].shape)353354"""355During training, we make use of these keys to gather the images and text embeddings and356concat them accordingly.357"""358359"""360## DreamBooth training loop361362Our DreamBooth training loop is very much inspired by363[this script](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py)364provided by the Diffusers team at Hugging Face. However, there is an important365difference to note. We only fine-tune the UNet (the model responsible for predicting366noise) and don't fine-tune the text encoder in this example. If you're looking for an367implementation that also performs the additional fine-tuning of the text encoder, refer368to [this repository](https://github.com/sayakpaul/dreambooth-keras/).369"""370371import tensorflow.experimental.numpy as tnp372373374class DreamBoothTrainer(tf.keras.Model):375# Reference:376# https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py377378def __init__(379self,380diffusion_model,381vae,382noise_scheduler,383use_mixed_precision=False,384prior_loss_weight=1.0,385max_grad_norm=1.0,386**kwargs,387):388super().__init__(**kwargs)389390self.diffusion_model = diffusion_model391self.vae = vae392self.noise_scheduler = noise_scheduler393self.prior_loss_weight = prior_loss_weight394self.max_grad_norm = max_grad_norm395396self.use_mixed_precision = use_mixed_precision397self.vae.trainable = False398399def train_step(self, inputs):400instance_batch = inputs[0]401class_batch = inputs[1]402403instance_images = instance_batch["instance_images"]404instance_embedded_text = instance_batch["instance_embedded_texts"]405class_images = class_batch["class_images"]406class_embedded_text = class_batch["class_embedded_texts"]407408images = tf.concat([instance_images, class_images], 0)409embedded_texts = tf.concat([instance_embedded_text, class_embedded_text], 0)410batch_size = tf.shape(images)[0]411412with tf.GradientTape() as tape:413# Project image into the latent space and sample from it.414latents = self.sample_from_encoder_outputs(self.vae(images, training=False))415# Know more about the magic number here:416# https://keras.io/examples/generative/fine_tune_via_textual_inversion/417latents = latents * 0.18215418419# Sample noise that we'll add to the latents.420noise = tf.random.normal(tf.shape(latents))421422# Sample a random timestep for each image.423timesteps = tnp.random.randint(4240, self.noise_scheduler.train_timesteps, (batch_size,)425)426427# Add noise to the latents according to the noise magnitude at each timestep428# (this is the forward diffusion process).429noisy_latents = self.noise_scheduler.add_noise(430tf.cast(latents, noise.dtype), noise, timesteps431)432433# Get the target for loss depending on the prediction type434# just the sampled noise for now.435target = noise # noise_schedule.predict_epsilon == True436437# Predict the noise residual and compute loss.438timestep_embedding = tf.map_fn(439lambda t: self.get_timestep_embedding(t), timesteps, dtype=tf.float32440)441model_pred = self.diffusion_model(442[noisy_latents, timestep_embedding, embedded_texts], training=True443)444loss = self.compute_loss(target, model_pred)445if self.use_mixed_precision:446loss = self.optimizer.get_scaled_loss(loss)447448# Update parameters of the diffusion model.449trainable_vars = self.diffusion_model.trainable_variables450gradients = tape.gradient(loss, trainable_vars)451if self.use_mixed_precision:452gradients = self.optimizer.get_unscaled_gradients(gradients)453gradients = [tf.clip_by_norm(g, self.max_grad_norm) for g in gradients]454self.optimizer.apply_gradients(zip(gradients, trainable_vars))455456return {m.name: m.result() for m in self.metrics}457458def get_timestep_embedding(self, timestep, dim=320, max_period=10000):459half = dim // 2460log_max_period = tf.math.log(tf.cast(max_period, tf.float32))461freqs = tf.math.exp(462-log_max_period * tf.range(0, half, dtype=tf.float32) / half463)464args = tf.convert_to_tensor([timestep], dtype=tf.float32) * freqs465embedding = tf.concat([tf.math.cos(args), tf.math.sin(args)], 0)466return embedding467468def sample_from_encoder_outputs(self, outputs):469mean, logvar = tf.split(outputs, 2, axis=-1)470logvar = tf.clip_by_value(logvar, -30.0, 20.0)471std = tf.exp(0.5 * logvar)472sample = tf.random.normal(tf.shape(mean), dtype=mean.dtype)473return mean + std * sample474475def compute_loss(self, target, model_pred):476# Chunk the noise and model_pred into two parts and compute the loss477# on each part separately.478# Since the first half of the inputs has instance samples and the second half479# has class samples, we do the chunking accordingly.480model_pred, model_pred_prior = tf.split(481model_pred, num_or_size_splits=2, axis=0482)483target, target_prior = tf.split(target, num_or_size_splits=2, axis=0)484485# Compute instance loss.486loss = self.compiled_loss(target, model_pred)487488# Compute prior loss.489prior_loss = self.compiled_loss(target_prior, model_pred_prior)490491# Add the prior loss to the instance loss.492loss = loss + self.prior_loss_weight * prior_loss493return loss494495def save_weights(self, filepath, overwrite=True, save_format=None, options=None):496# Overriding this method will allow us to use the `ModelCheckpoint`497# callback directly with this trainer class. In this case, it will498# only checkpoint the `diffusion_model` since that's what we're training499# during fine-tuning.500self.diffusion_model.save_weights(501filepath=filepath,502overwrite=overwrite,503save_format=save_format,504options=options,505)506507def load_weights(self, filepath, by_name=False, skip_mismatch=False, options=None):508# Similarly override `load_weights()` so that we can directly call it on509# the trainer class object.510self.diffusion_model.load_weights(511filepath=filepath,512by_name=by_name,513skip_mismatch=skip_mismatch,514options=options,515)516517518"""519## Trainer initialization520"""521522# Comment it if you are not using a GPU having tensor cores.523tf.keras.mixed_precision.set_global_policy("mixed_float16")524525use_mp = True # Set it to False if you're not using a GPU with tensor cores.526527image_encoder = keras_cv.models.stable_diffusion.ImageEncoder()528dreambooth_trainer = DreamBoothTrainer(529diffusion_model=keras_cv.models.stable_diffusion.DiffusionModel(530resolution, resolution, max_prompt_length531),532# Remove the top layer from the encoder, which cuts off the variance and only533# returns the mean.534vae=tf.keras.Model(535image_encoder.input,536image_encoder.layers[-2].output,537),538noise_scheduler=keras_cv.models.stable_diffusion.NoiseScheduler(),539use_mixed_precision=use_mp,540)541542# These hyperparameters come from this tutorial by Hugging Face:543# https://github.com/huggingface/diffusers/tree/main/examples/dreambooth544learning_rate = 5e-6545beta_1, beta_2 = 0.9, 0.999546weight_decay = (1e-2,)547epsilon = 1e-08548549optimizer = tf.keras.optimizers.experimental.AdamW(550learning_rate=learning_rate,551weight_decay=weight_decay,552beta_1=beta_1,553beta_2=beta_2,554epsilon=epsilon,555)556dreambooth_trainer.compile(optimizer=optimizer, loss="mse")557558"""559## Train!560561We first calculate the number of epochs, we need to train for.562"""563564num_update_steps_per_epoch = train_dataset.cardinality()565max_train_steps = 800566epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)567print(f"Training for {epochs} epochs.")568569"""570And then we start training!571"""572573ckpt_path = "dreambooth-unet.h5"574ckpt_callback = tf.keras.callbacks.ModelCheckpoint(575ckpt_path,576save_weights_only=True,577monitor="loss",578mode="min",579)580dreambooth_trainer.fit(train_dataset, epochs=epochs, callbacks=[ckpt_callback])581582"""583## Experiments and inference584585We ran various experiments with a slightly modified version of this example. Our586experiments are based on587[this repository](https://github.com/sayakpaul/dreambooth-keras/) and are inspired by588[this blog post](https://huggingface.co/blog/dreambooth) from Hugging Face.589590First, let's see how we can use the fine-tuned checkpoint for running inference.591"""592593# Initialize a new Stable Diffusion model.594dreambooth_model = keras_cv.models.StableDiffusion(595img_width=resolution, img_height=resolution, jit_compile=True596)597dreambooth_model.diffusion_model.load_weights(ckpt_path)598599# Note how the unique identifier and the class have been used in the prompt.600prompt = f"A photo of {unique_id} {class_label} in a bucket"601num_imgs_to_gen = 3602603images_dreamboothed = dreambooth_model.text_to_image(prompt, batch_size=num_imgs_to_gen)604plot_images(images_dreamboothed, prompt)605606"""607Now, let's load checkpoints from a different experiment we conducted where we also608fine-tuned the text encoder along with the UNet:609"""610611unet_weights = tf.keras.utils.get_file(612origin="https://huggingface.co/chansung/dreambooth-dog/resolve/main/lr%409e-06-max_train_steps%40200-train_text_encoder%40True-unet.h5"613)614text_encoder_weights = tf.keras.utils.get_file(615origin="https://huggingface.co/chansung/dreambooth-dog/resolve/main/lr%409e-06-max_train_steps%40200-train_text_encoder%40True-text_encoder.h5"616)617618dreambooth_model.diffusion_model.load_weights(unet_weights)619dreambooth_model.text_encoder.load_weights(text_encoder_weights)620621images_dreamboothed = dreambooth_model.text_to_image(prompt, batch_size=num_imgs_to_gen)622plot_images(images_dreamboothed, prompt)623624"""625The default number of steps for generating an image in `text_to_image()`626[is 50](https://github.com/keras-team/keras-cv/blob/3575bc3b944564fe15b46b917e6555aa6a9d7be0/keras_cv/models/stable_diffusion/stable_diffusion.py#L73).627Let's increase it to 100.628"""629630images_dreamboothed = dreambooth_model.text_to_image(631prompt, batch_size=num_imgs_to_gen, num_steps=100632)633plot_images(images_dreamboothed, prompt)634635"""636Feel free to experiment with different prompts (don't forget to add the unique identifier637and the class label!) to see how the results change. We welcome you to check out our638codebase and more experimental results639[here](https://github.com/sayakpaul/dreambooth-keras#results). You can also read640[this blog post](https://huggingface.co/blog/dreambooth) to get more ideas.641"""642643"""644## Acknowledgements645646* Thanks to the647[DreamBooth example script](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py)648provided by Hugging Face which helped us a lot in getting the initial implementation649ready quickly.650* Getting DreamBooth to work on human faces can be challenging. We have compiled some651general recommendations652[here](https://github.com/sayakpaul/dreambooth-keras#notes-on-preparing-data-for-dreambooth-training-of-faces).653Thanks to654[Abhishek Thakur](https://no.linkedin.com/in/abhi1thakur)655for helping with these.656"""657658659