Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/structured_data/classification_with_grn_and_vsn.py
3507 views
1
"""
2
Title: Classification with Gated Residual and Variable Selection Networks
3
Author: [Khalid Salama](https://www.linkedin.com/in/khalid-salama-24403144/)
4
Date created: 2021/02/10
5
Last modified: 2025/01/08
6
Description: Using Gated Residual and Variable Selection Networks for income level prediction.
7
Accelerator: GPU
8
Converted to Keras 3 by: [Sitam Meur](https://github.com/sitamgithub-MSIT) and made backend-agnostic by: [Humbulani Ndou](https://github.com/Humbulani1234)
9
"""
10
11
"""
12
## Introduction
13
14
This example demonstrates the use of Gated
15
Residual Networks (GRN) and Variable Selection Networks (VSN), proposed by
16
Bryan Lim et al. in
17
[Temporal Fusion Transformers (TFT) for Interpretable Multi-horizon Time Series Forecasting](https://arxiv.org/abs/1912.09363),
18
for structured data classification. GRNs give the flexibility to the model to apply
19
non-linear processing only where needed. VSNs allow the model to softly remove any
20
unnecessary noisy inputs which could negatively impact performance.
21
Together, those techniques help improving the learning capacity of deep neural
22
network models.
23
24
Note that this example implements only the GRN and VSN components described in
25
in the paper, rather than the whole TFT model, as GRN and VSN can be useful on
26
their own for structured data learning tasks.
27
28
29
To run the code you need to use TensorFlow 2.3 or higher.
30
"""
31
32
"""
33
## The dataset
34
35
This example uses the
36
[United States Census Income Dataset](https://archive.ics.uci.edu/ml/datasets/Census-Income+%28KDD%29)
37
provided by the
38
[UC Irvine Machine Learning Repository](https://archive.ics.uci.edu/ml/index.php).
39
The task is binary classification to determine whether a person makes over 50K a year.
40
41
The dataset includes ~300K instances with 41 input features: 7 numerical features
42
and 34 categorical features.
43
"""
44
45
"""
46
## Setup
47
"""
48
49
import os
50
import subprocess
51
import tarfile
52
53
os.environ["KERAS_BACKEND"] = "torch" # or jax, or tensorflow
54
55
import numpy as np
56
import pandas as pd
57
import keras
58
from keras import layers
59
60
"""
61
## Prepare the data
62
63
First we load the data from the UCI Machine Learning Repository into a Pandas DataFrame.
64
"""
65
66
# Column names.
67
CSV_HEADER = [
68
"age",
69
"class_of_worker",
70
"detailed_industry_recode",
71
"detailed_occupation_recode",
72
"education",
73
"wage_per_hour",
74
"enroll_in_edu_inst_last_wk",
75
"marital_stat",
76
"major_industry_code",
77
"major_occupation_code",
78
"race",
79
"hispanic_origin",
80
"sex",
81
"member_of_a_labor_union",
82
"reason_for_unemployment",
83
"full_or_part_time_employment_stat",
84
"capital_gains",
85
"capital_losses",
86
"dividends_from_stocks",
87
"tax_filer_stat",
88
"region_of_previous_residence",
89
"state_of_previous_residence",
90
"detailed_household_and_family_stat",
91
"detailed_household_summary_in_household",
92
"instance_weight",
93
"migration_code-change_in_msa",
94
"migration_code-change_in_reg",
95
"migration_code-move_within_reg",
96
"live_in_this_house_1_year_ago",
97
"migration_prev_res_in_sunbelt",
98
"num_persons_worked_for_employer",
99
"family_members_under_18",
100
"country_of_birth_father",
101
"country_of_birth_mother",
102
"country_of_birth_self",
103
"citizenship",
104
"own_business_or_self_employed",
105
"fill_inc_questionnaire_for_veterans_admin",
106
"veterans_benefits",
107
"weeks_worked_in_year",
108
"year",
109
"income_level",
110
]
111
112
data_url = "https://archive.ics.uci.edu/static/public/117/census+income+kdd.zip"
113
keras.utils.get_file(origin=data_url, extract=True)
114
115
"""
116
Determine the downloaded .tar.gz file path and
117
extract the files from the downloaded .tar.gz file
118
"""
119
120
extracted_path = os.path.join(
121
os.path.expanduser("~"), ".keras", "datasets", "census+income+kdd.zip"
122
)
123
for root, dirs, files in os.walk(extracted_path):
124
for file in files:
125
if file.endswith(".tar.gz"):
126
tar_gz_path = os.path.join(root, file)
127
with tarfile.open(tar_gz_path, "r:gz") as tar:
128
tar.extractall(path=root)
129
130
train_data_path = os.path.join(
131
os.path.expanduser("~"),
132
".keras",
133
"datasets",
134
"census+income+kdd.zip",
135
"census-income.data",
136
)
137
test_data_path = os.path.join(
138
os.path.expanduser("~"),
139
".keras",
140
"datasets",
141
"census+income+kdd.zip",
142
"census-income.test",
143
)
144
145
data = pd.read_csv(train_data_path, header=None, names=CSV_HEADER)
146
test_data = pd.read_csv(test_data_path, header=None, names=CSV_HEADER)
147
148
print(f"Data shape: {data.shape}")
149
print(f"Test data shape: {test_data.shape}")
150
151
152
"""
153
We convert the target column from string to integer.
154
"""
155
156
data["income_level"] = data["income_level"].apply(
157
lambda x: 0 if x == " - 50000." else 1
158
)
159
test_data["income_level"] = test_data["income_level"].apply(
160
lambda x: 0 if x == " - 50000." else 1
161
)
162
163
164
"""
165
Then, We split the dataset into train and validation sets.
166
"""
167
168
random_selection = np.random.rand(len(data.index)) <= 0.85
169
train_data = data[random_selection]
170
valid_data = data[~random_selection]
171
172
173
"""
174
Finally we store the train and test data splits locally to CSV files.
175
"""
176
177
train_data_file = "train_data.csv"
178
valid_data_file = "valid_data.csv"
179
test_data_file = "test_data.csv"
180
181
train_data.to_csv(train_data_file, index=False, header=False)
182
valid_data.to_csv(valid_data_file, index=False, header=False)
183
test_data.to_csv(test_data_file, index=False, header=False)
184
185
"""
186
## Define dataset metadata
187
188
Here, we define the metadata of the dataset that will be useful for reading and
189
parsing the data into input features, and encoding the input features with respect
190
to their types.
191
"""
192
193
# Target feature name.
194
TARGET_FEATURE_NAME = "income_level"
195
# Weight column name.
196
WEIGHT_COLUMN_NAME = "instance_weight"
197
# Numeric feature names.
198
NUMERIC_FEATURE_NAMES = [
199
"age",
200
"wage_per_hour",
201
"capital_gains",
202
"capital_losses",
203
"dividends_from_stocks",
204
"num_persons_worked_for_employer",
205
"weeks_worked_in_year",
206
]
207
# Categorical features and their vocabulary lists.
208
# Note that we add 'v=' as a prefix to all categorical feature values to make
209
# sure that they are treated as strings.
210
CATEGORICAL_FEATURES_WITH_VOCABULARY = {
211
feature_name: sorted([str(value) for value in list(data[feature_name].unique())])
212
for feature_name in CSV_HEADER
213
if feature_name
214
not in list(NUMERIC_FEATURE_NAMES + [WEIGHT_COLUMN_NAME, TARGET_FEATURE_NAME])
215
}
216
# All features names.
217
FEATURE_NAMES = NUMERIC_FEATURE_NAMES + list(
218
CATEGORICAL_FEATURES_WITH_VOCABULARY.keys()
219
)
220
# Feature default values.
221
COLUMN_DEFAULTS = [
222
(
223
[0.0]
224
if feature_name
225
in NUMERIC_FEATURE_NAMES + [TARGET_FEATURE_NAME, WEIGHT_COLUMN_NAME]
226
else ["NA"]
227
)
228
for feature_name in CSV_HEADER
229
]
230
231
"""
232
## Create a `tf.data.Dataset` for training and evaluation
233
234
We create an input function to read and parse the file, and convert features and
235
labels into a [`tf.data.Dataset`](https://www.tensorflow.org/guide/datasets) for
236
training and evaluation.
237
"""
238
239
# Tensorflow required for tf.data.Datasets
240
import tensorflow as tf
241
242
243
# We process our datasets elements here (categorical) and convert them to indices to avoid this step
244
# during model training since only tensorflow support strings.
245
def process(features, target):
246
for feature_name in features:
247
if feature_name in CATEGORICAL_FEATURES_WITH_VOCABULARY:
248
# Cast categorical feature values to string.
249
features[feature_name] = tf.cast(features[feature_name], "string")
250
vocabulary = CATEGORICAL_FEATURES_WITH_VOCABULARY[feature_name]
251
# Create a lookup to convert a string values to an integer indices.
252
# Since we are not using a mask token nor expecting any out of vocabulary
253
# (oov) token, we set mask_token to None and num_oov_indices to 0.
254
index = layers.StringLookup(
255
vocabulary=vocabulary,
256
mask_token=None,
257
num_oov_indices=0,
258
output_mode="int",
259
)
260
# Convert the string input values into integer indices.
261
value_index = index(features[feature_name])
262
features[feature_name] = value_index
263
else:
264
# Do nothing for numerical features
265
pass
266
267
# Get the instance weight.
268
weight = features.pop(WEIGHT_COLUMN_NAME)
269
# Change features from OrderedDict to Dict to match Inputs as they are Dict.
270
return dict(features), target, weight
271
272
273
def get_dataset_from_csv(csv_file_path, shuffle=False, batch_size=128):
274
dataset = tf.data.experimental.make_csv_dataset(
275
csv_file_path,
276
batch_size=batch_size,
277
column_names=CSV_HEADER,
278
column_defaults=COLUMN_DEFAULTS,
279
label_name=TARGET_FEATURE_NAME,
280
num_epochs=1,
281
header=False,
282
shuffle=shuffle,
283
).map(process)
284
285
return dataset
286
287
288
"""
289
## Create model inputs
290
"""
291
292
293
def create_model_inputs():
294
inputs = {}
295
for feature_name in FEATURE_NAMES:
296
if feature_name in CATEGORICAL_FEATURES_WITH_VOCABULARY:
297
# Make them int64, they are Categorical (whole units)
298
inputs[feature_name] = layers.Input(
299
name=feature_name, shape=(), dtype="int64"
300
)
301
else:
302
# Make them float32, they are Real numbers
303
inputs[feature_name] = layers.Input(
304
name=feature_name, shape=(), dtype="float32"
305
)
306
return inputs
307
308
309
"""
310
## Implement the Gated Linear Unit
311
312
[Gated Linear Units (GLUs)](https://arxiv.org/abs/1612.08083) provide the
313
flexibility to suppress input that are not relevant for a given task.
314
"""
315
316
317
class GatedLinearUnit(layers.Layer):
318
def __init__(self, units):
319
super().__init__()
320
self.linear = layers.Dense(units)
321
self.sigmoid = layers.Dense(units, activation="sigmoid")
322
323
def call(self, inputs):
324
return self.linear(inputs) * self.sigmoid(inputs)
325
326
# Remove build warnings
327
def build(self):
328
self.built = True
329
330
331
"""
332
## Implement the Gated Residual Network
333
334
The Gated Residual Network (GRN) works as follows:
335
336
1. Applies the nonlinear ELU transformation to the inputs.
337
2. Applies linear transformation followed by dropout.
338
4. Applies GLU and adds the original inputs to the output of the GLU to perform skip
339
(residual) connection.
340
6. Applies layer normalization and produces the output.
341
"""
342
343
344
class GatedResidualNetwork(layers.Layer):
345
def __init__(self, units, dropout_rate):
346
super().__init__()
347
self.units = units
348
self.elu_dense = layers.Dense(units, activation="elu")
349
self.linear_dense = layers.Dense(units)
350
self.dropout = layers.Dropout(dropout_rate)
351
self.gated_linear_unit = GatedLinearUnit(units)
352
self.layer_norm = layers.LayerNormalization()
353
self.project = layers.Dense(units)
354
355
def call(self, inputs):
356
x = self.elu_dense(inputs)
357
x = self.linear_dense(x)
358
x = self.dropout(x)
359
if inputs.shape[-1] != self.units:
360
inputs = self.project(inputs)
361
x = inputs + self.gated_linear_unit(x)
362
x = self.layer_norm(x)
363
return x
364
365
# Remove build warnings
366
def build(self):
367
self.built = True
368
369
370
"""
371
## Implement the Variable Selection Network
372
373
The Variable Selection Network (VSN) works as follows:
374
375
1. Applies a GRN to each feature individually.
376
2. Applies a GRN on the concatenation of all the features, followed by a softmax to
377
produce feature weights.
378
3. Produces a weighted sum of the output of the individual GRN.
379
380
Note that the output of the VSN is [batch_size, encoding_size], regardless of the
381
number of the input features.
382
383
For categorical features, we encode them using `layers.Embedding` using the
384
`encoding_size` as the embedding dimensions. For the numerical features,
385
we apply linear transformation using `layers.Dense` to project each feature into
386
`encoding_size`-dimensional vector. Thus, all the encoded features will have the
387
same dimensionality.
388
389
"""
390
391
392
class VariableSelection(layers.Layer):
393
def __init__(self, num_features, units, dropout_rate):
394
super().__init__()
395
self.units = units
396
# Create an embedding layers with the specified dimensions
397
self.embeddings = dict()
398
for input_ in CATEGORICAL_FEATURES_WITH_VOCABULARY:
399
vocabulary = CATEGORICAL_FEATURES_WITH_VOCABULARY[input_]
400
embedding_encoder = layers.Embedding(
401
input_dim=len(vocabulary), output_dim=self.units, name=input_
402
)
403
self.embeddings[input_] = embedding_encoder
404
405
# Projection layers for numeric features
406
self.proj_layer = dict()
407
for input_ in NUMERIC_FEATURE_NAMES:
408
proj_layer = layers.Dense(units=self.units)
409
self.proj_layer[input_] = proj_layer
410
411
self.grns = list()
412
# Create a GRN for each feature independently
413
for idx in range(num_features):
414
grn = GatedResidualNetwork(units, dropout_rate)
415
self.grns.append(grn)
416
# Create a GRN for the concatenation of all the features
417
self.grn_concat = GatedResidualNetwork(units, dropout_rate)
418
self.softmax = layers.Dense(units=num_features, activation="softmax")
419
420
def call(self, inputs):
421
concat_inputs = []
422
for input_ in inputs:
423
if input_ in CATEGORICAL_FEATURES_WITH_VOCABULARY:
424
max_index = self.embeddings[input_].input_dim - 1 # Clamp the indices
425
# torch had some index errors during embedding hence the clip function
426
embedded_feature = self.embeddings[input_](
427
keras.ops.clip(inputs[input_], 0, max_index)
428
)
429
concat_inputs.append(embedded_feature)
430
else:
431
# Project the numeric feature to encoding_size using linear transformation.
432
proj_feature = keras.ops.expand_dims(inputs[input_], -1)
433
proj_feature = self.proj_layer[input_](proj_feature)
434
concat_inputs.append(proj_feature)
435
436
v = layers.concatenate(concat_inputs)
437
v = self.grn_concat(v)
438
v = keras.ops.expand_dims(self.softmax(v), axis=-1)
439
x = []
440
for idx, input in enumerate(concat_inputs):
441
x.append(self.grns[idx](input))
442
x = keras.ops.stack(x, axis=1)
443
return keras.ops.squeeze(
444
keras.ops.matmul(keras.ops.transpose(v, axes=[0, 2, 1]), x), axis=1
445
)
446
447
# to remove the build warnings
448
def build(self):
449
self.built = True
450
451
452
"""
453
## Create Gated Residual and Variable Selection Networks model
454
"""
455
456
457
def create_model(encoding_size):
458
inputs = create_model_inputs()
459
num_features = len(inputs)
460
features = VariableSelection(num_features, encoding_size, dropout_rate)(inputs)
461
outputs = layers.Dense(units=1, activation="sigmoid")(features)
462
# Functional model
463
model = keras.Model(inputs=inputs, outputs=outputs)
464
return model
465
466
467
"""
468
## Compile, train, and evaluate the model
469
"""
470
471
learning_rate = 0.001
472
dropout_rate = 0.15
473
batch_size = 265
474
num_epochs = 20 # may be adjusted to a desired value
475
encoding_size = 16
476
477
model = create_model(encoding_size)
478
model.compile(
479
optimizer=keras.optimizers.Adam(learning_rate=learning_rate),
480
loss=keras.losses.BinaryCrossentropy(),
481
metrics=[keras.metrics.BinaryAccuracy(name="accuracy")],
482
)
483
484
"""
485
Let's visualize our connectivity graph:
486
"""
487
488
# `rankdir='LR'` is to make the graph horizontal.
489
keras.utils.plot_model(model, show_shapes=True, show_layer_names=True, rankdir="LR")
490
491
492
# Create an early stopping callback.
493
early_stopping = keras.callbacks.EarlyStopping(
494
monitor="val_loss", patience=5, restore_best_weights=True
495
)
496
497
print("Start training the model...")
498
train_dataset = get_dataset_from_csv(
499
train_data_file, shuffle=True, batch_size=batch_size
500
)
501
valid_dataset = get_dataset_from_csv(valid_data_file, batch_size=batch_size)
502
model.fit(
503
train_dataset,
504
epochs=num_epochs,
505
validation_data=valid_dataset,
506
callbacks=[early_stopping],
507
)
508
print("Model training finished.")
509
510
print("Evaluating model performance...")
511
test_dataset = get_dataset_from_csv(test_data_file, batch_size=batch_size)
512
_, accuracy = model.evaluate(test_dataset)
513
print(f"Test accuracy: {round(accuracy * 100, 2)}%")
514
515
"""
516
You should achieve more than 95% accuracy on the test set.
517
518
To increase the learning capacity of the model, you can try increasing the
519
`encoding_size` value, or stacking multiple GRN layers on top of the VSN layer.
520
This may require to also increase the `dropout_rate` value to avoid overfitting.
521
"""
522
523
"""
524
**Example available on HuggingFace**
525
526
| Trained Model | Demo |
527
| :--: | :--: |
528
| [![Generic badge](https://img.shields.io/badge/%F0%9F%A4%97%20Model-Classification%20With%20GRN%20%26%20VSN-red)](https://huggingface.co/keras-io/structured-data-classification-grn-vsn) | [![Generic badge](https://img.shields.io/badge/%F0%9F%A4%97%20Space-Classification%20With%20GRN%20%26%20VSN-red)](https://huggingface.co/spaces/keras-io/structured-data-classification-grn-vsn) |
529
"""
530
531