Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/generative/pixelcnn.py
3507 views
1
"""
2
Title: PixelCNN
3
Author: [ADMoreau](https://github.com/ADMoreau)
4
Date created: 2020/05/17
5
Last modified: 2020/05/23
6
Description: PixelCNN implemented in Keras.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
PixelCNN is a generative model proposed in 2016 by van den Oord et al.
14
(reference: [Conditional Image Generation with PixelCNN Decoders](https://arxiv.org/abs/1606.05328)).
15
It is designed to generate images (or other data types) iteratively
16
from an input vector where the probability distribution of prior elements dictates the
17
probability distribution of later elements. In the following example, images are generated
18
in this fashion, pixel-by-pixel, via a masked convolution kernel that only looks at data
19
from previously generated pixels (origin at the top left) to generate later pixels.
20
During inference, the output of the network is used as a probability distribution
21
from which new pixel values are sampled to generate a new image
22
(here, with MNIST, the pixels values are either black or white).
23
"""
24
25
import numpy as np
26
import keras
27
from keras import layers
28
from keras import ops
29
from tqdm import tqdm
30
31
"""
32
## Getting the Data
33
"""
34
35
# Model / data parameters
36
num_classes = 10
37
input_shape = (28, 28, 1)
38
n_residual_blocks = 5
39
# The data, split between train and test sets
40
(x, _), (y, _) = keras.datasets.mnist.load_data()
41
# Concatenate all the images together
42
data = np.concatenate((x, y), axis=0)
43
# Round all pixel values less than 33% of the max 256 value to 0
44
# anything above this value gets rounded up to 1 so that all values are either
45
# 0 or 1
46
data = np.where(data < (0.33 * 256), 0, 1)
47
data = data.astype(np.float32)
48
49
"""
50
## Create two classes for the requisite Layers for the model
51
"""
52
53
54
# The first layer is the PixelCNN layer. This layer simply
55
# builds on the 2D convolutional layer, but includes masking.
56
class PixelConvLayer(layers.Layer):
57
def __init__(self, mask_type, **kwargs):
58
super().__init__()
59
self.mask_type = mask_type
60
self.conv = layers.Conv2D(**kwargs)
61
62
def build(self, input_shape):
63
# Build the conv2d layer to initialize kernel variables
64
self.conv.build(input_shape)
65
# Use the initialized kernel to create the mask
66
kernel_shape = ops.shape(self.conv.kernel)
67
self.mask = np.zeros(shape=kernel_shape)
68
self.mask[: kernel_shape[0] // 2, ...] = 1.0
69
self.mask[kernel_shape[0] // 2, : kernel_shape[1] // 2, ...] = 1.0
70
if self.mask_type == "B":
71
self.mask[kernel_shape[0] // 2, kernel_shape[1] // 2, ...] = 1.0
72
73
def call(self, inputs):
74
self.conv.kernel.assign(self.conv.kernel * self.mask)
75
return self.conv(inputs)
76
77
78
# Next, we build our residual block layer.
79
# This is just a normal residual block, but based on the PixelConvLayer.
80
class ResidualBlock(keras.layers.Layer):
81
def __init__(self, filters, **kwargs):
82
super().__init__(**kwargs)
83
self.conv1 = keras.layers.Conv2D(
84
filters=filters, kernel_size=1, activation="relu"
85
)
86
self.pixel_conv = PixelConvLayer(
87
mask_type="B",
88
filters=filters // 2,
89
kernel_size=3,
90
activation="relu",
91
padding="same",
92
)
93
self.conv2 = keras.layers.Conv2D(
94
filters=filters, kernel_size=1, activation="relu"
95
)
96
97
def call(self, inputs):
98
x = self.conv1(inputs)
99
x = self.pixel_conv(x)
100
x = self.conv2(x)
101
return keras.layers.add([inputs, x])
102
103
104
"""
105
## Build the model based on the original paper
106
"""
107
108
inputs = keras.Input(shape=input_shape, batch_size=128)
109
x = PixelConvLayer(
110
mask_type="A", filters=128, kernel_size=7, activation="relu", padding="same"
111
)(inputs)
112
113
for _ in range(n_residual_blocks):
114
x = ResidualBlock(filters=128)(x)
115
116
for _ in range(2):
117
x = PixelConvLayer(
118
mask_type="B",
119
filters=128,
120
kernel_size=1,
121
strides=1,
122
activation="relu",
123
padding="valid",
124
)(x)
125
126
out = keras.layers.Conv2D(
127
filters=1, kernel_size=1, strides=1, activation="sigmoid", padding="valid"
128
)(x)
129
130
pixel_cnn = keras.Model(inputs, out)
131
adam = keras.optimizers.Adam(learning_rate=0.0005)
132
pixel_cnn.compile(optimizer=adam, loss="binary_crossentropy")
133
134
pixel_cnn.summary()
135
pixel_cnn.fit(
136
x=data, y=data, batch_size=128, epochs=50, validation_split=0.1, verbose=2
137
)
138
139
"""
140
## Demonstration
141
142
The PixelCNN cannot generate the full image at once. Instead, it must generate each pixel in
143
order, append the last generated pixel to the current image, and feed the image back into the
144
model to repeat the process.
145
"""
146
147
from IPython.display import Image, display
148
149
# Create an empty array of pixels.
150
batch = 4
151
pixels = np.zeros(shape=(batch,) + (pixel_cnn.input_shape)[1:])
152
batch, rows, cols, channels = pixels.shape
153
154
# Iterate over the pixels because generation has to be done sequentially pixel by pixel.
155
for row in tqdm(range(rows)):
156
for col in range(cols):
157
for channel in range(channels):
158
# Feed the whole array and retrieving the pixel value probabilities for the next
159
# pixel.
160
probs = pixel_cnn.predict(pixels, verbose=0)[:, row, col, channel]
161
# Use the probabilities to pick pixel values and append the values to the image
162
# frame.
163
pixels[:, row, col, channel] = ops.ceil(
164
probs - keras.random.uniform(probs.shape)
165
)
166
167
168
def deprocess_image(x):
169
# Stack the single channeled black and white image to rgb values.
170
x = np.stack((x, x, x), 2)
171
# Undo preprocessing
172
x *= 255.0
173
# Convert to uint8 and clip to the valid range [0, 255]
174
x = np.clip(x, 0, 255).astype("uint8")
175
return x
176
177
178
# Iterate over the generated images and plot them with matplotlib.
179
for i, pic in enumerate(pixels):
180
keras.utils.save_img(
181
"generated_image_{}.png".format(i), deprocess_image(np.squeeze(pic, -1))
182
)
183
184
display(Image("generated_image_0.png"))
185
display(Image("generated_image_1.png"))
186
display(Image("generated_image_2.png"))
187
display(Image("generated_image_3.png"))
188
189