"""
Title: Convolutional autoencoder for image denoising
Author: [Santiago L. Valdarrama](https://twitter.com/svpino)
Date created: 2021/03/01
Last modified: 2021/03/01
Description: How to train a deep convolutional autoencoder for image denoising.
Accelerator: GPU
"""
"""
## Introduction
This example demonstrates how to implement a deep convolutional autoencoder
for image denoising, mapping noisy digits images from the MNIST dataset to
clean digits images. This implementation is based on an original blog post
titled [Building Autoencoders in Keras](https://blog.keras.io/building-autoencoders-in-keras.html)
by [François Chollet](https://twitter.com/fchollet).
"""
"""
## Setup
"""
import numpy as np
import matplotlib.pyplot as plt
from keras import layers
from keras.datasets import mnist
from keras.models import Model
def preprocess(array):
"""Normalizes the supplied array and reshapes it."""
array = array.astype("float32") / 255.0
array = np.reshape(array, (len(array), 28, 28, 1))
return array
def noise(array):
"""Adds random noise to each image in the supplied array."""
noise_factor = 0.4
noisy_array = array + noise_factor * np.random.normal(
loc=0.0, scale=1.0, size=array.shape
)
return np.clip(noisy_array, 0.0, 1.0)
def display(array1, array2):
"""Displays ten random images from each array."""
n = 10
indices = np.random.randint(len(array1), size=n)
images1 = array1[indices, :]
images2 = array2[indices, :]
plt.figure(figsize=(20, 4))
for i, (image1, image2) in enumerate(zip(images1, images2)):
ax = plt.subplot(2, n, i + 1)
plt.imshow(image1.reshape(28, 28))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
ax = plt.subplot(2, n, i + 1 + n)
plt.imshow(image2.reshape(28, 28))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.show()
"""
## Prepare the data
"""
(train_data, _), (test_data, _) = mnist.load_data()
train_data = preprocess(train_data)
test_data = preprocess(test_data)
noisy_train_data = noise(train_data)
noisy_test_data = noise(test_data)
display(train_data, noisy_train_data)
"""
## Build the autoencoder
We are going to use the Functional API to build our convolutional autoencoder.
"""
input = layers.Input(shape=(28, 28, 1))
x = layers.Conv2D(32, (3, 3), activation="relu", padding="same")(input)
x = layers.MaxPooling2D((2, 2), padding="same")(x)
x = layers.Conv2D(32, (3, 3), activation="relu", padding="same")(x)
x = layers.MaxPooling2D((2, 2), padding="same")(x)
x = layers.Conv2DTranspose(32, (3, 3), strides=2, activation="relu", padding="same")(x)
x = layers.Conv2DTranspose(32, (3, 3), strides=2, activation="relu", padding="same")(x)
x = layers.Conv2D(1, (3, 3), activation="sigmoid", padding="same")(x)
autoencoder = Model(input, x)
autoencoder.compile(optimizer="adam", loss="binary_crossentropy")
autoencoder.summary()
"""
Now we can train our autoencoder using `train_data` as both our input data
and target. Notice we are setting up the validation data using the same
format.
"""
autoencoder.fit(
x=train_data,
y=train_data,
epochs=50,
batch_size=128,
shuffle=True,
validation_data=(test_data, test_data),
)
"""
Let's predict on our test dataset and display the original image together with
the prediction from our autoencoder.
Notice how the predictions are pretty close to the original images, although
not quite the same.
"""
predictions = autoencoder.predict(test_data)
display(test_data, predictions)
"""
Now that we know that our autoencoder works, let's retrain it using the noisy
data as our input and the clean data as our target. We want our autoencoder to
learn how to denoise the images.
"""
autoencoder.fit(
x=noisy_train_data,
y=train_data,
epochs=100,
batch_size=128,
shuffle=True,
validation_data=(noisy_test_data, test_data),
)
"""
Let's now predict on the noisy data and display the results of our autoencoder.
Notice how the autoencoder does an amazing job at removing the noise from the
input images.
"""
predictions = autoencoder.predict(noisy_test_data)
display(noisy_test_data, predictions)