Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/keras_recipes/approximating_non_function_mappings.py
3507 views
1
"""
2
Title: Approximating non-Function Mappings with Mixture Density Networks
3
Author: [lukewood](https://twitter.com/luke_wood_ml)
4
Date created: 2023/07/15
5
Last modified: 2023/07/15
6
Description: Approximate non one to one mapping using mixture density networks.
7
Accelerator: None
8
"""
9
10
"""
11
## Approximating NonFunctions
12
13
Neural networks are universal function approximators. Key word: function!
14
While powerful function approximators, neural networks are not able to
15
approximate non-functions.
16
One important characteristic of functions is that they map one input to a
17
unique output.
18
Neural networks do not perform well when the training set has multiple values of
19
Y for a single X.
20
Instead of learning the proper distribution, a naive neural network will
21
interpret the problem as a function and learn the geometric mean of all `Y` in
22
the training set.
23
24
In this guide I'll show you how to approximate the class of non-functions
25
consisting of mappings from `x -> y` such that multiple `y` may exist for a
26
given `x`. We'll use a class of neural networks called
27
"Mixture Density Networks".
28
29
I'm going to use the new
30
[multibackend Keras V3](https://github.com/keras-team/keras) to
31
build my Mixture Density networks.
32
Great job to the Keras team on the project - it's awesome to be able to swap
33
frameworks in one line of code.
34
35
Some bad news: I use TensorFlow probability in this guide... so it
36
actually works only with TensorFlow and JAX backends.
37
38
Anyways, let's start by installing dependencies and sorting out imports:
39
"""
40
"""shell
41
pip install -q --upgrade jax tensorflow-probability[jax] keras
42
"""
43
44
import os
45
46
os.environ["KERAS_BACKEND"] = "jax"
47
48
import numpy as np
49
import matplotlib.pyplot as plt
50
import keras
51
from keras import callbacks, layers, ops
52
from tensorflow_probability.substrates.jax import distributions as tfd
53
54
"""
55
Next, lets generate a noisy spiral that we're going to attempt to approximate.
56
I've defined a few functions below to do this:
57
"""
58
59
60
def normalize(x):
61
return (x - np.min(x)) / (np.max(x) - np.min(x))
62
63
64
def create_noisy_spiral(n, jitter_std=0.2, revolutions=2):
65
angle = np.random.uniform(0, 2 * np.pi * revolutions, [n])
66
r = angle
67
68
x = r * np.cos(angle)
69
y = r * np.sin(angle)
70
71
result = np.stack([x, y], axis=1)
72
result = result + np.random.normal(scale=jitter_std, size=[n, 2])
73
result = 5 * normalize(result)
74
return result
75
76
77
"""
78
Next, lets invoke this function many times to construct a sample dataset:
79
"""
80
81
xy = create_noisy_spiral(10000)
82
83
x, y = xy[:, 0:1], xy[:, 1:]
84
85
plt.scatter(x, y)
86
plt.show()
87
88
"""
89
As you can see, there's multiple possible values for Y with respect to a given
90
X.
91
Normal neural networks will simply learn the mean of these points with
92
respect to geometric space.
93
In the context of our spiral, however, the geometric mean of the each Y occurs
94
with a probability of zero.
95
96
We can quickly show this with a simple linear model:
97
"""
98
99
N_HIDDEN = 128
100
101
model = keras.Sequential(
102
[
103
layers.Dense(N_HIDDEN, activation="relu"),
104
layers.Dense(N_HIDDEN, activation="relu"),
105
layers.Dense(1),
106
]
107
)
108
109
"""
110
Let's use mean squared error as well as the adam optimizer.
111
These tend to be reasonable prototyping choices:
112
"""
113
114
model.compile(optimizer="adam", loss="mse")
115
116
"""
117
We can fit this model quite easy
118
"""
119
120
model.fit(
121
x,
122
y,
123
epochs=300,
124
batch_size=128,
125
validation_split=0.15,
126
callbacks=[callbacks.EarlyStopping(monitor="val_loss", patience=10)],
127
)
128
129
"""
130
And let's check out the result:
131
"""
132
133
y_pred = model.predict(x)
134
135
"""
136
As expected, the model learns the geometric mean of all points in `y` for a
137
given `x`.
138
"""
139
140
plt.scatter(x, y)
141
plt.scatter(x, y_pred)
142
plt.show()
143
144
"""
145
146
## Mixture Density Networks
147
148
Mixture Density networks can alleviate this problem.
149
A mixture density is a class of complicated densities expressible in terms of simpler densities.
150
Effectively, a mixture density is the sum of various probability distributions.
151
By summing various distributions, mixture densitry distributions can
152
model arbitrarily complex distributions.
153
Mixture Density networks learn to parameterize a mixture density distribution
154
based on a given training set.
155
156
As a practitioner, all you need to know, is that Mixture Density Networks solve
157
the problem of multiple values of Y for a given X.
158
I'm hoping to add a tool to your kit- but I'm not going to formally explain the
159
derivation of Mixture Density networks in this guide.
160
The most important thing to know is that a Mixture Density network learns to
161
parameterize a mixture density distribution.
162
This is done by computing a special loss with respect to both the provided
163
`y_i` label as well as the predicted distribution for the corresponding `x_i`.
164
This loss function operates by computing the probability that `y_i` would be
165
drawn from the predicted mixture distribution.
166
167
Let's implement a Mixture density network.
168
Below, a ton of helper functions are defined based on an old Keras library
169
[`Keras Mixture Density Network Layer`](https://github.com/cpmpercussion/keras-mdn-layer).
170
171
I've adapted the code for use with Keras core.
172
173
Lets start writing a Mixture Density Network!
174
First, we need a special activation function: ELU plus a tiny epsilon.
175
This helps prevent ELU from outputting 0 which causes NaNs in Mixture Density
176
Network loss evaluation.
177
"""
178
179
180
def elu_plus_one_plus_epsilon(x):
181
return keras.activations.elu(x) + 1 + keras.backend.epsilon()
182
183
184
"""
185
Next, lets actually define a MixtureDensity layer that outputs all values needed
186
to sample from the learned mixture distribution:
187
"""
188
189
190
class MixtureDensityOutput(layers.Layer):
191
def __init__(self, output_dimension, num_mixtures, **kwargs):
192
super().__init__(**kwargs)
193
self.output_dim = output_dimension
194
self.num_mix = num_mixtures
195
self.mdn_mus = layers.Dense(
196
self.num_mix * self.output_dim, name="mdn_mus"
197
) # mix*output vals, no activation
198
self.mdn_sigmas = layers.Dense(
199
self.num_mix * self.output_dim,
200
activation=elu_plus_one_plus_epsilon,
201
name="mdn_sigmas",
202
) # mix*output vals exp activation
203
self.mdn_pi = layers.Dense(self.num_mix, name="mdn_pi") # mix vals, logits
204
205
def build(self, input_shape):
206
self.mdn_mus.build(input_shape)
207
self.mdn_sigmas.build(input_shape)
208
self.mdn_pi.build(input_shape)
209
super().build(input_shape)
210
211
@property
212
def trainable_weights(self):
213
return (
214
self.mdn_mus.trainable_weights
215
+ self.mdn_sigmas.trainable_weights
216
+ self.mdn_pi.trainable_weights
217
)
218
219
@property
220
def non_trainable_weights(self):
221
return (
222
self.mdn_mus.non_trainable_weights
223
+ self.mdn_sigmas.non_trainable_weights
224
+ self.mdn_pi.non_trainable_weights
225
)
226
227
def call(self, x, mask=None):
228
return layers.concatenate(
229
[self.mdn_mus(x), self.mdn_sigmas(x), self.mdn_pi(x)], name="mdn_outputs"
230
)
231
232
233
"""
234
Lets construct an Mixture Density Network using our new layer:
235
"""
236
237
OUTPUT_DIMS = 1
238
N_MIXES = 20
239
240
mdn_network = keras.Sequential(
241
[
242
layers.Dense(N_HIDDEN, activation="relu"),
243
layers.Dense(N_HIDDEN, activation="relu"),
244
MixtureDensityOutput(OUTPUT_DIMS, N_MIXES),
245
]
246
)
247
248
"""
249
Next, let's implement a custom loss function to train the Mixture Density
250
Network layer based on the true values and our expected outputs:
251
"""
252
253
254
def get_mixture_loss_func(output_dim, num_mixes):
255
def mdn_loss_func(y_true, y_pred):
256
# Reshape inputs in case this is used in a TimeDistributed layer
257
y_pred = ops.reshape(y_pred, [-1, (2 * num_mixes * output_dim) + num_mixes])
258
y_true = ops.reshape(y_true, [-1, output_dim])
259
# Split the inputs into parameters
260
out_mu, out_sigma, out_pi = ops.split(y_pred, 3, axis=-1)
261
# Construct the mixture models
262
cat = tfd.Categorical(logits=out_pi)
263
mus = ops.split(out_mu, num_mixes, axis=1)
264
sigs = ops.split(out_sigma, num_mixes, axis=1)
265
coll = [
266
tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale)
267
for loc, scale in zip(mus, sigs)
268
]
269
mixture = tfd.Mixture(cat=cat, components=coll)
270
loss = mixture.log_prob(y_true)
271
loss = ops.negative(loss)
272
loss = ops.mean(loss)
273
return loss
274
275
return mdn_loss_func
276
277
278
mdn_network.compile(loss=get_mixture_loss_func(OUTPUT_DIMS, N_MIXES), optimizer="adam")
279
280
"""
281
Finally, we can call `model.fit()` like any other Keras model.
282
"""
283
284
mdn_network.fit(
285
x,
286
y,
287
epochs=300,
288
batch_size=128,
289
validation_split=0.15,
290
callbacks=[
291
callbacks.EarlyStopping(monitor="loss", patience=10, restore_best_weights=True),
292
callbacks.ReduceLROnPlateau(monitor="loss", patience=5),
293
],
294
)
295
296
"""
297
Let's make some predictions!
298
"""
299
300
y_pred_mixture = mdn_network.predict(x)
301
print(y_pred_mixture.shape)
302
303
"""
304
The MDN does not output a single value; instead it outputs values to
305
parameterize a mixture distribution.
306
To visualize these outputs, lets sample from the distribution.
307
308
Note that sampling is a lossy process.
309
If you want to preserve all information as part of a greater latent
310
representation (i.e. for downstream processing) I recommend you simply keep the
311
distribution parameters in place.
312
"""
313
314
315
def split_mixture_params(params, output_dim, num_mixes):
316
mus = params[: num_mixes * output_dim]
317
sigs = params[num_mixes * output_dim : 2 * num_mixes * output_dim]
318
pi_logits = params[-num_mixes:]
319
return mus, sigs, pi_logits
320
321
322
def softmax(w, t=1.0):
323
e = np.array(w) / t # adjust temperature
324
e -= e.max() # subtract max to protect from exploding exp values.
325
e = np.exp(e)
326
dist = e / np.sum(e)
327
return dist
328
329
330
def sample_from_categorical(dist):
331
r = np.random.rand(1) # uniform random number in [0,1]
332
accumulate = 0
333
for i in range(0, dist.size):
334
accumulate += dist[i]
335
if accumulate >= r:
336
return i
337
print("Error sampling categorical model.")
338
return -1
339
340
341
def sample_from_output(params, output_dim, num_mixes, temp=1.0, sigma_temp=1.0):
342
mus, sigs, pi_logits = split_mixture_params(params, output_dim, num_mixes)
343
pis = softmax(pi_logits, t=temp)
344
m = sample_from_categorical(pis)
345
# Alternative way to sample from categorical:
346
# m = np.random.choice(range(len(pis)), p=pis)
347
mus_vector = mus[m * output_dim : (m + 1) * output_dim]
348
sig_vector = sigs[m * output_dim : (m + 1) * output_dim]
349
scale_matrix = np.identity(output_dim) * sig_vector # scale matrix from diag
350
cov_matrix = np.matmul(scale_matrix, scale_matrix.T) # cov is scale squared.
351
cov_matrix = cov_matrix * sigma_temp # adjust for sigma temperature
352
sample = np.random.multivariate_normal(mus_vector, cov_matrix, 1)
353
return sample
354
355
356
"""
357
Next lets use our sampling function:
358
"""
359
360
# Sample from the predicted distributions
361
y_samples = np.apply_along_axis(
362
sample_from_output, 1, y_pred_mixture, 1, N_MIXES, temp=1.0
363
)
364
365
"""
366
Finally, we can visualize our network outputs
367
"""
368
369
plt.scatter(x, y, alpha=0.05, color="blue", label="Ground Truth")
370
plt.scatter(
371
x,
372
y_samples[:, :, 0],
373
color="green",
374
alpha=0.05,
375
label="Mixture Density Network prediction",
376
)
377
plt.show()
378
379
"""
380
Beautiful. Love to see it
381
382
# Conclusions
383
384
Neural Networks are universal function approximators - but they can only
385
approximate functions. Mixture Density networks can approximate arbitrary
386
x->y mappings using some neat probability tricks.
387
388
For more examples with `tensorflow_probability`
389
[start here](https://www.tensorflow.org/probability/examples/Probabilistic_Layers_Regression).
390
391
One more pretty graphic for the road:
392
"""
393
394
fig, axs = plt.subplots(1, 3)
395
fig.set_figheight(3)
396
fig.set_figwidth(12)
397
axs[0].set_title("Ground Truth")
398
axs[0].scatter(x, y, alpha=0.05, color="blue")
399
xlim = axs[0].get_xlim()
400
ylim = axs[0].get_ylim()
401
402
axs[1].set_title("Normal Model prediction")
403
axs[1].scatter(x, y_pred, alpha=0.05, color="red")
404
axs[1].set_xlim(xlim)
405
axs[1].set_ylim(ylim)
406
axs[2].scatter(
407
x,
408
y_samples[:, :, 0],
409
color="green",
410
alpha=0.05,
411
label="Mixture Density Network prediction",
412
)
413
axs[2].set_title("Mixture Density Network prediction")
414
axs[2].set_xlim(xlim)
415
axs[2].set_ylim(ylim)
416
plt.show()
417
418