Path: blob/master/examples/keras_recipes/approximating_non_function_mappings.py
3507 views
"""1Title: Approximating non-Function Mappings with Mixture Density Networks2Author: [lukewood](https://twitter.com/luke_wood_ml)3Date created: 2023/07/154Last modified: 2023/07/155Description: Approximate non one to one mapping using mixture density networks.6Accelerator: None7"""89"""10## Approximating NonFunctions1112Neural networks are universal function approximators. Key word: function!13While powerful function approximators, neural networks are not able to14approximate non-functions.15One important characteristic of functions is that they map one input to a16unique output.17Neural networks do not perform well when the training set has multiple values of18Y for a single X.19Instead of learning the proper distribution, a naive neural network will20interpret the problem as a function and learn the geometric mean of all `Y` in21the training set.2223In this guide I'll show you how to approximate the class of non-functions24consisting of mappings from `x -> y` such that multiple `y` may exist for a25given `x`. We'll use a class of neural networks called26"Mixture Density Networks".2728I'm going to use the new29[multibackend Keras V3](https://github.com/keras-team/keras) to30build my Mixture Density networks.31Great job to the Keras team on the project - it's awesome to be able to swap32frameworks in one line of code.3334Some bad news: I use TensorFlow probability in this guide... so it35actually works only with TensorFlow and JAX backends.3637Anyways, let's start by installing dependencies and sorting out imports:38"""39"""shell40pip install -q --upgrade jax tensorflow-probability[jax] keras41"""4243import os4445os.environ["KERAS_BACKEND"] = "jax"4647import numpy as np48import matplotlib.pyplot as plt49import keras50from keras import callbacks, layers, ops51from tensorflow_probability.substrates.jax import distributions as tfd5253"""54Next, lets generate a noisy spiral that we're going to attempt to approximate.55I've defined a few functions below to do this:56"""575859def normalize(x):60return (x - np.min(x)) / (np.max(x) - np.min(x))616263def create_noisy_spiral(n, jitter_std=0.2, revolutions=2):64angle = np.random.uniform(0, 2 * np.pi * revolutions, [n])65r = angle6667x = r * np.cos(angle)68y = r * np.sin(angle)6970result = np.stack([x, y], axis=1)71result = result + np.random.normal(scale=jitter_std, size=[n, 2])72result = 5 * normalize(result)73return result747576"""77Next, lets invoke this function many times to construct a sample dataset:78"""7980xy = create_noisy_spiral(10000)8182x, y = xy[:, 0:1], xy[:, 1:]8384plt.scatter(x, y)85plt.show()8687"""88As you can see, there's multiple possible values for Y with respect to a given89X.90Normal neural networks will simply learn the mean of these points with91respect to geometric space.92In the context of our spiral, however, the geometric mean of the each Y occurs93with a probability of zero.9495We can quickly show this with a simple linear model:96"""9798N_HIDDEN = 12899100model = keras.Sequential(101[102layers.Dense(N_HIDDEN, activation="relu"),103layers.Dense(N_HIDDEN, activation="relu"),104layers.Dense(1),105]106)107108"""109Let's use mean squared error as well as the adam optimizer.110These tend to be reasonable prototyping choices:111"""112113model.compile(optimizer="adam", loss="mse")114115"""116We can fit this model quite easy117"""118119model.fit(120x,121y,122epochs=300,123batch_size=128,124validation_split=0.15,125callbacks=[callbacks.EarlyStopping(monitor="val_loss", patience=10)],126)127128"""129And let's check out the result:130"""131132y_pred = model.predict(x)133134"""135As expected, the model learns the geometric mean of all points in `y` for a136given `x`.137"""138139plt.scatter(x, y)140plt.scatter(x, y_pred)141plt.show()142143"""144145## Mixture Density Networks146147Mixture Density networks can alleviate this problem.148A mixture density is a class of complicated densities expressible in terms of simpler densities.149Effectively, a mixture density is the sum of various probability distributions.150By summing various distributions, mixture densitry distributions can151model arbitrarily complex distributions.152Mixture Density networks learn to parameterize a mixture density distribution153based on a given training set.154155As a practitioner, all you need to know, is that Mixture Density Networks solve156the problem of multiple values of Y for a given X.157I'm hoping to add a tool to your kit- but I'm not going to formally explain the158derivation of Mixture Density networks in this guide.159The most important thing to know is that a Mixture Density network learns to160parameterize a mixture density distribution.161This is done by computing a special loss with respect to both the provided162`y_i` label as well as the predicted distribution for the corresponding `x_i`.163This loss function operates by computing the probability that `y_i` would be164drawn from the predicted mixture distribution.165166Let's implement a Mixture density network.167Below, a ton of helper functions are defined based on an old Keras library168[`Keras Mixture Density Network Layer`](https://github.com/cpmpercussion/keras-mdn-layer).169170I've adapted the code for use with Keras core.171172Lets start writing a Mixture Density Network!173First, we need a special activation function: ELU plus a tiny epsilon.174This helps prevent ELU from outputting 0 which causes NaNs in Mixture Density175Network loss evaluation.176"""177178179def elu_plus_one_plus_epsilon(x):180return keras.activations.elu(x) + 1 + keras.backend.epsilon()181182183"""184Next, lets actually define a MixtureDensity layer that outputs all values needed185to sample from the learned mixture distribution:186"""187188189class MixtureDensityOutput(layers.Layer):190def __init__(self, output_dimension, num_mixtures, **kwargs):191super().__init__(**kwargs)192self.output_dim = output_dimension193self.num_mix = num_mixtures194self.mdn_mus = layers.Dense(195self.num_mix * self.output_dim, name="mdn_mus"196) # mix*output vals, no activation197self.mdn_sigmas = layers.Dense(198self.num_mix * self.output_dim,199activation=elu_plus_one_plus_epsilon,200name="mdn_sigmas",201) # mix*output vals exp activation202self.mdn_pi = layers.Dense(self.num_mix, name="mdn_pi") # mix vals, logits203204def build(self, input_shape):205self.mdn_mus.build(input_shape)206self.mdn_sigmas.build(input_shape)207self.mdn_pi.build(input_shape)208super().build(input_shape)209210@property211def trainable_weights(self):212return (213self.mdn_mus.trainable_weights214+ self.mdn_sigmas.trainable_weights215+ self.mdn_pi.trainable_weights216)217218@property219def non_trainable_weights(self):220return (221self.mdn_mus.non_trainable_weights222+ self.mdn_sigmas.non_trainable_weights223+ self.mdn_pi.non_trainable_weights224)225226def call(self, x, mask=None):227return layers.concatenate(228[self.mdn_mus(x), self.mdn_sigmas(x), self.mdn_pi(x)], name="mdn_outputs"229)230231232"""233Lets construct an Mixture Density Network using our new layer:234"""235236OUTPUT_DIMS = 1237N_MIXES = 20238239mdn_network = keras.Sequential(240[241layers.Dense(N_HIDDEN, activation="relu"),242layers.Dense(N_HIDDEN, activation="relu"),243MixtureDensityOutput(OUTPUT_DIMS, N_MIXES),244]245)246247"""248Next, let's implement a custom loss function to train the Mixture Density249Network layer based on the true values and our expected outputs:250"""251252253def get_mixture_loss_func(output_dim, num_mixes):254def mdn_loss_func(y_true, y_pred):255# Reshape inputs in case this is used in a TimeDistributed layer256y_pred = ops.reshape(y_pred, [-1, (2 * num_mixes * output_dim) + num_mixes])257y_true = ops.reshape(y_true, [-1, output_dim])258# Split the inputs into parameters259out_mu, out_sigma, out_pi = ops.split(y_pred, 3, axis=-1)260# Construct the mixture models261cat = tfd.Categorical(logits=out_pi)262mus = ops.split(out_mu, num_mixes, axis=1)263sigs = ops.split(out_sigma, num_mixes, axis=1)264coll = [265tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale)266for loc, scale in zip(mus, sigs)267]268mixture = tfd.Mixture(cat=cat, components=coll)269loss = mixture.log_prob(y_true)270loss = ops.negative(loss)271loss = ops.mean(loss)272return loss273274return mdn_loss_func275276277mdn_network.compile(loss=get_mixture_loss_func(OUTPUT_DIMS, N_MIXES), optimizer="adam")278279"""280Finally, we can call `model.fit()` like any other Keras model.281"""282283mdn_network.fit(284x,285y,286epochs=300,287batch_size=128,288validation_split=0.15,289callbacks=[290callbacks.EarlyStopping(monitor="loss", patience=10, restore_best_weights=True),291callbacks.ReduceLROnPlateau(monitor="loss", patience=5),292],293)294295"""296Let's make some predictions!297"""298299y_pred_mixture = mdn_network.predict(x)300print(y_pred_mixture.shape)301302"""303The MDN does not output a single value; instead it outputs values to304parameterize a mixture distribution.305To visualize these outputs, lets sample from the distribution.306307Note that sampling is a lossy process.308If you want to preserve all information as part of a greater latent309representation (i.e. for downstream processing) I recommend you simply keep the310distribution parameters in place.311"""312313314def split_mixture_params(params, output_dim, num_mixes):315mus = params[: num_mixes * output_dim]316sigs = params[num_mixes * output_dim : 2 * num_mixes * output_dim]317pi_logits = params[-num_mixes:]318return mus, sigs, pi_logits319320321def softmax(w, t=1.0):322e = np.array(w) / t # adjust temperature323e -= e.max() # subtract max to protect from exploding exp values.324e = np.exp(e)325dist = e / np.sum(e)326return dist327328329def sample_from_categorical(dist):330r = np.random.rand(1) # uniform random number in [0,1]331accumulate = 0332for i in range(0, dist.size):333accumulate += dist[i]334if accumulate >= r:335return i336print("Error sampling categorical model.")337return -1338339340def sample_from_output(params, output_dim, num_mixes, temp=1.0, sigma_temp=1.0):341mus, sigs, pi_logits = split_mixture_params(params, output_dim, num_mixes)342pis = softmax(pi_logits, t=temp)343m = sample_from_categorical(pis)344# Alternative way to sample from categorical:345# m = np.random.choice(range(len(pis)), p=pis)346mus_vector = mus[m * output_dim : (m + 1) * output_dim]347sig_vector = sigs[m * output_dim : (m + 1) * output_dim]348scale_matrix = np.identity(output_dim) * sig_vector # scale matrix from diag349cov_matrix = np.matmul(scale_matrix, scale_matrix.T) # cov is scale squared.350cov_matrix = cov_matrix * sigma_temp # adjust for sigma temperature351sample = np.random.multivariate_normal(mus_vector, cov_matrix, 1)352return sample353354355"""356Next lets use our sampling function:357"""358359# Sample from the predicted distributions360y_samples = np.apply_along_axis(361sample_from_output, 1, y_pred_mixture, 1, N_MIXES, temp=1.0362)363364"""365Finally, we can visualize our network outputs366"""367368plt.scatter(x, y, alpha=0.05, color="blue", label="Ground Truth")369plt.scatter(370x,371y_samples[:, :, 0],372color="green",373alpha=0.05,374label="Mixture Density Network prediction",375)376plt.show()377378"""379Beautiful. Love to see it380381# Conclusions382383Neural Networks are universal function approximators - but they can only384approximate functions. Mixture Density networks can approximate arbitrary385x->y mappings using some neat probability tricks.386387For more examples with `tensorflow_probability`388[start here](https://www.tensorflow.org/probability/examples/Probabilistic_Layers_Regression).389390One more pretty graphic for the road:391"""392393fig, axs = plt.subplots(1, 3)394fig.set_figheight(3)395fig.set_figwidth(12)396axs[0].set_title("Ground Truth")397axs[0].scatter(x, y, alpha=0.05, color="blue")398xlim = axs[0].get_xlim()399ylim = axs[0].get_ylim()400401axs[1].set_title("Normal Model prediction")402axs[1].scatter(x, y_pred, alpha=0.05, color="red")403axs[1].set_xlim(xlim)404axs[1].set_ylim(ylim)405axs[2].scatter(406x,407y_samples[:, :, 0],408color="green",409alpha=0.05,410label="Mixture Density Network prediction",411)412axs[2].set_title("Mixture Density Network prediction")413axs[2].set_xlim(xlim)414axs[2].set_ylim(ylim)415plt.show()416417418