Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/generative/ddim.py
3507 views
1
"""
2
Title: Denoising Diffusion Implicit Models
3
Author: [András Béres](https://www.linkedin.com/in/andras-beres-789190210)
4
Date created: 2022/06/24
5
Last modified: 2022/06/24
6
Description: Generating images of flowers with denoising diffusion implicit models.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
### What are diffusion models?
14
15
Recently, [denoising diffusion models](https://arxiv.org/abs/2006.11239), including
16
[score-based generative models](https://arxiv.org/abs/1907.05600), gained popularity as a
17
powerful class of generative models, that can [rival](https://arxiv.org/abs/2105.05233)
18
even [generative adversarial networks (GANs)](https://arxiv.org/abs/1406.2661) in image
19
synthesis quality. They tend to generate more diverse samples, while being stable to
20
train and easy to scale. Recent large diffusion models, such as
21
[DALL-E 2](https://openai.com/dall-e-2/) and [Imagen](https://imagen.research.google/),
22
have shown incredible text-to-image generation capability. One of their drawbacks is
23
however, that they are slower to sample from, because they require multiple forward passes
24
for generating an image.
25
26
Diffusion refers to the process of turning a structured signal (an image) into noise
27
step-by-step. By simulating diffusion, we can generate noisy images from our training
28
images, and can train a neural network to try to denoise them. Using the trained network
29
we can simulate the opposite of diffusion, reverse diffusion, which is the process of an
30
image emerging from noise.
31
32
![diffusion process gif](https://i.imgur.com/dipPOfa.gif)
33
34
One-sentence summary: **diffusion models are trained to denoise noisy images, and can
35
generate images by iteratively denoising pure noise.**
36
37
### Goal of this example
38
39
This code example intends to be a minimal but feature-complete (with a generation quality
40
metric) implementation of diffusion models, with modest compute requirements and
41
reasonable performance. My implementation choices and hyperparameter tuning were done
42
with these goals in mind.
43
44
Since currently the literature of diffusion models is
45
[mathematically quite complex](https://arxiv.org/abs/2206.00364)
46
with multiple theoretical frameworks
47
([score matching](https://arxiv.org/abs/1907.05600),
48
[differential equations](https://arxiv.org/abs/2011.13456),
49
[Markov chains](https://arxiv.org/abs/2006.11239)) and sometimes even
50
[conflicting notations (see Appendix C.2)](https://arxiv.org/abs/2010.02502),
51
it can be daunting trying to understand
52
them. My view of these models in this example will be that they learn to separate a
53
noisy image into its image and Gaussian noise components.
54
55
In this example I made effort to break down all long mathematical expressions into
56
digestible pieces and gave all variables explanatory names. I also included numerous
57
links to relevant literature to help interested readers dive deeper into the topic, in
58
the hope that this code example will become a good starting point for practitioners
59
learning about diffusion models.
60
61
In the following sections, we will implement a continuous time version of
62
[Denoising Diffusion Implicit Models (DDIMs)](https://arxiv.org/abs/2010.02502)
63
with deterministic sampling.
64
"""
65
66
"""
67
## Setup
68
"""
69
70
import os
71
72
os.environ["KERAS_BACKEND"] = "tensorflow"
73
74
import math
75
import matplotlib.pyplot as plt
76
import tensorflow as tf
77
import tensorflow_datasets as tfds
78
79
import keras
80
from keras import layers
81
from keras import ops
82
83
"""
84
## Hyperparameters
85
"""
86
87
# data
88
dataset_name = "oxford_flowers102"
89
dataset_repetitions = 5
90
num_epochs = 1 # train for at least 50 epochs for good results
91
image_size = 64
92
# KID = Kernel Inception Distance, see related section
93
kid_image_size = 75
94
kid_diffusion_steps = 5
95
plot_diffusion_steps = 20
96
97
# sampling
98
min_signal_rate = 0.02
99
max_signal_rate = 0.95
100
101
# architecture
102
embedding_dims = 32
103
embedding_max_frequency = 1000.0
104
widths = [32, 64, 96, 128]
105
block_depth = 2
106
107
# optimization
108
batch_size = 64
109
ema = 0.999
110
learning_rate = 1e-3
111
weight_decay = 1e-4
112
113
"""
114
## Data pipeline
115
116
We will use the
117
[Oxford Flowers 102](https://www.tensorflow.org/datasets/catalog/oxford_flowers102)
118
dataset for
119
generating images of flowers, which is a diverse natural dataset containing around 8,000
120
images. Unfortunately the official splits are imbalanced, as most of the images are
121
contained in the test split. We create new splits (80% train, 20% validation) using the
122
[Tensorflow Datasets slicing API](https://www.tensorflow.org/datasets/splits). We apply
123
center crops as preprocessing, and repeat the dataset multiple times (reason given in the
124
next section).
125
"""
126
127
128
def preprocess_image(data):
129
# center crop image
130
height = ops.shape(data["image"])[0]
131
width = ops.shape(data["image"])[1]
132
crop_size = ops.minimum(height, width)
133
image = tf.image.crop_to_bounding_box(
134
data["image"],
135
(height - crop_size) // 2,
136
(width - crop_size) // 2,
137
crop_size,
138
crop_size,
139
)
140
141
# resize and clip
142
# for image downsampling it is important to turn on antialiasing
143
image = tf.image.resize(image, size=[image_size, image_size], antialias=True)
144
return ops.clip(image / 255.0, 0.0, 1.0)
145
146
147
def prepare_dataset(split):
148
# the validation dataset is shuffled as well, because data order matters
149
# for the KID estimation
150
return (
151
tfds.load(dataset_name, split=split, shuffle_files=True)
152
.map(preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)
153
.cache()
154
.repeat(dataset_repetitions)
155
.shuffle(10 * batch_size)
156
.batch(batch_size, drop_remainder=True)
157
.prefetch(buffer_size=tf.data.AUTOTUNE)
158
)
159
160
161
# load dataset
162
train_dataset = prepare_dataset("train[:80%]+validation[:80%]+test[:80%]")
163
val_dataset = prepare_dataset("train[80%:]+validation[80%:]+test[80%:]")
164
165
"""
166
## Kernel inception distance
167
168
[Kernel Inception Distance (KID)](https://arxiv.org/abs/1801.01401) is an image quality
169
metric which was proposed as a replacement for the popular
170
[Frechet Inception Distance (FID)](https://arxiv.org/abs/1706.08500).
171
I prefer KID to FID because it is simpler to
172
implement, can be estimated per-batch, and is computationally lighter. More details
173
[here](https://keras.io/examples/generative/gan_ada/#kernel-inception-distance).
174
175
In this example, the images are evaluated at the minimal possible resolution of the
176
Inception network (75x75 instead of 299x299), and the metric is only measured on the
177
validation set for computational efficiency. We also limit the number of sampling steps
178
at evaluation to 5 for the same reason.
179
180
Since the dataset is relatively small, we go over the train and validation splits
181
multiple times per epoch, because the KID estimation is noisy and compute-intensive, so
182
we want to evaluate only after many iterations, but for many iterations.
183
184
"""
185
186
187
@keras.saving.register_keras_serializable()
188
class KID(keras.metrics.Metric):
189
def __init__(self, name, **kwargs):
190
super().__init__(name=name, **kwargs)
191
192
# KID is estimated per batch and is averaged across batches
193
self.kid_tracker = keras.metrics.Mean(name="kid_tracker")
194
195
# a pretrained InceptionV3 is used without its classification layer
196
# transform the pixel values to the 0-255 range, then use the same
197
# preprocessing as during pretraining
198
self.encoder = keras.Sequential(
199
[
200
keras.Input(shape=(image_size, image_size, 3)),
201
layers.Rescaling(255.0),
202
layers.Resizing(height=kid_image_size, width=kid_image_size),
203
layers.Lambda(keras.applications.inception_v3.preprocess_input),
204
keras.applications.InceptionV3(
205
include_top=False,
206
input_shape=(kid_image_size, kid_image_size, 3),
207
weights="imagenet",
208
),
209
layers.GlobalAveragePooling2D(),
210
],
211
name="inception_encoder",
212
)
213
214
def polynomial_kernel(self, features_1, features_2):
215
feature_dimensions = ops.cast(ops.shape(features_1)[1], dtype="float32")
216
return (
217
features_1 @ ops.transpose(features_2) / feature_dimensions + 1.0
218
) ** 3.0
219
220
def update_state(self, real_images, generated_images, sample_weight=None):
221
real_features = self.encoder(real_images, training=False)
222
generated_features = self.encoder(generated_images, training=False)
223
224
# compute polynomial kernels using the two sets of features
225
kernel_real = self.polynomial_kernel(real_features, real_features)
226
kernel_generated = self.polynomial_kernel(
227
generated_features, generated_features
228
)
229
kernel_cross = self.polynomial_kernel(real_features, generated_features)
230
231
# estimate the squared maximum mean discrepancy using the average kernel values
232
batch_size = real_features.shape[0]
233
batch_size_f = ops.cast(batch_size, dtype="float32")
234
mean_kernel_real = ops.sum(kernel_real * (1.0 - ops.eye(batch_size))) / (
235
batch_size_f * (batch_size_f - 1.0)
236
)
237
mean_kernel_generated = ops.sum(
238
kernel_generated * (1.0 - ops.eye(batch_size))
239
) / (batch_size_f * (batch_size_f - 1.0))
240
mean_kernel_cross = ops.mean(kernel_cross)
241
kid = mean_kernel_real + mean_kernel_generated - 2.0 * mean_kernel_cross
242
243
# update the average KID estimate
244
self.kid_tracker.update_state(kid)
245
246
def result(self):
247
return self.kid_tracker.result()
248
249
def reset_state(self):
250
self.kid_tracker.reset_state()
251
252
253
"""
254
## Network architecture
255
256
Here we specify the architecture of the neural network that we will use for denoising. We
257
build a [U-Net](https://arxiv.org/abs/1505.04597) with identical input and output
258
dimensions. U-Net is a popular semantic segmentation architecture, whose main idea is
259
that it progressively downsamples and then upsamples its input image, and adds skip
260
connections between layers having the same resolution. These help with gradient flow and
261
avoid introducing a representation bottleneck, unlike usual
262
[autoencoders](https://www.deeplearningbook.org/contents/autoencoders.html). Based on
263
this, one can view
264
[diffusion models as denoising autoencoders](https://benanne.github.io/2022/01/31/diffusion.html)
265
without a bottleneck.
266
267
The network takes two inputs, the noisy images and the variances of their noise
268
components. The latter is required since denoising a signal requires different operations
269
at different levels of noise. We transform the noise variances using sinusoidal
270
embeddings, similarly to positional encodings used both in
271
[transformers](https://arxiv.org/abs/1706.03762) and
272
[NeRF](https://arxiv.org/abs/2003.08934). This helps the network to be
273
[highly sensitive](https://arxiv.org/abs/2006.10739) to the noise level, which is
274
crucial for good performance. We implement sinusoidal embeddings using a
275
[Lambda layer](https://keras.io/api/layers/core_layers/lambda/).
276
277
Some other considerations:
278
279
* We build the network using the
280
[Keras Functional API](https://keras.io/guides/functional_api/), and use
281
[closures](https://twitter.com/fchollet/status/1441927912836321280) to build blocks of
282
layers in a consistent style.
283
* [Diffusion models](https://arxiv.org/abs/2006.11239) embed the index of the timestep of
284
the diffusion process instead of the noise variance, while
285
[score-based models (Table 1)](https://arxiv.org/abs/2206.00364)
286
usually use some function of the noise level. I
287
prefer the latter so that we can change the sampling schedule at inference time, without
288
retraining the network.
289
* [Diffusion models](https://arxiv.org/abs/2006.11239) input the embedding to each
290
convolution block separately. We only input it at the start of the network for
291
simplicity, which in my experience barely decreases performance, because the skip and
292
residual connections help the information propagate through the network properly.
293
* In the literature it is common to use
294
[attention layers](https://keras.io/api/layers/attention_layers/multi_head_attention/)
295
at lower resolutions for better global coherence. I omitted it for simplicity.
296
* We disable the learnable center and scale parameters of the batch normalization layers,
297
since the following convolution layers make them redundant.
298
* We initialize the last convolution's kernel to all zeros as a good practice, making the
299
network predict only zeros after initialization, which is the mean of its targets. This
300
will improve behaviour at the start of training and make the mean squared error loss
301
start at exactly 1.
302
"""
303
304
305
@keras.saving.register_keras_serializable()
306
def sinusoidal_embedding(x):
307
embedding_min_frequency = 1.0
308
frequencies = ops.exp(
309
ops.linspace(
310
ops.log(embedding_min_frequency),
311
ops.log(embedding_max_frequency),
312
embedding_dims // 2,
313
)
314
)
315
angular_speeds = ops.cast(2.0 * math.pi * frequencies, "float32")
316
embeddings = ops.concatenate(
317
[ops.sin(angular_speeds * x), ops.cos(angular_speeds * x)], axis=3
318
)
319
return embeddings
320
321
322
def ResidualBlock(width):
323
def apply(x):
324
input_width = x.shape[3]
325
if input_width == width:
326
residual = x
327
else:
328
residual = layers.Conv2D(width, kernel_size=1)(x)
329
x = layers.BatchNormalization(center=False, scale=False)(x)
330
x = layers.Conv2D(width, kernel_size=3, padding="same", activation="swish")(x)
331
x = layers.Conv2D(width, kernel_size=3, padding="same")(x)
332
x = layers.Add()([x, residual])
333
return x
334
335
return apply
336
337
338
def DownBlock(width, block_depth):
339
def apply(x):
340
x, skips = x
341
for _ in range(block_depth):
342
x = ResidualBlock(width)(x)
343
skips.append(x)
344
x = layers.AveragePooling2D(pool_size=2)(x)
345
return x
346
347
return apply
348
349
350
def UpBlock(width, block_depth):
351
def apply(x):
352
x, skips = x
353
x = layers.UpSampling2D(size=2, interpolation="bilinear")(x)
354
for _ in range(block_depth):
355
x = layers.Concatenate()([x, skips.pop()])
356
x = ResidualBlock(width)(x)
357
return x
358
359
return apply
360
361
362
def get_network(image_size, widths, block_depth):
363
noisy_images = keras.Input(shape=(image_size, image_size, 3))
364
noise_variances = keras.Input(shape=(1, 1, 1))
365
366
e = layers.Lambda(sinusoidal_embedding, output_shape=(1, 1, 32))(noise_variances)
367
e = layers.UpSampling2D(size=image_size, interpolation="nearest")(e)
368
369
x = layers.Conv2D(widths[0], kernel_size=1)(noisy_images)
370
x = layers.Concatenate()([x, e])
371
372
skips = []
373
for width in widths[:-1]:
374
x = DownBlock(width, block_depth)([x, skips])
375
376
for _ in range(block_depth):
377
x = ResidualBlock(widths[-1])(x)
378
379
for width in reversed(widths[:-1]):
380
x = UpBlock(width, block_depth)([x, skips])
381
382
x = layers.Conv2D(3, kernel_size=1, kernel_initializer="zeros")(x)
383
384
return keras.Model([noisy_images, noise_variances], x, name="residual_unet")
385
386
387
"""
388
This showcases the power of the Functional API. Note how we built a relatively complex
389
U-Net with skip connections, residual blocks, multiple inputs, and sinusoidal embeddings
390
in 80 lines of code!
391
"""
392
393
"""
394
## Diffusion model
395
396
### Diffusion schedule
397
398
Let us say, that a diffusion process starts at time = 0, and ends at time = 1. This
399
variable will be called diffusion time, and can be either discrete (common in diffusion
400
models) or continuous (common in score-based models). I choose the latter, so that the
401
number of sampling steps can be changed at inference time.
402
403
We need to have a function that tells us at each point in the diffusion process the noise
404
levels and signal levels of the noisy image corresponding to the actual diffusion time.
405
This will be called the diffusion schedule (see `diffusion_schedule()`).
406
407
This schedule outputs two quantities: the `noise_rate` and the `signal_rate`
408
(corresponding to sqrt(1 - alpha) and sqrt(alpha) in the DDIM paper, respectively). We
409
generate the noisy image by weighting the random noise and the training image by their
410
corresponding rates and adding them together.
411
412
Since the (standard normal) random noises and the (normalized) images both have zero mean
413
and unit variance, the noise rate and signal rate can be interpreted as the standard
414
deviation of their components in the noisy image, while the squares of their rates can be
415
interpreted as their variance (or their power in the signal processing sense). The rates
416
will always be set so that their squared sum is 1, meaning that the noisy images will
417
always have unit variance, just like its unscaled components.
418
419
We will use a simplified, continuous version of the
420
[cosine schedule (Section 3.2)](https://arxiv.org/abs/2102.09672),
421
that is quite commonly used in the literature.
422
This schedule is symmetric, slow towards the start and end of the diffusion process, and
423
it also has a nice geometric interpretation, using the
424
[trigonometric properties of the unit circle](https://en.wikipedia.org/wiki/Unit_circle#/media/File:Circle-trig6.svg):
425
426
![diffusion schedule gif](https://i.imgur.com/JW9W0fA.gif)
427
428
### Training process
429
430
The training procedure (see `train_step()` and `denoise()`) of denoising diffusion models
431
is the following: we sample random diffusion times uniformly, and mix the training images
432
with random gaussian noises at rates corresponding to the diffusion times. Then, we train
433
the model to separate the noisy image to its two components.
434
435
Usually, the neural network is trained to predict the unscaled noise component, from
436
which the predicted image component can be calculated using the signal and noise rates.
437
Pixelwise
438
[mean squared error](https://keras.io/api/losses/regression_losses/#mean_squared_error-function) should
439
be used theoretically, however I recommend using
440
[mean absolute error](https://keras.io/api/losses/regression_losses/#mean_absolute_error-function)
441
instead (similarly to
442
[this](https://github.com/lucidrains/denoising-diffusion-pytorch/blob/master/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L371)
443
implementation), which produces better results on this dataset.
444
445
### Sampling (reverse diffusion)
446
447
When sampling (see `reverse_diffusion()`), at each step we take the previous estimate of
448
the noisy image and separate it into image and noise using our network. Then we recombine
449
these components using the signal and noise rate of the following step.
450
451
Though a similar view is shown in
452
[Equation 12 of DDIMs](https://arxiv.org/abs/2010.02502), I believe the above explanation
453
of the sampling equation is not widely known.
454
455
This example only implements the deterministic sampling procedure from DDIM, which
456
corresponds to *eta = 0* in the paper. One can also use stochastic sampling (in which
457
case the model becomes a
458
[Denoising Diffusion Probabilistic Model (DDPM)](https://arxiv.org/abs/2006.11239)),
459
where a part of the predicted noise is
460
replaced with the same or larger amount of random noise
461
([see Equation 16 and below](https://arxiv.org/abs/2010.02502)).
462
463
Stochastic sampling can be used without retraining the network (since both models are
464
trained the same way), and it can improve sample quality, while on the other hand
465
requiring more sampling steps usually.
466
"""
467
468
469
@keras.saving.register_keras_serializable()
470
class DiffusionModel(keras.Model):
471
def __init__(self, image_size, widths, block_depth):
472
super().__init__()
473
474
self.normalizer = layers.Normalization()
475
self.network = get_network(image_size, widths, block_depth)
476
self.ema_network = keras.models.clone_model(self.network)
477
478
def compile(self, **kwargs):
479
super().compile(**kwargs)
480
481
self.noise_loss_tracker = keras.metrics.Mean(name="n_loss")
482
self.image_loss_tracker = keras.metrics.Mean(name="i_loss")
483
self.kid = KID(name="kid")
484
485
@property
486
def metrics(self):
487
return [self.noise_loss_tracker, self.image_loss_tracker, self.kid]
488
489
def denormalize(self, images):
490
# convert the pixel values back to 0-1 range
491
images = self.normalizer.mean + images * self.normalizer.variance**0.5
492
return ops.clip(images, 0.0, 1.0)
493
494
def diffusion_schedule(self, diffusion_times):
495
# diffusion times -> angles
496
start_angle = ops.cast(ops.arccos(max_signal_rate), "float32")
497
end_angle = ops.cast(ops.arccos(min_signal_rate), "float32")
498
499
diffusion_angles = start_angle + diffusion_times * (end_angle - start_angle)
500
501
# angles -> signal and noise rates
502
signal_rates = ops.cos(diffusion_angles)
503
noise_rates = ops.sin(diffusion_angles)
504
# note that their squared sum is always: sin^2(x) + cos^2(x) = 1
505
506
return noise_rates, signal_rates
507
508
def denoise(self, noisy_images, noise_rates, signal_rates, training):
509
# the exponential moving average weights are used at evaluation
510
if training:
511
network = self.network
512
else:
513
network = self.ema_network
514
515
# predict noise component and calculate the image component using it
516
pred_noises = network([noisy_images, noise_rates**2], training=training)
517
pred_images = (noisy_images - noise_rates * pred_noises) / signal_rates
518
519
return pred_noises, pred_images
520
521
def reverse_diffusion(self, initial_noise, diffusion_steps):
522
# reverse diffusion = sampling
523
num_images = initial_noise.shape[0]
524
step_size = 1.0 / diffusion_steps
525
526
# important line:
527
# at the first sampling step, the "noisy image" is pure noise
528
# but its signal rate is assumed to be nonzero (min_signal_rate)
529
next_noisy_images = initial_noise
530
for step in range(diffusion_steps):
531
noisy_images = next_noisy_images
532
533
# separate the current noisy image to its components
534
diffusion_times = ops.ones((num_images, 1, 1, 1)) - step * step_size
535
noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)
536
pred_noises, pred_images = self.denoise(
537
noisy_images, noise_rates, signal_rates, training=False
538
)
539
# network used in eval mode
540
541
# remix the predicted components using the next signal and noise rates
542
next_diffusion_times = diffusion_times - step_size
543
next_noise_rates, next_signal_rates = self.diffusion_schedule(
544
next_diffusion_times
545
)
546
next_noisy_images = (
547
next_signal_rates * pred_images + next_noise_rates * pred_noises
548
)
549
# this new noisy image will be used in the next step
550
551
return pred_images
552
553
def generate(self, num_images, diffusion_steps):
554
# noise -> images -> denormalized images
555
initial_noise = keras.random.normal(
556
shape=(num_images, image_size, image_size, 3)
557
)
558
generated_images = self.reverse_diffusion(initial_noise, diffusion_steps)
559
generated_images = self.denormalize(generated_images)
560
return generated_images
561
562
def train_step(self, images):
563
# normalize images to have standard deviation of 1, like the noises
564
images = self.normalizer(images, training=True)
565
noises = keras.random.normal(shape=(batch_size, image_size, image_size, 3))
566
567
# sample uniform random diffusion times
568
diffusion_times = keras.random.uniform(
569
shape=(batch_size, 1, 1, 1), minval=0.0, maxval=1.0
570
)
571
noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)
572
# mix the images with noises accordingly
573
noisy_images = signal_rates * images + noise_rates * noises
574
575
with tf.GradientTape() as tape:
576
# train the network to separate noisy images to their components
577
pred_noises, pred_images = self.denoise(
578
noisy_images, noise_rates, signal_rates, training=True
579
)
580
581
noise_loss = self.loss(noises, pred_noises) # used for training
582
image_loss = self.loss(images, pred_images) # only used as metric
583
584
gradients = tape.gradient(noise_loss, self.network.trainable_weights)
585
self.optimizer.apply_gradients(zip(gradients, self.network.trainable_weights))
586
587
self.noise_loss_tracker.update_state(noise_loss)
588
self.image_loss_tracker.update_state(image_loss)
589
590
# track the exponential moving averages of weights
591
for weight, ema_weight in zip(self.network.weights, self.ema_network.weights):
592
ema_weight.assign(ema * ema_weight + (1 - ema) * weight)
593
594
# KID is not measured during the training phase for computational efficiency
595
return {m.name: m.result() for m in self.metrics[:-1]}
596
597
def test_step(self, images):
598
# normalize images to have standard deviation of 1, like the noises
599
images = self.normalizer(images, training=False)
600
noises = keras.random.normal(shape=(batch_size, image_size, image_size, 3))
601
602
# sample uniform random diffusion times
603
diffusion_times = keras.random.uniform(
604
shape=(batch_size, 1, 1, 1), minval=0.0, maxval=1.0
605
)
606
noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)
607
# mix the images with noises accordingly
608
noisy_images = signal_rates * images + noise_rates * noises
609
610
# use the network to separate noisy images to their components
611
pred_noises, pred_images = self.denoise(
612
noisy_images, noise_rates, signal_rates, training=False
613
)
614
615
noise_loss = self.loss(noises, pred_noises)
616
image_loss = self.loss(images, pred_images)
617
618
self.image_loss_tracker.update_state(image_loss)
619
self.noise_loss_tracker.update_state(noise_loss)
620
621
# measure KID between real and generated images
622
# this is computationally demanding, kid_diffusion_steps has to be small
623
images = self.denormalize(images)
624
generated_images = self.generate(
625
num_images=batch_size, diffusion_steps=kid_diffusion_steps
626
)
627
self.kid.update_state(images, generated_images)
628
629
return {m.name: m.result() for m in self.metrics}
630
631
def plot_images(self, epoch=None, logs=None, num_rows=3, num_cols=6):
632
# plot random generated images for visual evaluation of generation quality
633
generated_images = self.generate(
634
num_images=num_rows * num_cols,
635
diffusion_steps=plot_diffusion_steps,
636
)
637
638
plt.figure(figsize=(num_cols * 2.0, num_rows * 2.0))
639
for row in range(num_rows):
640
for col in range(num_cols):
641
index = row * num_cols + col
642
plt.subplot(num_rows, num_cols, index + 1)
643
plt.imshow(generated_images[index])
644
plt.axis("off")
645
plt.tight_layout()
646
plt.show()
647
plt.close()
648
649
650
"""
651
## Training
652
"""
653
654
# create and compile the model
655
model = DiffusionModel(image_size, widths, block_depth)
656
# below tensorflow 2.9:
657
# pip install tensorflow_addons
658
# import tensorflow_addons as tfa
659
# optimizer=tfa.optimizers.AdamW
660
model.compile(
661
optimizer=keras.optimizers.AdamW(
662
learning_rate=learning_rate, weight_decay=weight_decay
663
),
664
loss=keras.losses.mean_absolute_error,
665
)
666
# pixelwise mean absolute error is used as loss
667
668
# save the best model based on the validation KID metric
669
checkpoint_path = "checkpoints/diffusion_model.weights.h5"
670
checkpoint_callback = keras.callbacks.ModelCheckpoint(
671
filepath=checkpoint_path,
672
save_weights_only=True,
673
monitor="val_kid",
674
mode="min",
675
save_best_only=True,
676
)
677
678
# calculate mean and variance of training dataset for normalization
679
model.normalizer.adapt(train_dataset)
680
681
# run training and plot generated images periodically
682
model.fit(
683
train_dataset,
684
epochs=num_epochs,
685
validation_data=val_dataset,
686
callbacks=[
687
keras.callbacks.LambdaCallback(on_epoch_end=model.plot_images),
688
checkpoint_callback,
689
],
690
)
691
692
"""
693
## Inference
694
"""
695
696
# load the best model and generate images
697
model.load_weights(checkpoint_path)
698
model.plot_images()
699
700
"""
701
## Results
702
703
By running the training for at least 50 epochs (takes 2 hours on a T4 GPU and 30 minutes
704
on an A100 GPU), one can get high quality image generations using this code example.
705
706
The evolution of a batch of images over a 80 epoch training (color artifacts are due to
707
GIF compression):
708
709
![flowers training gif](https://i.imgur.com/FSCKtZq.gif)
710
711
Images generated using between 1 and 20 sampling steps from the same initial noise:
712
713
![flowers sampling steps gif](https://i.imgur.com/tM5LyH3.gif)
714
715
Interpolation (spherical) between initial noise samples:
716
717
![flowers interpolation gif](https://i.imgur.com/hk5Hd5o.gif)
718
719
Deterministic sampling process (noisy images on top, predicted images on bottom, 40
720
steps):
721
722
![flowers deterministic generation gif](https://i.imgur.com/wCvzynh.gif)
723
724
Stochastic sampling process (noisy images on top, predicted images on bottom, 80 steps):
725
726
![flowers stochastic generation gif](https://i.imgur.com/kRXOGzd.gif)
727
728
"""
729
730
"""
731
## Lessons learned
732
733
During preparation for this code example I have run numerous experiments using
734
[this repository](https://github.com/beresandras/clear-diffusion-keras).
735
In this section I list
736
the lessons learned and my recommendations in my subjective order of importance.
737
738
### Algorithmic tips
739
740
* **min. and max. signal rates**: I found the min. signal rate to be an important
741
hyperparameter. Setting it too low will make the generated images oversaturated, while
742
setting it too high will make them undersaturated. I recommend tuning it carefully. Also,
743
setting it to 0 will lead to a division by zero error. The max. signal rate can be set to
744
1, but I found that setting it lower slightly improves generation quality.
745
* **loss function**: While large models tend to use mean squared error (MSE) loss, I
746
recommend using mean absolute error (MAE) on this dataset. In my experience MSE loss
747
generates more diverse samples (it also seems to hallucinate more
748
[Section 3](https://arxiv.org/abs/2111.05826)), while MAE loss leads to smoother images.
749
I recommend trying both.
750
* **weight decay**: I did occasionally run into diverged trainings when scaling up the
751
model, and found that weight decay helps in avoiding instabilities at a low performance
752
cost. This is why I use
753
[AdamW](https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/experimental/AdamW)
754
instead of [Adam](https://keras.io/api/optimizers/adam/) in this example.
755
* **exponential moving average of weights**: This helps to reduce the variance of the KID
756
metric, and helps in averaging out short-term changes during training.
757
* **image augmentations**: Though I did not use image augmentations in this example, in
758
my experience adding horizontal flips to the training increases generation performance,
759
while random crops do not. Since we use a supervised denoising loss, overfitting can be
760
an issue, so image augmentations might be important on small datasets. One should also be
761
careful not to use
762
[leaky augmentations](https://keras.io/examples/generative/gan_ada/#invertible-data-augmentation),
763
which can be done following
764
[this method (end of Section 5)](https://arxiv.org/abs/2206.00364) for instance.
765
* **data normalization**: In the literature the pixel values of images are usually
766
converted to the -1 to 1 range. For theoretical correctness, I normalize the images to
767
have zero mean and unit variance instead, exactly like the random noises.
768
* **noise level input**: I chose to input the noise variance to the network, as it is
769
symmetrical under our sampling schedule. One could also input the noise rate (similar
770
performance), the signal rate (lower performance), or even the
771
[log-signal-to-noise ratio (Appendix B.1)](https://arxiv.org/abs/2107.00630)
772
(did not try, as its range is highly
773
dependent on the min. and max. signal rates, and would require adjusting the min.
774
embedding frequency accordingly).
775
* **gradient clipping**: Using global gradient clipping with a value of 1 can help with
776
training stability for large models, but decreased performance significantly in my
777
experience.
778
* **residual connection downscaling**: For
779
[deeper models (Appendix B)](https://arxiv.org/abs/2205.11487), scaling the residual
780
connections with 1/sqrt(2) can be helpful, but did not help in my case.
781
* **learning rate**: For me, [Adam optimizer's](https://keras.io/api/optimizers/adam/)
782
default learning rate of 1e-3 worked very well, but lower learning rates are more common
783
in the [literature (Tables 11-13)](https://arxiv.org/abs/2105.05233).
784
785
### Architectural tips
786
787
* **sinusoidal embedding**: Using sinusoidal embeddings on the noise level input of the
788
network is crucial for good performance. I recommend setting the min. embedding frequency
789
to the reciprocal of the range of this input, and since we use the noise variance in this
790
example, it can be left always at 1. The max. embedding frequency controls the smallest
791
change in the noise variance that the network will be sensitive to, and the embedding
792
dimensions set the number of frequency components in the embedding. In my experience the
793
performance is not too sensitive to these values.
794
* **skip connections**: Using skip connections in the network architecture is absolutely
795
critical, without them the model will fail to learn to denoise at a good performance.
796
* **residual connections**: In my experience residual connections also significantly
797
improve performance, but this might be due to the fact that we only input the noise
798
level embeddings to the first layer of the network instead of to all of them.
799
* **normalization**: When scaling up the model, I did occasionally encounter diverged
800
trainings, using normalization layers helped to mitigate this issue. In the literature it
801
is common to use
802
[GroupNormalization](https://www.tensorflow.org/addons/api_docs/python/tfa/layers/GroupNormalization)
803
(with 8 groups for example) or
804
[LayerNormalization](https://keras.io/api/layers/normalization_layers/layer_normalization/)
805
in the network, I however chose to use
806
[BatchNormalization](https://keras.io/api/layers/normalization_layers/batch_normalization/),
807
as it gave similar benefits in my experiments but was computationally lighter.
808
* **activations**: The choice of activation functions had a larger effect on generation
809
quality than I expected. In my experiments using non-monotonic activation functions
810
outperformed monotonic ones (such as
811
[ReLU](https://www.tensorflow.org/api_docs/python/tf/keras/activations/relu)), with
812
[Swish](https://www.tensorflow.org/api_docs/python/tf/keras/activations/swish) performing
813
the best (this is also what [Imagen uses, page 41](https://arxiv.org/abs/2205.11487)).
814
* **attention**: As mentioned earlier, it is common in the literature to use
815
[attention layers](https://keras.io/api/layers/attention_layers/multi_head_attention/) at low
816
resolutions for better global coherence. I omitted them for simplicity.
817
* **upsampling**:
818
[Bilinear and nearest neighbour upsampling](https://keras.io/api/layers/reshaping_layers/up_sampling2d/)
819
in the network performed similarly, however I did not try
820
[transposed convolutions](https://keras.io/api/layers/convolution_layers/convolution2d_transpose/).
821
822
For a similar list about GANs check out
823
[this Keras tutorial](https://keras.io/examples/generative/gan_ada/#gan-tips-and-tricks).
824
"""
825
826
"""
827
## What to try next?
828
829
If you would like to dive in deeper to the topic, I recommend checking out
830
[this repository](https://github.com/beresandras/clear-diffusion-keras) that I created in
831
preparation for this code example, which implements a wider range of features in a
832
similar style, such as:
833
834
* stochastic sampling
835
* second-order sampling based on the
836
[differential equation view of DDIMs (Equation 13)](https://arxiv.org/abs/2010.02502)
837
* more diffusion schedules
838
* more network output types: predicting image or
839
[velocity (Appendix D)](https://arxiv.org/abs/2202.00512) instead of noise
840
* more datasets
841
"""
842
843
"""
844
## Related works
845
846
* [Score-based generative modeling](https://yang-song.github.io/blog/2021/score/)
847
(blogpost)
848
* [What are diffusion models?](https://lilianweng.github.io/posts/2021-07-11-diffusion-models/)
849
(blogpost)
850
* [Annotated diffusion model](https://huggingface.co/blog/annotated-diffusion) (blogpost)
851
* [CVPR 2022 tutorial on diffusion models](https://cvpr2022-tutorial-diffusion-models.github.io/)
852
(slides available)
853
* [Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/abs/2206.00364):
854
attempts unifying diffusion methods under a common framework
855
* High-level video overviews: [1](https://www.youtube.com/watch?v=yTAMrHVG1ew),
856
[2](https://www.youtube.com/watch?v=344w5h24-h8)
857
* Detailed technical videos: [1](https://www.youtube.com/watch?v=fbLgFrlTnGU),
858
[2](https://www.youtube.com/watch?v=W-O7AZNzbzQ)
859
* Score-based generative models: [NCSN](https://arxiv.org/abs/1907.05600),
860
[NCSN+](https://arxiv.org/abs/2006.09011), [NCSN++](https://arxiv.org/abs/2011.13456)
861
* Denoising diffusion models: [DDPM](https://arxiv.org/abs/2006.11239),
862
[DDIM](https://arxiv.org/abs/2010.02502), [DDPM+](https://arxiv.org/abs/2102.09672),
863
[DDPM++](https://arxiv.org/abs/2105.05233)
864
* Large diffusion models: [GLIDE](https://arxiv.org/abs/2112.10741),
865
[DALL-E 2](https://arxiv.org/abs/2204.06125/), [Imagen](https://arxiv.org/abs/2205.11487)
866
867
868
"""
869
870