Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/structured_data/tabtransformer.py
3507 views
1
"""
2
Title: Structured data learning with TabTransformer
3
Author: [Khalid Salama](https://www.linkedin.com/in/khalid-salama-24403144/)
4
Date created: 2022/01/18
5
Last modified: 2022/01/18
6
Description: Using contextual embeddings for structured data classification.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
This example demonstrates how to do structured data classification using
14
[TabTransformer](https://arxiv.org/abs/2012.06678), a deep tabular data modeling
15
architecture for supervised and semi-supervised learning.
16
The TabTransformer is built upon self-attention based Transformers.
17
The Transformer layers transform the embeddings of categorical features
18
into robust contextual embeddings to achieve higher predictive accuracy.
19
20
21
22
## Setup
23
"""
24
import keras
25
from keras import layers
26
from keras import ops
27
28
import math
29
import numpy as np
30
import pandas as pd
31
from tensorflow import data as tf_data
32
import matplotlib.pyplot as plt
33
from functools import partial
34
35
"""
36
## Prepare the data
37
38
This example uses the
39
[United States Census Income Dataset](https://archive.ics.uci.edu/ml/datasets/census+income)
40
provided by the
41
[UC Irvine Machine Learning Repository](https://archive.ics.uci.edu/ml/index.php).
42
The task is binary classification
43
to predict whether a person is likely to be making over USD 50,000 a year.
44
45
The dataset includes 48,842 instances with 14 input features: 5 numerical features and 9 categorical features.
46
47
First, let's load the dataset from the UCI Machine Learning Repository into a Pandas
48
DataFrame:
49
"""
50
51
CSV_HEADER = [
52
"age",
53
"workclass",
54
"fnlwgt",
55
"education",
56
"education_num",
57
"marital_status",
58
"occupation",
59
"relationship",
60
"race",
61
"gender",
62
"capital_gain",
63
"capital_loss",
64
"hours_per_week",
65
"native_country",
66
"income_bracket",
67
]
68
69
train_data_url = (
70
"https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data"
71
)
72
train_data = pd.read_csv(train_data_url, header=None, names=CSV_HEADER)
73
74
test_data_url = (
75
"https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test"
76
)
77
test_data = pd.read_csv(test_data_url, header=None, names=CSV_HEADER)
78
79
print(f"Train dataset shape: {train_data.shape}")
80
print(f"Test dataset shape: {test_data.shape}")
81
82
"""
83
Remove the first record (because it is not a valid data example) and a trailing 'dot' in the class labels.
84
"""
85
86
test_data = test_data[1:]
87
test_data.income_bracket = test_data.income_bracket.apply(
88
lambda value: value.replace(".", "")
89
)
90
91
"""
92
Now we store the training and test data in separate CSV files.
93
"""
94
95
train_data_file = "train_data.csv"
96
test_data_file = "test_data.csv"
97
98
train_data.to_csv(train_data_file, index=False, header=False)
99
test_data.to_csv(test_data_file, index=False, header=False)
100
101
"""
102
## Define dataset metadata
103
104
Here, we define the metadata of the dataset that will be useful for reading and parsing
105
the data into input features, and encoding the input features with respect to their types.
106
"""
107
108
# A list of the numerical feature names.
109
NUMERIC_FEATURE_NAMES = [
110
"age",
111
"education_num",
112
"capital_gain",
113
"capital_loss",
114
"hours_per_week",
115
]
116
# A dictionary of the categorical features and their vocabulary.
117
CATEGORICAL_FEATURES_WITH_VOCABULARY = {
118
"workclass": sorted(list(train_data["workclass"].unique())),
119
"education": sorted(list(train_data["education"].unique())),
120
"marital_status": sorted(list(train_data["marital_status"].unique())),
121
"occupation": sorted(list(train_data["occupation"].unique())),
122
"relationship": sorted(list(train_data["relationship"].unique())),
123
"race": sorted(list(train_data["race"].unique())),
124
"gender": sorted(list(train_data["gender"].unique())),
125
"native_country": sorted(list(train_data["native_country"].unique())),
126
}
127
# Name of the column to be used as instances weight.
128
WEIGHT_COLUMN_NAME = "fnlwgt"
129
# A list of the categorical feature names.
130
CATEGORICAL_FEATURE_NAMES = list(CATEGORICAL_FEATURES_WITH_VOCABULARY.keys())
131
# A list of all the input features.
132
FEATURE_NAMES = NUMERIC_FEATURE_NAMES + CATEGORICAL_FEATURE_NAMES
133
# A list of column default values for each feature.
134
COLUMN_DEFAULTS = [
135
[0.0] if feature_name in NUMERIC_FEATURE_NAMES + [WEIGHT_COLUMN_NAME] else ["NA"]
136
for feature_name in CSV_HEADER
137
]
138
# The name of the target feature.
139
TARGET_FEATURE_NAME = "income_bracket"
140
# A list of the labels of the target features.
141
TARGET_LABELS = [" <=50K", " >50K"]
142
143
"""
144
## Configure the hyperparameters
145
146
The hyperparameters includes model architecture and training configurations.
147
"""
148
149
LEARNING_RATE = 0.001
150
WEIGHT_DECAY = 0.0001
151
DROPOUT_RATE = 0.2
152
BATCH_SIZE = 265
153
NUM_EPOCHS = 15
154
155
NUM_TRANSFORMER_BLOCKS = 3 # Number of transformer blocks.
156
NUM_HEADS = 4 # Number of attention heads.
157
EMBEDDING_DIMS = 16 # Embedding dimensions of the categorical features.
158
MLP_HIDDEN_UNITS_FACTORS = [
159
2,
160
1,
161
] # MLP hidden layer units, as factors of the number of inputs.
162
NUM_MLP_BLOCKS = 2 # Number of MLP blocks in the baseline model.
163
164
"""
165
## Implement data reading pipeline
166
167
We define an input function that reads and parses the file, then converts features
168
and labels into a[`tf.data.Dataset`](https://www.tensorflow.org/guide/datasets)
169
for training or evaluation.
170
"""
171
172
target_label_lookup = layers.StringLookup(
173
vocabulary=TARGET_LABELS, mask_token=None, num_oov_indices=0
174
)
175
176
177
def prepare_example(features, target):
178
target_index = target_label_lookup(target)
179
weights = features.pop(WEIGHT_COLUMN_NAME)
180
return features, target_index, weights
181
182
183
lookup_dict = {}
184
for feature_name in CATEGORICAL_FEATURE_NAMES:
185
vocabulary = CATEGORICAL_FEATURES_WITH_VOCABULARY[feature_name]
186
# Create a lookup to convert a string values to an integer indices.
187
# Since we are not using a mask token, nor expecting any out of vocabulary
188
# (oov) token, we set mask_token to None and num_oov_indices to 0.
189
lookup = layers.StringLookup(
190
vocabulary=vocabulary, mask_token=None, num_oov_indices=0
191
)
192
lookup_dict[feature_name] = lookup
193
194
195
def encode_categorical(batch_x, batch_y, weights):
196
for feature_name in CATEGORICAL_FEATURE_NAMES:
197
batch_x[feature_name] = lookup_dict[feature_name](batch_x[feature_name])
198
199
return batch_x, batch_y, weights
200
201
202
def get_dataset_from_csv(csv_file_path, batch_size=128, shuffle=False):
203
dataset = (
204
tf_data.experimental.make_csv_dataset(
205
csv_file_path,
206
batch_size=batch_size,
207
column_names=CSV_HEADER,
208
column_defaults=COLUMN_DEFAULTS,
209
label_name=TARGET_FEATURE_NAME,
210
num_epochs=1,
211
header=False,
212
na_value="?",
213
shuffle=shuffle,
214
)
215
.map(prepare_example, num_parallel_calls=tf_data.AUTOTUNE, deterministic=False)
216
.map(encode_categorical)
217
)
218
return dataset.cache()
219
220
221
"""
222
## Implement a training and evaluation procedure
223
"""
224
225
226
def run_experiment(
227
model,
228
train_data_file,
229
test_data_file,
230
num_epochs,
231
learning_rate,
232
weight_decay,
233
batch_size,
234
):
235
optimizer = keras.optimizers.AdamW(
236
learning_rate=learning_rate, weight_decay=weight_decay
237
)
238
239
model.compile(
240
optimizer=optimizer,
241
loss=keras.losses.BinaryCrossentropy(),
242
metrics=[keras.metrics.BinaryAccuracy(name="accuracy")],
243
)
244
245
train_dataset = get_dataset_from_csv(train_data_file, batch_size, shuffle=True)
246
validation_dataset = get_dataset_from_csv(test_data_file, batch_size)
247
248
print("Start training the model...")
249
history = model.fit(
250
train_dataset, epochs=num_epochs, validation_data=validation_dataset
251
)
252
print("Model training finished")
253
254
_, accuracy = model.evaluate(validation_dataset, verbose=0)
255
256
print(f"Validation accuracy: {round(accuracy * 100, 2)}%")
257
258
return history
259
260
261
"""
262
## Create model inputs
263
264
Now, define the inputs for the models as a dictionary, where the key is the feature name,
265
and the value is a `keras.layers.Input` tensor with the corresponding feature shape
266
and data type.
267
"""
268
269
270
def create_model_inputs():
271
inputs = {}
272
for feature_name in FEATURE_NAMES:
273
if feature_name in NUMERIC_FEATURE_NAMES:
274
inputs[feature_name] = layers.Input(
275
name=feature_name, shape=(), dtype="float32"
276
)
277
else:
278
inputs[feature_name] = layers.Input(
279
name=feature_name, shape=(), dtype="int32"
280
)
281
return inputs
282
283
284
"""
285
## Encode features
286
287
The `encode_inputs` method returns `encoded_categorical_feature_list` and `numerical_feature_list`.
288
We encode the categorical features as embeddings, using a fixed `embedding_dims` for all the features,
289
regardless their vocabulary sizes. This is required for the Transformer model.
290
"""
291
292
293
def encode_inputs(inputs, embedding_dims):
294
encoded_categorical_feature_list = []
295
numerical_feature_list = []
296
297
for feature_name in inputs:
298
if feature_name in CATEGORICAL_FEATURE_NAMES:
299
vocabulary = CATEGORICAL_FEATURES_WITH_VOCABULARY[feature_name]
300
# Create a lookup to convert a string values to an integer indices.
301
# Since we are not using a mask token, nor expecting any out of vocabulary
302
# (oov) token, we set mask_token to None and num_oov_indices to 0.
303
304
# Convert the string input values into integer indices.
305
306
# Create an embedding layer with the specified dimensions.
307
embedding = layers.Embedding(
308
input_dim=len(vocabulary), output_dim=embedding_dims
309
)
310
311
# Convert the index values to embedding representations.
312
encoded_categorical_feature = embedding(inputs[feature_name])
313
encoded_categorical_feature_list.append(encoded_categorical_feature)
314
315
else:
316
# Use the numerical features as-is.
317
numerical_feature = ops.expand_dims(inputs[feature_name], -1)
318
numerical_feature_list.append(numerical_feature)
319
320
return encoded_categorical_feature_list, numerical_feature_list
321
322
323
"""
324
## Implement an MLP block
325
"""
326
327
328
def create_mlp(hidden_units, dropout_rate, activation, normalization_layer, name=None):
329
mlp_layers = []
330
for units in hidden_units:
331
mlp_layers.append(normalization_layer())
332
mlp_layers.append(layers.Dense(units, activation=activation))
333
mlp_layers.append(layers.Dropout(dropout_rate))
334
335
return keras.Sequential(mlp_layers, name=name)
336
337
338
"""
339
## Experiment 1: a baseline model
340
341
In the first experiment, we create a simple multi-layer feed-forward network.
342
"""
343
344
345
def create_baseline_model(
346
embedding_dims, num_mlp_blocks, mlp_hidden_units_factors, dropout_rate
347
):
348
# Create model inputs.
349
inputs = create_model_inputs()
350
# encode features.
351
encoded_categorical_feature_list, numerical_feature_list = encode_inputs(
352
inputs, embedding_dims
353
)
354
# Concatenate all features.
355
features = layers.concatenate(
356
encoded_categorical_feature_list + numerical_feature_list
357
)
358
# Compute Feedforward layer units.
359
feedforward_units = [features.shape[-1]]
360
361
# Create several feedforwad layers with skip connections.
362
for layer_idx in range(num_mlp_blocks):
363
features = create_mlp(
364
hidden_units=feedforward_units,
365
dropout_rate=dropout_rate,
366
activation=keras.activations.gelu,
367
normalization_layer=layers.LayerNormalization,
368
name=f"feedforward_{layer_idx}",
369
)(features)
370
371
# Compute MLP hidden_units.
372
mlp_hidden_units = [
373
factor * features.shape[-1] for factor in mlp_hidden_units_factors
374
]
375
# Create final MLP.
376
features = create_mlp(
377
hidden_units=mlp_hidden_units,
378
dropout_rate=dropout_rate,
379
activation=keras.activations.selu,
380
normalization_layer=layers.BatchNormalization,
381
name="MLP",
382
)(features)
383
384
# Add a sigmoid as a binary classifer.
385
outputs = layers.Dense(units=1, activation="sigmoid", name="sigmoid")(features)
386
model = keras.Model(inputs=inputs, outputs=outputs)
387
return model
388
389
390
baseline_model = create_baseline_model(
391
embedding_dims=EMBEDDING_DIMS,
392
num_mlp_blocks=NUM_MLP_BLOCKS,
393
mlp_hidden_units_factors=MLP_HIDDEN_UNITS_FACTORS,
394
dropout_rate=DROPOUT_RATE,
395
)
396
397
print("Total model weights:", baseline_model.count_params())
398
keras.utils.plot_model(baseline_model, show_shapes=True, rankdir="LR")
399
400
"""
401
Let's train and evaluate the baseline model:
402
"""
403
404
history = run_experiment(
405
model=baseline_model,
406
train_data_file=train_data_file,
407
test_data_file=test_data_file,
408
num_epochs=NUM_EPOCHS,
409
learning_rate=LEARNING_RATE,
410
weight_decay=WEIGHT_DECAY,
411
batch_size=BATCH_SIZE,
412
)
413
414
"""
415
The baseline linear model achieves ~81% validation accuracy.
416
"""
417
418
"""
419
## Experiment 2: TabTransformer
420
421
The TabTransformer architecture works as follows:
422
423
1. All the categorical features are encoded as embeddings, using the same `embedding_dims`.
424
This means that each value in each categorical feature will have its own embedding vector.
425
2. A column embedding, one embedding vector for each categorical feature, is added (point-wise) to the categorical feature embedding.
426
3. The embedded categorical features are fed into a stack of Transformer blocks.
427
Each Transformer block consists of a multi-head self-attention layer followed by a feed-forward layer.
428
3. The outputs of the final Transformer layer, which are the *contextual embeddings* of the categorical features,
429
are concatenated with the input numerical features, and fed into a final MLP block.
430
4. A `softmax` classifer is applied at the end of the model.
431
432
The [paper](https://arxiv.org/abs/2012.06678) discusses both addition and concatenation of the column embedding in the
433
*Appendix: Experiment and Model Details* section.
434
The architecture of TabTransformer is shown below, as presented in the paper.
435
436
<img src="https://raw.githubusercontent.com/keras-team/keras-io/master/examples/structured_data/img/tabtransformer/tabtransformer.png" width="500"/>
437
"""
438
439
440
def create_tabtransformer_classifier(
441
num_transformer_blocks,
442
num_heads,
443
embedding_dims,
444
mlp_hidden_units_factors,
445
dropout_rate,
446
use_column_embedding=False,
447
):
448
# Create model inputs.
449
inputs = create_model_inputs()
450
# encode features.
451
encoded_categorical_feature_list, numerical_feature_list = encode_inputs(
452
inputs, embedding_dims
453
)
454
# Stack categorical feature embeddings for the Tansformer.
455
encoded_categorical_features = ops.stack(encoded_categorical_feature_list, axis=1)
456
# Concatenate numerical features.
457
numerical_features = layers.concatenate(numerical_feature_list)
458
459
# Add column embedding to categorical feature embeddings.
460
if use_column_embedding:
461
num_columns = encoded_categorical_features.shape[1]
462
column_embedding = layers.Embedding(
463
input_dim=num_columns, output_dim=embedding_dims
464
)
465
column_indices = ops.arange(start=0, stop=num_columns, step=1)
466
encoded_categorical_features = encoded_categorical_features + column_embedding(
467
column_indices
468
)
469
470
# Create multiple layers of the Transformer block.
471
for block_idx in range(num_transformer_blocks):
472
# Create a multi-head attention layer.
473
attention_output = layers.MultiHeadAttention(
474
num_heads=num_heads,
475
key_dim=embedding_dims,
476
dropout=dropout_rate,
477
name=f"multihead_attention_{block_idx}",
478
)(encoded_categorical_features, encoded_categorical_features)
479
# Skip connection 1.
480
x = layers.Add(name=f"skip_connection1_{block_idx}")(
481
[attention_output, encoded_categorical_features]
482
)
483
# Layer normalization 1.
484
x = layers.LayerNormalization(name=f"layer_norm1_{block_idx}", epsilon=1e-6)(x)
485
# Feedforward.
486
feedforward_output = create_mlp(
487
hidden_units=[embedding_dims],
488
dropout_rate=dropout_rate,
489
activation=keras.activations.gelu,
490
normalization_layer=partial(
491
layers.LayerNormalization, epsilon=1e-6
492
), # using partial to provide keyword arguments before initialization
493
name=f"feedforward_{block_idx}",
494
)(x)
495
# Skip connection 2.
496
x = layers.Add(name=f"skip_connection2_{block_idx}")([feedforward_output, x])
497
# Layer normalization 2.
498
encoded_categorical_features = layers.LayerNormalization(
499
name=f"layer_norm2_{block_idx}", epsilon=1e-6
500
)(x)
501
502
# Flatten the "contextualized" embeddings of the categorical features.
503
categorical_features = layers.Flatten()(encoded_categorical_features)
504
# Apply layer normalization to the numerical features.
505
numerical_features = layers.LayerNormalization(epsilon=1e-6)(numerical_features)
506
# Prepare the input for the final MLP block.
507
features = layers.concatenate([categorical_features, numerical_features])
508
509
# Compute MLP hidden_units.
510
mlp_hidden_units = [
511
factor * features.shape[-1] for factor in mlp_hidden_units_factors
512
]
513
# Create final MLP.
514
features = create_mlp(
515
hidden_units=mlp_hidden_units,
516
dropout_rate=dropout_rate,
517
activation=keras.activations.selu,
518
normalization_layer=layers.BatchNormalization,
519
name="MLP",
520
)(features)
521
522
# Add a sigmoid as a binary classifer.
523
outputs = layers.Dense(units=1, activation="sigmoid", name="sigmoid")(features)
524
model = keras.Model(inputs=inputs, outputs=outputs)
525
return model
526
527
528
tabtransformer_model = create_tabtransformer_classifier(
529
num_transformer_blocks=NUM_TRANSFORMER_BLOCKS,
530
num_heads=NUM_HEADS,
531
embedding_dims=EMBEDDING_DIMS,
532
mlp_hidden_units_factors=MLP_HIDDEN_UNITS_FACTORS,
533
dropout_rate=DROPOUT_RATE,
534
)
535
536
print("Total model weights:", tabtransformer_model.count_params())
537
keras.utils.plot_model(tabtransformer_model, show_shapes=True, rankdir="LR")
538
539
"""
540
Let's train and evaluate the TabTransformer model:
541
"""
542
543
history = run_experiment(
544
model=tabtransformer_model,
545
train_data_file=train_data_file,
546
test_data_file=test_data_file,
547
num_epochs=NUM_EPOCHS,
548
learning_rate=LEARNING_RATE,
549
weight_decay=WEIGHT_DECAY,
550
batch_size=BATCH_SIZE,
551
)
552
553
"""
554
The TabTransformer model achieves ~85% validation accuracy.
555
Note that, with the default parameter configurations, both the baseline and the TabTransformer
556
have similar number of trainable weights: 109,895 and 87,745 respectively, and both use the same training hyperparameters.
557
"""
558
559
"""
560
## Conclusion
561
562
TabTransformer significantly outperforms MLP and recent
563
deep networks for tabular data while matching the performance of tree-based ensemble models.
564
TabTransformer can be learned in end-to-end supervised training using labeled examples.
565
For a scenario where there are a few labeled examples and a large number of unlabeled
566
examples, a pre-training procedure can be employed to train the Transformer layers using unlabeled data.
567
This is followed by fine-tuning of the pre-trained Transformer layers along with
568
the top MLP layer using the labeled data.
569
"""
570
571