Path: blob/master/examples/structured_data/deep_neural_decision_forests.py
3507 views
"""1Title: Classification with Neural Decision Forests2Author: [Khalid Salama](https://www.linkedin.com/in/khalid-salama-24403144/)3Date created: 2021/01/154Last modified: 2021/01/155Description: How to train differentiable decision trees for end-to-end learning in deep neural networks.6Accelerator: GPU7"""89"""10## Introduction1112This example provides an implementation of the13[Deep Neural Decision Forest](https://ieeexplore.ieee.org/document/7410529)14model introduced by P. Kontschieder et al. for structured data classification.15It demonstrates how to build a stochastic and differentiable decision tree model,16train it end-to-end, and unify decision trees with deep representation learning.1718## The dataset1920This example uses the21[United States Census Income Dataset](https://archive.ics.uci.edu/ml/datasets/census+income)22provided by the23[UC Irvine Machine Learning Repository](https://archive.ics.uci.edu/ml/index.php).24The task is binary classification25to predict whether a person is likely to be making over USD 50,000 a year.2627The dataset includes 48,842 instances with 14 input features (such as age, work class, education, occupation, and so on): 5 numerical features28and 9 categorical features.29"""3031"""32## Setup33"""3435import keras36from keras import layers37from keras.layers import StringLookup38from keras import ops394041from tensorflow import data as tf_data42import numpy as np43import pandas as pd4445import math464748"""49## Prepare the data50"""5152CSV_HEADER = [53"age",54"workclass",55"fnlwgt",56"education",57"education_num",58"marital_status",59"occupation",60"relationship",61"race",62"gender",63"capital_gain",64"capital_loss",65"hours_per_week",66"native_country",67"income_bracket",68]6970train_data_url = (71"https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data"72)73train_data = pd.read_csv(train_data_url, header=None, names=CSV_HEADER)7475test_data_url = (76"https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test"77)78test_data = pd.read_csv(test_data_url, header=None, names=CSV_HEADER)7980print(f"Train dataset shape: {train_data.shape}")81print(f"Test dataset shape: {test_data.shape}")8283"""84Remove the first record (because it is not a valid data example) and a trailing85'dot' in the class labels.86"""8788test_data = test_data[1:]89test_data.income_bracket = test_data.income_bracket.apply(90lambda value: value.replace(".", "")91)9293"""94We store the training and test data splits locally as CSV files.95"""9697train_data_file = "train_data.csv"98test_data_file = "test_data.csv"99100train_data.to_csv(train_data_file, index=False, header=False)101test_data.to_csv(test_data_file, index=False, header=False)102103"""104## Define dataset metadata105106Here, we define the metadata of the dataset that will be useful for reading and parsing107and encoding input features.108"""109110# A list of the numerical feature names.111NUMERIC_FEATURE_NAMES = [112"age",113"education_num",114"capital_gain",115"capital_loss",116"hours_per_week",117]118# A dictionary of the categorical features and their vocabulary.119CATEGORICAL_FEATURES_WITH_VOCABULARY = {120"workclass": sorted(list(train_data["workclass"].unique())),121"education": sorted(list(train_data["education"].unique())),122"marital_status": sorted(list(train_data["marital_status"].unique())),123"occupation": sorted(list(train_data["occupation"].unique())),124"relationship": sorted(list(train_data["relationship"].unique())),125"race": sorted(list(train_data["race"].unique())),126"gender": sorted(list(train_data["gender"].unique())),127"native_country": sorted(list(train_data["native_country"].unique())),128}129# A list of the columns to ignore from the dataset.130IGNORE_COLUMN_NAMES = ["fnlwgt"]131# A list of the categorical feature names.132CATEGORICAL_FEATURE_NAMES = list(CATEGORICAL_FEATURES_WITH_VOCABULARY.keys())133# A list of all the input features.134FEATURE_NAMES = NUMERIC_FEATURE_NAMES + CATEGORICAL_FEATURE_NAMES135# A list of column default values for each feature.136COLUMN_DEFAULTS = [137[0.0] if feature_name in NUMERIC_FEATURE_NAMES + IGNORE_COLUMN_NAMES else ["NA"]138for feature_name in CSV_HEADER139]140# The name of the target feature.141TARGET_FEATURE_NAME = "income_bracket"142# A list of the labels of the target features.143TARGET_LABELS = [" <=50K", " >50K"]144145"""146## Create `tf_data.Dataset` objects for training and validation147148We create an input function to read and parse the file, and convert features and labels149into a [`tf_data.Dataset`](https://www.tensorflow.org/guide/datasets)150for training and validation. We also preprocess the input by mapping the target label151to an index.152"""153154155target_label_lookup = StringLookup(156vocabulary=TARGET_LABELS, mask_token=None, num_oov_indices=0157)158159160lookup_dict = {}161for feature_name in CATEGORICAL_FEATURE_NAMES:162vocabulary = CATEGORICAL_FEATURES_WITH_VOCABULARY[feature_name]163# Create a lookup to convert a string values to an integer indices.164# Since we are not using a mask token, nor expecting any out of vocabulary165# (oov) token, we set mask_token to None and num_oov_indices to 0.166lookup = StringLookup(vocabulary=vocabulary, mask_token=None, num_oov_indices=0)167lookup_dict[feature_name] = lookup168169170def encode_categorical(batch_x, batch_y):171for feature_name in CATEGORICAL_FEATURE_NAMES:172batch_x[feature_name] = lookup_dict[feature_name](batch_x[feature_name])173174return batch_x, batch_y175176177def get_dataset_from_csv(csv_file_path, shuffle=False, batch_size=128):178dataset = (179tf_data.experimental.make_csv_dataset(180csv_file_path,181batch_size=batch_size,182column_names=CSV_HEADER,183column_defaults=COLUMN_DEFAULTS,184label_name=TARGET_FEATURE_NAME,185num_epochs=1,186header=False,187na_value="?",188shuffle=shuffle,189)190.map(lambda features, target: (features, target_label_lookup(target)))191.map(encode_categorical)192)193194return dataset.cache()195196197"""198## Create model inputs199"""200201202def create_model_inputs():203inputs = {}204for feature_name in FEATURE_NAMES:205if feature_name in NUMERIC_FEATURE_NAMES:206inputs[feature_name] = layers.Input(207name=feature_name, shape=(), dtype="float32"208)209else:210inputs[feature_name] = layers.Input(211name=feature_name, shape=(), dtype="int32"212)213return inputs214215216"""217## Encode input features218"""219220221def encode_inputs(inputs):222encoded_features = []223for feature_name in inputs:224if feature_name in CATEGORICAL_FEATURE_NAMES:225vocabulary = CATEGORICAL_FEATURES_WITH_VOCABULARY[feature_name]226# Create a lookup to convert a string values to an integer indices.227# Since we are not using a mask token, nor expecting any out of vocabulary228# (oov) token, we set mask_token to None and num_oov_indices to 0.229value_index = inputs[feature_name]230embedding_dims = int(math.sqrt(lookup.vocabulary_size()))231# Create an embedding layer with the specified dimensions.232embedding = layers.Embedding(233input_dim=lookup.vocabulary_size(), output_dim=embedding_dims234)235# Convert the index values to embedding representations.236encoded_feature = embedding(value_index)237else:238# Use the numerical features as-is.239encoded_feature = inputs[feature_name]240if inputs[feature_name].shape[-1] is None:241encoded_feature = keras.ops.expand_dims(encoded_feature, -1)242243encoded_features.append(encoded_feature)244245encoded_features = layers.concatenate(encoded_features)246return encoded_features247248249"""250## Deep Neural Decision Tree251252A neural decision tree model has two sets of weights to learn. The first set is `pi`,253which represents the probability distribution of the classes in the tree leaves.254The second set is the weights of the routing layer `decision_fn`, which represents the probability255of going to each leave. The forward pass of the model works as follows:2562571. The model expects input `features` as a single vector encoding all the features of an instance258in the batch. This vector can be generated from a Convolution Neural Network (CNN) applied to images259or dense transformations applied to structured data features.2602. The model first applies a `used_features_mask` to randomly select a subset of input features to use.2613. Then, the model computes the probabilities (`mu`) for the input instances to reach the tree leaves262by iteratively performing a *stochastic* routing throughout the tree levels.2634. Finally, the probabilities of reaching the leaves are combined by the class probabilities at the264leaves to produce the final `outputs`.265"""266267268class NeuralDecisionTree(keras.Model):269def __init__(self, depth, num_features, used_features_rate, num_classes):270super().__init__()271self.depth = depth272self.num_leaves = 2**depth273self.num_classes = num_classes274275# Create a mask for the randomly selected features.276num_used_features = int(num_features * used_features_rate)277one_hot = np.eye(num_features)278sampled_feature_indices = np.random.choice(279np.arange(num_features), num_used_features, replace=False280)281self.used_features_mask = ops.convert_to_tensor(282one_hot[sampled_feature_indices], dtype="float32"283)284285# Initialize the weights of the classes in leaves.286self.pi = self.add_weight(287initializer="random_normal",288shape=[self.num_leaves, self.num_classes],289dtype="float32",290trainable=True,291)292293# Initialize the stochastic routing layer.294self.decision_fn = layers.Dense(295units=self.num_leaves, activation="sigmoid", name="decision"296)297298def call(self, features):299batch_size = ops.shape(features)[0]300301# Apply the feature mask to the input features.302features = ops.matmul(303features, ops.transpose(self.used_features_mask)304) # [batch_size, num_used_features]305# Compute the routing probabilities.306decisions = ops.expand_dims(307self.decision_fn(features), axis=2308) # [batch_size, num_leaves, 1]309# Concatenate the routing probabilities with their complements.310decisions = layers.concatenate(311[decisions, 1 - decisions], axis=2312) # [batch_size, num_leaves, 2]313314mu = ops.ones([batch_size, 1, 1])315316begin_idx = 1317end_idx = 2318# Traverse the tree in breadth-first order.319for level in range(self.depth):320mu = ops.reshape(mu, [batch_size, -1, 1]) # [batch_size, 2 ** level, 1]321mu = ops.tile(mu, (1, 1, 2)) # [batch_size, 2 ** level, 2]322level_decisions = decisions[323:, begin_idx:end_idx, :324] # [batch_size, 2 ** level, 2]325mu = mu * level_decisions # [batch_size, 2**level, 2]326begin_idx = end_idx327end_idx = begin_idx + 2 ** (level + 1)328329mu = ops.reshape(mu, [batch_size, self.num_leaves]) # [batch_size, num_leaves]330probabilities = keras.activations.softmax(self.pi) # [num_leaves, num_classes]331outputs = ops.matmul(mu, probabilities) # [batch_size, num_classes]332return outputs333334335"""336## Deep Neural Decision Forest337338The neural decision forest model consists of a set of neural decision trees that are339trained simultaneously. The output of the forest model is the average outputs of its trees.340"""341342343class NeuralDecisionForest(keras.Model):344def __init__(self, num_trees, depth, num_features, used_features_rate, num_classes):345super().__init__()346self.ensemble = []347# Initialize the ensemble by adding NeuralDecisionTree instances.348# Each tree will have its own randomly selected input features to use.349for _ in range(num_trees):350self.ensemble.append(351NeuralDecisionTree(depth, num_features, used_features_rate, num_classes)352)353354def call(self, inputs):355# Initialize the outputs: a [batch_size, num_classes] matrix of zeros.356batch_size = ops.shape(inputs)[0]357outputs = ops.zeros([batch_size, num_classes])358359# Aggregate the outputs of trees in the ensemble.360for tree in self.ensemble:361outputs += tree(inputs)362# Divide the outputs by the ensemble size to get the average.363outputs /= len(self.ensemble)364return outputs365366367"""368Finally, let's set up the code that will train and evaluate the model.369"""370371learning_rate = 0.01372batch_size = 265373num_epochs = 10374375376def run_experiment(model):377model.compile(378optimizer=keras.optimizers.Adam(learning_rate=learning_rate),379loss=keras.losses.SparseCategoricalCrossentropy(),380metrics=[keras.metrics.SparseCategoricalAccuracy()],381)382383print("Start training the model...")384train_dataset = get_dataset_from_csv(385train_data_file, shuffle=True, batch_size=batch_size386)387388model.fit(train_dataset, epochs=num_epochs)389print("Model training finished")390391print("Evaluating the model on the test data...")392test_dataset = get_dataset_from_csv(test_data_file, batch_size=batch_size)393394_, accuracy = model.evaluate(test_dataset)395print(f"Test accuracy: {round(accuracy * 100, 2)}%")396397398"""399## Experiment 1: train a decision tree model400401In this experiment, we train a single neural decision tree model402where we use all input features.403"""404405num_trees = 10406depth = 10407used_features_rate = 1.0408num_classes = len(TARGET_LABELS)409410411def create_tree_model():412inputs = create_model_inputs()413features = encode_inputs(inputs)414features = layers.BatchNormalization()(features)415num_features = features.shape[1]416417tree = NeuralDecisionTree(depth, num_features, used_features_rate, num_classes)418419outputs = tree(features)420model = keras.Model(inputs=inputs, outputs=outputs)421return model422423424tree_model = create_tree_model()425run_experiment(tree_model)426427428"""429## Experiment 2: train a forest model430431In this experiment, we train a neural decision forest with `num_trees` trees432where each tree uses randomly selected 50% of the input features. You can control the number433of features to be used in each tree by setting the `used_features_rate` variable.434In addition, we set the depth to 5 instead of 10 compared to the previous experiment.435"""436437num_trees = 25438depth = 5439used_features_rate = 0.5440441442def create_forest_model():443inputs = create_model_inputs()444features = encode_inputs(inputs)445features = layers.BatchNormalization()(features)446num_features = features.shape[1]447448forest_model = NeuralDecisionForest(449num_trees, depth, num_features, used_features_rate, num_classes450)451452outputs = forest_model(features)453model = keras.Model(inputs=inputs, outputs=outputs)454return model455456457forest_model = create_forest_model()458459run_experiment(forest_model)460461462