Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/generative/real_nvp.py
3507 views
1
"""
2
Title: Density estimation using Real NVP
3
Authors: [Mandolini Giorgio Maria](https://www.linkedin.com/in/giorgio-maria-mandolini-a2a1b71b4/), [Sanna Daniele](https://www.linkedin.com/in/daniele-sanna-338629bb/), [Zannini Quirini Giorgio](https://www.linkedin.com/in/giorgio-zannini-quirini-16ab181a0/)
4
Date created: 2020/08/10
5
Last modified: 2020/08/10
6
Description: Estimating the density distribution of the "double moon" dataset.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
The aim of this work is to map a simple distribution - which is easy to sample
14
and whose density is simple to estimate - to a more complex one learned from the data.
15
This kind of generative model is also known as "normalizing flow".
16
17
In order to do this, the model is trained via the maximum
18
likelihood principle, using the "change of variable" formula.
19
20
We will use an affine coupling function. We create it such that its inverse, as well as
21
the determinant of the Jacobian, are easy to obtain (more details in the referenced paper).
22
23
**Requirements:**
24
25
* Tensorflow 2.9.1
26
* Tensorflow probability 0.17.0
27
28
**Reference:**
29
30
[Density estimation using Real NVP](https://arxiv.org/abs/1605.08803)
31
"""
32
33
"""
34
## Setup
35
36
"""
37
import tensorflow as tf
38
from tensorflow import keras
39
from tensorflow.keras import layers
40
from tensorflow.keras import regularizers
41
from sklearn.datasets import make_moons
42
import numpy as np
43
import matplotlib.pyplot as plt
44
import tensorflow_probability as tfp
45
46
"""
47
## Load the data
48
"""
49
50
data = make_moons(3000, noise=0.05)[0].astype("float32")
51
norm = layers.Normalization()
52
norm.adapt(data)
53
normalized_data = norm(data)
54
55
"""
56
## Affine coupling layer
57
"""
58
59
# Creating a custom layer with keras API.
60
output_dim = 256
61
reg = 0.01
62
63
64
def Coupling(input_shape):
65
input = keras.layers.Input(shape=input_shape)
66
67
t_layer_1 = keras.layers.Dense(
68
output_dim, activation="relu", kernel_regularizer=regularizers.l2(reg)
69
)(input)
70
t_layer_2 = keras.layers.Dense(
71
output_dim, activation="relu", kernel_regularizer=regularizers.l2(reg)
72
)(t_layer_1)
73
t_layer_3 = keras.layers.Dense(
74
output_dim, activation="relu", kernel_regularizer=regularizers.l2(reg)
75
)(t_layer_2)
76
t_layer_4 = keras.layers.Dense(
77
output_dim, activation="relu", kernel_regularizer=regularizers.l2(reg)
78
)(t_layer_3)
79
t_layer_5 = keras.layers.Dense(
80
input_shape, activation="linear", kernel_regularizer=regularizers.l2(reg)
81
)(t_layer_4)
82
83
s_layer_1 = keras.layers.Dense(
84
output_dim, activation="relu", kernel_regularizer=regularizers.l2(reg)
85
)(input)
86
s_layer_2 = keras.layers.Dense(
87
output_dim, activation="relu", kernel_regularizer=regularizers.l2(reg)
88
)(s_layer_1)
89
s_layer_3 = keras.layers.Dense(
90
output_dim, activation="relu", kernel_regularizer=regularizers.l2(reg)
91
)(s_layer_2)
92
s_layer_4 = keras.layers.Dense(
93
output_dim, activation="relu", kernel_regularizer=regularizers.l2(reg)
94
)(s_layer_3)
95
s_layer_5 = keras.layers.Dense(
96
input_shape, activation="tanh", kernel_regularizer=regularizers.l2(reg)
97
)(s_layer_4)
98
99
return keras.Model(inputs=input, outputs=[s_layer_5, t_layer_5])
100
101
102
"""
103
## Real NVP
104
"""
105
106
107
class RealNVP(keras.Model):
108
def __init__(self, num_coupling_layers):
109
super().__init__()
110
111
self.num_coupling_layers = num_coupling_layers
112
113
# Distribution of the latent space.
114
self.distribution = tfp.distributions.MultivariateNormalDiag(
115
loc=[0.0, 0.0], scale_diag=[1.0, 1.0]
116
)
117
self.masks = np.array(
118
[[0, 1], [1, 0]] * (num_coupling_layers // 2), dtype="float32"
119
)
120
self.loss_tracker = keras.metrics.Mean(name="loss")
121
self.layers_list = [Coupling(2) for i in range(num_coupling_layers)]
122
123
@property
124
def metrics(self):
125
"""List of the model's metrics.
126
127
We make sure the loss tracker is listed as part of `model.metrics`
128
so that `fit()` and `evaluate()` are able to `reset()` the loss tracker
129
at the start of each epoch and at the start of an `evaluate()` call.
130
"""
131
return [self.loss_tracker]
132
133
def call(self, x, training=True):
134
log_det_inv = 0
135
direction = 1
136
if training:
137
direction = -1
138
for i in range(self.num_coupling_layers)[::direction]:
139
x_masked = x * self.masks[i]
140
reversed_mask = 1 - self.masks[i]
141
s, t = self.layers_list[i](x_masked)
142
s *= reversed_mask
143
t *= reversed_mask
144
gate = (direction - 1) / 2
145
x = (
146
reversed_mask
147
* (x * tf.exp(direction * s) + direction * t * tf.exp(gate * s))
148
+ x_masked
149
)
150
log_det_inv += gate * tf.reduce_sum(s, [1])
151
152
return x, log_det_inv
153
154
# Log likelihood of the normal distribution plus the log determinant of the jacobian.
155
156
def log_loss(self, x):
157
y, logdet = self(x)
158
log_likelihood = self.distribution.log_prob(y) + logdet
159
return -tf.reduce_mean(log_likelihood)
160
161
def train_step(self, data):
162
with tf.GradientTape() as tape:
163
loss = self.log_loss(data)
164
165
g = tape.gradient(loss, self.trainable_variables)
166
self.optimizer.apply_gradients(zip(g, self.trainable_variables))
167
self.loss_tracker.update_state(loss)
168
169
return {"loss": self.loss_tracker.result()}
170
171
def test_step(self, data):
172
loss = self.log_loss(data)
173
self.loss_tracker.update_state(loss)
174
175
return {"loss": self.loss_tracker.result()}
176
177
178
"""
179
## Model training
180
"""
181
182
model = RealNVP(num_coupling_layers=6)
183
184
model.compile(optimizer=keras.optimizers.Adam(learning_rate=0.0001))
185
186
history = model.fit(
187
normalized_data, batch_size=256, epochs=300, verbose=2, validation_split=0.2
188
)
189
190
"""
191
## Performance evaluation
192
"""
193
194
plt.figure(figsize=(15, 10))
195
plt.plot(history.history["loss"])
196
plt.plot(history.history["val_loss"])
197
plt.title("model loss")
198
plt.legend(["train", "validation"], loc="upper right")
199
plt.ylabel("loss")
200
plt.xlabel("epoch")
201
202
# From data to latent space.
203
z, _ = model(normalized_data)
204
205
# From latent space to data.
206
samples = model.distribution.sample(3000)
207
x, _ = model.predict(samples)
208
209
f, axes = plt.subplots(2, 2)
210
f.set_size_inches(20, 15)
211
212
axes[0, 0].scatter(normalized_data[:, 0], normalized_data[:, 1], color="r")
213
axes[0, 0].set(title="Inference data space X", xlabel="x", ylabel="y")
214
axes[0, 1].scatter(z[:, 0], z[:, 1], color="r")
215
axes[0, 1].set(title="Inference latent space Z", xlabel="x", ylabel="y")
216
axes[0, 1].set_xlim([-3.5, 4])
217
axes[0, 1].set_ylim([-4, 4])
218
axes[1, 0].scatter(samples[:, 0], samples[:, 1], color="g")
219
axes[1, 0].set(title="Generated latent space Z", xlabel="x", ylabel="y")
220
axes[1, 1].scatter(x[:, 0], x[:, 1], color="g")
221
axes[1, 1].set(title="Generated data space X", label="x", ylabel="y")
222
axes[1, 1].set_xlim([-2, 2])
223
axes[1, 1].set_ylim([-2, 2])
224
225