Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/vision/edsr.py
3507 views
1
"""
2
Title: Enhanced Deep Residual Networks for single-image super-resolution
3
Author: Gitesh Chawda
4
Date created: 2022/04/07
5
Last modified: 2024/08/27
6
Description: Training an EDSR model on the DIV2K Dataset.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
In this example, we implement
14
[Enhanced Deep Residual Networks for Single Image Super-Resolution (EDSR)](https://arxiv.org/abs/1707.02921)
15
by Bee Lim, Sanghyun Son, Heewon Kim, Seungjun Nah, and Kyoung Mu Lee.
16
17
The EDSR architecture is based on the SRResNet architecture and consists of multiple
18
residual blocks. It uses constant scaling layers instead of batch normalization layers to
19
produce consistent results (input and output have similar distributions, thus
20
normalizing intermediate features may not be desirable). Instead of using a L2 loss (mean squared error),
21
the authors employed an L1 loss (mean absolute error), which performs better empirically.
22
23
Our implementation only includes 16 residual blocks with 64 channels.
24
25
Alternatively, as shown in the Keras example
26
[Image Super-Resolution using an Efficient Sub-Pixel CNN](https://keras.io/examples/vision/super_resolution_sub_pixel/#image-superresolution-using-an-efficient-subpixel-cnn),
27
you can do super-resolution using an ESPCN Model. According to the survey paper, EDSR is one of the top-five
28
best-performing super-resolution methods based on PSNR scores. However, it has more
29
parameters and requires more computational power than other approaches.
30
It has a PSNR value (≈34db) that is slightly higher than ESPCN (≈32db).
31
As per the survey paper, EDSR performs better than ESPCN.
32
33
Paper:
34
[A comprehensive review of deep learning based single image super-resolution](https://arxiv.org/abs/2102.09351)
35
36
Comparison Graph:
37
<img src="https://dfzljdn9uc3pi.cloudfront.net/2021/cs-621/1/fig-11-2x.jpg" width="500" />
38
"""
39
40
"""
41
## Imports
42
"""
43
import os
44
45
os.environ["KERAS_BACKEND"] = "tensorflow"
46
47
import numpy as np
48
import tensorflow as tf
49
import tensorflow_datasets as tfds
50
import matplotlib.pyplot as plt
51
52
import keras
53
from keras import layers
54
from keras import ops
55
56
AUTOTUNE = tf.data.AUTOTUNE
57
58
"""
59
## Download the training dataset
60
61
We use the DIV2K Dataset, a prominent single-image super-resolution dataset with 1,000
62
images of scenes with various sorts of degradations,
63
divided into 800 images for training, 100 images for validation, and 100
64
images for testing. We use 4x bicubic downsampled images as our "low quality" reference.
65
"""
66
67
# Download DIV2K from TF Datasets
68
# Using bicubic 4x degradation type
69
div2k_data = tfds.image.Div2k(config="bicubic_x4")
70
div2k_data.download_and_prepare()
71
72
# Taking train data from div2k_data object
73
train = div2k_data.as_dataset(split="train", as_supervised=True)
74
train_cache = train.cache()
75
# Validation data
76
val = div2k_data.as_dataset(split="validation", as_supervised=True)
77
val_cache = val.cache()
78
79
"""
80
## Flip, crop and resize images
81
"""
82
83
84
def flip_left_right(lowres_img, highres_img):
85
"""Flips Images to left and right."""
86
87
# Outputs random values from a uniform distribution in between 0 to 1
88
rn = keras.random.uniform(shape=(), maxval=1)
89
# If rn is less than 0.5 it returns original lowres_img and highres_img
90
# If rn is greater than 0.5 it returns flipped image
91
return ops.cond(
92
rn < 0.5,
93
lambda: (lowres_img, highres_img),
94
lambda: (
95
ops.flip(lowres_img),
96
ops.flip(highres_img),
97
),
98
)
99
100
101
def random_rotate(lowres_img, highres_img):
102
"""Rotates Images by 90 degrees."""
103
104
# Outputs random values from uniform distribution in between 0 to 4
105
rn = ops.cast(
106
keras.random.uniform(shape=(), maxval=4, dtype="float32"), dtype="int32"
107
)
108
# Here rn signifies number of times the image(s) are rotated by 90 degrees
109
return tf.image.rot90(lowres_img, rn), tf.image.rot90(highres_img, rn)
110
111
112
def random_crop(lowres_img, highres_img, hr_crop_size=96, scale=4):
113
"""Crop images.
114
115
low resolution images: 24x24
116
high resolution images: 96x96
117
"""
118
lowres_crop_size = hr_crop_size // scale # 96//4=24
119
lowres_img_shape = ops.shape(lowres_img)[:2] # (height,width)
120
121
lowres_width = ops.cast(
122
keras.random.uniform(
123
shape=(), maxval=lowres_img_shape[1] - lowres_crop_size + 1, dtype="float32"
124
),
125
dtype="int32",
126
)
127
lowres_height = ops.cast(
128
keras.random.uniform(
129
shape=(), maxval=lowres_img_shape[0] - lowres_crop_size + 1, dtype="float32"
130
),
131
dtype="int32",
132
)
133
134
highres_width = lowres_width * scale
135
highres_height = lowres_height * scale
136
137
lowres_img_cropped = lowres_img[
138
lowres_height : lowres_height + lowres_crop_size,
139
lowres_width : lowres_width + lowres_crop_size,
140
] # 24x24
141
highres_img_cropped = highres_img[
142
highres_height : highres_height + hr_crop_size,
143
highres_width : highres_width + hr_crop_size,
144
] # 96x96
145
146
return lowres_img_cropped, highres_img_cropped
147
148
149
"""
150
## Prepare a `tf.data.Dataset` object
151
152
We augment the training data with random horizontal flips and 90 rotations.
153
154
As low resolution images, we use 24x24 RGB input patches.
155
"""
156
157
158
def dataset_object(dataset_cache, training=True):
159
ds = dataset_cache
160
ds = ds.map(
161
lambda lowres, highres: random_crop(lowres, highres, scale=4),
162
num_parallel_calls=AUTOTUNE,
163
)
164
165
if training:
166
ds = ds.map(random_rotate, num_parallel_calls=AUTOTUNE)
167
ds = ds.map(flip_left_right, num_parallel_calls=AUTOTUNE)
168
# Batching Data
169
ds = ds.batch(16)
170
171
if training:
172
# Repeating Data, so that cardinality if dataset becomes infinte
173
ds = ds.repeat()
174
# prefetching allows later images to be prepared while the current image is being processed
175
ds = ds.prefetch(buffer_size=AUTOTUNE)
176
return ds
177
178
179
train_ds = dataset_object(train_cache, training=True)
180
val_ds = dataset_object(val_cache, training=False)
181
182
"""
183
## Visualize the data
184
185
Let's visualize a few sample images:
186
"""
187
188
lowres, highres = next(iter(train_ds))
189
190
# High Resolution Images
191
plt.figure(figsize=(10, 10))
192
for i in range(9):
193
ax = plt.subplot(3, 3, i + 1)
194
plt.imshow(highres[i].numpy().astype("uint8"))
195
plt.title(highres[i].shape)
196
plt.axis("off")
197
198
# Low Resolution Images
199
plt.figure(figsize=(10, 10))
200
for i in range(9):
201
ax = plt.subplot(3, 3, i + 1)
202
plt.imshow(lowres[i].numpy().astype("uint8"))
203
plt.title(lowres[i].shape)
204
plt.axis("off")
205
206
207
def PSNR(super_resolution, high_resolution):
208
"""Compute the peak signal-to-noise ratio, measures quality of image."""
209
# Max value of pixel is 255
210
psnr_value = tf.image.psnr(high_resolution, super_resolution, max_val=255)[0]
211
return psnr_value
212
213
214
"""
215
## Build the model
216
217
In the paper, the authors train three models: EDSR, MDSR, and a baseline model. In this code example,
218
we only train the baseline model.
219
220
### Comparison with model with three residual blocks
221
222
The residual block design of EDSR differs from that of ResNet. Batch normalization
223
layers have been removed (together with the final ReLU activation): since batch normalization
224
layers normalize the features, they hurt output value range flexibility.
225
It is thus better to remove them. Further, it also helps reduce the
226
amount of GPU RAM required by the model, since the batch normalization layers consume the same amount of
227
memory as the preceding convolutional layers.
228
229
<img src="https://miro.medium.com/max/1050/1*EPviXGqlGWotVtV2gqVvNg.png" width="500" />
230
"""
231
232
233
class EDSRModel(keras.Model):
234
def train_step(self, data):
235
# Unpack the data. Its structure depends on your model and
236
# on what you pass to `fit()`.
237
x, y = data
238
239
with tf.GradientTape() as tape:
240
y_pred = self(x, training=True) # Forward pass
241
# Compute the loss value
242
# (the loss function is configured in `compile()`)
243
loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)
244
245
# Compute gradients
246
trainable_vars = self.trainable_variables
247
gradients = tape.gradient(loss, trainable_vars)
248
# Update weights
249
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
250
# Update metrics (includes the metric that tracks the loss)
251
self.compiled_metrics.update_state(y, y_pred)
252
# Return a dict mapping metric names to current value
253
return {m.name: m.result() for m in self.metrics}
254
255
def predict_step(self, x):
256
# Adding dummy dimension using tf.expand_dims and converting to float32 using tf.cast
257
x = ops.cast(tf.expand_dims(x, axis=0), dtype="float32")
258
# Passing low resolution image to model
259
super_resolution_img = self(x, training=False)
260
# Clips the tensor from min(0) to max(255)
261
super_resolution_img = ops.clip(super_resolution_img, 0, 255)
262
# Rounds the values of a tensor to the nearest integer
263
super_resolution_img = ops.round(super_resolution_img)
264
# Removes dimensions of size 1 from the shape of a tensor and converting to uint8
265
super_resolution_img = ops.squeeze(
266
ops.cast(super_resolution_img, dtype="uint8"), axis=0
267
)
268
return super_resolution_img
269
270
271
# Residual Block
272
def ResBlock(inputs):
273
x = layers.Conv2D(64, 3, padding="same", activation="relu")(inputs)
274
x = layers.Conv2D(64, 3, padding="same")(x)
275
x = layers.Add()([inputs, x])
276
return x
277
278
279
# Upsampling Block
280
def Upsampling(inputs, factor=2, **kwargs):
281
x = layers.Conv2D(64 * (factor**2), 3, padding="same", **kwargs)(inputs)
282
x = layers.Lambda(lambda x: tf.nn.depth_to_space(x, block_size=factor))(x)
283
x = layers.Conv2D(64 * (factor**2), 3, padding="same", **kwargs)(x)
284
x = layers.Lambda(lambda x: tf.nn.depth_to_space(x, block_size=factor))(x)
285
return x
286
287
288
def make_model(num_filters, num_of_residual_blocks):
289
# Flexible Inputs to input_layer
290
input_layer = layers.Input(shape=(None, None, 3))
291
# Scaling Pixel Values
292
x = layers.Rescaling(scale=1.0 / 255)(input_layer)
293
x = x_new = layers.Conv2D(num_filters, 3, padding="same")(x)
294
295
# 16 residual blocks
296
for _ in range(num_of_residual_blocks):
297
x_new = ResBlock(x_new)
298
299
x_new = layers.Conv2D(num_filters, 3, padding="same")(x_new)
300
x = layers.Add()([x, x_new])
301
302
x = Upsampling(x)
303
x = layers.Conv2D(3, 3, padding="same")(x)
304
305
output_layer = layers.Rescaling(scale=255)(x)
306
return EDSRModel(input_layer, output_layer)
307
308
309
model = make_model(num_filters=64, num_of_residual_blocks=16)
310
311
"""
312
## Train the model
313
"""
314
315
# Using adam optimizer with initial learning rate as 1e-4, changing learning rate after 5000 steps to 5e-5
316
optim_edsr = keras.optimizers.Adam(
317
learning_rate=keras.optimizers.schedules.PiecewiseConstantDecay(
318
boundaries=[5000], values=[1e-4, 5e-5]
319
)
320
)
321
# Compiling model with loss as mean absolute error(L1 Loss) and metric as psnr
322
model.compile(optimizer=optim_edsr, loss="mae", metrics=[PSNR])
323
# Training for more epochs will improve results
324
model.fit(train_ds, epochs=100, steps_per_epoch=200, validation_data=val_ds)
325
326
"""
327
## Run inference on new images and plot the results
328
"""
329
330
331
def plot_results(lowres, preds):
332
"""
333
Displays low resolution image and super resolution image
334
"""
335
plt.figure(figsize=(24, 14))
336
plt.subplot(132), plt.imshow(lowres), plt.title("Low resolution")
337
plt.subplot(133), plt.imshow(preds), plt.title("Prediction")
338
plt.show()
339
340
341
for lowres, highres in val.take(10):
342
lowres = tf.image.random_crop(lowres, (150, 150, 3))
343
preds = model.predict_step(lowres)
344
plot_results(lowres, preds)
345
346
"""
347
## Final remarks
348
349
In this example, we implemented the EDSR model (Enhanced Deep Residual Networks for Single Image
350
Super-Resolution). You could improve the model accuracy by training the model for more epochs, as well as
351
training the model with a wider variety of inputs with mixed downgrading factors, so as to
352
be able to handle a greater range of real-world images.
353
354
You could also improve on the given baseline EDSR model by implementing EDSR+,
355
or MDSR( Multi-Scale super-resolution) and MDSR+,
356
which were proposed in the same paper.
357
"""
358
359