Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/vision/autoencoder.py
3507 views
1
"""
2
Title: Convolutional autoencoder for image denoising
3
Author: [Santiago L. Valdarrama](https://twitter.com/svpino)
4
Date created: 2021/03/01
5
Last modified: 2021/03/01
6
Description: How to train a deep convolutional autoencoder for image denoising.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
This example demonstrates how to implement a deep convolutional autoencoder
14
for image denoising, mapping noisy digits images from the MNIST dataset to
15
clean digits images. This implementation is based on an original blog post
16
titled [Building Autoencoders in Keras](https://blog.keras.io/building-autoencoders-in-keras.html)
17
by [François Chollet](https://twitter.com/fchollet).
18
"""
19
20
"""
21
## Setup
22
"""
23
24
import numpy as np
25
import matplotlib.pyplot as plt
26
27
from keras import layers
28
from keras.datasets import mnist
29
from keras.models import Model
30
31
32
def preprocess(array):
33
"""Normalizes the supplied array and reshapes it."""
34
array = array.astype("float32") / 255.0
35
array = np.reshape(array, (len(array), 28, 28, 1))
36
return array
37
38
39
def noise(array):
40
"""Adds random noise to each image in the supplied array."""
41
noise_factor = 0.4
42
noisy_array = array + noise_factor * np.random.normal(
43
loc=0.0, scale=1.0, size=array.shape
44
)
45
46
return np.clip(noisy_array, 0.0, 1.0)
47
48
49
def display(array1, array2):
50
"""Displays ten random images from each array."""
51
n = 10
52
indices = np.random.randint(len(array1), size=n)
53
images1 = array1[indices, :]
54
images2 = array2[indices, :]
55
56
plt.figure(figsize=(20, 4))
57
for i, (image1, image2) in enumerate(zip(images1, images2)):
58
ax = plt.subplot(2, n, i + 1)
59
plt.imshow(image1.reshape(28, 28))
60
plt.gray()
61
ax.get_xaxis().set_visible(False)
62
ax.get_yaxis().set_visible(False)
63
64
ax = plt.subplot(2, n, i + 1 + n)
65
plt.imshow(image2.reshape(28, 28))
66
plt.gray()
67
ax.get_xaxis().set_visible(False)
68
ax.get_yaxis().set_visible(False)
69
70
plt.show()
71
72
73
"""
74
## Prepare the data
75
"""
76
77
# Since we only need images from the dataset to encode and decode, we
78
# won't use the labels.
79
(train_data, _), (test_data, _) = mnist.load_data()
80
81
# Normalize and reshape the data
82
train_data = preprocess(train_data)
83
test_data = preprocess(test_data)
84
85
# Create a copy of the data with added noise
86
noisy_train_data = noise(train_data)
87
noisy_test_data = noise(test_data)
88
89
# Display the train data and a version of it with added noise
90
display(train_data, noisy_train_data)
91
92
"""
93
## Build the autoencoder
94
95
We are going to use the Functional API to build our convolutional autoencoder.
96
"""
97
98
input = layers.Input(shape=(28, 28, 1))
99
100
# Encoder
101
x = layers.Conv2D(32, (3, 3), activation="relu", padding="same")(input)
102
x = layers.MaxPooling2D((2, 2), padding="same")(x)
103
x = layers.Conv2D(32, (3, 3), activation="relu", padding="same")(x)
104
x = layers.MaxPooling2D((2, 2), padding="same")(x)
105
106
# Decoder
107
x = layers.Conv2DTranspose(32, (3, 3), strides=2, activation="relu", padding="same")(x)
108
x = layers.Conv2DTranspose(32, (3, 3), strides=2, activation="relu", padding="same")(x)
109
x = layers.Conv2D(1, (3, 3), activation="sigmoid", padding="same")(x)
110
111
# Autoencoder
112
autoencoder = Model(input, x)
113
autoencoder.compile(optimizer="adam", loss="binary_crossentropy")
114
autoencoder.summary()
115
116
"""
117
Now we can train our autoencoder using `train_data` as both our input data
118
and target. Notice we are setting up the validation data using the same
119
format.
120
"""
121
122
autoencoder.fit(
123
x=train_data,
124
y=train_data,
125
epochs=50,
126
batch_size=128,
127
shuffle=True,
128
validation_data=(test_data, test_data),
129
)
130
131
"""
132
Let's predict on our test dataset and display the original image together with
133
the prediction from our autoencoder.
134
135
Notice how the predictions are pretty close to the original images, although
136
not quite the same.
137
"""
138
139
predictions = autoencoder.predict(test_data)
140
display(test_data, predictions)
141
142
"""
143
Now that we know that our autoencoder works, let's retrain it using the noisy
144
data as our input and the clean data as our target. We want our autoencoder to
145
learn how to denoise the images.
146
"""
147
148
autoencoder.fit(
149
x=noisy_train_data,
150
y=train_data,
151
epochs=100,
152
batch_size=128,
153
shuffle=True,
154
validation_data=(noisy_test_data, test_data),
155
)
156
157
"""
158
Let's now predict on the noisy data and display the results of our autoencoder.
159
160
Notice how the autoencoder does an amazing job at removing the noise from the
161
input images.
162
"""
163
164
predictions = autoencoder.predict(noisy_test_data)
165
display(noisy_test_data, predictions)
166
167