Path: blob/master/examples/keras_recipes/bayesian_neural_networks.py
3507 views
"""1Title: Probabilistic Bayesian Neural Networks2Author: [Khalid Salama](https://www.linkedin.com/in/khalid-salama-24403144/)3Date created: 2021/01/154Last modified: 2021/01/155Description: Building probabilistic Bayesian neural network models with TensorFlow Probability.6Accelerator: GPU7"""89"""10## Introduction1112Taking a probabilistic approach to deep learning allows to account for *uncertainty*,13so that models can assign less levels of confidence to incorrect predictions.14Sources of uncertainty can be found in the data, due to measurement error or15noise in the labels, or the model, due to insufficient data availability for16the model to learn effectively.171819This example demonstrates how to build basic probabilistic Bayesian neural networks20to account for these two types of uncertainty.21We use [TensorFlow Probability](https://www.tensorflow.org/probability) library,22which is compatible with Keras API.2324This example requires TensorFlow 2.3 or higher.25You can install Tensorflow Probability using the following command:2627```python28pip install tensorflow-probability29```30"""3132"""33## The dataset3435We use the [Wine Quality](https://archive.ics.uci.edu/ml/datasets/wine+quality)36dataset, which is available in the [TensorFlow Datasets](https://www.tensorflow.org/datasets/catalog/wine_quality).37We use the red wine subset, which contains 4,898 examples.38The dataset has 11numerical physicochemical features of the wine, and the task39is to predict the wine quality, which is a score between 0 and 10.40In this example, we treat this as a regression task.4142You can install TensorFlow Datasets using the following command:4344```python45pip install tensorflow-datasets46```47"""4849"""50## Setup51"""5253import numpy as np54import tensorflow as tf55from tensorflow import keras56from tensorflow.keras import layers57import tensorflow_datasets as tfds58import tensorflow_probability as tfp5960"""61## Create training and evaluation datasets6263Here, we load the `wine_quality` dataset using `tfds.load()`, and we convert64the target feature to float. Then, we shuffle the dataset and split it into65training and test sets. We take the first `train_size` examples as the train66split, and the rest as the test split.67"""686970def get_train_and_test_splits(train_size, batch_size=1):71# We prefetch with a buffer the same size as the dataset because th dataset72# is very small and fits into memory.73dataset = (74tfds.load(name="wine_quality", as_supervised=True, split="train")75.map(lambda x, y: (x, tf.cast(y, tf.float32)))76.prefetch(buffer_size=dataset_size)77.cache()78)79# We shuffle with a buffer the same size as the dataset.80train_dataset = (81dataset.take(train_size).shuffle(buffer_size=train_size).batch(batch_size)82)83test_dataset = dataset.skip(train_size).batch(batch_size)8485return train_dataset, test_dataset868788"""89## Compile, train, and evaluate the model90"""9192hidden_units = [8, 8]93learning_rate = 0.001949596def run_experiment(model, loss, train_dataset, test_dataset):97model.compile(98optimizer=keras.optimizers.RMSprop(learning_rate=learning_rate),99loss=loss,100metrics=[keras.metrics.RootMeanSquaredError()],101)102103print("Start training the model...")104model.fit(train_dataset, epochs=num_epochs, validation_data=test_dataset)105print("Model training finished.")106_, rmse = model.evaluate(train_dataset, verbose=0)107print(f"Train RMSE: {round(rmse, 3)}")108109print("Evaluating model performance...")110_, rmse = model.evaluate(test_dataset, verbose=0)111print(f"Test RMSE: {round(rmse, 3)}")112113114"""115## Create model inputs116"""117118FEATURE_NAMES = [119"fixed acidity",120"volatile acidity",121"citric acid",122"residual sugar",123"chlorides",124"free sulfur dioxide",125"total sulfur dioxide",126"density",127"pH",128"sulphates",129"alcohol",130]131132133def create_model_inputs():134inputs = {}135for feature_name in FEATURE_NAMES:136inputs[feature_name] = layers.Input(137name=feature_name, shape=(1,), dtype=tf.float32138)139return inputs140141142"""143## Experiment 1: standard neural network144145We create a standard deterministic neural network model as a baseline.146"""147148149def create_baseline_model():150inputs = create_model_inputs()151input_values = [value for _, value in sorted(inputs.items())]152features = keras.layers.concatenate(input_values)153features = layers.BatchNormalization()(features)154155# Create hidden layers with deterministic weights using the Dense layer.156for units in hidden_units:157features = layers.Dense(units, activation="sigmoid")(features)158# The output is deterministic: a single point estimate.159outputs = layers.Dense(units=1)(features)160161model = keras.Model(inputs=inputs, outputs=outputs)162return model163164165"""166Let's split the wine dataset into training and test sets, with 85% and 15% of167the examples, respectively.168"""169170dataset_size = 4898171batch_size = 256172train_size = int(dataset_size * 0.85)173train_dataset, test_dataset = get_train_and_test_splits(train_size, batch_size)174175"""176Now let's train the baseline model. We use the `MeanSquaredError`177as the loss function.178"""179180num_epochs = 100181mse_loss = keras.losses.MeanSquaredError()182baseline_model = create_baseline_model()183run_experiment(baseline_model, mse_loss, train_dataset, test_dataset)184185"""186We take a sample from the test set use the model to obtain predictions for them.187Note that since the baseline model is deterministic, we get a single a188*point estimate* prediction for each test example, with no information about the189uncertainty of the model nor the prediction.190"""191192sample = 10193examples, targets = list(test_dataset.unbatch().shuffle(batch_size * 10).batch(sample))[1940195]196197predicted = baseline_model(examples).numpy()198for idx in range(sample):199print(f"Predicted: {round(float(predicted[idx][0]), 1)} - Actual: {targets[idx]}")200201"""202## Experiment 2: Bayesian neural network (BNN)203204The object of the Bayesian approach for modeling neural networks is to capture205the *epistemic uncertainty*, which is uncertainty about the model fitness,206due to limited training data.207208The idea is that, instead of learning specific weight (and bias) *values* in the209neural network, the Bayesian approach learns weight *distributions*210- from which we can sample to produce an output for a given input -211to encode weight uncertainty.212213Thus, we need to define prior and the posterior distributions of these weights,214and the training process is to learn the parameters of these distributions.215"""216217218# Define the prior weight distribution as Normal of mean=0 and stddev=1.219# Note that, in this example, the we prior distribution is not trainable,220# as we fix its parameters.221def prior(kernel_size, bias_size, dtype=None):222n = kernel_size + bias_size223prior_model = keras.Sequential(224[225tfp.layers.DistributionLambda(226lambda t: tfp.distributions.MultivariateNormalDiag(227loc=tf.zeros(n), scale_diag=tf.ones(n)228)229)230]231)232return prior_model233234235# Define variational posterior weight distribution as multivariate Gaussian.236# Note that the learnable parameters for this distribution are the means,237# variances, and covariances.238def posterior(kernel_size, bias_size, dtype=None):239n = kernel_size + bias_size240posterior_model = keras.Sequential(241[242tfp.layers.VariableLayer(243tfp.layers.MultivariateNormalTriL.params_size(n), dtype=dtype244),245tfp.layers.MultivariateNormalTriL(n),246]247)248return posterior_model249250251"""252We use the `tfp.layers.DenseVariational` layer instead of the standard253`keras.layers.Dense` layer in the neural network model.254"""255256257def create_bnn_model(train_size):258inputs = create_model_inputs()259features = keras.layers.concatenate(list(inputs.values()))260features = layers.BatchNormalization()(features)261262# Create hidden layers with weight uncertainty using the DenseVariational layer.263for units in hidden_units:264features = tfp.layers.DenseVariational(265units=units,266make_prior_fn=prior,267make_posterior_fn=posterior,268kl_weight=1 / train_size,269activation="sigmoid",270)(features)271272# The output is deterministic: a single point estimate.273outputs = layers.Dense(units=1)(features)274model = keras.Model(inputs=inputs, outputs=outputs)275return model276277278"""279The epistemic uncertainty can be reduced as we increase the size of the280training data. That is, the more data the BNN model sees, the more it is certain281about its estimates for the weights (distribution parameters).282Let's test this behaviour by training the BNN model on a small subset of283the training set, and then on the full training set, to compare the output variances.284"""285286"""287### Train BNN with a small training subset.288"""289290num_epochs = 500291train_sample_size = int(train_size * 0.3)292small_train_dataset = train_dataset.unbatch().take(train_sample_size).batch(batch_size)293294bnn_model_small = create_bnn_model(train_sample_size)295run_experiment(bnn_model_small, mse_loss, small_train_dataset, test_dataset)296297"""298Since we have trained a BNN model, the model produces a different output each time299we call it with the same input, since each time a new set of weights are sampled300from the distributions to construct the network and produce an output.301The less certain the mode weights are, the more variability (wider range) we will302see in the outputs of the same inputs.303"""304305306def compute_predictions(model, iterations=100):307predicted = []308for _ in range(iterations):309predicted.append(model(examples).numpy())310predicted = np.concatenate(predicted, axis=1)311312prediction_mean = np.mean(predicted, axis=1).tolist()313prediction_min = np.min(predicted, axis=1).tolist()314prediction_max = np.max(predicted, axis=1).tolist()315prediction_range = (np.max(predicted, axis=1) - np.min(predicted, axis=1)).tolist()316317for idx in range(sample):318print(319f"Predictions mean: {round(prediction_mean[idx], 2)}, "320f"min: {round(prediction_min[idx], 2)}, "321f"max: {round(prediction_max[idx], 2)}, "322f"range: {round(prediction_range[idx], 2)} - "323f"Actual: {targets[idx]}"324)325326327compute_predictions(bnn_model_small)328329"""330### Train BNN with the whole training set.331"""332333num_epochs = 500334bnn_model_full = create_bnn_model(train_size)335run_experiment(bnn_model_full, mse_loss, train_dataset, test_dataset)336337compute_predictions(bnn_model_full)338339"""340Notice that the model trained with the full training dataset shows smaller range341(uncertainty) in the prediction values for the same inputs, compared to the model342trained with a subset of the training dataset.343"""344345"""346## Experiment 3: probabilistic Bayesian neural network347348So far, the output of the standard and the Bayesian NN models that we built is349deterministic, that is, produces a point estimate as a prediction for a given example.350We can create a probabilistic NN by letting the model output a distribution.351In this case, the model captures the *aleatoric uncertainty* as well,352which is due to irreducible noise in the data, or to the stochastic nature of the353process generating the data.354355In this example, we model the output as a `IndependentNormal` distribution,356with learnable mean and variance parameters. If the task was classification,357we would have used `IndependentBernoulli` with binary classes, and `OneHotCategorical`358with multiple classes, to model distribution of the model output.359"""360361362def create_probablistic_bnn_model(train_size):363inputs = create_model_inputs()364features = keras.layers.concatenate(list(inputs.values()))365features = layers.BatchNormalization()(features)366367# Create hidden layers with weight uncertainty using the DenseVariational layer.368for units in hidden_units:369features = tfp.layers.DenseVariational(370units=units,371make_prior_fn=prior,372make_posterior_fn=posterior,373kl_weight=1 / train_size,374activation="sigmoid",375)(features)376377# Create a probabilisticå output (Normal distribution), and use the `Dense` layer378# to produce the parameters of the distribution.379# We set units=2 to learn both the mean and the variance of the Normal distribution.380distribution_params = layers.Dense(units=2)(features)381outputs = tfp.layers.IndependentNormal(1)(distribution_params)382383model = keras.Model(inputs=inputs, outputs=outputs)384return model385386387"""388Since the output of the model is a distribution, rather than a point estimate,389we use the [negative loglikelihood](https://en.wikipedia.org/wiki/Likelihood_function)390as our loss function to compute how likely to see the true data (targets) from the391estimated distribution produced by the model.392"""393394395def negative_loglikelihood(targets, estimated_distribution):396return -estimated_distribution.log_prob(targets)397398399num_epochs = 1000400prob_bnn_model = create_probablistic_bnn_model(train_size)401run_experiment(prob_bnn_model, negative_loglikelihood, train_dataset, test_dataset)402403"""404Now let's produce an output from the model given the test examples.405The output is now a distribution, and we can use its mean and variance406to compute the confidence intervals (CI) of the prediction.407"""408409prediction_distribution = prob_bnn_model(examples)410prediction_mean = prediction_distribution.mean().numpy().tolist()411prediction_stdv = prediction_distribution.stddev().numpy()412413# The 95% CI is computed as mean ± (1.96 * stdv)414upper = (prediction_mean + (1.96 * prediction_stdv)).tolist()415lower = (prediction_mean - (1.96 * prediction_stdv)).tolist()416prediction_stdv = prediction_stdv.tolist()417418for idx in range(sample):419print(420f"Prediction mean: {round(prediction_mean[idx][0], 2)}, "421f"stddev: {round(prediction_stdv[idx][0], 2)}, "422f"95% CI: [{round(upper[idx][0], 2)} - {round(lower[idx][0], 2)}]"423f" - Actual: {targets[idx]}"424)425426427