Path: blob/master/examples/structured_data/ipynb/structured_data_classification_from_scratch.ipynb
3508 views
Structured data classification from scratch
Author: fchollet
Date created: 2020/06/09
Last modified: 2020/06/09
Description: Binary classification of structured data including numerical and categorical features.
Introduction
This example demonstrates how to do structured data classification, starting from a raw CSV file. Our data includes both numerical and categorical features. We will use Keras preprocessing layers to normalize the numerical features and vectorize the categorical ones.
Note that this example should be run with TensorFlow 2.5 or higher.
The dataset
Our dataset is provided by the Cleveland Clinic Foundation for Heart Disease. It's a CSV file with 303 rows. Each row contains information about a patient (a sample), and each column describes an attribute of the patient (a feature). We use the features to predict whether a patient has a heart disease (binary classification).
Here's the description of each feature:
Column | Description | Feature Type |
---|---|---|
Age | Age in years | Numerical |
Sex | (1 = male; 0 = female) | Categorical |
CP | Chest pain type (0, 1, 2, 3, 4) | Categorical |
Trestbpd | Resting blood pressure (in mm Hg on admission) | Numerical |
Chol | Serum cholesterol in mg/dl | Numerical |
FBS | fasting blood sugar in 120 mg/dl (1 = true; 0 = false) | Categorical |
RestECG | Resting electrocardiogram results (0, 1, 2) | Categorical |
Thalach | Maximum heart rate achieved | Numerical |
Exang | Exercise induced angina (1 = yes; 0 = no) | Categorical |
Oldpeak | ST depression induced by exercise relative to rest | Numerical |
Slope | Slope of the peak exercise ST segment | Numerical |
CA | Number of major vessels (0-3) colored by fluoroscopy | Both numerical & categorical |
Thal | 3 = normal; 6 = fixed defect; 7 = reversible defect | Categorical |
Target | Diagnosis of heart disease (1 = true; 0 = false) | Target |
Setup
Preparing the data
Let's download the data and load it into a Pandas dataframe:
The dataset includes 303 samples with 14 columns per sample (13 features, plus the target label):
Here's a preview of a few samples:
The last column, "target", indicates whether the patient has a heart disease (1) or not (0).
Let's split the data into a training and validation set:
Define dataset metadata
Here, we define the metadata of the dataset that will be useful for reading and parsing the data into input features, and encoding the input features with respect to their types.
Feature preprocessing with Keras layers
The following features are categorical features encoded as integers:
sex
cp
fbs
restecg
exang
ca
We will encode these features using one-hot encoding. We have two options here:
Use
CategoryEncoding()
, which requires knowing the range of input values and will error on input outside the range.Use
IntegerLookup()
which will build a lookup table for inputs and reserve an output index for unkown input values.
For this example, we want a simple solution that will handle out of range inputs at inference, so we will use IntegerLookup()
.
We also have a categorical feature encoded as a string: thal
. We will create an index of all possible features and encode output using the StringLookup()
layer.
Finally, the following feature are continuous numerical features:
age
trestbps
chol
thalach
oldpeak
slope
For each of these features, we will use a Normalization()
layer to make sure the mean of each feature is 0 and its standard deviation is 1.
Below, we define 2 utility functions to do the operations:
encode_numerical_feature
to apply featurewise normalization to numerical features.process
to one-hot encode string or integer categorical features.
Let's generate tf.data.Dataset
objects for each dataframe:
Each Dataset
yields a tuple (input, target)
where input
is a dictionary of features and target
is the value 0
or 1
:
Let's batch the datasets:
Build a model
With this done, we can create our end-to-end model:
Let's visualize our connectivity graph:
Train the model
We quickly get to 80% validation accuracy.
Inference on new data
To get a prediction for a new sample, you can simply call model.predict()
. There are just two things you need to do:
wrap scalars into a list so as to have a batch dimension (models only process batches of data, not single samples)
Call
convert_to_tensor
on each feature
Conclusions
The orignal model (the one that runs only on tensorflow) converges quickly to around 80% and remains there for extended periods and at times hits 85%
The updated model (the backed-agnostic) model may fluctuate between 78% and 83% and at times hitting 86% validation accuracy and converges around 80% also.