Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/generative/vae.py
3507 views
1
"""
2
Title: Variational AutoEncoder
3
Author: [fchollet](https://twitter.com/fchollet)
4
Date created: 2020/05/03
5
Last modified: 2024/04/24
6
Description: Convolutional Variational AutoEncoder (VAE) trained on MNIST digits.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Setup
12
"""
13
14
import os
15
16
os.environ["KERAS_BACKEND"] = "tensorflow"
17
18
import numpy as np
19
import tensorflow as tf
20
import keras
21
from keras import ops
22
from keras import layers
23
24
"""
25
## Create a sampling layer
26
"""
27
28
29
class Sampling(layers.Layer):
30
"""Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""
31
32
def __init__(self, **kwargs):
33
super().__init__(**kwargs)
34
self.seed_generator = keras.random.SeedGenerator(1337)
35
36
def call(self, inputs):
37
z_mean, z_log_var = inputs
38
batch = ops.shape(z_mean)[0]
39
dim = ops.shape(z_mean)[1]
40
epsilon = keras.random.normal(shape=(batch, dim), seed=self.seed_generator)
41
return z_mean + ops.exp(0.5 * z_log_var) * epsilon
42
43
44
"""
45
## Build the encoder
46
"""
47
48
latent_dim = 2
49
50
encoder_inputs = keras.Input(shape=(28, 28, 1))
51
x = layers.Conv2D(32, 3, activation="relu", strides=2, padding="same")(encoder_inputs)
52
x = layers.Conv2D(64, 3, activation="relu", strides=2, padding="same")(x)
53
x = layers.Flatten()(x)
54
x = layers.Dense(16, activation="relu")(x)
55
z_mean = layers.Dense(latent_dim, name="z_mean")(x)
56
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
57
z = Sampling()([z_mean, z_log_var])
58
encoder = keras.Model(encoder_inputs, [z_mean, z_log_var, z], name="encoder")
59
encoder.summary()
60
61
"""
62
## Build the decoder
63
"""
64
65
latent_inputs = keras.Input(shape=(latent_dim,))
66
x = layers.Dense(7 * 7 * 64, activation="relu")(latent_inputs)
67
x = layers.Reshape((7, 7, 64))(x)
68
x = layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same")(x)
69
x = layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same")(x)
70
decoder_outputs = layers.Conv2DTranspose(1, 3, activation="sigmoid", padding="same")(x)
71
decoder = keras.Model(latent_inputs, decoder_outputs, name="decoder")
72
decoder.summary()
73
74
"""
75
## Define the VAE as a `Model` with a custom `train_step`
76
"""
77
78
79
class VAE(keras.Model):
80
def __init__(self, encoder, decoder, **kwargs):
81
super().__init__(**kwargs)
82
self.encoder = encoder
83
self.decoder = decoder
84
self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
85
self.reconstruction_loss_tracker = keras.metrics.Mean(
86
name="reconstruction_loss"
87
)
88
self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")
89
90
@property
91
def metrics(self):
92
return [
93
self.total_loss_tracker,
94
self.reconstruction_loss_tracker,
95
self.kl_loss_tracker,
96
]
97
98
def train_step(self, data):
99
with tf.GradientTape() as tape:
100
z_mean, z_log_var, z = self.encoder(data)
101
reconstruction = self.decoder(z)
102
reconstruction_loss = ops.mean(
103
ops.sum(
104
keras.losses.binary_crossentropy(data, reconstruction),
105
axis=(1, 2),
106
)
107
)
108
kl_loss = -0.5 * (1 + z_log_var - ops.square(z_mean) - ops.exp(z_log_var))
109
kl_loss = ops.mean(ops.sum(kl_loss, axis=1))
110
total_loss = reconstruction_loss + kl_loss
111
grads = tape.gradient(total_loss, self.trainable_weights)
112
self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
113
self.total_loss_tracker.update_state(total_loss)
114
self.reconstruction_loss_tracker.update_state(reconstruction_loss)
115
self.kl_loss_tracker.update_state(kl_loss)
116
return {
117
"loss": self.total_loss_tracker.result(),
118
"reconstruction_loss": self.reconstruction_loss_tracker.result(),
119
"kl_loss": self.kl_loss_tracker.result(),
120
}
121
122
123
"""
124
## Train the VAE
125
"""
126
127
(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
128
mnist_digits = np.concatenate([x_train, x_test], axis=0)
129
mnist_digits = np.expand_dims(mnist_digits, -1).astype("float32") / 255
130
131
vae = VAE(encoder, decoder)
132
vae.compile(optimizer=keras.optimizers.Adam())
133
vae.fit(mnist_digits, epochs=30, batch_size=128)
134
135
"""
136
## Display a grid of sampled digits
137
"""
138
139
import matplotlib.pyplot as plt
140
141
142
def plot_latent_space(vae, n=30, figsize=15):
143
# display a n*n 2D manifold of digits
144
digit_size = 28
145
scale = 1.0
146
figure = np.zeros((digit_size * n, digit_size * n))
147
# linearly spaced coordinates corresponding to the 2D plot
148
# of digit classes in the latent space
149
grid_x = np.linspace(-scale, scale, n)
150
grid_y = np.linspace(-scale, scale, n)[::-1]
151
152
for i, yi in enumerate(grid_y):
153
for j, xi in enumerate(grid_x):
154
z_sample = np.array([[xi, yi]])
155
x_decoded = vae.decoder.predict(z_sample, verbose=0)
156
digit = x_decoded[0].reshape(digit_size, digit_size)
157
figure[
158
i * digit_size : (i + 1) * digit_size,
159
j * digit_size : (j + 1) * digit_size,
160
] = digit
161
162
plt.figure(figsize=(figsize, figsize))
163
start_range = digit_size // 2
164
end_range = n * digit_size + start_range
165
pixel_range = np.arange(start_range, end_range, digit_size)
166
sample_range_x = np.round(grid_x, 1)
167
sample_range_y = np.round(grid_y, 1)
168
plt.xticks(pixel_range, sample_range_x)
169
plt.yticks(pixel_range, sample_range_y)
170
plt.xlabel("z[0]")
171
plt.ylabel("z[1]")
172
plt.imshow(figure, cmap="Greys_r")
173
plt.show()
174
175
176
plot_latent_space(vae)
177
178
"""
179
## Display how the latent space clusters different digit classes
180
"""
181
182
183
def plot_label_clusters(vae, data, labels):
184
# display a 2D plot of the digit classes in the latent space
185
z_mean, _, _ = vae.encoder.predict(data, verbose=0)
186
plt.figure(figsize=(12, 10))
187
plt.scatter(z_mean[:, 0], z_mean[:, 1], c=labels)
188
plt.colorbar()
189
plt.xlabel("z[0]")
190
plt.ylabel("z[1]")
191
plt.show()
192
193
194
(x_train, y_train), _ = keras.datasets.mnist.load_data()
195
x_train = np.expand_dims(x_train, -1).astype("float32") / 255
196
197
plot_label_clusters(vae, x_train, y_train)
198
199