Path: blob/master/examples/generative/md/pixelcnn.md
3506 views
PixelCNN
Author: ADMoreau
Date created: 2020/05/17
Last modified: 2020/05/23
Description: PixelCNN implemented in Keras.
Introduction
PixelCNN is a generative model proposed in 2016 by van den Oord et al. (reference: Conditional Image Generation with PixelCNN Decoders). It is designed to generate images (or other data types) iteratively from an input vector where the probability distribution of prior elements dictates the probability distribution of later elements. In the following example, images are generated in this fashion, pixel-by-pixel, via a masked convolution kernel that only looks at data from previously generated pixels (origin at the top left) to generate later pixels. During inference, the output of the network is used as a probability distribution from which new pixel values are sampled to generate a new image (here, with MNIST, the pixels values are either black or white).
Getting the Data
Create two classes for the requisite Layers for the model
Build the model based on the original paper
Model: "functional_1"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩ │ input_layer (InputLayer) │ (128, 28, 28, 1) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ pixel_conv_layer │ (128, 28, 28, 128) │ 6,400 │ │ (PixelConvLayer) │ │ │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ residual_block (ResidualBlock) │ (128, 28, 28, 128) │ 98,624 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ residual_block_1 │ (128, 28, 28, 128) │ 98,624 │ │ (ResidualBlock) │ │ │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ residual_block_2 │ (128, 28, 28, 128) │ 98,624 │ │ (ResidualBlock) │ │ │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ residual_block_3 │ (128, 28, 28, 128) │ 98,624 │ │ (ResidualBlock) │ │ │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ residual_block_4 │ (128, 28, 28, 128) │ 98,624 │ │ (ResidualBlock) │ │ │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ pixel_conv_layer_6 │ (128, 28, 28, 128) │ 16,512 │ │ (PixelConvLayer) │ │ │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ pixel_conv_layer_7 │ (128, 28, 28, 128) │ 16,512 │ │ (PixelConvLayer) │ │ │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ conv2d_18 (Conv2D) │ (128, 28, 28, 1) │ 129 │ └─────────────────────────────────┴───────────────────────────┴────────────┘
Total params: 532,673 (2.03 MB)
Trainable params: 532,673 (2.03 MB)
Non-trainable params: 0 (0.00 B)
<keras.src.callbacks.history.History at 0x7f45e6d78760>