Path: blob/master/examples/structured_data/classification_with_grn_and_vsn.py
3507 views
"""1Title: Classification with Gated Residual and Variable Selection Networks2Author: [Khalid Salama](https://www.linkedin.com/in/khalid-salama-24403144/)3Date created: 2021/02/104Last modified: 2025/01/085Description: Using Gated Residual and Variable Selection Networks for income level prediction.6Accelerator: GPU7Converted to Keras 3 by: [Sitam Meur](https://github.com/sitamgithub-MSIT) and made backend-agnostic by: [Humbulani Ndou](https://github.com/Humbulani1234)8"""910"""11## Introduction1213This example demonstrates the use of Gated14Residual Networks (GRN) and Variable Selection Networks (VSN), proposed by15Bryan Lim et al. in16[Temporal Fusion Transformers (TFT) for Interpretable Multi-horizon Time Series Forecasting](https://arxiv.org/abs/1912.09363),17for structured data classification. GRNs give the flexibility to the model to apply18non-linear processing only where needed. VSNs allow the model to softly remove any19unnecessary noisy inputs which could negatively impact performance.20Together, those techniques help improving the learning capacity of deep neural21network models.2223Note that this example implements only the GRN and VSN components described in24in the paper, rather than the whole TFT model, as GRN and VSN can be useful on25their own for structured data learning tasks.262728To run the code you need to use TensorFlow 2.3 or higher.29"""3031"""32## The dataset3334This example uses the35[United States Census Income Dataset](https://archive.ics.uci.edu/ml/datasets/Census-Income+%28KDD%29)36provided by the37[UC Irvine Machine Learning Repository](https://archive.ics.uci.edu/ml/index.php).38The task is binary classification to determine whether a person makes over 50K a year.3940The dataset includes ~300K instances with 41 input features: 7 numerical features41and 34 categorical features.42"""4344"""45## Setup46"""4748import os49import subprocess50import tarfile5152os.environ["KERAS_BACKEND"] = "torch" # or jax, or tensorflow5354import numpy as np55import pandas as pd56import keras57from keras import layers5859"""60## Prepare the data6162First we load the data from the UCI Machine Learning Repository into a Pandas DataFrame.63"""6465# Column names.66CSV_HEADER = [67"age",68"class_of_worker",69"detailed_industry_recode",70"detailed_occupation_recode",71"education",72"wage_per_hour",73"enroll_in_edu_inst_last_wk",74"marital_stat",75"major_industry_code",76"major_occupation_code",77"race",78"hispanic_origin",79"sex",80"member_of_a_labor_union",81"reason_for_unemployment",82"full_or_part_time_employment_stat",83"capital_gains",84"capital_losses",85"dividends_from_stocks",86"tax_filer_stat",87"region_of_previous_residence",88"state_of_previous_residence",89"detailed_household_and_family_stat",90"detailed_household_summary_in_household",91"instance_weight",92"migration_code-change_in_msa",93"migration_code-change_in_reg",94"migration_code-move_within_reg",95"live_in_this_house_1_year_ago",96"migration_prev_res_in_sunbelt",97"num_persons_worked_for_employer",98"family_members_under_18",99"country_of_birth_father",100"country_of_birth_mother",101"country_of_birth_self",102"citizenship",103"own_business_or_self_employed",104"fill_inc_questionnaire_for_veterans_admin",105"veterans_benefits",106"weeks_worked_in_year",107"year",108"income_level",109]110111data_url = "https://archive.ics.uci.edu/static/public/117/census+income+kdd.zip"112keras.utils.get_file(origin=data_url, extract=True)113114"""115Determine the downloaded .tar.gz file path and116extract the files from the downloaded .tar.gz file117"""118119extracted_path = os.path.join(120os.path.expanduser("~"), ".keras", "datasets", "census+income+kdd.zip"121)122for root, dirs, files in os.walk(extracted_path):123for file in files:124if file.endswith(".tar.gz"):125tar_gz_path = os.path.join(root, file)126with tarfile.open(tar_gz_path, "r:gz") as tar:127tar.extractall(path=root)128129train_data_path = os.path.join(130os.path.expanduser("~"),131".keras",132"datasets",133"census+income+kdd.zip",134"census-income.data",135)136test_data_path = os.path.join(137os.path.expanduser("~"),138".keras",139"datasets",140"census+income+kdd.zip",141"census-income.test",142)143144data = pd.read_csv(train_data_path, header=None, names=CSV_HEADER)145test_data = pd.read_csv(test_data_path, header=None, names=CSV_HEADER)146147print(f"Data shape: {data.shape}")148print(f"Test data shape: {test_data.shape}")149150151"""152We convert the target column from string to integer.153"""154155data["income_level"] = data["income_level"].apply(156lambda x: 0 if x == " - 50000." else 1157)158test_data["income_level"] = test_data["income_level"].apply(159lambda x: 0 if x == " - 50000." else 1160)161162163"""164Then, We split the dataset into train and validation sets.165"""166167random_selection = np.random.rand(len(data.index)) <= 0.85168train_data = data[random_selection]169valid_data = data[~random_selection]170171172"""173Finally we store the train and test data splits locally to CSV files.174"""175176train_data_file = "train_data.csv"177valid_data_file = "valid_data.csv"178test_data_file = "test_data.csv"179180train_data.to_csv(train_data_file, index=False, header=False)181valid_data.to_csv(valid_data_file, index=False, header=False)182test_data.to_csv(test_data_file, index=False, header=False)183184"""185## Define dataset metadata186187Here, we define the metadata of the dataset that will be useful for reading and188parsing the data into input features, and encoding the input features with respect189to their types.190"""191192# Target feature name.193TARGET_FEATURE_NAME = "income_level"194# Weight column name.195WEIGHT_COLUMN_NAME = "instance_weight"196# Numeric feature names.197NUMERIC_FEATURE_NAMES = [198"age",199"wage_per_hour",200"capital_gains",201"capital_losses",202"dividends_from_stocks",203"num_persons_worked_for_employer",204"weeks_worked_in_year",205]206# Categorical features and their vocabulary lists.207# Note that we add 'v=' as a prefix to all categorical feature values to make208# sure that they are treated as strings.209CATEGORICAL_FEATURES_WITH_VOCABULARY = {210feature_name: sorted([str(value) for value in list(data[feature_name].unique())])211for feature_name in CSV_HEADER212if feature_name213not in list(NUMERIC_FEATURE_NAMES + [WEIGHT_COLUMN_NAME, TARGET_FEATURE_NAME])214}215# All features names.216FEATURE_NAMES = NUMERIC_FEATURE_NAMES + list(217CATEGORICAL_FEATURES_WITH_VOCABULARY.keys()218)219# Feature default values.220COLUMN_DEFAULTS = [221(222[0.0]223if feature_name224in NUMERIC_FEATURE_NAMES + [TARGET_FEATURE_NAME, WEIGHT_COLUMN_NAME]225else ["NA"]226)227for feature_name in CSV_HEADER228]229230"""231## Create a `tf.data.Dataset` for training and evaluation232233We create an input function to read and parse the file, and convert features and234labels into a [`tf.data.Dataset`](https://www.tensorflow.org/guide/datasets) for235training and evaluation.236"""237238# Tensorflow required for tf.data.Datasets239import tensorflow as tf240241242# We process our datasets elements here (categorical) and convert them to indices to avoid this step243# during model training since only tensorflow support strings.244def process(features, target):245for feature_name in features:246if feature_name in CATEGORICAL_FEATURES_WITH_VOCABULARY:247# Cast categorical feature values to string.248features[feature_name] = tf.cast(features[feature_name], "string")249vocabulary = CATEGORICAL_FEATURES_WITH_VOCABULARY[feature_name]250# Create a lookup to convert a string values to an integer indices.251# Since we are not using a mask token nor expecting any out of vocabulary252# (oov) token, we set mask_token to None and num_oov_indices to 0.253index = layers.StringLookup(254vocabulary=vocabulary,255mask_token=None,256num_oov_indices=0,257output_mode="int",258)259# Convert the string input values into integer indices.260value_index = index(features[feature_name])261features[feature_name] = value_index262else:263# Do nothing for numerical features264pass265266# Get the instance weight.267weight = features.pop(WEIGHT_COLUMN_NAME)268# Change features from OrderedDict to Dict to match Inputs as they are Dict.269return dict(features), target, weight270271272def get_dataset_from_csv(csv_file_path, shuffle=False, batch_size=128):273dataset = tf.data.experimental.make_csv_dataset(274csv_file_path,275batch_size=batch_size,276column_names=CSV_HEADER,277column_defaults=COLUMN_DEFAULTS,278label_name=TARGET_FEATURE_NAME,279num_epochs=1,280header=False,281shuffle=shuffle,282).map(process)283284return dataset285286287"""288## Create model inputs289"""290291292def create_model_inputs():293inputs = {}294for feature_name in FEATURE_NAMES:295if feature_name in CATEGORICAL_FEATURES_WITH_VOCABULARY:296# Make them int64, they are Categorical (whole units)297inputs[feature_name] = layers.Input(298name=feature_name, shape=(), dtype="int64"299)300else:301# Make them float32, they are Real numbers302inputs[feature_name] = layers.Input(303name=feature_name, shape=(), dtype="float32"304)305return inputs306307308"""309## Implement the Gated Linear Unit310311[Gated Linear Units (GLUs)](https://arxiv.org/abs/1612.08083) provide the312flexibility to suppress input that are not relevant for a given task.313"""314315316class GatedLinearUnit(layers.Layer):317def __init__(self, units):318super().__init__()319self.linear = layers.Dense(units)320self.sigmoid = layers.Dense(units, activation="sigmoid")321322def call(self, inputs):323return self.linear(inputs) * self.sigmoid(inputs)324325# Remove build warnings326def build(self):327self.built = True328329330"""331## Implement the Gated Residual Network332333The Gated Residual Network (GRN) works as follows:3343351. Applies the nonlinear ELU transformation to the inputs.3362. Applies linear transformation followed by dropout.3374. Applies GLU and adds the original inputs to the output of the GLU to perform skip338(residual) connection.3396. Applies layer normalization and produces the output.340"""341342343class GatedResidualNetwork(layers.Layer):344def __init__(self, units, dropout_rate):345super().__init__()346self.units = units347self.elu_dense = layers.Dense(units, activation="elu")348self.linear_dense = layers.Dense(units)349self.dropout = layers.Dropout(dropout_rate)350self.gated_linear_unit = GatedLinearUnit(units)351self.layer_norm = layers.LayerNormalization()352self.project = layers.Dense(units)353354def call(self, inputs):355x = self.elu_dense(inputs)356x = self.linear_dense(x)357x = self.dropout(x)358if inputs.shape[-1] != self.units:359inputs = self.project(inputs)360x = inputs + self.gated_linear_unit(x)361x = self.layer_norm(x)362return x363364# Remove build warnings365def build(self):366self.built = True367368369"""370## Implement the Variable Selection Network371372The Variable Selection Network (VSN) works as follows:3733741. Applies a GRN to each feature individually.3752. Applies a GRN on the concatenation of all the features, followed by a softmax to376produce feature weights.3773. Produces a weighted sum of the output of the individual GRN.378379Note that the output of the VSN is [batch_size, encoding_size], regardless of the380number of the input features.381382For categorical features, we encode them using `layers.Embedding` using the383`encoding_size` as the embedding dimensions. For the numerical features,384we apply linear transformation using `layers.Dense` to project each feature into385`encoding_size`-dimensional vector. Thus, all the encoded features will have the386same dimensionality.387388"""389390391class VariableSelection(layers.Layer):392def __init__(self, num_features, units, dropout_rate):393super().__init__()394self.units = units395# Create an embedding layers with the specified dimensions396self.embeddings = dict()397for input_ in CATEGORICAL_FEATURES_WITH_VOCABULARY:398vocabulary = CATEGORICAL_FEATURES_WITH_VOCABULARY[input_]399embedding_encoder = layers.Embedding(400input_dim=len(vocabulary), output_dim=self.units, name=input_401)402self.embeddings[input_] = embedding_encoder403404# Projection layers for numeric features405self.proj_layer = dict()406for input_ in NUMERIC_FEATURE_NAMES:407proj_layer = layers.Dense(units=self.units)408self.proj_layer[input_] = proj_layer409410self.grns = list()411# Create a GRN for each feature independently412for idx in range(num_features):413grn = GatedResidualNetwork(units, dropout_rate)414self.grns.append(grn)415# Create a GRN for the concatenation of all the features416self.grn_concat = GatedResidualNetwork(units, dropout_rate)417self.softmax = layers.Dense(units=num_features, activation="softmax")418419def call(self, inputs):420concat_inputs = []421for input_ in inputs:422if input_ in CATEGORICAL_FEATURES_WITH_VOCABULARY:423max_index = self.embeddings[input_].input_dim - 1 # Clamp the indices424# torch had some index errors during embedding hence the clip function425embedded_feature = self.embeddings[input_](426keras.ops.clip(inputs[input_], 0, max_index)427)428concat_inputs.append(embedded_feature)429else:430# Project the numeric feature to encoding_size using linear transformation.431proj_feature = keras.ops.expand_dims(inputs[input_], -1)432proj_feature = self.proj_layer[input_](proj_feature)433concat_inputs.append(proj_feature)434435v = layers.concatenate(concat_inputs)436v = self.grn_concat(v)437v = keras.ops.expand_dims(self.softmax(v), axis=-1)438x = []439for idx, input in enumerate(concat_inputs):440x.append(self.grns[idx](input))441x = keras.ops.stack(x, axis=1)442return keras.ops.squeeze(443keras.ops.matmul(keras.ops.transpose(v, axes=[0, 2, 1]), x), axis=1444)445446# to remove the build warnings447def build(self):448self.built = True449450451"""452## Create Gated Residual and Variable Selection Networks model453"""454455456def create_model(encoding_size):457inputs = create_model_inputs()458num_features = len(inputs)459features = VariableSelection(num_features, encoding_size, dropout_rate)(inputs)460outputs = layers.Dense(units=1, activation="sigmoid")(features)461# Functional model462model = keras.Model(inputs=inputs, outputs=outputs)463return model464465466"""467## Compile, train, and evaluate the model468"""469470learning_rate = 0.001471dropout_rate = 0.15472batch_size = 265473num_epochs = 20 # may be adjusted to a desired value474encoding_size = 16475476model = create_model(encoding_size)477model.compile(478optimizer=keras.optimizers.Adam(learning_rate=learning_rate),479loss=keras.losses.BinaryCrossentropy(),480metrics=[keras.metrics.BinaryAccuracy(name="accuracy")],481)482483"""484Let's visualize our connectivity graph:485"""486487# `rankdir='LR'` is to make the graph horizontal.488keras.utils.plot_model(model, show_shapes=True, show_layer_names=True, rankdir="LR")489490491# Create an early stopping callback.492early_stopping = keras.callbacks.EarlyStopping(493monitor="val_loss", patience=5, restore_best_weights=True494)495496print("Start training the model...")497train_dataset = get_dataset_from_csv(498train_data_file, shuffle=True, batch_size=batch_size499)500valid_dataset = get_dataset_from_csv(valid_data_file, batch_size=batch_size)501model.fit(502train_dataset,503epochs=num_epochs,504validation_data=valid_dataset,505callbacks=[early_stopping],506)507print("Model training finished.")508509print("Evaluating model performance...")510test_dataset = get_dataset_from_csv(test_data_file, batch_size=batch_size)511_, accuracy = model.evaluate(test_dataset)512print(f"Test accuracy: {round(accuracy * 100, 2)}%")513514"""515You should achieve more than 95% accuracy on the test set.516517To increase the learning capacity of the model, you can try increasing the518`encoding_size` value, or stacking multiple GRN layers on top of the VSN layer.519This may require to also increase the `dropout_rate` value to avoid overfitting.520"""521522"""523**Example available on HuggingFace**524525| Trained Model | Demo |526| :--: | :--: |527| [](https://huggingface.co/keras-io/structured-data-classification-grn-vsn) | [](https://huggingface.co/spaces/keras-io/structured-data-classification-grn-vsn) |528"""529530531