Path: blob/master/examples/structured_data/tabtransformer.py
3507 views
"""1Title: Structured data learning with TabTransformer2Author: [Khalid Salama](https://www.linkedin.com/in/khalid-salama-24403144/)3Date created: 2022/01/184Last modified: 2022/01/185Description: Using contextual embeddings for structured data classification.6Accelerator: GPU7"""89"""10## Introduction1112This example demonstrates how to do structured data classification using13[TabTransformer](https://arxiv.org/abs/2012.06678), a deep tabular data modeling14architecture for supervised and semi-supervised learning.15The TabTransformer is built upon self-attention based Transformers.16The Transformer layers transform the embeddings of categorical features17into robust contextual embeddings to achieve higher predictive accuracy.18192021## Setup22"""23import keras24from keras import layers25from keras import ops2627import math28import numpy as np29import pandas as pd30from tensorflow import data as tf_data31import matplotlib.pyplot as plt32from functools import partial3334"""35## Prepare the data3637This example uses the38[United States Census Income Dataset](https://archive.ics.uci.edu/ml/datasets/census+income)39provided by the40[UC Irvine Machine Learning Repository](https://archive.ics.uci.edu/ml/index.php).41The task is binary classification42to predict whether a person is likely to be making over USD 50,000 a year.4344The dataset includes 48,842 instances with 14 input features: 5 numerical features and 9 categorical features.4546First, let's load the dataset from the UCI Machine Learning Repository into a Pandas47DataFrame:48"""4950CSV_HEADER = [51"age",52"workclass",53"fnlwgt",54"education",55"education_num",56"marital_status",57"occupation",58"relationship",59"race",60"gender",61"capital_gain",62"capital_loss",63"hours_per_week",64"native_country",65"income_bracket",66]6768train_data_url = (69"https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data"70)71train_data = pd.read_csv(train_data_url, header=None, names=CSV_HEADER)7273test_data_url = (74"https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test"75)76test_data = pd.read_csv(test_data_url, header=None, names=CSV_HEADER)7778print(f"Train dataset shape: {train_data.shape}")79print(f"Test dataset shape: {test_data.shape}")8081"""82Remove the first record (because it is not a valid data example) and a trailing 'dot' in the class labels.83"""8485test_data = test_data[1:]86test_data.income_bracket = test_data.income_bracket.apply(87lambda value: value.replace(".", "")88)8990"""91Now we store the training and test data in separate CSV files.92"""9394train_data_file = "train_data.csv"95test_data_file = "test_data.csv"9697train_data.to_csv(train_data_file, index=False, header=False)98test_data.to_csv(test_data_file, index=False, header=False)99100"""101## Define dataset metadata102103Here, we define the metadata of the dataset that will be useful for reading and parsing104the data into input features, and encoding the input features with respect to their types.105"""106107# A list of the numerical feature names.108NUMERIC_FEATURE_NAMES = [109"age",110"education_num",111"capital_gain",112"capital_loss",113"hours_per_week",114]115# A dictionary of the categorical features and their vocabulary.116CATEGORICAL_FEATURES_WITH_VOCABULARY = {117"workclass": sorted(list(train_data["workclass"].unique())),118"education": sorted(list(train_data["education"].unique())),119"marital_status": sorted(list(train_data["marital_status"].unique())),120"occupation": sorted(list(train_data["occupation"].unique())),121"relationship": sorted(list(train_data["relationship"].unique())),122"race": sorted(list(train_data["race"].unique())),123"gender": sorted(list(train_data["gender"].unique())),124"native_country": sorted(list(train_data["native_country"].unique())),125}126# Name of the column to be used as instances weight.127WEIGHT_COLUMN_NAME = "fnlwgt"128# A list of the categorical feature names.129CATEGORICAL_FEATURE_NAMES = list(CATEGORICAL_FEATURES_WITH_VOCABULARY.keys())130# A list of all the input features.131FEATURE_NAMES = NUMERIC_FEATURE_NAMES + CATEGORICAL_FEATURE_NAMES132# A list of column default values for each feature.133COLUMN_DEFAULTS = [134[0.0] if feature_name in NUMERIC_FEATURE_NAMES + [WEIGHT_COLUMN_NAME] else ["NA"]135for feature_name in CSV_HEADER136]137# The name of the target feature.138TARGET_FEATURE_NAME = "income_bracket"139# A list of the labels of the target features.140TARGET_LABELS = [" <=50K", " >50K"]141142"""143## Configure the hyperparameters144145The hyperparameters includes model architecture and training configurations.146"""147148LEARNING_RATE = 0.001149WEIGHT_DECAY = 0.0001150DROPOUT_RATE = 0.2151BATCH_SIZE = 265152NUM_EPOCHS = 15153154NUM_TRANSFORMER_BLOCKS = 3 # Number of transformer blocks.155NUM_HEADS = 4 # Number of attention heads.156EMBEDDING_DIMS = 16 # Embedding dimensions of the categorical features.157MLP_HIDDEN_UNITS_FACTORS = [1582,1591,160] # MLP hidden layer units, as factors of the number of inputs.161NUM_MLP_BLOCKS = 2 # Number of MLP blocks in the baseline model.162163"""164## Implement data reading pipeline165166We define an input function that reads and parses the file, then converts features167and labels into a[`tf.data.Dataset`](https://www.tensorflow.org/guide/datasets)168for training or evaluation.169"""170171target_label_lookup = layers.StringLookup(172vocabulary=TARGET_LABELS, mask_token=None, num_oov_indices=0173)174175176def prepare_example(features, target):177target_index = target_label_lookup(target)178weights = features.pop(WEIGHT_COLUMN_NAME)179return features, target_index, weights180181182lookup_dict = {}183for feature_name in CATEGORICAL_FEATURE_NAMES:184vocabulary = CATEGORICAL_FEATURES_WITH_VOCABULARY[feature_name]185# Create a lookup to convert a string values to an integer indices.186# Since we are not using a mask token, nor expecting any out of vocabulary187# (oov) token, we set mask_token to None and num_oov_indices to 0.188lookup = layers.StringLookup(189vocabulary=vocabulary, mask_token=None, num_oov_indices=0190)191lookup_dict[feature_name] = lookup192193194def encode_categorical(batch_x, batch_y, weights):195for feature_name in CATEGORICAL_FEATURE_NAMES:196batch_x[feature_name] = lookup_dict[feature_name](batch_x[feature_name])197198return batch_x, batch_y, weights199200201def get_dataset_from_csv(csv_file_path, batch_size=128, shuffle=False):202dataset = (203tf_data.experimental.make_csv_dataset(204csv_file_path,205batch_size=batch_size,206column_names=CSV_HEADER,207column_defaults=COLUMN_DEFAULTS,208label_name=TARGET_FEATURE_NAME,209num_epochs=1,210header=False,211na_value="?",212shuffle=shuffle,213)214.map(prepare_example, num_parallel_calls=tf_data.AUTOTUNE, deterministic=False)215.map(encode_categorical)216)217return dataset.cache()218219220"""221## Implement a training and evaluation procedure222"""223224225def run_experiment(226model,227train_data_file,228test_data_file,229num_epochs,230learning_rate,231weight_decay,232batch_size,233):234optimizer = keras.optimizers.AdamW(235learning_rate=learning_rate, weight_decay=weight_decay236)237238model.compile(239optimizer=optimizer,240loss=keras.losses.BinaryCrossentropy(),241metrics=[keras.metrics.BinaryAccuracy(name="accuracy")],242)243244train_dataset = get_dataset_from_csv(train_data_file, batch_size, shuffle=True)245validation_dataset = get_dataset_from_csv(test_data_file, batch_size)246247print("Start training the model...")248history = model.fit(249train_dataset, epochs=num_epochs, validation_data=validation_dataset250)251print("Model training finished")252253_, accuracy = model.evaluate(validation_dataset, verbose=0)254255print(f"Validation accuracy: {round(accuracy * 100, 2)}%")256257return history258259260"""261## Create model inputs262263Now, define the inputs for the models as a dictionary, where the key is the feature name,264and the value is a `keras.layers.Input` tensor with the corresponding feature shape265and data type.266"""267268269def create_model_inputs():270inputs = {}271for feature_name in FEATURE_NAMES:272if feature_name in NUMERIC_FEATURE_NAMES:273inputs[feature_name] = layers.Input(274name=feature_name, shape=(), dtype="float32"275)276else:277inputs[feature_name] = layers.Input(278name=feature_name, shape=(), dtype="int32"279)280return inputs281282283"""284## Encode features285286The `encode_inputs` method returns `encoded_categorical_feature_list` and `numerical_feature_list`.287We encode the categorical features as embeddings, using a fixed `embedding_dims` for all the features,288regardless their vocabulary sizes. This is required for the Transformer model.289"""290291292def encode_inputs(inputs, embedding_dims):293encoded_categorical_feature_list = []294numerical_feature_list = []295296for feature_name in inputs:297if feature_name in CATEGORICAL_FEATURE_NAMES:298vocabulary = CATEGORICAL_FEATURES_WITH_VOCABULARY[feature_name]299# Create a lookup to convert a string values to an integer indices.300# Since we are not using a mask token, nor expecting any out of vocabulary301# (oov) token, we set mask_token to None and num_oov_indices to 0.302303# Convert the string input values into integer indices.304305# Create an embedding layer with the specified dimensions.306embedding = layers.Embedding(307input_dim=len(vocabulary), output_dim=embedding_dims308)309310# Convert the index values to embedding representations.311encoded_categorical_feature = embedding(inputs[feature_name])312encoded_categorical_feature_list.append(encoded_categorical_feature)313314else:315# Use the numerical features as-is.316numerical_feature = ops.expand_dims(inputs[feature_name], -1)317numerical_feature_list.append(numerical_feature)318319return encoded_categorical_feature_list, numerical_feature_list320321322"""323## Implement an MLP block324"""325326327def create_mlp(hidden_units, dropout_rate, activation, normalization_layer, name=None):328mlp_layers = []329for units in hidden_units:330mlp_layers.append(normalization_layer())331mlp_layers.append(layers.Dense(units, activation=activation))332mlp_layers.append(layers.Dropout(dropout_rate))333334return keras.Sequential(mlp_layers, name=name)335336337"""338## Experiment 1: a baseline model339340In the first experiment, we create a simple multi-layer feed-forward network.341"""342343344def create_baseline_model(345embedding_dims, num_mlp_blocks, mlp_hidden_units_factors, dropout_rate346):347# Create model inputs.348inputs = create_model_inputs()349# encode features.350encoded_categorical_feature_list, numerical_feature_list = encode_inputs(351inputs, embedding_dims352)353# Concatenate all features.354features = layers.concatenate(355encoded_categorical_feature_list + numerical_feature_list356)357# Compute Feedforward layer units.358feedforward_units = [features.shape[-1]]359360# Create several feedforwad layers with skip connections.361for layer_idx in range(num_mlp_blocks):362features = create_mlp(363hidden_units=feedforward_units,364dropout_rate=dropout_rate,365activation=keras.activations.gelu,366normalization_layer=layers.LayerNormalization,367name=f"feedforward_{layer_idx}",368)(features)369370# Compute MLP hidden_units.371mlp_hidden_units = [372factor * features.shape[-1] for factor in mlp_hidden_units_factors373]374# Create final MLP.375features = create_mlp(376hidden_units=mlp_hidden_units,377dropout_rate=dropout_rate,378activation=keras.activations.selu,379normalization_layer=layers.BatchNormalization,380name="MLP",381)(features)382383# Add a sigmoid as a binary classifer.384outputs = layers.Dense(units=1, activation="sigmoid", name="sigmoid")(features)385model = keras.Model(inputs=inputs, outputs=outputs)386return model387388389baseline_model = create_baseline_model(390embedding_dims=EMBEDDING_DIMS,391num_mlp_blocks=NUM_MLP_BLOCKS,392mlp_hidden_units_factors=MLP_HIDDEN_UNITS_FACTORS,393dropout_rate=DROPOUT_RATE,394)395396print("Total model weights:", baseline_model.count_params())397keras.utils.plot_model(baseline_model, show_shapes=True, rankdir="LR")398399"""400Let's train and evaluate the baseline model:401"""402403history = run_experiment(404model=baseline_model,405train_data_file=train_data_file,406test_data_file=test_data_file,407num_epochs=NUM_EPOCHS,408learning_rate=LEARNING_RATE,409weight_decay=WEIGHT_DECAY,410batch_size=BATCH_SIZE,411)412413"""414The baseline linear model achieves ~81% validation accuracy.415"""416417"""418## Experiment 2: TabTransformer419420The TabTransformer architecture works as follows:4214221. All the categorical features are encoded as embeddings, using the same `embedding_dims`.423This means that each value in each categorical feature will have its own embedding vector.4242. A column embedding, one embedding vector for each categorical feature, is added (point-wise) to the categorical feature embedding.4253. The embedded categorical features are fed into a stack of Transformer blocks.426Each Transformer block consists of a multi-head self-attention layer followed by a feed-forward layer.4273. The outputs of the final Transformer layer, which are the *contextual embeddings* of the categorical features,428are concatenated with the input numerical features, and fed into a final MLP block.4294. A `softmax` classifer is applied at the end of the model.430431The [paper](https://arxiv.org/abs/2012.06678) discusses both addition and concatenation of the column embedding in the432*Appendix: Experiment and Model Details* section.433The architecture of TabTransformer is shown below, as presented in the paper.434435<img src="https://raw.githubusercontent.com/keras-team/keras-io/master/examples/structured_data/img/tabtransformer/tabtransformer.png" width="500"/>436"""437438439def create_tabtransformer_classifier(440num_transformer_blocks,441num_heads,442embedding_dims,443mlp_hidden_units_factors,444dropout_rate,445use_column_embedding=False,446):447# Create model inputs.448inputs = create_model_inputs()449# encode features.450encoded_categorical_feature_list, numerical_feature_list = encode_inputs(451inputs, embedding_dims452)453# Stack categorical feature embeddings for the Tansformer.454encoded_categorical_features = ops.stack(encoded_categorical_feature_list, axis=1)455# Concatenate numerical features.456numerical_features = layers.concatenate(numerical_feature_list)457458# Add column embedding to categorical feature embeddings.459if use_column_embedding:460num_columns = encoded_categorical_features.shape[1]461column_embedding = layers.Embedding(462input_dim=num_columns, output_dim=embedding_dims463)464column_indices = ops.arange(start=0, stop=num_columns, step=1)465encoded_categorical_features = encoded_categorical_features + column_embedding(466column_indices467)468469# Create multiple layers of the Transformer block.470for block_idx in range(num_transformer_blocks):471# Create a multi-head attention layer.472attention_output = layers.MultiHeadAttention(473num_heads=num_heads,474key_dim=embedding_dims,475dropout=dropout_rate,476name=f"multihead_attention_{block_idx}",477)(encoded_categorical_features, encoded_categorical_features)478# Skip connection 1.479x = layers.Add(name=f"skip_connection1_{block_idx}")(480[attention_output, encoded_categorical_features]481)482# Layer normalization 1.483x = layers.LayerNormalization(name=f"layer_norm1_{block_idx}", epsilon=1e-6)(x)484# Feedforward.485feedforward_output = create_mlp(486hidden_units=[embedding_dims],487dropout_rate=dropout_rate,488activation=keras.activations.gelu,489normalization_layer=partial(490layers.LayerNormalization, epsilon=1e-6491), # using partial to provide keyword arguments before initialization492name=f"feedforward_{block_idx}",493)(x)494# Skip connection 2.495x = layers.Add(name=f"skip_connection2_{block_idx}")([feedforward_output, x])496# Layer normalization 2.497encoded_categorical_features = layers.LayerNormalization(498name=f"layer_norm2_{block_idx}", epsilon=1e-6499)(x)500501# Flatten the "contextualized" embeddings of the categorical features.502categorical_features = layers.Flatten()(encoded_categorical_features)503# Apply layer normalization to the numerical features.504numerical_features = layers.LayerNormalization(epsilon=1e-6)(numerical_features)505# Prepare the input for the final MLP block.506features = layers.concatenate([categorical_features, numerical_features])507508# Compute MLP hidden_units.509mlp_hidden_units = [510factor * features.shape[-1] for factor in mlp_hidden_units_factors511]512# Create final MLP.513features = create_mlp(514hidden_units=mlp_hidden_units,515dropout_rate=dropout_rate,516activation=keras.activations.selu,517normalization_layer=layers.BatchNormalization,518name="MLP",519)(features)520521# Add a sigmoid as a binary classifer.522outputs = layers.Dense(units=1, activation="sigmoid", name="sigmoid")(features)523model = keras.Model(inputs=inputs, outputs=outputs)524return model525526527tabtransformer_model = create_tabtransformer_classifier(528num_transformer_blocks=NUM_TRANSFORMER_BLOCKS,529num_heads=NUM_HEADS,530embedding_dims=EMBEDDING_DIMS,531mlp_hidden_units_factors=MLP_HIDDEN_UNITS_FACTORS,532dropout_rate=DROPOUT_RATE,533)534535print("Total model weights:", tabtransformer_model.count_params())536keras.utils.plot_model(tabtransformer_model, show_shapes=True, rankdir="LR")537538"""539Let's train and evaluate the TabTransformer model:540"""541542history = run_experiment(543model=tabtransformer_model,544train_data_file=train_data_file,545test_data_file=test_data_file,546num_epochs=NUM_EPOCHS,547learning_rate=LEARNING_RATE,548weight_decay=WEIGHT_DECAY,549batch_size=BATCH_SIZE,550)551552"""553The TabTransformer model achieves ~85% validation accuracy.554Note that, with the default parameter configurations, both the baseline and the TabTransformer555have similar number of trainable weights: 109,895 and 87,745 respectively, and both use the same training hyperparameters.556"""557558"""559## Conclusion560561TabTransformer significantly outperforms MLP and recent562deep networks for tabular data while matching the performance of tree-based ensemble models.563TabTransformer can be learned in end-to-end supervised training using labeled examples.564For a scenario where there are a few labeled examples and a large number of unlabeled565examples, a pre-training procedure can be employed to train the Transformer layers using unlabeled data.566This is followed by fine-tuning of the pre-trained Transformer layers along with567the top MLP layer using the labeled data.568"""569570571