Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/structured_data/structured_data_classification_with_feature_space.py
3507 views
1
"""
2
Title: Structured data classification with FeatureSpace
3
Author: [fchollet](https://twitter.com/fchollet)
4
Date created: 2022/11/09
5
Last modified: 2022/11/09
6
Description: Classify tabular data in a few lines of code.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
This example demonstrates how to do structured data classification
14
(also known as tabular data classification), starting from a raw
15
CSV file. Our data includes numerical features,
16
and integer categorical features, and string categorical features.
17
We will use the utility `keras.utils.FeatureSpace` to index,
18
preprocess, and encode our features.
19
20
The code is adapted from the example
21
[Structured data classification from scratch](https://keras.io/examples/structured_data/structured_data_classification_from_scratch/).
22
While the previous example managed its own low-level feature preprocessing and
23
encoding with Keras preprocessing layers, in this example we
24
delegate everything to `FeatureSpace`, making the workflow
25
extremely quick and easy.
26
27
### The dataset
28
29
[Our dataset](https://archive.ics.uci.edu/ml/datasets/heart+Disease) is provided by the
30
Cleveland Clinic Foundation for Heart Disease.
31
It's a CSV file with 303 rows. Each row contains information about a patient (a
32
**sample**), and each column describes an attribute of the patient (a **feature**). We
33
use the features to predict whether a patient has a heart disease
34
(**binary classification**).
35
36
Here's the description of each feature:
37
38
Column| Description| Feature Type
39
------------|--------------------|----------------------
40
Age | Age in years | Numerical
41
Sex | (1 = male; 0 = female) | Categorical
42
CP | Chest pain type (0, 1, 2, 3, 4) | Categorical
43
Trestbpd | Resting blood pressure (in mm Hg on admission) | Numerical
44
Chol | Serum cholesterol in mg/dl | Numerical
45
FBS | fasting blood sugar in 120 mg/dl (1 = true; 0 = false) | Categorical
46
RestECG | Resting electrocardiogram results (0, 1, 2) | Categorical
47
Thalach | Maximum heart rate achieved | Numerical
48
Exang | Exercise induced angina (1 = yes; 0 = no) | Categorical
49
Oldpeak | ST depression induced by exercise relative to rest | Numerical
50
Slope | Slope of the peak exercise ST segment | Numerical
51
CA | Number of major vessels (0-3) colored by fluoroscopy | Both numerical & categorical
52
Thal | 3 = normal; 6 = fixed defect; 7 = reversible defect | Categorical
53
Target | Diagnosis of heart disease (1 = true; 0 = false) | Target
54
"""
55
56
"""
57
## Setup
58
"""
59
60
import os
61
62
os.environ["KERAS_BACKEND"] = "tensorflow"
63
64
import tensorflow as tf
65
import pandas as pd
66
import keras
67
from keras.utils import FeatureSpace
68
69
"""
70
## Preparing the data
71
72
Let's download the data and load it into a Pandas dataframe:
73
"""
74
75
file_url = "http://storage.googleapis.com/download.tensorflow.org/data/heart.csv"
76
dataframe = pd.read_csv(file_url)
77
78
"""
79
The dataset includes 303 samples with 14 columns per sample
80
(13 features, plus the target label):
81
"""
82
83
print(dataframe.shape)
84
85
"""
86
Here's a preview of a few samples:
87
"""
88
89
dataframe.head()
90
91
"""
92
The last column, "target", indicates whether the patient
93
has a heart disease (1) or not (0).
94
95
Let's split the data into a training and validation set:
96
"""
97
98
val_dataframe = dataframe.sample(frac=0.2, random_state=1337)
99
train_dataframe = dataframe.drop(val_dataframe.index)
100
101
print(
102
"Using %d samples for training and %d for validation"
103
% (len(train_dataframe), len(val_dataframe))
104
)
105
106
"""
107
Let's generate `tf.data.Dataset` objects for each dataframe:
108
"""
109
110
111
def dataframe_to_dataset(dataframe):
112
dataframe = dataframe.copy()
113
labels = dataframe.pop("target")
114
ds = tf.data.Dataset.from_tensor_slices((dict(dataframe), labels))
115
ds = ds.shuffle(buffer_size=len(dataframe))
116
return ds
117
118
119
train_ds = dataframe_to_dataset(train_dataframe)
120
val_ds = dataframe_to_dataset(val_dataframe)
121
122
"""
123
Each `Dataset` yields a tuple `(input, target)` where `input` is a dictionary of features
124
and `target` is the value `0` or `1`:
125
"""
126
127
for x, y in train_ds.take(1):
128
print("Input:", x)
129
print("Target:", y)
130
131
"""
132
Let's batch the datasets:
133
"""
134
135
train_ds = train_ds.batch(32)
136
val_ds = val_ds.batch(32)
137
138
"""
139
## Configuring a `FeatureSpace`
140
141
To configure how each feature should be preprocessed,
142
we instantiate a `keras.utils.FeatureSpace`, and we
143
pass to it a dictionary that maps the name of our features
144
to a string that describes the feature type.
145
146
We have a few "integer categorical" features such as `"FBS"`,
147
one "string categorical" feature (`"thal"`),
148
and a few numerical features, which we'd like to normalize
149
-- except `"age"`, which we'd like to discretize into
150
a number of bins.
151
152
We also use the `crosses` argument
153
to capture *feature interactions* for some categorical
154
features, that is to say, create additional features
155
that represent value co-occurrences for these categorical features.
156
You can compute feature crosses like this for arbitrary sets of
157
categorical features -- not just tuples of two features.
158
Because the resulting co-occurences are hashed
159
into a fixed-sized vector, you don't need to worry about whether
160
the co-occurence space is too large.
161
"""
162
163
feature_space = FeatureSpace(
164
features={
165
# Categorical features encoded as integers
166
"sex": "integer_categorical",
167
"cp": "integer_categorical",
168
"fbs": "integer_categorical",
169
"restecg": "integer_categorical",
170
"exang": "integer_categorical",
171
"ca": "integer_categorical",
172
# Categorical feature encoded as string
173
"thal": "string_categorical",
174
# Numerical features to discretize
175
"age": "float_discretized",
176
# Numerical features to normalize
177
"trestbps": "float_normalized",
178
"chol": "float_normalized",
179
"thalach": "float_normalized",
180
"oldpeak": "float_normalized",
181
"slope": "float_normalized",
182
},
183
# We create additional features by hashing
184
# value co-occurrences for the
185
# following groups of categorical features.
186
crosses=[("sex", "age"), ("thal", "ca")],
187
# The hashing space for these co-occurrences
188
# wil be 32-dimensional.
189
crossing_dim=32,
190
# Our utility will one-hot encode all categorical
191
# features and concat all features into a single
192
# vector (one vector per sample).
193
output_mode="concat",
194
)
195
196
"""
197
## Further customizing a `FeatureSpace`
198
199
Specifying the feature type via a string name is quick and easy,
200
but sometimes you may want to further configure the preprocessing
201
of each feature. For instance, in our case, our categorical
202
features don't have a large set of possible values -- it's only
203
a handful of values per feature (e.g. `1` and `0` for the feature `"FBS"`),
204
and all possible values are represented in the training set.
205
As a result, we don't need to reserve an index to represent "out of vocabulary" values
206
for these features -- which would have been the default behavior.
207
Below, we just specify `num_oov_indices=0` in each of these features
208
to tell the feature preprocessor to skip "out of vocabulary" indexing.
209
210
Other customizations you have access to include specifying the number of
211
bins for discretizing features of type `"float_discretized"`,
212
or the dimensionality of the hashing space for feature crossing.
213
"""
214
215
feature_space = FeatureSpace(
216
features={
217
# Categorical features encoded as integers
218
"sex": FeatureSpace.integer_categorical(num_oov_indices=0),
219
"cp": FeatureSpace.integer_categorical(num_oov_indices=0),
220
"fbs": FeatureSpace.integer_categorical(num_oov_indices=0),
221
"restecg": FeatureSpace.integer_categorical(num_oov_indices=0),
222
"exang": FeatureSpace.integer_categorical(num_oov_indices=0),
223
"ca": FeatureSpace.integer_categorical(num_oov_indices=0),
224
# Categorical feature encoded as string
225
"thal": FeatureSpace.string_categorical(num_oov_indices=0),
226
# Numerical features to discretize
227
"age": FeatureSpace.float_discretized(num_bins=30),
228
# Numerical features to normalize
229
"trestbps": FeatureSpace.float_normalized(),
230
"chol": FeatureSpace.float_normalized(),
231
"thalach": FeatureSpace.float_normalized(),
232
"oldpeak": FeatureSpace.float_normalized(),
233
"slope": FeatureSpace.float_normalized(),
234
},
235
# Specify feature cross with a custom crossing dim.
236
crosses=[
237
FeatureSpace.cross(feature_names=("sex", "age"), crossing_dim=64),
238
FeatureSpace.cross(
239
feature_names=("thal", "ca"),
240
crossing_dim=16,
241
),
242
],
243
output_mode="concat",
244
)
245
246
"""
247
## Adapt the `FeatureSpace` to the training data
248
249
Before we start using the `FeatureSpace` to build a model, we have
250
to adapt it to the training data. During `adapt()`, the `FeatureSpace` will:
251
252
- Index the set of possible values for categorical features.
253
- Compute the mean and variance for numerical features to normalize.
254
- Compute the value boundaries for the different bins for numerical features to discretize.
255
256
Note that `adapt()` should be called on a `tf.data.Dataset` which yields dicts
257
of feature values -- no labels.
258
"""
259
260
train_ds_with_no_labels = train_ds.map(lambda x, _: x)
261
feature_space.adapt(train_ds_with_no_labels)
262
263
"""
264
At this point, the `FeatureSpace` can be called on a dict of raw feature values, and will return a
265
single concatenate vector for each sample, combining encoded features and feature crosses.
266
"""
267
268
for x, _ in train_ds.take(1):
269
preprocessed_x = feature_space(x)
270
print("preprocessed_x.shape:", preprocessed_x.shape)
271
print("preprocessed_x.dtype:", preprocessed_x.dtype)
272
273
"""
274
## Two ways to manage preprocessing: as part of the `tf.data` pipeline, or in the model itself
275
276
There are two ways in which you can leverage your `FeatureSpace`:
277
278
### Asynchronous preprocessing in `tf.data`
279
280
You can make it part of your data pipeline, before the model. This enables asynchronous parallel
281
preprocessing of the data on CPU before it hits the model. Do this if you're training on GPU or TPU,
282
or if you want to speed up preprocessing. Usually, this is always the right thing to do during training.
283
284
### Synchronous preprocessing in the model
285
286
You can make it part of your model. This means that the model will expect dicts of raw feature
287
values, and the preprocessing batch will be done synchronously (in a blocking manner) before the
288
rest of the forward pass. Do this if you want to have an end-to-end model that can process
289
raw feature values -- but keep in mind that your model will only be able to run on CPU,
290
since most types of feature preprocessing (e.g. string preprocessing) are not GPU or TPU compatible.
291
292
Do not do this on GPU / TPU or in performance-sensitive settings. In general, you want to do in-model
293
preprocessing when you do inference on CPU.
294
295
In our case, we will apply the `FeatureSpace` in the tf.data pipeline during training, but we will
296
do inference with an end-to-end model that includes the `FeatureSpace`.
297
"""
298
299
"""
300
Let's create a training and validation dataset of preprocessed batches:
301
"""
302
303
preprocessed_train_ds = train_ds.map(
304
lambda x, y: (feature_space(x), y), num_parallel_calls=tf.data.AUTOTUNE
305
)
306
preprocessed_train_ds = preprocessed_train_ds.prefetch(tf.data.AUTOTUNE)
307
308
preprocessed_val_ds = val_ds.map(
309
lambda x, y: (feature_space(x), y), num_parallel_calls=tf.data.AUTOTUNE
310
)
311
preprocessed_val_ds = preprocessed_val_ds.prefetch(tf.data.AUTOTUNE)
312
313
"""
314
## Build a model
315
316
Time to build a model -- or rather two models:
317
318
- A training model that expects preprocessed features (one sample = one vector)
319
- An inference model that expects raw features (one sample = dict of raw feature values)
320
"""
321
322
dict_inputs = feature_space.get_inputs()
323
encoded_features = feature_space.get_encoded_features()
324
325
x = keras.layers.Dense(32, activation="relu")(encoded_features)
326
x = keras.layers.Dropout(0.5)(x)
327
predictions = keras.layers.Dense(1, activation="sigmoid")(x)
328
329
training_model = keras.Model(inputs=encoded_features, outputs=predictions)
330
training_model.compile(
331
optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"]
332
)
333
334
inference_model = keras.Model(inputs=dict_inputs, outputs=predictions)
335
336
"""
337
## Train the model
338
339
Let's train our model for 50 epochs. Note that feature preprocessing is happening
340
as part of the tf.data pipeline, not as part of the model.
341
"""
342
343
training_model.fit(
344
preprocessed_train_ds,
345
epochs=20,
346
validation_data=preprocessed_val_ds,
347
verbose=2,
348
)
349
350
"""
351
We quickly get to 80% validation accuracy.
352
"""
353
354
"""
355
## Inference on new data with the end-to-end model
356
357
Now, we can use our inference model (which includes the `FeatureSpace`)
358
to make predictions based on dicts of raw features values, as follows:
359
"""
360
361
sample = {
362
"age": 60,
363
"sex": 1,
364
"cp": 1,
365
"trestbps": 145,
366
"chol": 233,
367
"fbs": 1,
368
"restecg": 2,
369
"thalach": 150,
370
"exang": 0,
371
"oldpeak": 2.3,
372
"slope": 3,
373
"ca": 0,
374
"thal": "fixed",
375
}
376
377
input_dict = {name: tf.convert_to_tensor([value]) for name, value in sample.items()}
378
predictions = inference_model.predict(input_dict)
379
380
print(
381
f"This particular patient had a {100 * predictions[0][0]:.2f}% probability "
382
"of having a heart disease, as evaluated by our model."
383
)
384
385