Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/structured_data/structured_data_classification_from_scratch.py
3507 views
1
"""
2
Title: Structured data classification from scratch
3
Author: [fchollet](https://twitter.com/fchollet)
4
Date created: 2020/06/09
5
Last modified: 2020/06/09
6
Description: Binary classification of structured data including numerical and categorical features.
7
Accelerator: GPU
8
Made backend-agnostic by: [Humbulani Ndou](https://github.com/Humbulani1234)
9
"""
10
11
"""
12
## Introduction
13
14
This example demonstrates how to do structured data classification, starting from a raw
15
CSV file. Our data includes both numerical and categorical features. We will use Keras
16
preprocessing layers to normalize the numerical features and vectorize the categorical
17
ones.
18
19
Note that this example should be run with TensorFlow 2.5 or higher.
20
21
### The dataset
22
23
[Our dataset](https://archive.ics.uci.edu/ml/datasets/heart+Disease) is provided by the
24
Cleveland Clinic Foundation for Heart Disease.
25
It's a CSV file with 303 rows. Each row contains information about a patient (a
26
**sample**), and each column describes an attribute of the patient (a **feature**). We
27
use the features to predict whether a patient has a heart disease (**binary
28
classification**).
29
30
Here's the description of each feature:
31
32
Column| Description| Feature Type
33
------------|--------------------|----------------------
34
Age | Age in years | Numerical
35
Sex | (1 = male; 0 = female) | Categorical
36
CP | Chest pain type (0, 1, 2, 3, 4) | Categorical
37
Trestbpd | Resting blood pressure (in mm Hg on admission) | Numerical
38
Chol | Serum cholesterol in mg/dl | Numerical
39
FBS | fasting blood sugar in 120 mg/dl (1 = true; 0 = false) | Categorical
40
RestECG | Resting electrocardiogram results (0, 1, 2) | Categorical
41
Thalach | Maximum heart rate achieved | Numerical
42
Exang | Exercise induced angina (1 = yes; 0 = no) | Categorical
43
Oldpeak | ST depression induced by exercise relative to rest | Numerical
44
Slope | Slope of the peak exercise ST segment | Numerical
45
CA | Number of major vessels (0-3) colored by fluoroscopy | Both numerical & categorical
46
Thal | 3 = normal; 6 = fixed defect; 7 = reversible defect | Categorical
47
Target | Diagnosis of heart disease (1 = true; 0 = false) | Target
48
"""
49
50
"""
51
## Setup
52
"""
53
54
import os
55
56
os.environ["KERAS_BACKEND"] = "torch" # or torch, or tensorflow
57
58
import pandas as pd
59
import keras
60
from keras import layers
61
62
"""
63
## Preparing the data
64
65
Let's download the data and load it into a Pandas dataframe:
66
"""
67
68
file_url = "http://storage.googleapis.com/download.tensorflow.org/data/heart.csv"
69
dataframe = pd.read_csv(file_url)
70
71
"""
72
The dataset includes 303 samples with 14 columns per sample (13 features, plus the target
73
label):
74
"""
75
76
dataframe.shape
77
78
"""
79
Here's a preview of a few samples:
80
"""
81
82
dataframe.head()
83
84
"""
85
The last column, "target", indicates whether the patient has a heart disease (1) or not
86
(0).
87
88
Let's split the data into a training and validation set:
89
"""
90
91
val_dataframe = dataframe.sample(frac=0.2, random_state=1337)
92
train_dataframe = dataframe.drop(val_dataframe.index)
93
94
print(
95
f"Using {len(train_dataframe)} samples for training "
96
f"and {len(val_dataframe)} for validation"
97
)
98
99
100
"""
101
## Define dataset metadata
102
103
Here, we define the metadata of the dataset that will be useful for reading and
104
parsing the data into input features, and encoding the input features with respect
105
to their types.
106
"""
107
108
COLUMN_NAMES = [
109
"age",
110
"sex",
111
"cp",
112
"trestbps",
113
"chol",
114
"fbs",
115
"restecg",
116
"thalach",
117
"exang",
118
"oldpeak",
119
"slope",
120
"ca",
121
"thal",
122
"target",
123
]
124
# Target feature name.
125
TARGET_FEATURE_NAME = "target"
126
# Numeric feature names.
127
NUMERIC_FEATURE_NAMES = ["age", "trestbps", "thalach", "oldpeak", "slope", "chol"]
128
# Categorical features and their vocabulary lists.
129
# Note that we add 'v=' as a prefix to all categorical feature values to make
130
# sure that they are treated as strings.
131
132
CATEGORICAL_FEATURES_WITH_VOCABULARY = {
133
feature_name: sorted(
134
[
135
# Integer categorcal must be int and string must be str
136
value if dataframe[feature_name].dtype == "int64" else str(value)
137
for value in list(dataframe[feature_name].unique())
138
]
139
)
140
for feature_name in COLUMN_NAMES
141
if feature_name not in list(NUMERIC_FEATURE_NAMES + [TARGET_FEATURE_NAME])
142
}
143
# All features names.
144
FEATURE_NAMES = NUMERIC_FEATURE_NAMES + list(
145
CATEGORICAL_FEATURES_WITH_VOCABULARY.keys()
146
)
147
148
149
"""
150
## Feature preprocessing with Keras layers
151
152
153
The following features are categorical features encoded as integers:
154
155
- `sex`
156
- `cp`
157
- `fbs`
158
- `restecg`
159
- `exang`
160
- `ca`
161
162
We will encode these features using **one-hot encoding**. We have two options
163
here:
164
165
- Use `CategoryEncoding()`, which requires knowing the range of input values
166
and will error on input outside the range.
167
- Use `IntegerLookup()` which will build a lookup table for inputs and reserve
168
an output index for unkown input values.
169
170
For this example, we want a simple solution that will handle out of range inputs
171
at inference, so we will use `IntegerLookup()`.
172
173
We also have a categorical feature encoded as a string: `thal`. We will create an
174
index of all possible features and encode output using the `StringLookup()` layer.
175
176
Finally, the following feature are continuous numerical features:
177
178
- `age`
179
- `trestbps`
180
- `chol`
181
- `thalach`
182
- `oldpeak`
183
- `slope`
184
185
For each of these features, we will use a `Normalization()` layer to make sure the mean
186
of each feature is 0 and its standard deviation is 1.
187
188
Below, we define 2 utility functions to do the operations:
189
190
- `encode_numerical_feature` to apply featurewise normalization to numerical features.
191
- `process` to one-hot encode string or integer categorical features.
192
"""
193
194
# Tensorflow required for tf.data.Dataset
195
import tensorflow as tf
196
197
198
# We process our datasets elements here (categorical) and convert them to indices to avoid this step
199
# during model training since only tensorflow support strings.
200
def encode_categorical(features, target):
201
for feature_name in features:
202
if feature_name in CATEGORICAL_FEATURES_WITH_VOCABULARY:
203
lookup_class = (
204
layers.StringLookup
205
if features[feature_name].dtype == "string"
206
else layers.IntegerLookup
207
)
208
vocabulary = CATEGORICAL_FEATURES_WITH_VOCABULARY[feature_name]
209
# Create a lookup to convert a string values to an integer indices.
210
# Since we are not using a mask token nor expecting any out of vocabulary
211
# (oov) token, we set mask_token to None and num_oov_indices to 0.
212
index = lookup_class(
213
vocabulary=vocabulary,
214
mask_token=None,
215
num_oov_indices=0,
216
output_mode="binary",
217
)
218
# Convert the string input values into integer indices.
219
value_index = index(features[feature_name])
220
features[feature_name] = value_index
221
222
else:
223
pass
224
225
# Change features from OrderedDict to Dict to match Inputs as they are Dict.
226
return dict(features), target
227
228
229
def encode_numerical_feature(feature, name, dataset):
230
# Create a Normalization layer for our feature
231
normalizer = layers.Normalization()
232
# Prepare a Dataset that only yields our feature
233
feature_ds = dataset.map(lambda x, y: x[name])
234
feature_ds = feature_ds.map(lambda x: tf.expand_dims(x, -1))
235
# Learn the statistics of the data
236
normalizer.adapt(feature_ds)
237
# Normalize the input feature
238
encoded_feature = normalizer(feature)
239
return encoded_feature
240
241
242
"""
243
Let's generate `tf.data.Dataset` objects for each dataframe:
244
"""
245
246
247
def dataframe_to_dataset(dataframe):
248
dataframe = dataframe.copy()
249
labels = dataframe.pop("target")
250
ds = tf.data.Dataset.from_tensor_slices((dict(dataframe), labels)).map(
251
encode_categorical
252
)
253
ds = ds.shuffle(buffer_size=len(dataframe))
254
return ds
255
256
257
train_ds = dataframe_to_dataset(train_dataframe)
258
val_ds = dataframe_to_dataset(val_dataframe)
259
260
"""
261
Each `Dataset` yields a tuple `(input, target)` where `input` is a dictionary of features
262
and `target` is the value `0` or `1`:
263
"""
264
265
for x, y in train_ds.take(1):
266
print("Input:", x)
267
print("Target:", y)
268
269
"""
270
Let's batch the datasets:
271
"""
272
273
train_ds = train_ds.batch(32)
274
val_ds = val_ds.batch(32)
275
276
277
"""
278
## Build a model
279
280
With this done, we can create our end-to-end model:
281
"""
282
283
284
# Categorical features have different shapes after the encoding, dependent on the
285
# vocabulary or unique values of each feature. We create them accordinly to match the
286
# input data elements generated by tf.data.Dataset after pre-processing them
287
def create_model_inputs():
288
inputs = {}
289
290
# This a helper function for creating categorical features
291
def create_input_helper(feature_name):
292
num_categories = len(CATEGORICAL_FEATURES_WITH_VOCABULARY[feature_name])
293
inputs[feature_name] = layers.Input(
294
name=feature_name, shape=(num_categories,), dtype="int64"
295
)
296
return inputs
297
298
for feature_name in FEATURE_NAMES:
299
if feature_name in CATEGORICAL_FEATURES_WITH_VOCABULARY:
300
# Categorical features
301
create_input_helper(feature_name)
302
else:
303
# Make them float32, they are Real numbers
304
feature_input = layers.Input(name=feature_name, shape=(1,), dtype="float32")
305
# Process the Inputs here
306
inputs[feature_name] = encode_numerical_feature(
307
feature_input, feature_name, train_ds
308
)
309
return inputs
310
311
312
# This Layer defines the logic of the Model to perform the classification
313
class Classifier(keras.layers.Layer):
314
315
def __init__(self, **kwargs):
316
super().__init__(**kwargs)
317
self.dense_1 = layers.Dense(32, activation="relu")
318
self.dropout = layers.Dropout(0.5)
319
self.dense_2 = layers.Dense(1, activation="sigmoid")
320
321
def call(self, inputs):
322
all_features = layers.concatenate(list(inputs.values()))
323
x = self.dense_1(all_features)
324
x = self.dropout(x)
325
output = self.dense_2(x)
326
return output
327
328
# Surpress build warnings
329
def build(self, input_shape):
330
self.built = True
331
332
333
# Create the Classifier model
334
def create_model():
335
all_inputs = create_model_inputs()
336
output = Classifier()(all_inputs)
337
model = keras.Model(all_inputs, output)
338
return model
339
340
341
model = create_model()
342
model.compile("adam", "binary_crossentropy", metrics=["accuracy"])
343
344
"""
345
Let's visualize our connectivity graph:
346
"""
347
348
# `rankdir='LR'` is to make the graph horizontal.
349
keras.utils.plot_model(model, show_shapes=True, rankdir="LR")
350
351
"""
352
## Train the model
353
"""
354
355
model.fit(train_ds, epochs=50, validation_data=val_ds)
356
357
358
"""
359
We quickly get to 80% validation accuracy.
360
"""
361
362
"""
363
## Inference on new data
364
365
To get a prediction for a new sample, you can simply call `model.predict()`. There are
366
just two things you need to do:
367
368
1. wrap scalars into a list so as to have a batch dimension (models only process batches
369
of data, not single samples)
370
2. Call `convert_to_tensor` on each feature
371
"""
372
373
sample = {
374
"age": 60,
375
"sex": 1,
376
"cp": 1,
377
"trestbps": 145,
378
"chol": 233,
379
"fbs": 1,
380
"restecg": 2,
381
"thalach": 150,
382
"exang": 0,
383
"oldpeak": 2.3,
384
"slope": 3,
385
"ca": 0,
386
"thal": "fixed",
387
}
388
389
390
# Given the category (in the sample above - key) and the category value (in the sample above - value),
391
# we return its one-hot encoding
392
def get_cat_encoding(cat, cat_value):
393
# Create a list of zeros with the same length as categories
394
encoding = [0] * len(cat)
395
# Find the index of category_value in categories and set the corresponding position to 1
396
if cat_value in cat:
397
encoding[cat.index(cat_value)] = 1
398
return encoding
399
400
401
for name, value in sample.items():
402
if name in CATEGORICAL_FEATURES_WITH_VOCABULARY:
403
sample.update(
404
{
405
name: get_cat_encoding(
406
CATEGORICAL_FEATURES_WITH_VOCABULARY[name], sample[name]
407
)
408
}
409
)
410
# Convert inputs to tensors
411
input_dict = {name: tf.convert_to_tensor([value]) for name, value in sample.items()}
412
predictions = model.predict(input_dict)
413
414
print(
415
f"This particular patient had a {100 * predictions[0][0]:.1f} "
416
"percent probability of having a heart disease, "
417
"as evaluated by our model."
418
)
419
420
"""
421
## Conclusions
422
423
- The orignal model (the one that runs only on tensorflow) converges quickly to around 80% and remains
424
there for extended periods and at times hits 85%
425
- The updated model (the backed-agnostic) model may fluctuate between 78% and 83% and at times hitting 86%
426
validation accuracy and converges around 80% also.
427
428
"""
429
430