Path: blob/master/examples/structured_data/wide_deep_cross_networks.py
3507 views
"""1Title: Structured data learning with Wide, Deep, and Cross networks2Author: [Khalid Salama](https://www.linkedin.com/in/khalid-salama-24403144/)3Date created: 2020/12/314Last modified: 2025/01/035Description: Using Wide & Deep and Deep & Cross networks for structured data classification.6Accelerator: GPU7"""89"""10## Introduction1112This example demonstrates how to do structured data classification using the two modeling13techniques:14151. [Wide & Deep](https://ai.googleblog.com/2016/06/wide-deep-learning-better-together-with.html) models162. [Deep & Cross](https://arxiv.org/abs/1708.05123) models1718Note that this example should be run with TensorFlow 2.5 or higher.19"""2021"""22## The dataset2324This example uses the [Covertype](https://archive.ics.uci.edu/ml/datasets/covertype) dataset from the UCI25Machine Learning Repository. The task is to predict forest cover type from cartographic variables.26The dataset includes 506,011 instances with 12 input features: 10 numerical features and 227categorical features. Each instance is categorized into 1 of 7 classes.28"""2930"""31## Setup32"""3334import os3536# Only the TensorFlow backend supports string inputs.37os.environ["KERAS_BACKEND"] = "tensorflow"3839import math40import numpy as np41import pandas as pd42from tensorflow import data as tf_data43import keras44from keras import layers4546"""47## Prepare the data4849First, let's load the dataset from the UCI Machine Learning Repository into a Pandas50DataFrame:51"""5253data_url = (54"https://archive.ics.uci.edu/ml/machine-learning-databases/covtype/covtype.data.gz"55)56raw_data = pd.read_csv(data_url, header=None)57print(f"Dataset shape: {raw_data.shape}")58raw_data.head()5960"""61The two categorical features in the dataset are binary-encoded.62We will convert this dataset representation to the typical representation, where each63categorical feature is represented as a single integer value.64"""6566soil_type_values = [f"soil_type_{idx+1}" for idx in range(40)]67wilderness_area_values = [f"area_type_{idx+1}" for idx in range(4)]6869soil_type = raw_data.loc[:, 14:53].apply(70lambda x: soil_type_values[0::1][x.to_numpy().nonzero()[0][0]], axis=171)72wilderness_area = raw_data.loc[:, 10:13].apply(73lambda x: wilderness_area_values[0::1][x.to_numpy().nonzero()[0][0]], axis=174)7576CSV_HEADER = [77"Elevation",78"Aspect",79"Slope",80"Horizontal_Distance_To_Hydrology",81"Vertical_Distance_To_Hydrology",82"Horizontal_Distance_To_Roadways",83"Hillshade_9am",84"Hillshade_Noon",85"Hillshade_3pm",86"Horizontal_Distance_To_Fire_Points",87"Wilderness_Area",88"Soil_Type",89"Cover_Type",90]9192data = pd.concat(93[raw_data.loc[:, 0:9], wilderness_area, soil_type, raw_data.loc[:, 54]],94axis=1,95ignore_index=True,96)97data.columns = CSV_HEADER9899# Convert the target label indices into a range from 0 to 6 (there are 7 labels in total).100data["Cover_Type"] = data["Cover_Type"] - 1101102print(f"Dataset shape: {data.shape}")103data.head().T104105"""106The shape of the DataFrame shows there are 13 columns per sample107(12 for the features and 1 for the target label).108109Let's split the data into training (85%) and test (15%) sets.110"""111112train_splits = []113test_splits = []114115for _, group_data in data.groupby("Cover_Type"):116random_selection = np.random.rand(len(group_data.index)) <= 0.85117train_splits.append(group_data[random_selection])118test_splits.append(group_data[~random_selection])119120train_data = pd.concat(train_splits).sample(frac=1).reset_index(drop=True)121test_data = pd.concat(test_splits).sample(frac=1).reset_index(drop=True)122123print(f"Train split size: {len(train_data.index)}")124print(f"Test split size: {len(test_data.index)}")125126"""127Next, store the training and test data in separate CSV files.128"""129130train_data_file = "train_data.csv"131test_data_file = "test_data.csv"132133train_data.to_csv(train_data_file, index=False)134test_data.to_csv(test_data_file, index=False)135136"""137## Define dataset metadata138139Here, we define the metadata of the dataset that will be useful for reading and parsing140the data into input features, and encoding the input features with respect to their types.141"""142143TARGET_FEATURE_NAME = "Cover_Type"144145TARGET_FEATURE_LABELS = ["0", "1", "2", "3", "4", "5", "6"]146147NUMERIC_FEATURE_NAMES = [148"Aspect",149"Elevation",150"Hillshade_3pm",151"Hillshade_9am",152"Hillshade_Noon",153"Horizontal_Distance_To_Fire_Points",154"Horizontal_Distance_To_Hydrology",155"Horizontal_Distance_To_Roadways",156"Slope",157"Vertical_Distance_To_Hydrology",158]159160CATEGORICAL_FEATURES_WITH_VOCABULARY = {161"Soil_Type": list(data["Soil_Type"].unique()),162"Wilderness_Area": list(data["Wilderness_Area"].unique()),163}164165CATEGORICAL_FEATURE_NAMES = list(CATEGORICAL_FEATURES_WITH_VOCABULARY.keys())166167FEATURE_NAMES = NUMERIC_FEATURE_NAMES + CATEGORICAL_FEATURE_NAMES168169COLUMN_DEFAULTS = [170[0] if feature_name in NUMERIC_FEATURE_NAMES + [TARGET_FEATURE_NAME] else ["NA"]171for feature_name in CSV_HEADER172]173174NUM_CLASSES = len(TARGET_FEATURE_LABELS)175176"""177## Experiment setup178179Next, let's define an input function that reads and parses the file, then converts features180and labels into a[`tf.data.Dataset`](https://www.tensorflow.org/guide/datasets)181for training or evaluation.182"""183184185# To convert the datasets elements to from OrderedDict to Dictionary186def process(features, target):187return dict(features), target188189190def get_dataset_from_csv(csv_file_path, batch_size, shuffle=False):191dataset = tf_data.experimental.make_csv_dataset(192csv_file_path,193batch_size=batch_size,194column_names=CSV_HEADER,195column_defaults=COLUMN_DEFAULTS,196label_name=TARGET_FEATURE_NAME,197num_epochs=1,198header=True,199shuffle=shuffle,200).map(process)201return dataset.cache()202203204"""205Here we configure the parameters and implement the procedure for running a training and206evaluation experiment given a model.207"""208209learning_rate = 0.001210dropout_rate = 0.1211batch_size = 265212num_epochs = 1213214hidden_units = [32, 32]215216217def run_experiment(model):218model.compile(219optimizer=keras.optimizers.Adam(learning_rate=learning_rate),220loss=keras.losses.SparseCategoricalCrossentropy(),221metrics=[keras.metrics.SparseCategoricalAccuracy()],222)223224train_dataset = get_dataset_from_csv(train_data_file, batch_size, shuffle=True)225226test_dataset = get_dataset_from_csv(test_data_file, batch_size)227228print("Start training the model...")229history = model.fit(train_dataset, epochs=num_epochs)230print("Model training finished")231232_, accuracy = model.evaluate(test_dataset, verbose=0)233234print(f"Test accuracy: {round(accuracy * 100, 2)}%")235236237"""238## Create model inputs239240Now, define the inputs for the models as a dictionary, where the key is the feature name,241and the value is a `keras.layers.Input` tensor with the corresponding feature shape242and data type.243"""244245246def create_model_inputs():247inputs = {}248for feature_name in FEATURE_NAMES:249if feature_name in NUMERIC_FEATURE_NAMES:250inputs[feature_name] = layers.Input(251name=feature_name, shape=(), dtype="float32"252)253else:254inputs[feature_name] = layers.Input(255name=feature_name, shape=(), dtype="string"256)257return inputs258259260"""261## Encode features262263We create two representations of our input features: sparse and dense:2641. In the **sparse** representation, the categorical features are encoded with one-hot265encoding using the `CategoryEncoding` layer. This representation can be useful for the266model to *memorize* particular feature values to make certain predictions.2672. In the **dense** representation, the categorical features are encoded with268low-dimensional embeddings using the `Embedding` layer. This representation helps269the model to *generalize* well to unseen feature combinations.270"""271272273def encode_inputs(inputs, use_embedding=False):274encoded_features = []275for feature_name in inputs:276if feature_name in CATEGORICAL_FEATURE_NAMES:277vocabulary = CATEGORICAL_FEATURES_WITH_VOCABULARY[feature_name]278# Create a lookup to convert string values to an integer indices.279# Since we are not using a mask token nor expecting any out of vocabulary280# (oov) token, we set mask_token to None and num_oov_indices to 0.281lookup = layers.StringLookup(282vocabulary=vocabulary,283mask_token=None,284num_oov_indices=0,285output_mode="int" if use_embedding else "binary",286)287if use_embedding:288# Convert the string input values into integer indices.289encoded_feature = lookup(inputs[feature_name])290embedding_dims = int(math.sqrt(len(vocabulary)))291# Create an embedding layer with the specified dimensions.292embedding = layers.Embedding(293input_dim=len(vocabulary), output_dim=embedding_dims294)295# Convert the index values to embedding representations.296encoded_feature = embedding(encoded_feature)297else:298# Convert the string input values into a one hot encoding.299encoded_feature = lookup(300keras.ops.expand_dims(inputs[feature_name], -1)301)302else:303# Use the numerical features as-is.304encoded_feature = keras.ops.expand_dims(inputs[feature_name], -1)305306encoded_features.append(encoded_feature)307308all_features = layers.concatenate(encoded_features)309return all_features310311312"""313## Experiment 1: a baseline model314315In the first experiment, let's create a multi-layer feed-forward network,316where the categorical features are one-hot encoded.317"""318319320def create_baseline_model():321inputs = create_model_inputs()322features = encode_inputs(inputs)323324for units in hidden_units:325features = layers.Dense(units)(features)326features = layers.BatchNormalization()(features)327features = layers.ReLU()(features)328features = layers.Dropout(dropout_rate)(features)329330outputs = layers.Dense(units=NUM_CLASSES, activation="softmax")(features)331model = keras.Model(inputs=inputs, outputs=outputs)332return model333334335baseline_model = create_baseline_model()336keras.utils.plot_model(baseline_model, show_shapes=True, rankdir="LR")337338"""339Let's run it:340"""341342run_experiment(baseline_model)343344"""345The baseline linear model achieves ~76% test accuracy.346"""347348"""349## Experiment 2: Wide & Deep model350351In the second experiment, we create a Wide & Deep model. The wide part of the model352a linear model, while the deep part of the model is a multi-layer feed-forward network.353354Use the sparse representation of the input features in the wide part of the model and the355dense representation of the input features for the deep part of the model.356357Note that every input features contributes to both parts of the model with different358representations.359"""360361362def create_wide_and_deep_model():363inputs = create_model_inputs()364wide = encode_inputs(inputs)365wide = layers.BatchNormalization()(wide)366367deep = encode_inputs(inputs, use_embedding=True)368for units in hidden_units:369deep = layers.Dense(units)(deep)370deep = layers.BatchNormalization()(deep)371deep = layers.ReLU()(deep)372deep = layers.Dropout(dropout_rate)(deep)373374merged = layers.concatenate([wide, deep])375outputs = layers.Dense(units=NUM_CLASSES, activation="softmax")(merged)376model = keras.Model(inputs=inputs, outputs=outputs)377return model378379380wide_and_deep_model = create_wide_and_deep_model()381keras.utils.plot_model(wide_and_deep_model, show_shapes=True, rankdir="LR")382383"""384Let's run it:385"""386387run_experiment(wide_and_deep_model)388389"""390The wide and deep model achieves ~79% test accuracy.391"""392393"""394## Experiment 3: Deep & Cross model395396In the third experiment, we create a Deep & Cross model. The deep part of this model397is the same as the deep part created in the previous experiment. The key idea of398the cross part is to apply explicit feature crossing in an efficient way,399where the degree of cross features grows with layer depth.400"""401402403def create_deep_and_cross_model():404inputs = create_model_inputs()405x0 = encode_inputs(inputs, use_embedding=True)406407cross = x0408for _ in hidden_units:409units = cross.shape[-1]410x = layers.Dense(units)(cross)411cross = x0 * x + cross412cross = layers.BatchNormalization()(cross)413414deep = x0415for units in hidden_units:416deep = layers.Dense(units)(deep)417deep = layers.BatchNormalization()(deep)418deep = layers.ReLU()(deep)419deep = layers.Dropout(dropout_rate)(deep)420421merged = layers.concatenate([cross, deep])422outputs = layers.Dense(units=NUM_CLASSES, activation="softmax")(merged)423model = keras.Model(inputs=inputs, outputs=outputs)424return model425426427deep_and_cross_model = create_deep_and_cross_model()428keras.utils.plot_model(deep_and_cross_model, show_shapes=True, rankdir="LR")429430"""431Let's run it:432"""433434run_experiment(deep_and_cross_model)435436"""437The deep and cross model achieves ~81% test accuracy.438"""439440"""441## Conclusion442443You can use Keras Preprocessing Layers to easily handle categorical features444with different encoding mechanisms, including one-hot encoding and feature embedding.445In addition, different model architectures ā like wide, deep, and cross networks446ā have different advantages, with respect to different dataset properties.447You can explore using them independently or combining them to achieve the best result448for your dataset.449"""450451452