Path: blob/master/examples/structured_data/structured_data_classification_with_feature_space.py
3507 views
"""1Title: Structured data classification with FeatureSpace2Author: [fchollet](https://twitter.com/fchollet)3Date created: 2022/11/094Last modified: 2022/11/095Description: Classify tabular data in a few lines of code.6Accelerator: GPU7"""89"""10## Introduction1112This example demonstrates how to do structured data classification13(also known as tabular data classification), starting from a raw14CSV file. Our data includes numerical features,15and integer categorical features, and string categorical features.16We will use the utility `keras.utils.FeatureSpace` to index,17preprocess, and encode our features.1819The code is adapted from the example20[Structured data classification from scratch](https://keras.io/examples/structured_data/structured_data_classification_from_scratch/).21While the previous example managed its own low-level feature preprocessing and22encoding with Keras preprocessing layers, in this example we23delegate everything to `FeatureSpace`, making the workflow24extremely quick and easy.2526### The dataset2728[Our dataset](https://archive.ics.uci.edu/ml/datasets/heart+Disease) is provided by the29Cleveland Clinic Foundation for Heart Disease.30It's a CSV file with 303 rows. Each row contains information about a patient (a31**sample**), and each column describes an attribute of the patient (a **feature**). We32use the features to predict whether a patient has a heart disease33(**binary classification**).3435Here's the description of each feature:3637Column| Description| Feature Type38------------|--------------------|----------------------39Age | Age in years | Numerical40Sex | (1 = male; 0 = female) | Categorical41CP | Chest pain type (0, 1, 2, 3, 4) | Categorical42Trestbpd | Resting blood pressure (in mm Hg on admission) | Numerical43Chol | Serum cholesterol in mg/dl | Numerical44FBS | fasting blood sugar in 120 mg/dl (1 = true; 0 = false) | Categorical45RestECG | Resting electrocardiogram results (0, 1, 2) | Categorical46Thalach | Maximum heart rate achieved | Numerical47Exang | Exercise induced angina (1 = yes; 0 = no) | Categorical48Oldpeak | ST depression induced by exercise relative to rest | Numerical49Slope | Slope of the peak exercise ST segment | Numerical50CA | Number of major vessels (0-3) colored by fluoroscopy | Both numerical & categorical51Thal | 3 = normal; 6 = fixed defect; 7 = reversible defect | Categorical52Target | Diagnosis of heart disease (1 = true; 0 = false) | Target53"""5455"""56## Setup57"""5859import os6061os.environ["KERAS_BACKEND"] = "tensorflow"6263import tensorflow as tf64import pandas as pd65import keras66from keras.utils import FeatureSpace6768"""69## Preparing the data7071Let's download the data and load it into a Pandas dataframe:72"""7374file_url = "http://storage.googleapis.com/download.tensorflow.org/data/heart.csv"75dataframe = pd.read_csv(file_url)7677"""78The dataset includes 303 samples with 14 columns per sample79(13 features, plus the target label):80"""8182print(dataframe.shape)8384"""85Here's a preview of a few samples:86"""8788dataframe.head()8990"""91The last column, "target", indicates whether the patient92has a heart disease (1) or not (0).9394Let's split the data into a training and validation set:95"""9697val_dataframe = dataframe.sample(frac=0.2, random_state=1337)98train_dataframe = dataframe.drop(val_dataframe.index)99100print(101"Using %d samples for training and %d for validation"102% (len(train_dataframe), len(val_dataframe))103)104105"""106Let's generate `tf.data.Dataset` objects for each dataframe:107"""108109110def dataframe_to_dataset(dataframe):111dataframe = dataframe.copy()112labels = dataframe.pop("target")113ds = tf.data.Dataset.from_tensor_slices((dict(dataframe), labels))114ds = ds.shuffle(buffer_size=len(dataframe))115return ds116117118train_ds = dataframe_to_dataset(train_dataframe)119val_ds = dataframe_to_dataset(val_dataframe)120121"""122Each `Dataset` yields a tuple `(input, target)` where `input` is a dictionary of features123and `target` is the value `0` or `1`:124"""125126for x, y in train_ds.take(1):127print("Input:", x)128print("Target:", y)129130"""131Let's batch the datasets:132"""133134train_ds = train_ds.batch(32)135val_ds = val_ds.batch(32)136137"""138## Configuring a `FeatureSpace`139140To configure how each feature should be preprocessed,141we instantiate a `keras.utils.FeatureSpace`, and we142pass to it a dictionary that maps the name of our features143to a string that describes the feature type.144145We have a few "integer categorical" features such as `"FBS"`,146one "string categorical" feature (`"thal"`),147and a few numerical features, which we'd like to normalize148-- except `"age"`, which we'd like to discretize into149a number of bins.150151We also use the `crosses` argument152to capture *feature interactions* for some categorical153features, that is to say, create additional features154that represent value co-occurrences for these categorical features.155You can compute feature crosses like this for arbitrary sets of156categorical features -- not just tuples of two features.157Because the resulting co-occurences are hashed158into a fixed-sized vector, you don't need to worry about whether159the co-occurence space is too large.160"""161162feature_space = FeatureSpace(163features={164# Categorical features encoded as integers165"sex": "integer_categorical",166"cp": "integer_categorical",167"fbs": "integer_categorical",168"restecg": "integer_categorical",169"exang": "integer_categorical",170"ca": "integer_categorical",171# Categorical feature encoded as string172"thal": "string_categorical",173# Numerical features to discretize174"age": "float_discretized",175# Numerical features to normalize176"trestbps": "float_normalized",177"chol": "float_normalized",178"thalach": "float_normalized",179"oldpeak": "float_normalized",180"slope": "float_normalized",181},182# We create additional features by hashing183# value co-occurrences for the184# following groups of categorical features.185crosses=[("sex", "age"), ("thal", "ca")],186# The hashing space for these co-occurrences187# wil be 32-dimensional.188crossing_dim=32,189# Our utility will one-hot encode all categorical190# features and concat all features into a single191# vector (one vector per sample).192output_mode="concat",193)194195"""196## Further customizing a `FeatureSpace`197198Specifying the feature type via a string name is quick and easy,199but sometimes you may want to further configure the preprocessing200of each feature. For instance, in our case, our categorical201features don't have a large set of possible values -- it's only202a handful of values per feature (e.g. `1` and `0` for the feature `"FBS"`),203and all possible values are represented in the training set.204As a result, we don't need to reserve an index to represent "out of vocabulary" values205for these features -- which would have been the default behavior.206Below, we just specify `num_oov_indices=0` in each of these features207to tell the feature preprocessor to skip "out of vocabulary" indexing.208209Other customizations you have access to include specifying the number of210bins for discretizing features of type `"float_discretized"`,211or the dimensionality of the hashing space for feature crossing.212"""213214feature_space = FeatureSpace(215features={216# Categorical features encoded as integers217"sex": FeatureSpace.integer_categorical(num_oov_indices=0),218"cp": FeatureSpace.integer_categorical(num_oov_indices=0),219"fbs": FeatureSpace.integer_categorical(num_oov_indices=0),220"restecg": FeatureSpace.integer_categorical(num_oov_indices=0),221"exang": FeatureSpace.integer_categorical(num_oov_indices=0),222"ca": FeatureSpace.integer_categorical(num_oov_indices=0),223# Categorical feature encoded as string224"thal": FeatureSpace.string_categorical(num_oov_indices=0),225# Numerical features to discretize226"age": FeatureSpace.float_discretized(num_bins=30),227# Numerical features to normalize228"trestbps": FeatureSpace.float_normalized(),229"chol": FeatureSpace.float_normalized(),230"thalach": FeatureSpace.float_normalized(),231"oldpeak": FeatureSpace.float_normalized(),232"slope": FeatureSpace.float_normalized(),233},234# Specify feature cross with a custom crossing dim.235crosses=[236FeatureSpace.cross(feature_names=("sex", "age"), crossing_dim=64),237FeatureSpace.cross(238feature_names=("thal", "ca"),239crossing_dim=16,240),241],242output_mode="concat",243)244245"""246## Adapt the `FeatureSpace` to the training data247248Before we start using the `FeatureSpace` to build a model, we have249to adapt it to the training data. During `adapt()`, the `FeatureSpace` will:250251- Index the set of possible values for categorical features.252- Compute the mean and variance for numerical features to normalize.253- Compute the value boundaries for the different bins for numerical features to discretize.254255Note that `adapt()` should be called on a `tf.data.Dataset` which yields dicts256of feature values -- no labels.257"""258259train_ds_with_no_labels = train_ds.map(lambda x, _: x)260feature_space.adapt(train_ds_with_no_labels)261262"""263At this point, the `FeatureSpace` can be called on a dict of raw feature values, and will return a264single concatenate vector for each sample, combining encoded features and feature crosses.265"""266267for x, _ in train_ds.take(1):268preprocessed_x = feature_space(x)269print("preprocessed_x.shape:", preprocessed_x.shape)270print("preprocessed_x.dtype:", preprocessed_x.dtype)271272"""273## Two ways to manage preprocessing: as part of the `tf.data` pipeline, or in the model itself274275There are two ways in which you can leverage your `FeatureSpace`:276277### Asynchronous preprocessing in `tf.data`278279You can make it part of your data pipeline, before the model. This enables asynchronous parallel280preprocessing of the data on CPU before it hits the model. Do this if you're training on GPU or TPU,281or if you want to speed up preprocessing. Usually, this is always the right thing to do during training.282283### Synchronous preprocessing in the model284285You can make it part of your model. This means that the model will expect dicts of raw feature286values, and the preprocessing batch will be done synchronously (in a blocking manner) before the287rest of the forward pass. Do this if you want to have an end-to-end model that can process288raw feature values -- but keep in mind that your model will only be able to run on CPU,289since most types of feature preprocessing (e.g. string preprocessing) are not GPU or TPU compatible.290291Do not do this on GPU / TPU or in performance-sensitive settings. In general, you want to do in-model292preprocessing when you do inference on CPU.293294In our case, we will apply the `FeatureSpace` in the tf.data pipeline during training, but we will295do inference with an end-to-end model that includes the `FeatureSpace`.296"""297298"""299Let's create a training and validation dataset of preprocessed batches:300"""301302preprocessed_train_ds = train_ds.map(303lambda x, y: (feature_space(x), y), num_parallel_calls=tf.data.AUTOTUNE304)305preprocessed_train_ds = preprocessed_train_ds.prefetch(tf.data.AUTOTUNE)306307preprocessed_val_ds = val_ds.map(308lambda x, y: (feature_space(x), y), num_parallel_calls=tf.data.AUTOTUNE309)310preprocessed_val_ds = preprocessed_val_ds.prefetch(tf.data.AUTOTUNE)311312"""313## Build a model314315Time to build a model -- or rather two models:316317- A training model that expects preprocessed features (one sample = one vector)318- An inference model that expects raw features (one sample = dict of raw feature values)319"""320321dict_inputs = feature_space.get_inputs()322encoded_features = feature_space.get_encoded_features()323324x = keras.layers.Dense(32, activation="relu")(encoded_features)325x = keras.layers.Dropout(0.5)(x)326predictions = keras.layers.Dense(1, activation="sigmoid")(x)327328training_model = keras.Model(inputs=encoded_features, outputs=predictions)329training_model.compile(330optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"]331)332333inference_model = keras.Model(inputs=dict_inputs, outputs=predictions)334335"""336## Train the model337338Let's train our model for 50 epochs. Note that feature preprocessing is happening339as part of the tf.data pipeline, not as part of the model.340"""341342training_model.fit(343preprocessed_train_ds,344epochs=20,345validation_data=preprocessed_val_ds,346verbose=2,347)348349"""350We quickly get to 80% validation accuracy.351"""352353"""354## Inference on new data with the end-to-end model355356Now, we can use our inference model (which includes the `FeatureSpace`)357to make predictions based on dicts of raw features values, as follows:358"""359360sample = {361"age": 60,362"sex": 1,363"cp": 1,364"trestbps": 145,365"chol": 233,366"fbs": 1,367"restecg": 2,368"thalach": 150,369"exang": 0,370"oldpeak": 2.3,371"slope": 3,372"ca": 0,373"thal": "fixed",374}375376input_dict = {name: tf.convert_to_tensor([value]) for name, value in sample.items()}377predictions = inference_model.predict(input_dict)378379print(380f"This particular patient had a {100 * predictions[0][0]:.2f}% probability "381"of having a heart disease, as evaluated by our model."382)383384385