Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/generative/ddpm.py
3507 views
1
"""
2
Title: Denoising Diffusion Probabilistic Model
3
Author: [A_K_Nain](https://twitter.com/A_K_Nain)
4
Date created: 2022/11/30
5
Last modified: 2022/12/07
6
Description: Generating images of flowers with denoising diffusion probabilistic models.
7
"""
8
9
"""
10
## Introduction
11
12
Generative modeling experienced tremendous growth in the last five years. Models like
13
VAEs, GANs, and flow-based models proved to be a great success in generating
14
high-quality content, especially images. Diffusion models are a new type of generative
15
model that has proven to be better than previous approaches.
16
17
Diffusion models are inspired by non-equilibrium thermodynamics, and they learn to
18
generate by denoising. Learning by denoising consists of two processes,
19
each of which is a Markov Chain. These are:
20
21
1. The forward process: In the forward process, we slowly add random noise to the data
22
in a series of time steps `(t1, t2, ..., tn )`. Samples at the current time step are
23
drawn from a Gaussian distribution where the mean of the distribution is conditioned
24
on the sample at the previous time step, and the variance of the distribution follows
25
a fixed schedule. At the end of the forward process, the samples end up with a pure
26
noise distribution.
27
28
2. The reverse process: During the reverse process, we try to undo the added noise at
29
every time step. We start with the pure noise distribution (the last step of the
30
forward process) and try to denoise the samples in the backward direction
31
`(tn, tn-1, ..., t1)`.
32
33
We implement the [Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239)
34
paper or DDPMs for short in this code example. It was the first paper demonstrating
35
the use of diffusion models for generating high-quality images. The authors proved
36
that a certain parameterization of diffusion models reveals an equivalence with
37
denoising score matching over multiple noise levels during training and with annealed
38
Langevin dynamics during sampling that generates the best quality results.
39
40
This paper replicates both the Markov chains (forward process and reverse process)
41
involved in the diffusion process but for images. The forward process is fixed and
42
gradually adds Gaussian noise to the images according to a fixed variance schedule
43
denoted by beta in the paper. This is what the diffusion process looks like in case
44
of images: (image -> noise::noise -> image)
45
46
![diffusion process gif](https://imgur.com/Yn7tho9.gif)
47
48
49
The paper describes two algorithms, one for training the model, and the other for
50
sampling from the trained model. Training is performed by optimizing the usual
51
variational bound on negative log-likelihood. The objective function is further
52
simplified, and the network is treated as a noise prediction network. Once optimized,
53
we can sample from the network to generate new images from noise samples. Here is an
54
overview of both algorithms as presented in the paper:
55
56
![ddpms](https://i.imgur.com/S7KH5hZ.png)
57
58
59
**Note:** DDPM is just one way of implementing a diffusion model. Also, the sampling
60
algorithm in the DDPM replicates the complete Markov chain. Hence, it's slow in
61
generating new samples compared to other generative models like GANs. Lots of research
62
efforts have been made to address this issue. One such example is Denoising Diffusion
63
Implicit Models, or DDIM for short, where the authors replaced the Markov chain with a
64
non-Markovian process to sample faster. You can find the code example for DDIM
65
[here](https://keras.io/examples/generative/ddim/)
66
67
Implementing a DDPM model is simple. We define a model that takes
68
two inputs: Images and the randomly sampled time steps. At each training step, we
69
perform the following operations to train our model:
70
71
1. Sample random noise to be added to the inputs.
72
2. Apply the forward process to diffuse the inputs with the sampled noise.
73
3. Your model takes these noisy samples as inputs and outputs the noise
74
prediction for each time step.
75
4. Given true noise and predicted noise, we calculate the loss values
76
5. We then calculate the gradients and update the model weights.
77
78
Given that our model knows how to denoise a noisy sample at a given time step,
79
we can leverage this idea to generate new samples, starting from a pure noise
80
distribution.
81
"""
82
83
"""
84
## Setup
85
"""
86
87
import math
88
import numpy as np
89
import matplotlib.pyplot as plt
90
91
# Requires TensorFlow >=2.11 for the GroupNormalization layer.
92
import tensorflow as tf
93
from tensorflow import keras
94
from tensorflow.keras import layers
95
import tensorflow_datasets as tfds
96
97
"""
98
## Hyperparameters
99
"""
100
101
batch_size = 32
102
num_epochs = 1 # Just for the sake of demonstration
103
total_timesteps = 1000
104
norm_groups = 8 # Number of groups used in GroupNormalization layer
105
learning_rate = 2e-4
106
107
img_size = 64
108
img_channels = 3
109
clip_min = -1.0
110
clip_max = 1.0
111
112
first_conv_channels = 64
113
channel_multiplier = [1, 2, 4, 8]
114
widths = [first_conv_channels * mult for mult in channel_multiplier]
115
has_attention = [False, False, True, True]
116
num_res_blocks = 2 # Number of residual blocks
117
118
dataset_name = "oxford_flowers102"
119
splits = ["train"]
120
121
122
"""
123
## Dataset
124
125
We use the [Oxford Flowers 102](https://www.tensorflow.org/datasets/catalog/oxford_flowers102)
126
dataset for generating images of flowers. In terms of preprocessing, we use center
127
cropping for resizing the images to the desired image size, and we rescale the pixel
128
values in the range `[-1.0, 1.0]`. This is in line with the range of the pixel values that
129
was applied by the authors of the [DDPMs paper](https://arxiv.org/abs/2006.11239). For
130
augmenting training data, we randomly flip the images left/right.
131
"""
132
133
134
# Load the dataset
135
(ds,) = tfds.load(dataset_name, split=splits, with_info=False, shuffle_files=True)
136
137
138
def augment(img):
139
"""Flips an image left/right randomly."""
140
return tf.image.random_flip_left_right(img)
141
142
143
def resize_and_rescale(img, size):
144
"""Resize the image to the desired size first and then
145
rescale the pixel values in the range [-1.0, 1.0].
146
147
Args:
148
img: Image tensor
149
size: Desired image size for resizing
150
Returns:
151
Resized and rescaled image tensor
152
"""
153
154
height = tf.shape(img)[0]
155
width = tf.shape(img)[1]
156
crop_size = tf.minimum(height, width)
157
158
img = tf.image.crop_to_bounding_box(
159
img,
160
(height - crop_size) // 2,
161
(width - crop_size) // 2,
162
crop_size,
163
crop_size,
164
)
165
166
# Resize
167
img = tf.cast(img, dtype=tf.float32)
168
img = tf.image.resize(img, size=size, antialias=True)
169
170
# Rescale the pixel values
171
img = img / 127.5 - 1.0
172
img = tf.clip_by_value(img, clip_min, clip_max)
173
return img
174
175
176
def train_preprocessing(x):
177
img = x["image"]
178
img = resize_and_rescale(img, size=(img_size, img_size))
179
img = augment(img)
180
return img
181
182
183
train_ds = (
184
ds.map(train_preprocessing, num_parallel_calls=tf.data.AUTOTUNE)
185
.batch(batch_size, drop_remainder=True)
186
.shuffle(batch_size * 2)
187
.prefetch(tf.data.AUTOTUNE)
188
)
189
190
191
"""
192
## Gaussian diffusion utilities
193
194
We define the forward process and the reverse process
195
as a separate utility. Most of the code in this utility has been borrowed
196
from the original implementation with some slight modifications.
197
"""
198
199
200
class GaussianDiffusion:
201
"""Gaussian diffusion utility.
202
203
Args:
204
beta_start: Start value of the scheduled variance
205
beta_end: End value of the scheduled variance
206
timesteps: Number of time steps in the forward process
207
"""
208
209
def __init__(
210
self,
211
beta_start=1e-4,
212
beta_end=0.02,
213
timesteps=1000,
214
clip_min=-1.0,
215
clip_max=1.0,
216
):
217
self.beta_start = beta_start
218
self.beta_end = beta_end
219
self.timesteps = timesteps
220
self.clip_min = clip_min
221
self.clip_max = clip_max
222
223
# Define the linear variance schedule
224
self.betas = betas = np.linspace(
225
beta_start,
226
beta_end,
227
timesteps,
228
dtype=np.float64, # Using float64 for better precision
229
)
230
self.num_timesteps = int(timesteps)
231
232
alphas = 1.0 - betas
233
alphas_cumprod = np.cumprod(alphas, axis=0)
234
alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
235
236
self.betas = tf.constant(betas, dtype=tf.float32)
237
self.alphas_cumprod = tf.constant(alphas_cumprod, dtype=tf.float32)
238
self.alphas_cumprod_prev = tf.constant(alphas_cumprod_prev, dtype=tf.float32)
239
240
# Calculations for diffusion q(x_t | x_{t-1}) and others
241
self.sqrt_alphas_cumprod = tf.constant(
242
np.sqrt(alphas_cumprod), dtype=tf.float32
243
)
244
245
self.sqrt_one_minus_alphas_cumprod = tf.constant(
246
np.sqrt(1.0 - alphas_cumprod), dtype=tf.float32
247
)
248
249
self.log_one_minus_alphas_cumprod = tf.constant(
250
np.log(1.0 - alphas_cumprod), dtype=tf.float32
251
)
252
253
self.sqrt_recip_alphas_cumprod = tf.constant(
254
np.sqrt(1.0 / alphas_cumprod), dtype=tf.float32
255
)
256
self.sqrt_recipm1_alphas_cumprod = tf.constant(
257
np.sqrt(1.0 / alphas_cumprod - 1), dtype=tf.float32
258
)
259
260
# Calculations for posterior q(x_{t-1} | x_t, x_0)
261
posterior_variance = (
262
betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
263
)
264
self.posterior_variance = tf.constant(posterior_variance, dtype=tf.float32)
265
266
# Log calculation clipped because the posterior variance is 0 at the beginning
267
# of the diffusion chain
268
self.posterior_log_variance_clipped = tf.constant(
269
np.log(np.maximum(posterior_variance, 1e-20)), dtype=tf.float32
270
)
271
272
self.posterior_mean_coef1 = tf.constant(
273
betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod),
274
dtype=tf.float32,
275
)
276
277
self.posterior_mean_coef2 = tf.constant(
278
(1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod),
279
dtype=tf.float32,
280
)
281
282
def _extract(self, a, t, x_shape):
283
"""Extract some coefficients at specified timesteps,
284
then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
285
286
Args:
287
a: Tensor to extract from
288
t: Timestep for which the coefficients are to be extracted
289
x_shape: Shape of the current batched samples
290
"""
291
batch_size = x_shape[0]
292
out = tf.gather(a, t)
293
return tf.reshape(out, [batch_size, 1, 1, 1])
294
295
def q_mean_variance(self, x_start, t):
296
"""Extracts the mean, and the variance at current timestep.
297
298
Args:
299
x_start: Initial sample (before the first diffusion step)
300
t: Current timestep
301
"""
302
x_start_shape = tf.shape(x_start)
303
mean = self._extract(self.sqrt_alphas_cumprod, t, x_start_shape) * x_start
304
variance = self._extract(1.0 - self.alphas_cumprod, t, x_start_shape)
305
log_variance = self._extract(
306
self.log_one_minus_alphas_cumprod, t, x_start_shape
307
)
308
return mean, variance, log_variance
309
310
def q_sample(self, x_start, t, noise):
311
"""Diffuse the data.
312
313
Args:
314
x_start: Initial sample (before the first diffusion step)
315
t: Current timestep
316
noise: Gaussian noise to be added at the current timestep
317
Returns:
318
Diffused samples at timestep `t`
319
"""
320
x_start_shape = tf.shape(x_start)
321
return (
322
self._extract(self.sqrt_alphas_cumprod, t, tf.shape(x_start)) * x_start
323
+ self._extract(self.sqrt_one_minus_alphas_cumprod, t, x_start_shape)
324
* noise
325
)
326
327
def predict_start_from_noise(self, x_t, t, noise):
328
x_t_shape = tf.shape(x_t)
329
return (
330
self._extract(self.sqrt_recip_alphas_cumprod, t, x_t_shape) * x_t
331
- self._extract(self.sqrt_recipm1_alphas_cumprod, t, x_t_shape) * noise
332
)
333
334
def q_posterior(self, x_start, x_t, t):
335
"""Compute the mean and variance of the diffusion
336
posterior q(x_{t-1} | x_t, x_0).
337
338
Args:
339
x_start: Stating point(sample) for the posterior computation
340
x_t: Sample at timestep `t`
341
t: Current timestep
342
Returns:
343
Posterior mean and variance at current timestep
344
"""
345
346
x_t_shape = tf.shape(x_t)
347
posterior_mean = (
348
self._extract(self.posterior_mean_coef1, t, x_t_shape) * x_start
349
+ self._extract(self.posterior_mean_coef2, t, x_t_shape) * x_t
350
)
351
posterior_variance = self._extract(self.posterior_variance, t, x_t_shape)
352
posterior_log_variance_clipped = self._extract(
353
self.posterior_log_variance_clipped, t, x_t_shape
354
)
355
return posterior_mean, posterior_variance, posterior_log_variance_clipped
356
357
def p_mean_variance(self, pred_noise, x, t, clip_denoised=True):
358
x_recon = self.predict_start_from_noise(x, t=t, noise=pred_noise)
359
if clip_denoised:
360
x_recon = tf.clip_by_value(x_recon, self.clip_min, self.clip_max)
361
362
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
363
x_start=x_recon, x_t=x, t=t
364
)
365
return model_mean, posterior_variance, posterior_log_variance
366
367
def p_sample(self, pred_noise, x, t, clip_denoised=True):
368
"""Sample from the diffusion model.
369
370
Args:
371
pred_noise: Noise predicted by the diffusion model
372
x: Samples at a given timestep for which the noise was predicted
373
t: Current timestep
374
clip_denoised (bool): Whether to clip the predicted noise
375
within the specified range or not.
376
"""
377
model_mean, _, model_log_variance = self.p_mean_variance(
378
pred_noise, x=x, t=t, clip_denoised=clip_denoised
379
)
380
noise = tf.random.normal(shape=x.shape, dtype=x.dtype)
381
# No noise when t == 0
382
nonzero_mask = tf.reshape(
383
1 - tf.cast(tf.equal(t, 0), tf.float32), [tf.shape(x)[0], 1, 1, 1]
384
)
385
return model_mean + nonzero_mask * tf.exp(0.5 * model_log_variance) * noise
386
387
388
"""
389
## Network architecture
390
391
U-Net, originally developed for semantic segmentation, is an architecture that is
392
widely used for implementing diffusion models but with some slight modifications:
393
394
1. The network accepts two inputs: Image and time step
395
2. Self-attention between the convolution blocks once we reach a specific resolution
396
(16x16 in the paper)
397
3. Group Normalization instead of weight normalization
398
399
We implement most of the things as used in the original paper. We use the
400
`swish` activation function throughout the network. We use the variance scaling
401
kernel initializer.
402
403
The only difference here is the number of groups used for the
404
`GroupNormalization` layer. For the flowers dataset,
405
we found that a value of `groups=8` produces better results
406
compared to the default value of `groups=32`. Dropout is optional and should be
407
used where chances of over fitting is high. In the paper, the authors used dropout
408
only when training on CIFAR10.
409
"""
410
411
412
# Kernel initializer to use
413
def kernel_init(scale):
414
scale = max(scale, 1e-10)
415
return keras.initializers.VarianceScaling(
416
scale, mode="fan_avg", distribution="uniform"
417
)
418
419
420
class AttentionBlock(layers.Layer):
421
"""Applies self-attention.
422
423
Args:
424
units: Number of units in the dense layers
425
groups: Number of groups to be used for GroupNormalization layer
426
"""
427
428
def __init__(self, units, groups=8, **kwargs):
429
self.units = units
430
self.groups = groups
431
super().__init__(**kwargs)
432
433
self.norm = layers.GroupNormalization(groups=groups)
434
self.query = layers.Dense(units, kernel_initializer=kernel_init(1.0))
435
self.key = layers.Dense(units, kernel_initializer=kernel_init(1.0))
436
self.value = layers.Dense(units, kernel_initializer=kernel_init(1.0))
437
self.proj = layers.Dense(units, kernel_initializer=kernel_init(0.0))
438
439
def call(self, inputs):
440
batch_size = tf.shape(inputs)[0]
441
height = tf.shape(inputs)[1]
442
width = tf.shape(inputs)[2]
443
scale = tf.cast(self.units, tf.float32) ** (-0.5)
444
445
inputs = self.norm(inputs)
446
q = self.query(inputs)
447
k = self.key(inputs)
448
v = self.value(inputs)
449
450
attn_score = tf.einsum("bhwc, bHWc->bhwHW", q, k) * scale
451
attn_score = tf.reshape(attn_score, [batch_size, height, width, height * width])
452
453
attn_score = tf.nn.softmax(attn_score, -1)
454
attn_score = tf.reshape(attn_score, [batch_size, height, width, height, width])
455
456
proj = tf.einsum("bhwHW,bHWc->bhwc", attn_score, v)
457
proj = self.proj(proj)
458
return inputs + proj
459
460
461
class TimeEmbedding(layers.Layer):
462
def __init__(self, dim, **kwargs):
463
super().__init__(**kwargs)
464
self.dim = dim
465
self.half_dim = dim // 2
466
self.emb = math.log(10000) / (self.half_dim - 1)
467
self.emb = tf.exp(tf.range(self.half_dim, dtype=tf.float32) * -self.emb)
468
469
def call(self, inputs):
470
inputs = tf.cast(inputs, dtype=tf.float32)
471
emb = inputs[:, None] * self.emb[None, :]
472
emb = tf.concat([tf.sin(emb), tf.cos(emb)], axis=-1)
473
return emb
474
475
476
def ResidualBlock(width, groups=8, activation_fn=keras.activations.swish):
477
def apply(inputs):
478
x, t = inputs
479
input_width = x.shape[3]
480
481
if input_width == width:
482
residual = x
483
else:
484
residual = layers.Conv2D(
485
width, kernel_size=1, kernel_initializer=kernel_init(1.0)
486
)(x)
487
488
temb = activation_fn(t)
489
temb = layers.Dense(width, kernel_initializer=kernel_init(1.0))(temb)[
490
:, None, None, :
491
]
492
493
x = layers.GroupNormalization(groups=groups)(x)
494
x = activation_fn(x)
495
x = layers.Conv2D(
496
width, kernel_size=3, padding="same", kernel_initializer=kernel_init(1.0)
497
)(x)
498
499
x = layers.Add()([x, temb])
500
x = layers.GroupNormalization(groups=groups)(x)
501
x = activation_fn(x)
502
503
x = layers.Conv2D(
504
width, kernel_size=3, padding="same", kernel_initializer=kernel_init(0.0)
505
)(x)
506
x = layers.Add()([x, residual])
507
return x
508
509
return apply
510
511
512
def DownSample(width):
513
def apply(x):
514
x = layers.Conv2D(
515
width,
516
kernel_size=3,
517
strides=2,
518
padding="same",
519
kernel_initializer=kernel_init(1.0),
520
)(x)
521
return x
522
523
return apply
524
525
526
def UpSample(width, interpolation="nearest"):
527
def apply(x):
528
x = layers.UpSampling2D(size=2, interpolation=interpolation)(x)
529
x = layers.Conv2D(
530
width, kernel_size=3, padding="same", kernel_initializer=kernel_init(1.0)
531
)(x)
532
return x
533
534
return apply
535
536
537
def TimeMLP(units, activation_fn=keras.activations.swish):
538
def apply(inputs):
539
temb = layers.Dense(
540
units, activation=activation_fn, kernel_initializer=kernel_init(1.0)
541
)(inputs)
542
temb = layers.Dense(units, kernel_initializer=kernel_init(1.0))(temb)
543
return temb
544
545
return apply
546
547
548
def build_model(
549
img_size,
550
img_channels,
551
widths,
552
has_attention,
553
num_res_blocks=2,
554
norm_groups=8,
555
interpolation="nearest",
556
activation_fn=keras.activations.swish,
557
):
558
image_input = layers.Input(
559
shape=(img_size, img_size, img_channels), name="image_input"
560
)
561
time_input = keras.Input(shape=(), dtype=tf.int64, name="time_input")
562
563
x = layers.Conv2D(
564
first_conv_channels,
565
kernel_size=(3, 3),
566
padding="same",
567
kernel_initializer=kernel_init(1.0),
568
)(image_input)
569
570
temb = TimeEmbedding(dim=first_conv_channels * 4)(time_input)
571
temb = TimeMLP(units=first_conv_channels * 4, activation_fn=activation_fn)(temb)
572
573
skips = [x]
574
575
# DownBlock
576
for i in range(len(widths)):
577
for _ in range(num_res_blocks):
578
x = ResidualBlock(
579
widths[i], groups=norm_groups, activation_fn=activation_fn
580
)([x, temb])
581
if has_attention[i]:
582
x = AttentionBlock(widths[i], groups=norm_groups)(x)
583
skips.append(x)
584
585
if widths[i] != widths[-1]:
586
x = DownSample(widths[i])(x)
587
skips.append(x)
588
589
# MiddleBlock
590
x = ResidualBlock(widths[-1], groups=norm_groups, activation_fn=activation_fn)(
591
[x, temb]
592
)
593
x = AttentionBlock(widths[-1], groups=norm_groups)(x)
594
x = ResidualBlock(widths[-1], groups=norm_groups, activation_fn=activation_fn)(
595
[x, temb]
596
)
597
598
# UpBlock
599
for i in reversed(range(len(widths))):
600
for _ in range(num_res_blocks + 1):
601
x = layers.Concatenate(axis=-1)([x, skips.pop()])
602
x = ResidualBlock(
603
widths[i], groups=norm_groups, activation_fn=activation_fn
604
)([x, temb])
605
if has_attention[i]:
606
x = AttentionBlock(widths[i], groups=norm_groups)(x)
607
608
if i != 0:
609
x = UpSample(widths[i], interpolation=interpolation)(x)
610
611
# End block
612
x = layers.GroupNormalization(groups=norm_groups)(x)
613
x = activation_fn(x)
614
x = layers.Conv2D(3, (3, 3), padding="same", kernel_initializer=kernel_init(0.0))(x)
615
return keras.Model([image_input, time_input], x, name="unet")
616
617
618
"""
619
## Training
620
621
We follow the same setup for training the diffusion model as described
622
in the paper. We use `Adam` optimizer with a learning rate of `2e-4`.
623
We use EMA on model parameters with a decay factor of 0.999. We
624
treat our model as noise prediction network i.e. at every training step, we
625
input a batch of images and corresponding time steps to our UNet,
626
and the network outputs the noise as predictions.
627
628
The only difference is that we aren't using the Kernel Inception Distance (KID)
629
or Frechet Inception Distance (FID) for evaluating the quality of generated
630
samples during training. This is because both these metrics are compute heavy
631
and are skipped for the brevity of implementation.
632
633
**Note: ** We are using mean squared error as the loss function which is aligned with
634
the paper, and theoretically makes sense. In practice, though, it is also common to
635
use mean absolute error or Huber loss as the loss function.
636
"""
637
638
639
class DiffusionModel(keras.Model):
640
def __init__(self, network, ema_network, timesteps, gdf_util, ema=0.999):
641
super().__init__()
642
self.network = network
643
self.ema_network = ema_network
644
self.timesteps = timesteps
645
self.gdf_util = gdf_util
646
self.ema = ema
647
648
def train_step(self, images):
649
# 1. Get the batch size
650
batch_size = tf.shape(images)[0]
651
652
# 2. Sample timesteps uniformly
653
t = tf.random.uniform(
654
minval=0, maxval=self.timesteps, shape=(batch_size,), dtype=tf.int64
655
)
656
657
with tf.GradientTape() as tape:
658
# 3. Sample random noise to be added to the images in the batch
659
noise = tf.random.normal(shape=tf.shape(images), dtype=images.dtype)
660
661
# 4. Diffuse the images with noise
662
images_t = self.gdf_util.q_sample(images, t, noise)
663
664
# 5. Pass the diffused images and time steps to the network
665
pred_noise = self.network([images_t, t], training=True)
666
667
# 6. Calculate the loss
668
loss = self.loss(noise, pred_noise)
669
670
# 7. Get the gradients
671
gradients = tape.gradient(loss, self.network.trainable_weights)
672
673
# 8. Update the weights of the network
674
self.optimizer.apply_gradients(zip(gradients, self.network.trainable_weights))
675
676
# 9. Updates the weight values for the network with EMA weights
677
for weight, ema_weight in zip(self.network.weights, self.ema_network.weights):
678
ema_weight.assign(self.ema * ema_weight + (1 - self.ema) * weight)
679
680
# 10. Return loss values
681
return {"loss": loss}
682
683
def generate_images(self, num_images=16):
684
# 1. Randomly sample noise (starting point for reverse process)
685
samples = tf.random.normal(
686
shape=(num_images, img_size, img_size, img_channels), dtype=tf.float32
687
)
688
# 2. Sample from the model iteratively
689
for t in reversed(range(0, self.timesteps)):
690
tt = tf.cast(tf.fill(num_images, t), dtype=tf.int64)
691
pred_noise = self.ema_network.predict(
692
[samples, tt], verbose=0, batch_size=num_images
693
)
694
samples = self.gdf_util.p_sample(
695
pred_noise, samples, tt, clip_denoised=True
696
)
697
# 3. Return generated samples
698
return samples
699
700
def plot_images(
701
self, epoch=None, logs=None, num_rows=2, num_cols=8, figsize=(12, 5)
702
):
703
"""Utility to plot images using the diffusion model during training."""
704
generated_samples = self.generate_images(num_images=num_rows * num_cols)
705
generated_samples = (
706
tf.clip_by_value(generated_samples * 127.5 + 127.5, 0.0, 255.0)
707
.numpy()
708
.astype(np.uint8)
709
)
710
711
_, ax = plt.subplots(num_rows, num_cols, figsize=figsize)
712
for i, image in enumerate(generated_samples):
713
if num_rows == 1:
714
ax[i].imshow(image)
715
ax[i].axis("off")
716
else:
717
ax[i // num_cols, i % num_cols].imshow(image)
718
ax[i // num_cols, i % num_cols].axis("off")
719
720
plt.tight_layout()
721
plt.show()
722
723
724
# Build the unet model
725
network = build_model(
726
img_size=img_size,
727
img_channels=img_channels,
728
widths=widths,
729
has_attention=has_attention,
730
num_res_blocks=num_res_blocks,
731
norm_groups=norm_groups,
732
activation_fn=keras.activations.swish,
733
)
734
ema_network = build_model(
735
img_size=img_size,
736
img_channels=img_channels,
737
widths=widths,
738
has_attention=has_attention,
739
num_res_blocks=num_res_blocks,
740
norm_groups=norm_groups,
741
activation_fn=keras.activations.swish,
742
)
743
ema_network.set_weights(network.get_weights()) # Initially the weights are the same
744
745
# Get an instance of the Gaussian Diffusion utilities
746
gdf_util = GaussianDiffusion(timesteps=total_timesteps)
747
748
# Get the model
749
model = DiffusionModel(
750
network=network,
751
ema_network=ema_network,
752
gdf_util=gdf_util,
753
timesteps=total_timesteps,
754
)
755
756
# Compile the model
757
model.compile(
758
loss=keras.losses.MeanSquaredError(),
759
optimizer=keras.optimizers.Adam(learning_rate=learning_rate),
760
)
761
762
# Train the model
763
model.fit(
764
train_ds,
765
epochs=num_epochs,
766
batch_size=batch_size,
767
callbacks=[keras.callbacks.LambdaCallback(on_epoch_end=model.plot_images)],
768
)
769
770
"""
771
## Results
772
773
We trained this model for 800 epochs on a V100 GPU,
774
and each epoch took almost 8 seconds to finish. We load those weights
775
here, and we generate a few samples starting from pure noise.
776
"""
777
778
"""shell
779
curl -LO https://github.com/AakashKumarNain/ddpms/releases/download/v3.0.0/checkpoints.zip
780
unzip -qq checkpoints.zip
781
"""
782
783
# Load the model weights
784
model.ema_network.load_weights("checkpoints/diffusion_model_checkpoint")
785
786
# Generate and plot some samples
787
model.plot_images(num_rows=4, num_cols=8)
788
789
790
"""
791
## Conclusion
792
793
We successfully implemented and trained a diffusion model exactly in the same
794
fashion as implemented by the authors of the DDPMs paper. You can find the
795
original implementation [here](https://github.com/hojonathanho/diffusion).
796
797
There are a few things that you can try to improve the model:
798
799
1. Increasing the width of each block. A bigger model can learn to denoise
800
in fewer epochs, though you may have to take care of overfitting.
801
802
2. We implemented the linear schedule for variance scheduling. You can implement
803
other schemes like cosine scheduling and compare the performance.
804
"""
805
806
"""
807
## References
808
809
1. [Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239)
810
2. [Author's implementation](https://github.com/hojonathanho/diffusion)
811
3. [A deep dive into DDPMs](https://magic-with-latents.github.io/latent/posts/ddpms/part3/)
812
4. [Denoising Diffusion Implicit Models](https://keras.io/examples/generative/ddim/)
813
5. [Annotated Diffusion Model](https://huggingface.co/blog/annotated-diffusion)
814
6. [AIAIART](https://www.youtube.com/watch?v=XTs7M6TSK9I&t=14s)
815
"""
816
817