Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/structured_data/feature_space_advanced.py
3507 views
1
"""
2
Title: FeatureSpace advanced use cases
3
Author: [Dimitre Oliveira](https://www.linkedin.com/in/dimitre-oliveira-7a1a0113a/)
4
Date created: 2023/07/01
5
Last modified: 2025/01/03
6
Description: How to use FeatureSpace for advanced preprocessing use cases.
7
Accelerator: None
8
"""
9
10
"""
11
## Introduction
12
13
This example is an extension of the
14
[Structured data classification with FeatureSpace](https://keras.io/examples/structured_data/structured_data_classification_with_feature_space/)
15
code example, and here we will extend it to cover more complex use
16
cases of the [`keras.utils.FeatureSpace`](https://keras.io/api/utils/feature_space/)
17
preprocessing utility, like feature hashing, feature crosses, handling missing values and
18
integrating [Keras preprocessing layers](https://keras.io/api/layers/preprocessing_layers/)
19
with FeatureSpace.
20
21
The general task still is structured data classification (also known as tabular data
22
classification) using a data that includes numerical features, integer categorical
23
features, and string categorical features.
24
"""
25
26
"""
27
### The dataset
28
29
[Our dataset](https://archive.ics.uci.edu/dataset/222/bank+marketing) is provided by a
30
Portuguese banking institution.
31
It's a CSV file with 4119 rows. Each row contains information about marketing campaigns
32
based on phone calls, and each column describes an attribute of the client. We use the
33
features to predict whether the client subscribed ('yes') or not ('no') to the product
34
(bank term deposit).
35
36
Here's the description of each feature:
37
38
Column| Description| Feature Type
39
------|------------|-------------
40
Age | Age of the client | Numerical
41
Job | Type of job | Categorical
42
Marital | Marital status | Categorical
43
Education | Education level of the client | Categorical
44
Default | Has credit in default? | Categorical
45
Housing | Has housing loan? | Categorical
46
Loan | Has personal loan? | Categorical
47
Contact | Contact communication type | Categorical
48
Month | Last contact month of year | Categorical
49
Day_of_week | Last contact day of the week | Categorical
50
Duration | Last contact duration, in seconds | Numerical
51
Campaign | Number of contacts performed during this campaign and for this client | Numerical
52
Pdays | Number of days that passed by after the client was last contacted from a previous campaign | Numerical
53
Previous | Number of contacts performed before this campaign and for this client | Numerical
54
Poutcome | Outcome of the previous marketing campaign | Categorical
55
Emp.var.rate | Employment variation rate | Numerical
56
Cons.price.idx | Consumer price index | Numerical
57
Cons.conf.idx | Consumer confidence index | Numerical
58
Euribor3m | Euribor 3 month rate | Numerical
59
Nr.employed | Number of employees | Numerical
60
Y | Has the client subscribed a term deposit? | Target
61
62
**Important note regarding the feature `duration`**: this attribute highly affects the
63
output target (e.g., if duration=0 then y='no'). Yet, the duration is not known before a
64
call is performed. Also, after the end of the call y is obviously known. Thus, this input
65
should only be included for benchmark purposes and should be discarded if the intention
66
is to have a realistic predictive model. For this reason we will drop it.
67
68
"""
69
70
"""
71
## Setup
72
"""
73
74
import os
75
76
os.environ["KERAS_BACKEND"] = "tensorflow"
77
78
import keras
79
from keras.utils import FeatureSpace
80
import pandas as pd
81
import tensorflow as tf
82
from pathlib import Path
83
from zipfile import ZipFile
84
85
"""
86
## Load the data
87
88
Let's download the data and load it into a Pandas dataframe:
89
"""
90
91
data_url = "https://archive.ics.uci.edu/static/public/222/bank+marketing.zip"
92
data_zipped_path = keras.utils.get_file("bank_marketing.zip", data_url, extract=True)
93
keras_datasets_path = Path(data_zipped_path)
94
with ZipFile(f"{keras_datasets_path}/bank-additional.zip", "r") as zip:
95
# Extract files
96
zip.extractall(path=keras_datasets_path)
97
98
dataframe = pd.read_csv(
99
f"{keras_datasets_path}/bank-additional/bank-additional.csv", sep=";"
100
)
101
102
"""
103
We will create a new feature `previously_contacted` to be able to demonstrate some useful
104
preprocessing techniques, this feature is based on `pdays`. According to the dataset
105
information if `pdays = 999` it means that the client was not previously contacted, so
106
let's create a feature to capture that.
107
"""
108
109
# Droping `duration` to avoid target leak
110
dataframe.drop("duration", axis=1, inplace=True)
111
# Creating the new feature `previously_contacted`
112
dataframe["previously_contacted"] = dataframe["pdays"].map(
113
lambda x: 0 if x == 999 else 1
114
)
115
116
"""
117
The dataset includes 4119 samples with 21 columns per sample (20 features, plus the
118
target label), here's a preview of a few samples:
119
"""
120
121
print(f"Dataframe shape: {dataframe.shape}")
122
print(dataframe.head())
123
124
"""
125
The column, "y", indicates whether the client has subscribed a term deposit or not.
126
"""
127
128
"""
129
## Train/validation split
130
131
Let's split the data into a training and validation set:
132
"""
133
134
valid_dataframe = dataframe.sample(frac=0.2, random_state=0)
135
train_dataframe = dataframe.drop(valid_dataframe.index)
136
137
print(
138
f"Using {len(train_dataframe)} samples for training and "
139
f"{len(valid_dataframe)} for validation"
140
)
141
142
"""
143
## Generating TF datasets
144
145
Let's generate
146
[`tf.data.Dataset`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset) objects
147
for each dataframe, since our target column `y` is a string we also need to encode it as
148
an integer to be able to train our model with it. To achieve this we will create a
149
`StringLookup` layer that will map the strings "no" and "yes" into "0" and "1"
150
respectively.
151
"""
152
153
label_lookup = keras.layers.StringLookup(
154
# the order here is important since the first index will be encoded as 0
155
vocabulary=["no", "yes"],
156
num_oov_indices=0,
157
)
158
159
160
def encode_label(x, y):
161
encoded_y = label_lookup(y)
162
return x, encoded_y
163
164
165
def dataframe_to_dataset(dataframe):
166
dataframe = dataframe.copy()
167
labels = dataframe.pop("y")
168
ds = tf.data.Dataset.from_tensor_slices((dict(dataframe), labels))
169
ds = ds.map(encode_label, num_parallel_calls=tf.data.AUTOTUNE)
170
ds = ds.shuffle(buffer_size=len(dataframe))
171
return ds
172
173
174
train_ds = dataframe_to_dataset(train_dataframe)
175
valid_ds = dataframe_to_dataset(valid_dataframe)
176
177
"""
178
Each `Dataset` yields a tuple `(input, target)` where `input` is a dictionary of features
179
and `target` is the value `0` or `1`:
180
"""
181
182
for x, y in dataframe_to_dataset(train_dataframe).take(1):
183
print(f"Input: {x}")
184
print(f"Target: {y}")
185
186
"""
187
## Preprocessing
188
189
Usually our data is not on the proper or best format for modeling, this is why most of
190
the time we need to do some kind of preprocessing on the features to make them compatible
191
with the model or to extract the most of them for the task. We need to do this
192
preprocessing step for training but but at inference we also need to make sure that the
193
data goes through the same process, this where a utility like `FeatureSpace` shines, we
194
can define all the preprocessing once and re-use it at different stages of our system.
195
196
Here we will see how to use `FeatureSpace` to perform more complex transformations and
197
its flexibility, then combine everything together into a single component to preprocess
198
data for our model.
199
"""
200
201
"""
202
The `FeatureSpace` utility learns how to process the data by using the `adapt()` function
203
to learn from it, this requires a dataset containing only feature, so let's create it
204
together with a utility function to show the preprocessing example in practice:
205
"""
206
207
train_ds_with_no_labels = train_ds.map(lambda x, _: x)
208
209
210
def example_feature_space(dataset, feature_space, feature_names):
211
feature_space.adapt(dataset)
212
for x in dataset.take(1):
213
inputs = {feature_name: x[feature_name] for feature_name in feature_names}
214
preprocessed_x = feature_space(inputs)
215
print(f"Input: {[{k:v.numpy()} for k, v in inputs.items()]}")
216
print(
217
f"Preprocessed output: {[{k:v.numpy()} for k, v in preprocessed_x.items()]}"
218
)
219
220
221
"""
222
### Feature hashing
223
"""
224
225
"""
226
**Feature hashing** means hashing or encoding a set of values into a defined number of
227
bins, in this case we have `campaign` (number of contacts performed during this campaign
228
and for a client) which is a numerical feature that can assume a varying range of values
229
and we will hash it into 4 bins, this means that any possible value of the original
230
feature will be placed into one of those possible 4 bins. The output here can be a
231
one-hot encoded vector or a single number.
232
"""
233
234
feature_space = FeatureSpace(
235
features={
236
"campaign": FeatureSpace.integer_hashed(num_bins=4, output_mode="one_hot")
237
},
238
output_mode="dict",
239
)
240
example_feature_space(train_ds_with_no_labels, feature_space, ["campaign"])
241
242
"""
243
**Feature hashing** can also be used for string features.
244
"""
245
246
feature_space = FeatureSpace(
247
features={
248
"education": FeatureSpace.string_hashed(num_bins=3, output_mode="one_hot")
249
},
250
output_mode="dict",
251
)
252
example_feature_space(train_ds_with_no_labels, feature_space, ["education"])
253
254
"""
255
For numerical features we can get a similar behavior by using the `float_discretized`
256
option, the main difference between this and `integer_hashed` is that with the former we
257
bin the values while keeping some numerical relationship (close values will likely be
258
placed at the same bin) while the later (hashing) we cannot guarantee that those numbers
259
will be hashed into the same bin, it depends on the hashing function.
260
"""
261
262
feature_space = FeatureSpace(
263
features={"age": FeatureSpace.float_discretized(num_bins=3, output_mode="one_hot")},
264
output_mode="dict",
265
)
266
example_feature_space(train_ds_with_no_labels, feature_space, ["age"])
267
268
"""
269
### Feature indexing
270
"""
271
272
"""
273
**Indexing** a string feature essentially means creating a discrete numerical
274
representation for it, this is especially important for string features since most models
275
only accept numerical features. This transformation will place the string values into
276
different categories. The output here can be a one-hot encoded vector or a single number.
277
278
Note that by specifying `num_oov_indices=1` we leave one spot at our output vector for
279
OOV (out of vocabulary) values this is an important tool to handle missing or unseen
280
values after the training (values that were not seen during the `adapt()` step)
281
"""
282
283
feature_space = FeatureSpace(
284
features={
285
"default": FeatureSpace.string_categorical(
286
num_oov_indices=1, output_mode="one_hot"
287
)
288
},
289
output_mode="dict",
290
)
291
example_feature_space(train_ds_with_no_labels, feature_space, ["default"])
292
293
"""
294
We also can do **feature indexing** for integer features, this can be quite important for
295
some datasets where categorical features are replaced by numbers, for instance features
296
like `sex` or `gender` where values like (`1 and 0`) do not have a numerical relationship
297
between them, they are just different categories, this behavior can be perfectly captured
298
by this transformation.
299
300
On this dataset we can use the feature that we created `previously_contacted`. For this
301
case we want to explicitly set `num_oov_indices=0`, the reason is that we only expect two
302
possible values for the feature, anything else would be either wrong input or an issue
303
with the data creation, for this reason we would probably just want the code to throw an
304
error so that we can be aware of the issue and fix it.
305
"""
306
307
feature_space = FeatureSpace(
308
features={
309
"previously_contacted": FeatureSpace.integer_categorical(
310
num_oov_indices=0, output_mode="one_hot"
311
)
312
},
313
output_mode="dict",
314
)
315
example_feature_space(train_ds_with_no_labels, feature_space, ["previously_contacted"])
316
317
"""
318
### Feature crosses (mixing features of diverse types)
319
320
With **crosses** we can do feature interactions between an arbitrary number of features
321
of mixed types as long as they are categorical features, you can think of instead of
322
having a feature {'age': 20} and another {'job': 'entrepreneur'} we can have
323
{'age_X_job': 20_entrepreneur}, but with `FeatureSpace` and **crosses** we can apply
324
specific preprocessing to each individual feature and to the feature cross itself. This
325
option can be very powerful for specific use cases, here might be a good option since age
326
combined with job can have different meanings for the banking domain.
327
328
We will cross `age` and `job` and hash the combination output of them into a vector
329
representation of size 8. The output here can be a one-hot encoded vector or a single
330
number.
331
332
Sometimes the combination of multiple features can result into on a super large feature
333
space, think about crossing someone's ZIP code with its last name, the possibilities
334
would be in the thousands, that is why the `crossing_dim` parameter is so important it
335
limits the output dimension of the cross feature.
336
337
Note that the combination of possible values of the 6 bins of `age` and the 12 values of
338
`job` would be 72, so by choosing `crossing_dim = 8` we are choosing to constrain the
339
output vector.
340
"""
341
342
feature_space = FeatureSpace(
343
features={
344
"age": FeatureSpace.integer_hashed(num_bins=6, output_mode="one_hot"),
345
"job": FeatureSpace.string_categorical(
346
num_oov_indices=0, output_mode="one_hot"
347
),
348
},
349
crosses=[
350
FeatureSpace.cross(
351
feature_names=("age", "job"),
352
crossing_dim=8,
353
output_mode="one_hot",
354
)
355
],
356
output_mode="dict",
357
)
358
example_feature_space(train_ds_with_no_labels, feature_space, ["age", "job"])
359
360
"""
361
### FeatureSpace using a Keras preprocessing layer
362
363
To be a really flexible and extensible feature we cannot only rely on those pre-defined
364
transformation, we must be able to re-use other transformations from the Keras/TensorFlow
365
ecosystem and customize our own, this is why `FeatureSpace` is also designed to work with
366
[Keras preprocessing layers](https://keras.io/api/layers/preprocessing_layers/), this way we
367
can use sophisticated data transformations provided by the framework, you can even create
368
your own custom Keras preprocessing layers and use it in the same way.
369
370
Here we are going to use the
371
[`keras.layers.TextVectorization`](https://keras.io/api/layers/preprocessing_layers/text/text_vectorization/#textvectorization-class)
372
preprocessing layer to create a TF-IDF
373
feature from our data. Note that this feature is not a really good use case for TF-IDF,
374
this is just for demonstration purposes.
375
"""
376
377
custom_layer = keras.layers.TextVectorization(output_mode="tf_idf")
378
379
feature_space = FeatureSpace(
380
features={
381
"education": FeatureSpace.feature(
382
preprocessor=custom_layer, dtype="string", output_mode="float"
383
)
384
},
385
output_mode="dict",
386
)
387
example_feature_space(train_ds_with_no_labels, feature_space, ["education"])
388
389
"""
390
## Configuring the final `FeatureSpace`
391
392
Now that we know how to use `FeatureSpace` for more complex use cases let's pick the ones
393
that looks more useful for this task and create the final `FeatureSpace` component.
394
395
To configure how each feature should be preprocessed,
396
we instantiate a `keras.utils.FeatureSpace`, and we
397
pass to it a dictionary that maps the name of our features
398
to the feature transformation function.
399
400
"""
401
402
feature_space = FeatureSpace(
403
features={
404
# Categorical features encoded as integers
405
"previously_contacted": FeatureSpace.integer_categorical(num_oov_indices=0),
406
# Categorical features encoded as string
407
"marital": FeatureSpace.string_categorical(num_oov_indices=0),
408
"education": FeatureSpace.string_categorical(num_oov_indices=0),
409
"default": FeatureSpace.string_categorical(num_oov_indices=0),
410
"housing": FeatureSpace.string_categorical(num_oov_indices=0),
411
"loan": FeatureSpace.string_categorical(num_oov_indices=0),
412
"contact": FeatureSpace.string_categorical(num_oov_indices=0),
413
"month": FeatureSpace.string_categorical(num_oov_indices=0),
414
"day_of_week": FeatureSpace.string_categorical(num_oov_indices=0),
415
"poutcome": FeatureSpace.string_categorical(num_oov_indices=0),
416
# Categorical features to hash and bin
417
"job": FeatureSpace.string_hashed(num_bins=3),
418
# Numerical features to hash and bin
419
"pdays": FeatureSpace.integer_hashed(num_bins=4),
420
# Numerical features to normalize and bin
421
"age": FeatureSpace.float_discretized(num_bins=4),
422
# Numerical features to normalize
423
"campaign": FeatureSpace.float_normalized(),
424
"previous": FeatureSpace.float_normalized(),
425
"emp.var.rate": FeatureSpace.float_normalized(),
426
"cons.price.idx": FeatureSpace.float_normalized(),
427
"cons.conf.idx": FeatureSpace.float_normalized(),
428
"euribor3m": FeatureSpace.float_normalized(),
429
"nr.employed": FeatureSpace.float_normalized(),
430
},
431
# Specify feature cross with a custom crossing dim.
432
crosses=[
433
FeatureSpace.cross(feature_names=("age", "job"), crossing_dim=8),
434
FeatureSpace.cross(feature_names=("housing", "loan"), crossing_dim=6),
435
FeatureSpace.cross(
436
feature_names=("poutcome", "previously_contacted"), crossing_dim=2
437
),
438
],
439
output_mode="concat",
440
)
441
442
"""
443
## Adapt the `FeatureSpace` to the training data
444
445
Before we start using the `FeatureSpace` to build a model, we have
446
to adapt it to the training data. During `adapt()`, the `FeatureSpace` will:
447
448
- Index the set of possible values for categorical features.
449
- Compute the mean and variance for numerical features to normalize.
450
- Compute the value boundaries for the different bins for numerical features to
451
discretize.
452
- Any other kind of preprocessing required by custom layers.
453
454
Note that `adapt()` should be called on a `tf.data.Dataset` which yields dicts
455
of feature values -- no labels.
456
457
But first let's batch the datasets
458
"""
459
460
train_ds = train_ds.batch(32)
461
valid_ds = valid_ds.batch(32)
462
463
train_ds_with_no_labels = train_ds.map(lambda x, _: x)
464
feature_space.adapt(train_ds_with_no_labels)
465
466
"""
467
At this point, the `FeatureSpace` can be called on a dict of raw feature values, and
468
because we set `output_mode="concat"` it will return a single concatenate vector for each
469
sample, combining encoded features and feature crosses.
470
"""
471
472
for x, _ in train_ds.take(1):
473
preprocessed_x = feature_space(x)
474
print(f"preprocessed_x shape: {preprocessed_x.shape}")
475
print(f"preprocessed_x sample: \n{preprocessed_x[0]}")
476
477
"""
478
## Saving the `FeatureSpace`
479
480
At this point we can choose to save our `FeatureSpace` component, this have many
481
advantages like re-using it on different experiments that use the same model, saving time
482
if you need to re-run the preprocessing step, and mainly for model deployment, where by
483
loading it you can be sure that you will be applying the same preprocessing steps don't
484
matter the device or environment, this is a great way to reduce
485
[training/servingskew](https://developers.google.com/machine-learning/guides/rules-of-ml#training-serving_skew).
486
"""
487
488
feature_space.save("myfeaturespace.keras")
489
490
"""
491
## Preprocessing with `FeatureSpace` as part of the tf.data pipeline
492
493
We will opt to use our component asynchronously by making it part of the tf.data
494
pipeline, as noted at the
495
[previous guide](https://keras.io/examples/structured_data/structured_data_classification_with_feature_space/)
496
This enables asynchronous parallel preprocessing of the data on CPU before it
497
hits the model. Usually, this is always the right thing to do during training.
498
499
Let's create a training and validation dataset of preprocessed batches:
500
"""
501
502
preprocessed_train_ds = train_ds.map(
503
lambda x, y: (feature_space(x), y), num_parallel_calls=tf.data.AUTOTUNE
504
).prefetch(tf.data.AUTOTUNE)
505
506
preprocessed_valid_ds = valid_ds.map(
507
lambda x, y: (feature_space(x), y), num_parallel_calls=tf.data.AUTOTUNE
508
).prefetch(tf.data.AUTOTUNE)
509
510
"""
511
## Model
512
513
We will take advantage of our `FeatureSpace` component to build the model, as we want the
514
model to be compatible with our preprocessing function, let's use the the `FeatureSpace`
515
feature map as the input of our model.
516
"""
517
518
encoded_features = feature_space.get_encoded_features()
519
print(encoded_features)
520
521
"""
522
This model is quite trivial only for demonstration purposes so don't pay too much
523
attention to the architecture.
524
"""
525
526
x = keras.layers.Dense(64, activation="relu")(encoded_features)
527
x = keras.layers.Dropout(0.5)(x)
528
output = keras.layers.Dense(1, activation="sigmoid")(x)
529
530
model = keras.Model(inputs=encoded_features, outputs=output)
531
model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"])
532
533
"""
534
## Training
535
536
Let's train our model for 20 epochs. Note that feature preprocessing is happening as part
537
of the tf.data pipeline, not as part of the model.
538
"""
539
540
model.fit(
541
preprocessed_train_ds, validation_data=preprocessed_valid_ds, epochs=10, verbose=2
542
)
543
544
"""
545
## Inference on new data with the end-to-end model
546
547
Now, we can build our inference model (which includes the `FeatureSpace`) to make
548
predictions based on dicts of raw features values, as follows:
549
"""
550
551
"""
552
### Loading the `FeatureSpace`
553
554
First let's load the `FeatureSpace` that we saved a few moment ago, this can be quite
555
handy if you train a model but want to do inference at different time, possibly using a
556
different device or environment.
557
"""
558
559
loaded_feature_space = keras.saving.load_model("myfeaturespace.keras")
560
561
"""
562
### Building the inference end-to-end model
563
564
To build the inference model we need both the feature input map and the preprocessing
565
encoded Keras tensors.
566
"""
567
568
dict_inputs = loaded_feature_space.get_inputs()
569
encoded_features = loaded_feature_space.get_encoded_features()
570
print(encoded_features)
571
572
print(dict_inputs)
573
574
outputs = model(encoded_features)
575
inference_model = keras.Model(inputs=dict_inputs, outputs=outputs)
576
577
sample = {
578
"age": 30,
579
"job": "blue-collar",
580
"marital": "married",
581
"education": "basic.9y",
582
"default": "no",
583
"housing": "yes",
584
"loan": "no",
585
"contact": "cellular",
586
"month": "may",
587
"day_of_week": "fri",
588
"campaign": 2,
589
"pdays": 999,
590
"previous": 0,
591
"poutcome": "nonexistent",
592
"emp.var.rate": -1.8,
593
"cons.price.idx": 92.893,
594
"cons.conf.idx": -46.2,
595
"euribor3m": 1.313,
596
"nr.employed": 5099.1,
597
"previously_contacted": 0,
598
}
599
600
input_dict = {
601
name: keras.ops.convert_to_tensor([value]) for name, value in sample.items()
602
}
603
predictions = inference_model.predict(input_dict)
604
605
print(
606
f"This particular client has a {100 * predictions[0][0]:.2f}% probability "
607
"of subscribing a term deposit, as evaluated by our model."
608
)
609
610