Path: blob/master/examples/structured_data/structured_data_classification_from_scratch.py
3507 views
"""1Title: Structured data classification from scratch2Author: [fchollet](https://twitter.com/fchollet)3Date created: 2020/06/094Last modified: 2020/06/095Description: Binary classification of structured data including numerical and categorical features.6Accelerator: GPU7Made backend-agnostic by: [Humbulani Ndou](https://github.com/Humbulani1234)8"""910"""11## Introduction1213This example demonstrates how to do structured data classification, starting from a raw14CSV file. Our data includes both numerical and categorical features. We will use Keras15preprocessing layers to normalize the numerical features and vectorize the categorical16ones.1718Note that this example should be run with TensorFlow 2.5 or higher.1920### The dataset2122[Our dataset](https://archive.ics.uci.edu/ml/datasets/heart+Disease) is provided by the23Cleveland Clinic Foundation for Heart Disease.24It's a CSV file with 303 rows. Each row contains information about a patient (a25**sample**), and each column describes an attribute of the patient (a **feature**). We26use the features to predict whether a patient has a heart disease (**binary27classification**).2829Here's the description of each feature:3031Column| Description| Feature Type32------------|--------------------|----------------------33Age | Age in years | Numerical34Sex | (1 = male; 0 = female) | Categorical35CP | Chest pain type (0, 1, 2, 3, 4) | Categorical36Trestbpd | Resting blood pressure (in mm Hg on admission) | Numerical37Chol | Serum cholesterol in mg/dl | Numerical38FBS | fasting blood sugar in 120 mg/dl (1 = true; 0 = false) | Categorical39RestECG | Resting electrocardiogram results (0, 1, 2) | Categorical40Thalach | Maximum heart rate achieved | Numerical41Exang | Exercise induced angina (1 = yes; 0 = no) | Categorical42Oldpeak | ST depression induced by exercise relative to rest | Numerical43Slope | Slope of the peak exercise ST segment | Numerical44CA | Number of major vessels (0-3) colored by fluoroscopy | Both numerical & categorical45Thal | 3 = normal; 6 = fixed defect; 7 = reversible defect | Categorical46Target | Diagnosis of heart disease (1 = true; 0 = false) | Target47"""4849"""50## Setup51"""5253import os5455os.environ["KERAS_BACKEND"] = "torch" # or torch, or tensorflow5657import pandas as pd58import keras59from keras import layers6061"""62## Preparing the data6364Let's download the data and load it into a Pandas dataframe:65"""6667file_url = "http://storage.googleapis.com/download.tensorflow.org/data/heart.csv"68dataframe = pd.read_csv(file_url)6970"""71The dataset includes 303 samples with 14 columns per sample (13 features, plus the target72label):73"""7475dataframe.shape7677"""78Here's a preview of a few samples:79"""8081dataframe.head()8283"""84The last column, "target", indicates whether the patient has a heart disease (1) or not85(0).8687Let's split the data into a training and validation set:88"""8990val_dataframe = dataframe.sample(frac=0.2, random_state=1337)91train_dataframe = dataframe.drop(val_dataframe.index)9293print(94f"Using {len(train_dataframe)} samples for training "95f"and {len(val_dataframe)} for validation"96)979899"""100## Define dataset metadata101102Here, we define the metadata of the dataset that will be useful for reading and103parsing the data into input features, and encoding the input features with respect104to their types.105"""106107COLUMN_NAMES = [108"age",109"sex",110"cp",111"trestbps",112"chol",113"fbs",114"restecg",115"thalach",116"exang",117"oldpeak",118"slope",119"ca",120"thal",121"target",122]123# Target feature name.124TARGET_FEATURE_NAME = "target"125# Numeric feature names.126NUMERIC_FEATURE_NAMES = ["age", "trestbps", "thalach", "oldpeak", "slope", "chol"]127# Categorical features and their vocabulary lists.128# Note that we add 'v=' as a prefix to all categorical feature values to make129# sure that they are treated as strings.130131CATEGORICAL_FEATURES_WITH_VOCABULARY = {132feature_name: sorted(133[134# Integer categorcal must be int and string must be str135value if dataframe[feature_name].dtype == "int64" else str(value)136for value in list(dataframe[feature_name].unique())137]138)139for feature_name in COLUMN_NAMES140if feature_name not in list(NUMERIC_FEATURE_NAMES + [TARGET_FEATURE_NAME])141}142# All features names.143FEATURE_NAMES = NUMERIC_FEATURE_NAMES + list(144CATEGORICAL_FEATURES_WITH_VOCABULARY.keys()145)146147148"""149## Feature preprocessing with Keras layers150151152The following features are categorical features encoded as integers:153154- `sex`155- `cp`156- `fbs`157- `restecg`158- `exang`159- `ca`160161We will encode these features using **one-hot encoding**. We have two options162here:163164- Use `CategoryEncoding()`, which requires knowing the range of input values165and will error on input outside the range.166- Use `IntegerLookup()` which will build a lookup table for inputs and reserve167an output index for unkown input values.168169For this example, we want a simple solution that will handle out of range inputs170at inference, so we will use `IntegerLookup()`.171172We also have a categorical feature encoded as a string: `thal`. We will create an173index of all possible features and encode output using the `StringLookup()` layer.174175Finally, the following feature are continuous numerical features:176177- `age`178- `trestbps`179- `chol`180- `thalach`181- `oldpeak`182- `slope`183184For each of these features, we will use a `Normalization()` layer to make sure the mean185of each feature is 0 and its standard deviation is 1.186187Below, we define 2 utility functions to do the operations:188189- `encode_numerical_feature` to apply featurewise normalization to numerical features.190- `process` to one-hot encode string or integer categorical features.191"""192193# Tensorflow required for tf.data.Dataset194import tensorflow as tf195196197# We process our datasets elements here (categorical) and convert them to indices to avoid this step198# during model training since only tensorflow support strings.199def encode_categorical(features, target):200for feature_name in features:201if feature_name in CATEGORICAL_FEATURES_WITH_VOCABULARY:202lookup_class = (203layers.StringLookup204if features[feature_name].dtype == "string"205else layers.IntegerLookup206)207vocabulary = CATEGORICAL_FEATURES_WITH_VOCABULARY[feature_name]208# Create a lookup to convert a string values to an integer indices.209# Since we are not using a mask token nor expecting any out of vocabulary210# (oov) token, we set mask_token to None and num_oov_indices to 0.211index = lookup_class(212vocabulary=vocabulary,213mask_token=None,214num_oov_indices=0,215output_mode="binary",216)217# Convert the string input values into integer indices.218value_index = index(features[feature_name])219features[feature_name] = value_index220221else:222pass223224# Change features from OrderedDict to Dict to match Inputs as they are Dict.225return dict(features), target226227228def encode_numerical_feature(feature, name, dataset):229# Create a Normalization layer for our feature230normalizer = layers.Normalization()231# Prepare a Dataset that only yields our feature232feature_ds = dataset.map(lambda x, y: x[name])233feature_ds = feature_ds.map(lambda x: tf.expand_dims(x, -1))234# Learn the statistics of the data235normalizer.adapt(feature_ds)236# Normalize the input feature237encoded_feature = normalizer(feature)238return encoded_feature239240241"""242Let's generate `tf.data.Dataset` objects for each dataframe:243"""244245246def dataframe_to_dataset(dataframe):247dataframe = dataframe.copy()248labels = dataframe.pop("target")249ds = tf.data.Dataset.from_tensor_slices((dict(dataframe), labels)).map(250encode_categorical251)252ds = ds.shuffle(buffer_size=len(dataframe))253return ds254255256train_ds = dataframe_to_dataset(train_dataframe)257val_ds = dataframe_to_dataset(val_dataframe)258259"""260Each `Dataset` yields a tuple `(input, target)` where `input` is a dictionary of features261and `target` is the value `0` or `1`:262"""263264for x, y in train_ds.take(1):265print("Input:", x)266print("Target:", y)267268"""269Let's batch the datasets:270"""271272train_ds = train_ds.batch(32)273val_ds = val_ds.batch(32)274275276"""277## Build a model278279With this done, we can create our end-to-end model:280"""281282283# Categorical features have different shapes after the encoding, dependent on the284# vocabulary or unique values of each feature. We create them accordinly to match the285# input data elements generated by tf.data.Dataset after pre-processing them286def create_model_inputs():287inputs = {}288289# This a helper function for creating categorical features290def create_input_helper(feature_name):291num_categories = len(CATEGORICAL_FEATURES_WITH_VOCABULARY[feature_name])292inputs[feature_name] = layers.Input(293name=feature_name, shape=(num_categories,), dtype="int64"294)295return inputs296297for feature_name in FEATURE_NAMES:298if feature_name in CATEGORICAL_FEATURES_WITH_VOCABULARY:299# Categorical features300create_input_helper(feature_name)301else:302# Make them float32, they are Real numbers303feature_input = layers.Input(name=feature_name, shape=(1,), dtype="float32")304# Process the Inputs here305inputs[feature_name] = encode_numerical_feature(306feature_input, feature_name, train_ds307)308return inputs309310311# This Layer defines the logic of the Model to perform the classification312class Classifier(keras.layers.Layer):313314def __init__(self, **kwargs):315super().__init__(**kwargs)316self.dense_1 = layers.Dense(32, activation="relu")317self.dropout = layers.Dropout(0.5)318self.dense_2 = layers.Dense(1, activation="sigmoid")319320def call(self, inputs):321all_features = layers.concatenate(list(inputs.values()))322x = self.dense_1(all_features)323x = self.dropout(x)324output = self.dense_2(x)325return output326327# Surpress build warnings328def build(self, input_shape):329self.built = True330331332# Create the Classifier model333def create_model():334all_inputs = create_model_inputs()335output = Classifier()(all_inputs)336model = keras.Model(all_inputs, output)337return model338339340model = create_model()341model.compile("adam", "binary_crossentropy", metrics=["accuracy"])342343"""344Let's visualize our connectivity graph:345"""346347# `rankdir='LR'` is to make the graph horizontal.348keras.utils.plot_model(model, show_shapes=True, rankdir="LR")349350"""351## Train the model352"""353354model.fit(train_ds, epochs=50, validation_data=val_ds)355356357"""358We quickly get to 80% validation accuracy.359"""360361"""362## Inference on new data363364To get a prediction for a new sample, you can simply call `model.predict()`. There are365just two things you need to do:3663671. wrap scalars into a list so as to have a batch dimension (models only process batches368of data, not single samples)3692. Call `convert_to_tensor` on each feature370"""371372sample = {373"age": 60,374"sex": 1,375"cp": 1,376"trestbps": 145,377"chol": 233,378"fbs": 1,379"restecg": 2,380"thalach": 150,381"exang": 0,382"oldpeak": 2.3,383"slope": 3,384"ca": 0,385"thal": "fixed",386}387388389# Given the category (in the sample above - key) and the category value (in the sample above - value),390# we return its one-hot encoding391def get_cat_encoding(cat, cat_value):392# Create a list of zeros with the same length as categories393encoding = [0] * len(cat)394# Find the index of category_value in categories and set the corresponding position to 1395if cat_value in cat:396encoding[cat.index(cat_value)] = 1397return encoding398399400for name, value in sample.items():401if name in CATEGORICAL_FEATURES_WITH_VOCABULARY:402sample.update(403{404name: get_cat_encoding(405CATEGORICAL_FEATURES_WITH_VOCABULARY[name], sample[name]406)407}408)409# Convert inputs to tensors410input_dict = {name: tf.convert_to_tensor([value]) for name, value in sample.items()}411predictions = model.predict(input_dict)412413print(414f"This particular patient had a {100 * predictions[0][0]:.1f} "415"percent probability of having a heart disease, "416"as evaluated by our model."417)418419"""420## Conclusions421422- The orignal model (the one that runs only on tensorflow) converges quickly to around 80% and remains423there for extended periods and at times hits 85%424- The updated model (the backed-agnostic) model may fluctuate between 78% and 83% and at times hitting 86%425validation accuracy and converges around 80% also.426427"""428429430