Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/vision/mirnet.py
3507 views
1
"""
2
Title: Low-light image enhancement using MIRNet
3
Author: [Soumik Rakshit](http://github.com/soumik12345)
4
Date created: 2021/09/11
5
Last modified: 2023/07/15
6
Description: Implementing the MIRNet architecture for low-light image enhancement.
7
Accelerator: GPU
8
Converted to Keras 3 by: [Soumik Rakshit](http://github.com/soumik12345)
9
"""
10
11
"""
12
## Introduction
13
14
With the goal of recovering high-quality image content from its degraded version, image
15
restoration enjoys numerous applications, such as in
16
photography, security, medical imaging, and remote sensing. In this example, we implement the
17
**MIRNet** model for low-light image enhancement, a fully-convolutional architecture that
18
learns an enriched set of
19
features that combines contextual information from multiple scales, while
20
simultaneously preserving the high-resolution spatial details.
21
22
### References:
23
24
- [Learning Enriched Features for Real Image Restoration and Enhancement](https://arxiv.org/abs/2003.06792)
25
- [The Retinex Theory of Color Vision](http://www.cnbc.cmu.edu/~tai/cp_papers/E.Land_Retinex_Theory_ScientifcAmerican.pdf)
26
- [Two deterministic half-quadratic regularization algorithms for computed imaging](https://ieeexplore.ieee.org/document/413553)
27
"""
28
29
"""
30
## Downloading LOLDataset
31
32
The **LoL Dataset** has been created for low-light image enhancement.
33
It provides 485 images for training and 15 for testing. Each image pair in the dataset
34
consists of a low-light input image and its corresponding well-exposed reference image.
35
"""
36
37
import os
38
39
os.environ["KERAS_BACKEND"] = "tensorflow"
40
41
import random
42
import numpy as np
43
from glob import glob
44
from PIL import Image, ImageOps
45
import matplotlib.pyplot as plt
46
47
import keras
48
from keras import layers
49
50
import tensorflow as tf
51
52
"""shell
53
wget https://huggingface.co/datasets/geekyrakshit/LoL-Dataset/resolve/main/lol_dataset.zip
54
unzip -q lol_dataset.zip && rm lol_dataset.zip
55
"""
56
57
"""
58
## Creating a TensorFlow Dataset
59
60
We use 300 image pairs from the LoL Dataset's training set for training,
61
and we use the remaining 185 image pairs for validation.
62
We generate random crops of size `128 x 128` from the image pairs to be
63
used for both training and validation.
64
"""
65
66
random.seed(10)
67
68
IMAGE_SIZE = 128
69
BATCH_SIZE = 4
70
MAX_TRAIN_IMAGES = 300
71
72
73
def read_image(image_path):
74
image = tf.io.read_file(image_path)
75
image = tf.image.decode_png(image, channels=3)
76
image.set_shape([None, None, 3])
77
image = tf.cast(image, dtype=tf.float32) / 255.0
78
return image
79
80
81
def random_crop(low_image, enhanced_image):
82
low_image_shape = tf.shape(low_image)[:2]
83
low_w = tf.random.uniform(
84
shape=(), maxval=low_image_shape[1] - IMAGE_SIZE + 1, dtype=tf.int32
85
)
86
low_h = tf.random.uniform(
87
shape=(), maxval=low_image_shape[0] - IMAGE_SIZE + 1, dtype=tf.int32
88
)
89
low_image_cropped = low_image[
90
low_h : low_h + IMAGE_SIZE, low_w : low_w + IMAGE_SIZE
91
]
92
enhanced_image_cropped = enhanced_image[
93
low_h : low_h + IMAGE_SIZE, low_w : low_w + IMAGE_SIZE
94
]
95
# in order to avoid `NONE` during shape inference
96
low_image_cropped.set_shape([IMAGE_SIZE, IMAGE_SIZE, 3])
97
enhanced_image_cropped.set_shape([IMAGE_SIZE, IMAGE_SIZE, 3])
98
return low_image_cropped, enhanced_image_cropped
99
100
101
def load_data(low_light_image_path, enhanced_image_path):
102
low_light_image = read_image(low_light_image_path)
103
enhanced_image = read_image(enhanced_image_path)
104
low_light_image, enhanced_image = random_crop(low_light_image, enhanced_image)
105
return low_light_image, enhanced_image
106
107
108
def get_dataset(low_light_images, enhanced_images):
109
dataset = tf.data.Dataset.from_tensor_slices((low_light_images, enhanced_images))
110
dataset = dataset.map(load_data, num_parallel_calls=tf.data.AUTOTUNE)
111
dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)
112
return dataset
113
114
115
train_low_light_images = sorted(glob("./lol_dataset/our485/low/*"))[:MAX_TRAIN_IMAGES]
116
train_enhanced_images = sorted(glob("./lol_dataset/our485/high/*"))[:MAX_TRAIN_IMAGES]
117
118
val_low_light_images = sorted(glob("./lol_dataset/our485/low/*"))[MAX_TRAIN_IMAGES:]
119
val_enhanced_images = sorted(glob("./lol_dataset/our485/high/*"))[MAX_TRAIN_IMAGES:]
120
121
test_low_light_images = sorted(glob("./lol_dataset/eval15/low/*"))
122
test_enhanced_images = sorted(glob("./lol_dataset/eval15/high/*"))
123
124
125
train_dataset = get_dataset(train_low_light_images, train_enhanced_images)
126
val_dataset = get_dataset(val_low_light_images, val_enhanced_images)
127
128
129
print("Train Dataset:", train_dataset.element_spec)
130
print("Val Dataset:", val_dataset.element_spec)
131
132
"""
133
## MIRNet Model
134
135
Here are the main features of the MIRNet model:
136
137
- A feature extraction model that computes a complementary set of features across multiple
138
spatial scales, while maintaining the original high-resolution features to preserve
139
precise spatial details.
140
- A regularly repeated mechanism for information exchange, where the features across
141
multi-resolution branches are progressively fused together for improved representation
142
learning.
143
- A new approach to fuse multi-scale features using a selective kernel network
144
that dynamically combines variable receptive fields and faithfully preserves
145
the original feature information at each spatial resolution.
146
- A recursive residual design that progressively breaks down the input signal
147
in order to simplify the overall learning process, and allows the construction
148
of very deep networks.
149
150
151
![](https://raw.githubusercontent.com/soumik12345/MIRNet/master/assets/mirnet_architecture.png)
152
"""
153
154
"""
155
### Selective Kernel Feature Fusion
156
157
The Selective Kernel Feature Fusion or SKFF module performs dynamic adjustment of
158
receptive fields via two operations: **Fuse** and **Select**. The Fuse operator generates
159
global feature descriptors by combining the information from multi-resolution streams.
160
The Select operator uses these descriptors to recalibrate the feature maps (of different
161
streams) followed by their aggregation.
162
163
**Fuse**: The SKFF receives inputs from three parallel convolution streams carrying
164
different scales of information. We first combine these multi-scale features using an
165
element-wise sum, on which we apply Global Average Pooling (GAP) across the spatial
166
dimension. Next, we apply a channel- downscaling convolution layer to generate a compact
167
feature representation which passes through three parallel channel-upscaling convolution
168
layers (one for each resolution stream) and provides us with three feature descriptors.
169
170
**Select**: This operator applies the softmax function to the feature descriptors to
171
obtain the corresponding activations that are used to adaptively recalibrate multi-scale
172
feature maps. The aggregated features are defined as the sum of product of the corresponding
173
multi-scale feature and the feature descriptor.
174
175
![](https://i.imgur.com/7U6ixF6.png)
176
"""
177
178
179
def selective_kernel_feature_fusion(
180
multi_scale_feature_1, multi_scale_feature_2, multi_scale_feature_3
181
):
182
channels = list(multi_scale_feature_1.shape)[-1]
183
combined_feature = layers.Add()(
184
[multi_scale_feature_1, multi_scale_feature_2, multi_scale_feature_3]
185
)
186
gap = layers.GlobalAveragePooling2D()(combined_feature)
187
channel_wise_statistics = layers.Reshape((1, 1, channels))(gap)
188
compact_feature_representation = layers.Conv2D(
189
filters=channels // 8, kernel_size=(1, 1), activation="relu"
190
)(channel_wise_statistics)
191
feature_descriptor_1 = layers.Conv2D(
192
channels, kernel_size=(1, 1), activation="softmax"
193
)(compact_feature_representation)
194
feature_descriptor_2 = layers.Conv2D(
195
channels, kernel_size=(1, 1), activation="softmax"
196
)(compact_feature_representation)
197
feature_descriptor_3 = layers.Conv2D(
198
channels, kernel_size=(1, 1), activation="softmax"
199
)(compact_feature_representation)
200
feature_1 = multi_scale_feature_1 * feature_descriptor_1
201
feature_2 = multi_scale_feature_2 * feature_descriptor_2
202
feature_3 = multi_scale_feature_3 * feature_descriptor_3
203
aggregated_feature = layers.Add()([feature_1, feature_2, feature_3])
204
return aggregated_feature
205
206
207
"""
208
### Dual Attention Unit
209
210
The Dual Attention Unit or DAU is used to extract features in the convolutional streams.
211
While the SKFF block fuses information across multi-resolution branches, we also need a
212
mechanism to share information within a feature tensor, both along the spatial and the
213
channel dimensions which is done by the DAU block. The DAU suppresses less useful
214
features and only allows more informative ones to pass further. This feature
215
recalibration is achieved by using **Channel Attention** and **Spatial Attention**
216
mechanisms.
217
218
The **Channel Attention** branch exploits the inter-channel relationships of the
219
convolutional feature maps by applying squeeze and excitation operations. Given a feature
220
map, the squeeze operation applies Global Average Pooling across spatial dimensions to
221
encode global context, thus yielding a feature descriptor. The excitation operator passes
222
this feature descriptor through two convolutional layers followed by the sigmoid gating
223
and generates activations. Finally, the output of Channel Attention branch is obtained by
224
rescaling the input feature map with the output activations.
225
226
The **Spatial Attention** branch is designed to exploit the inter-spatial dependencies of
227
convolutional features. The goal of Spatial Attention is to generate a spatial attention
228
map and use it to recalibrate the incoming features. To generate the spatial attention
229
map, the Spatial Attention branch first independently applies Global Average Pooling and
230
Max Pooling operations on input features along the channel dimensions and concatenates
231
the outputs to form a resultant feature map which is then passed through a convolution
232
and sigmoid activation to obtain the spatial attention map. This spatial attention map is
233
then used to rescale the input feature map.
234
235
![](https://i.imgur.com/Dl0IwQs.png)
236
"""
237
238
239
class ChannelPooling(layers.Layer):
240
def __init__(self, axis=-1, *args, **kwargs):
241
super().__init__(*args, **kwargs)
242
self.axis = axis
243
self.concat = layers.Concatenate(axis=self.axis)
244
245
def call(self, inputs):
246
average_pooling = tf.expand_dims(tf.reduce_mean(inputs, axis=-1), axis=-1)
247
max_pooling = tf.expand_dims(tf.reduce_max(inputs, axis=-1), axis=-1)
248
return self.concat([average_pooling, max_pooling])
249
250
def get_config(self):
251
config = super().get_config()
252
config.update({"axis": self.axis})
253
254
255
def spatial_attention_block(input_tensor):
256
compressed_feature_map = ChannelPooling(axis=-1)(input_tensor)
257
feature_map = layers.Conv2D(1, kernel_size=(1, 1))(compressed_feature_map)
258
feature_map = keras.activations.sigmoid(feature_map)
259
return input_tensor * feature_map
260
261
262
def channel_attention_block(input_tensor):
263
channels = list(input_tensor.shape)[-1]
264
average_pooling = layers.GlobalAveragePooling2D()(input_tensor)
265
feature_descriptor = layers.Reshape((1, 1, channels))(average_pooling)
266
feature_activations = layers.Conv2D(
267
filters=channels // 8, kernel_size=(1, 1), activation="relu"
268
)(feature_descriptor)
269
feature_activations = layers.Conv2D(
270
filters=channels, kernel_size=(1, 1), activation="sigmoid"
271
)(feature_activations)
272
return input_tensor * feature_activations
273
274
275
def dual_attention_unit_block(input_tensor):
276
channels = list(input_tensor.shape)[-1]
277
feature_map = layers.Conv2D(
278
channels, kernel_size=(3, 3), padding="same", activation="relu"
279
)(input_tensor)
280
feature_map = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(
281
feature_map
282
)
283
channel_attention = channel_attention_block(feature_map)
284
spatial_attention = spatial_attention_block(feature_map)
285
concatenation = layers.Concatenate(axis=-1)([channel_attention, spatial_attention])
286
concatenation = layers.Conv2D(channels, kernel_size=(1, 1))(concatenation)
287
return layers.Add()([input_tensor, concatenation])
288
289
290
"""
291
### Multi-Scale Residual Block
292
293
The Multi-Scale Residual Block is capable of generating a spatially-precise output by
294
maintaining high-resolution representations, while receiving rich contextual information
295
from low-resolutions. The MRB consists of multiple (three in this paper)
296
fully-convolutional streams connected in parallel. It allows information exchange across
297
parallel streams in order to consolidate the high-resolution features with the help of
298
low-resolution features, and vice versa. The MIRNet employs a recursive residual design
299
(with skip connections) to ease the flow of information during the learning process. In
300
order to maintain the residual nature of our architecture, residual resizing modules are
301
used to perform downsampling and upsampling operations that are used in the Multi-scale
302
Residual Block.
303
304
![](https://i.imgur.com/wzZKV57.png)
305
"""
306
307
# Recursive Residual Modules
308
309
310
def down_sampling_module(input_tensor):
311
channels = list(input_tensor.shape)[-1]
312
main_branch = layers.Conv2D(channels, kernel_size=(1, 1), activation="relu")(
313
input_tensor
314
)
315
main_branch = layers.Conv2D(
316
channels, kernel_size=(3, 3), padding="same", activation="relu"
317
)(main_branch)
318
main_branch = layers.MaxPooling2D()(main_branch)
319
main_branch = layers.Conv2D(channels * 2, kernel_size=(1, 1))(main_branch)
320
skip_branch = layers.MaxPooling2D()(input_tensor)
321
skip_branch = layers.Conv2D(channels * 2, kernel_size=(1, 1))(skip_branch)
322
return layers.Add()([skip_branch, main_branch])
323
324
325
def up_sampling_module(input_tensor):
326
channels = list(input_tensor.shape)[-1]
327
main_branch = layers.Conv2D(channels, kernel_size=(1, 1), activation="relu")(
328
input_tensor
329
)
330
main_branch = layers.Conv2D(
331
channels, kernel_size=(3, 3), padding="same", activation="relu"
332
)(main_branch)
333
main_branch = layers.UpSampling2D()(main_branch)
334
main_branch = layers.Conv2D(channels // 2, kernel_size=(1, 1))(main_branch)
335
skip_branch = layers.UpSampling2D()(input_tensor)
336
skip_branch = layers.Conv2D(channels // 2, kernel_size=(1, 1))(skip_branch)
337
return layers.Add()([skip_branch, main_branch])
338
339
340
# MRB Block
341
def multi_scale_residual_block(input_tensor, channels):
342
# features
343
level1 = input_tensor
344
level2 = down_sampling_module(input_tensor)
345
level3 = down_sampling_module(level2)
346
# DAU
347
level1_dau = dual_attention_unit_block(level1)
348
level2_dau = dual_attention_unit_block(level2)
349
level3_dau = dual_attention_unit_block(level3)
350
# SKFF
351
level1_skff = selective_kernel_feature_fusion(
352
level1_dau,
353
up_sampling_module(level2_dau),
354
up_sampling_module(up_sampling_module(level3_dau)),
355
)
356
level2_skff = selective_kernel_feature_fusion(
357
down_sampling_module(level1_dau),
358
level2_dau,
359
up_sampling_module(level3_dau),
360
)
361
level3_skff = selective_kernel_feature_fusion(
362
down_sampling_module(down_sampling_module(level1_dau)),
363
down_sampling_module(level2_dau),
364
level3_dau,
365
)
366
# DAU 2
367
level1_dau_2 = dual_attention_unit_block(level1_skff)
368
level2_dau_2 = up_sampling_module((dual_attention_unit_block(level2_skff)))
369
level3_dau_2 = up_sampling_module(
370
up_sampling_module(dual_attention_unit_block(level3_skff))
371
)
372
# SKFF 2
373
skff_ = selective_kernel_feature_fusion(level1_dau_2, level2_dau_2, level3_dau_2)
374
conv = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(skff_)
375
return layers.Add()([input_tensor, conv])
376
377
378
"""
379
### MIRNet Model
380
"""
381
382
383
def recursive_residual_group(input_tensor, num_mrb, channels):
384
conv1 = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(input_tensor)
385
for _ in range(num_mrb):
386
conv1 = multi_scale_residual_block(conv1, channels)
387
conv2 = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(conv1)
388
return layers.Add()([conv2, input_tensor])
389
390
391
def mirnet_model(num_rrg, num_mrb, channels):
392
input_tensor = keras.Input(shape=[None, None, 3])
393
x1 = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(input_tensor)
394
for _ in range(num_rrg):
395
x1 = recursive_residual_group(x1, num_mrb, channels)
396
conv = layers.Conv2D(3, kernel_size=(3, 3), padding="same")(x1)
397
output_tensor = layers.Add()([input_tensor, conv])
398
return keras.Model(input_tensor, output_tensor)
399
400
401
model = mirnet_model(num_rrg=3, num_mrb=2, channels=64)
402
403
"""
404
## Training
405
406
- We train MIRNet using **Charbonnier Loss** as the loss function and **Adam
407
Optimizer** with a learning rate of `1e-4`.
408
- We use **Peak Signal Noise Ratio** or PSNR as a metric which is an expression for the
409
ratio between the maximum possible value (power) of a signal and the power of distorting
410
noise that affects the quality of its representation.
411
"""
412
413
414
def charbonnier_loss(y_true, y_pred):
415
return tf.reduce_mean(tf.sqrt(tf.square(y_true - y_pred) + tf.square(1e-3)))
416
417
418
def peak_signal_noise_ratio(y_true, y_pred):
419
return tf.image.psnr(y_pred, y_true, max_val=255.0)
420
421
422
optimizer = keras.optimizers.Adam(learning_rate=1e-4)
423
model.compile(
424
optimizer=optimizer,
425
loss=charbonnier_loss,
426
metrics=[peak_signal_noise_ratio],
427
)
428
429
history = model.fit(
430
train_dataset,
431
validation_data=val_dataset,
432
epochs=50,
433
callbacks=[
434
keras.callbacks.ReduceLROnPlateau(
435
monitor="val_peak_signal_noise_ratio",
436
factor=0.5,
437
patience=5,
438
verbose=1,
439
min_delta=1e-7,
440
mode="max",
441
)
442
],
443
)
444
445
446
def plot_history(value, name):
447
plt.plot(history.history[value], label=f"train_{name.lower()}")
448
plt.plot(history.history[f"val_{value}"], label=f"val_{name.lower()}")
449
plt.xlabel("Epochs")
450
plt.ylabel(name)
451
plt.title(f"Train and Validation {name} Over Epochs", fontsize=14)
452
plt.legend()
453
plt.grid()
454
plt.show()
455
456
457
plot_history("loss", "Loss")
458
plot_history("peak_signal_noise_ratio", "PSNR")
459
460
"""
461
## Inference
462
"""
463
464
465
def plot_results(images, titles, figure_size=(12, 12)):
466
fig = plt.figure(figsize=figure_size)
467
for i in range(len(images)):
468
fig.add_subplot(1, len(images), i + 1).set_title(titles[i])
469
_ = plt.imshow(images[i])
470
plt.axis("off")
471
plt.show()
472
473
474
def infer(original_image):
475
image = keras.utils.img_to_array(original_image)
476
image = image.astype("float32") / 255.0
477
image = np.expand_dims(image, axis=0)
478
output = model.predict(image, verbose=0)
479
output_image = output[0] * 255.0
480
output_image = output_image.clip(0, 255)
481
output_image = output_image.reshape(
482
(np.shape(output_image)[0], np.shape(output_image)[1], 3)
483
)
484
output_image = Image.fromarray(np.uint8(output_image))
485
original_image = Image.fromarray(np.uint8(original_image))
486
return output_image
487
488
489
"""
490
### Inference on Test Images
491
492
We compare the test images from LOLDataset enhanced by MIRNet with images
493
enhanced via the `PIL.ImageOps.autocontrast()` function.
494
495
You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/lowlight-enhance-mirnet)
496
and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/Enhance_Low_Light_Image).
497
"""
498
499
500
for low_light_image in random.sample(test_low_light_images, 6):
501
original_image = Image.open(low_light_image)
502
enhanced_image = infer(original_image)
503
plot_results(
504
[original_image, ImageOps.autocontrast(original_image), enhanced_image],
505
["Original", "PIL Autocontrast", "MIRNet Enhanced"],
506
(20, 12),
507
)
508
509