Path: blob/master/examples/structured_data/feature_space_advanced.py
3507 views
"""1Title: FeatureSpace advanced use cases2Author: [Dimitre Oliveira](https://www.linkedin.com/in/dimitre-oliveira-7a1a0113a/)3Date created: 2023/07/014Last modified: 2025/01/035Description: How to use FeatureSpace for advanced preprocessing use cases.6Accelerator: None7"""89"""10## Introduction1112This example is an extension of the13[Structured data classification with FeatureSpace](https://keras.io/examples/structured_data/structured_data_classification_with_feature_space/)14code example, and here we will extend it to cover more complex use15cases of the [`keras.utils.FeatureSpace`](https://keras.io/api/utils/feature_space/)16preprocessing utility, like feature hashing, feature crosses, handling missing values and17integrating [Keras preprocessing layers](https://keras.io/api/layers/preprocessing_layers/)18with FeatureSpace.1920The general task still is structured data classification (also known as tabular data21classification) using a data that includes numerical features, integer categorical22features, and string categorical features.23"""2425"""26### The dataset2728[Our dataset](https://archive.ics.uci.edu/dataset/222/bank+marketing) is provided by a29Portuguese banking institution.30It's a CSV file with 4119 rows. Each row contains information about marketing campaigns31based on phone calls, and each column describes an attribute of the client. We use the32features to predict whether the client subscribed ('yes') or not ('no') to the product33(bank term deposit).3435Here's the description of each feature:3637Column| Description| Feature Type38------|------------|-------------39Age | Age of the client | Numerical40Job | Type of job | Categorical41Marital | Marital status | Categorical42Education | Education level of the client | Categorical43Default | Has credit in default? | Categorical44Housing | Has housing loan? | Categorical45Loan | Has personal loan? | Categorical46Contact | Contact communication type | Categorical47Month | Last contact month of year | Categorical48Day_of_week | Last contact day of the week | Categorical49Duration | Last contact duration, in seconds | Numerical50Campaign | Number of contacts performed during this campaign and for this client | Numerical51Pdays | Number of days that passed by after the client was last contacted from a previous campaign | Numerical52Previous | Number of contacts performed before this campaign and for this client | Numerical53Poutcome | Outcome of the previous marketing campaign | Categorical54Emp.var.rate | Employment variation rate | Numerical55Cons.price.idx | Consumer price index | Numerical56Cons.conf.idx | Consumer confidence index | Numerical57Euribor3m | Euribor 3 month rate | Numerical58Nr.employed | Number of employees | Numerical59Y | Has the client subscribed a term deposit? | Target6061**Important note regarding the feature `duration`**: this attribute highly affects the62output target (e.g., if duration=0 then y='no'). Yet, the duration is not known before a63call is performed. Also, after the end of the call y is obviously known. Thus, this input64should only be included for benchmark purposes and should be discarded if the intention65is to have a realistic predictive model. For this reason we will drop it.6667"""6869"""70## Setup71"""7273import os7475os.environ["KERAS_BACKEND"] = "tensorflow"7677import keras78from keras.utils import FeatureSpace79import pandas as pd80import tensorflow as tf81from pathlib import Path82from zipfile import ZipFile8384"""85## Load the data8687Let's download the data and load it into a Pandas dataframe:88"""8990data_url = "https://archive.ics.uci.edu/static/public/222/bank+marketing.zip"91data_zipped_path = keras.utils.get_file("bank_marketing.zip", data_url, extract=True)92keras_datasets_path = Path(data_zipped_path)93with ZipFile(f"{keras_datasets_path}/bank-additional.zip", "r") as zip:94# Extract files95zip.extractall(path=keras_datasets_path)9697dataframe = pd.read_csv(98f"{keras_datasets_path}/bank-additional/bank-additional.csv", sep=";"99)100101"""102We will create a new feature `previously_contacted` to be able to demonstrate some useful103preprocessing techniques, this feature is based on `pdays`. According to the dataset104information if `pdays = 999` it means that the client was not previously contacted, so105let's create a feature to capture that.106"""107108# Droping `duration` to avoid target leak109dataframe.drop("duration", axis=1, inplace=True)110# Creating the new feature `previously_contacted`111dataframe["previously_contacted"] = dataframe["pdays"].map(112lambda x: 0 if x == 999 else 1113)114115"""116The dataset includes 4119 samples with 21 columns per sample (20 features, plus the117target label), here's a preview of a few samples:118"""119120print(f"Dataframe shape: {dataframe.shape}")121print(dataframe.head())122123"""124The column, "y", indicates whether the client has subscribed a term deposit or not.125"""126127"""128## Train/validation split129130Let's split the data into a training and validation set:131"""132133valid_dataframe = dataframe.sample(frac=0.2, random_state=0)134train_dataframe = dataframe.drop(valid_dataframe.index)135136print(137f"Using {len(train_dataframe)} samples for training and "138f"{len(valid_dataframe)} for validation"139)140141"""142## Generating TF datasets143144Let's generate145[`tf.data.Dataset`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset) objects146for each dataframe, since our target column `y` is a string we also need to encode it as147an integer to be able to train our model with it. To achieve this we will create a148`StringLookup` layer that will map the strings "no" and "yes" into "0" and "1"149respectively.150"""151152label_lookup = keras.layers.StringLookup(153# the order here is important since the first index will be encoded as 0154vocabulary=["no", "yes"],155num_oov_indices=0,156)157158159def encode_label(x, y):160encoded_y = label_lookup(y)161return x, encoded_y162163164def dataframe_to_dataset(dataframe):165dataframe = dataframe.copy()166labels = dataframe.pop("y")167ds = tf.data.Dataset.from_tensor_slices((dict(dataframe), labels))168ds = ds.map(encode_label, num_parallel_calls=tf.data.AUTOTUNE)169ds = ds.shuffle(buffer_size=len(dataframe))170return ds171172173train_ds = dataframe_to_dataset(train_dataframe)174valid_ds = dataframe_to_dataset(valid_dataframe)175176"""177Each `Dataset` yields a tuple `(input, target)` where `input` is a dictionary of features178and `target` is the value `0` or `1`:179"""180181for x, y in dataframe_to_dataset(train_dataframe).take(1):182print(f"Input: {x}")183print(f"Target: {y}")184185"""186## Preprocessing187188Usually our data is not on the proper or best format for modeling, this is why most of189the time we need to do some kind of preprocessing on the features to make them compatible190with the model or to extract the most of them for the task. We need to do this191preprocessing step for training but but at inference we also need to make sure that the192data goes through the same process, this where a utility like `FeatureSpace` shines, we193can define all the preprocessing once and re-use it at different stages of our system.194195Here we will see how to use `FeatureSpace` to perform more complex transformations and196its flexibility, then combine everything together into a single component to preprocess197data for our model.198"""199200"""201The `FeatureSpace` utility learns how to process the data by using the `adapt()` function202to learn from it, this requires a dataset containing only feature, so let's create it203together with a utility function to show the preprocessing example in practice:204"""205206train_ds_with_no_labels = train_ds.map(lambda x, _: x)207208209def example_feature_space(dataset, feature_space, feature_names):210feature_space.adapt(dataset)211for x in dataset.take(1):212inputs = {feature_name: x[feature_name] for feature_name in feature_names}213preprocessed_x = feature_space(inputs)214print(f"Input: {[{k:v.numpy()} for k, v in inputs.items()]}")215print(216f"Preprocessed output: {[{k:v.numpy()} for k, v in preprocessed_x.items()]}"217)218219220"""221### Feature hashing222"""223224"""225**Feature hashing** means hashing or encoding a set of values into a defined number of226bins, in this case we have `campaign` (number of contacts performed during this campaign227and for a client) which is a numerical feature that can assume a varying range of values228and we will hash it into 4 bins, this means that any possible value of the original229feature will be placed into one of those possible 4 bins. The output here can be a230one-hot encoded vector or a single number.231"""232233feature_space = FeatureSpace(234features={235"campaign": FeatureSpace.integer_hashed(num_bins=4, output_mode="one_hot")236},237output_mode="dict",238)239example_feature_space(train_ds_with_no_labels, feature_space, ["campaign"])240241"""242**Feature hashing** can also be used for string features.243"""244245feature_space = FeatureSpace(246features={247"education": FeatureSpace.string_hashed(num_bins=3, output_mode="one_hot")248},249output_mode="dict",250)251example_feature_space(train_ds_with_no_labels, feature_space, ["education"])252253"""254For numerical features we can get a similar behavior by using the `float_discretized`255option, the main difference between this and `integer_hashed` is that with the former we256bin the values while keeping some numerical relationship (close values will likely be257placed at the same bin) while the later (hashing) we cannot guarantee that those numbers258will be hashed into the same bin, it depends on the hashing function.259"""260261feature_space = FeatureSpace(262features={"age": FeatureSpace.float_discretized(num_bins=3, output_mode="one_hot")},263output_mode="dict",264)265example_feature_space(train_ds_with_no_labels, feature_space, ["age"])266267"""268### Feature indexing269"""270271"""272**Indexing** a string feature essentially means creating a discrete numerical273representation for it, this is especially important for string features since most models274only accept numerical features. This transformation will place the string values into275different categories. The output here can be a one-hot encoded vector or a single number.276277Note that by specifying `num_oov_indices=1` we leave one spot at our output vector for278OOV (out of vocabulary) values this is an important tool to handle missing or unseen279values after the training (values that were not seen during the `adapt()` step)280"""281282feature_space = FeatureSpace(283features={284"default": FeatureSpace.string_categorical(285num_oov_indices=1, output_mode="one_hot"286)287},288output_mode="dict",289)290example_feature_space(train_ds_with_no_labels, feature_space, ["default"])291292"""293We also can do **feature indexing** for integer features, this can be quite important for294some datasets where categorical features are replaced by numbers, for instance features295like `sex` or `gender` where values like (`1 and 0`) do not have a numerical relationship296between them, they are just different categories, this behavior can be perfectly captured297by this transformation.298299On this dataset we can use the feature that we created `previously_contacted`. For this300case we want to explicitly set `num_oov_indices=0`, the reason is that we only expect two301possible values for the feature, anything else would be either wrong input or an issue302with the data creation, for this reason we would probably just want the code to throw an303error so that we can be aware of the issue and fix it.304"""305306feature_space = FeatureSpace(307features={308"previously_contacted": FeatureSpace.integer_categorical(309num_oov_indices=0, output_mode="one_hot"310)311},312output_mode="dict",313)314example_feature_space(train_ds_with_no_labels, feature_space, ["previously_contacted"])315316"""317### Feature crosses (mixing features of diverse types)318319With **crosses** we can do feature interactions between an arbitrary number of features320of mixed types as long as they are categorical features, you can think of instead of321having a feature {'age': 20} and another {'job': 'entrepreneur'} we can have322{'age_X_job': 20_entrepreneur}, but with `FeatureSpace` and **crosses** we can apply323specific preprocessing to each individual feature and to the feature cross itself. This324option can be very powerful for specific use cases, here might be a good option since age325combined with job can have different meanings for the banking domain.326327We will cross `age` and `job` and hash the combination output of them into a vector328representation of size 8. The output here can be a one-hot encoded vector or a single329number.330331Sometimes the combination of multiple features can result into on a super large feature332space, think about crossing someone's ZIP code with its last name, the possibilities333would be in the thousands, that is why the `crossing_dim` parameter is so important it334limits the output dimension of the cross feature.335336Note that the combination of possible values of the 6 bins of `age` and the 12 values of337`job` would be 72, so by choosing `crossing_dim = 8` we are choosing to constrain the338output vector.339"""340341feature_space = FeatureSpace(342features={343"age": FeatureSpace.integer_hashed(num_bins=6, output_mode="one_hot"),344"job": FeatureSpace.string_categorical(345num_oov_indices=0, output_mode="one_hot"346),347},348crosses=[349FeatureSpace.cross(350feature_names=("age", "job"),351crossing_dim=8,352output_mode="one_hot",353)354],355output_mode="dict",356)357example_feature_space(train_ds_with_no_labels, feature_space, ["age", "job"])358359"""360### FeatureSpace using a Keras preprocessing layer361362To be a really flexible and extensible feature we cannot only rely on those pre-defined363transformation, we must be able to re-use other transformations from the Keras/TensorFlow364ecosystem and customize our own, this is why `FeatureSpace` is also designed to work with365[Keras preprocessing layers](https://keras.io/api/layers/preprocessing_layers/), this way we366can use sophisticated data transformations provided by the framework, you can even create367your own custom Keras preprocessing layers and use it in the same way.368369Here we are going to use the370[`keras.layers.TextVectorization`](https://keras.io/api/layers/preprocessing_layers/text/text_vectorization/#textvectorization-class)371preprocessing layer to create a TF-IDF372feature from our data. Note that this feature is not a really good use case for TF-IDF,373this is just for demonstration purposes.374"""375376custom_layer = keras.layers.TextVectorization(output_mode="tf_idf")377378feature_space = FeatureSpace(379features={380"education": FeatureSpace.feature(381preprocessor=custom_layer, dtype="string", output_mode="float"382)383},384output_mode="dict",385)386example_feature_space(train_ds_with_no_labels, feature_space, ["education"])387388"""389## Configuring the final `FeatureSpace`390391Now that we know how to use `FeatureSpace` for more complex use cases let's pick the ones392that looks more useful for this task and create the final `FeatureSpace` component.393394To configure how each feature should be preprocessed,395we instantiate a `keras.utils.FeatureSpace`, and we396pass to it a dictionary that maps the name of our features397to the feature transformation function.398399"""400401feature_space = FeatureSpace(402features={403# Categorical features encoded as integers404"previously_contacted": FeatureSpace.integer_categorical(num_oov_indices=0),405# Categorical features encoded as string406"marital": FeatureSpace.string_categorical(num_oov_indices=0),407"education": FeatureSpace.string_categorical(num_oov_indices=0),408"default": FeatureSpace.string_categorical(num_oov_indices=0),409"housing": FeatureSpace.string_categorical(num_oov_indices=0),410"loan": FeatureSpace.string_categorical(num_oov_indices=0),411"contact": FeatureSpace.string_categorical(num_oov_indices=0),412"month": FeatureSpace.string_categorical(num_oov_indices=0),413"day_of_week": FeatureSpace.string_categorical(num_oov_indices=0),414"poutcome": FeatureSpace.string_categorical(num_oov_indices=0),415# Categorical features to hash and bin416"job": FeatureSpace.string_hashed(num_bins=3),417# Numerical features to hash and bin418"pdays": FeatureSpace.integer_hashed(num_bins=4),419# Numerical features to normalize and bin420"age": FeatureSpace.float_discretized(num_bins=4),421# Numerical features to normalize422"campaign": FeatureSpace.float_normalized(),423"previous": FeatureSpace.float_normalized(),424"emp.var.rate": FeatureSpace.float_normalized(),425"cons.price.idx": FeatureSpace.float_normalized(),426"cons.conf.idx": FeatureSpace.float_normalized(),427"euribor3m": FeatureSpace.float_normalized(),428"nr.employed": FeatureSpace.float_normalized(),429},430# Specify feature cross with a custom crossing dim.431crosses=[432FeatureSpace.cross(feature_names=("age", "job"), crossing_dim=8),433FeatureSpace.cross(feature_names=("housing", "loan"), crossing_dim=6),434FeatureSpace.cross(435feature_names=("poutcome", "previously_contacted"), crossing_dim=2436),437],438output_mode="concat",439)440441"""442## Adapt the `FeatureSpace` to the training data443444Before we start using the `FeatureSpace` to build a model, we have445to adapt it to the training data. During `adapt()`, the `FeatureSpace` will:446447- Index the set of possible values for categorical features.448- Compute the mean and variance for numerical features to normalize.449- Compute the value boundaries for the different bins for numerical features to450discretize.451- Any other kind of preprocessing required by custom layers.452453Note that `adapt()` should be called on a `tf.data.Dataset` which yields dicts454of feature values -- no labels.455456But first let's batch the datasets457"""458459train_ds = train_ds.batch(32)460valid_ds = valid_ds.batch(32)461462train_ds_with_no_labels = train_ds.map(lambda x, _: x)463feature_space.adapt(train_ds_with_no_labels)464465"""466At this point, the `FeatureSpace` can be called on a dict of raw feature values, and467because we set `output_mode="concat"` it will return a single concatenate vector for each468sample, combining encoded features and feature crosses.469"""470471for x, _ in train_ds.take(1):472preprocessed_x = feature_space(x)473print(f"preprocessed_x shape: {preprocessed_x.shape}")474print(f"preprocessed_x sample: \n{preprocessed_x[0]}")475476"""477## Saving the `FeatureSpace`478479At this point we can choose to save our `FeatureSpace` component, this have many480advantages like re-using it on different experiments that use the same model, saving time481if you need to re-run the preprocessing step, and mainly for model deployment, where by482loading it you can be sure that you will be applying the same preprocessing steps don't483matter the device or environment, this is a great way to reduce484[training/servingskew](https://developers.google.com/machine-learning/guides/rules-of-ml#training-serving_skew).485"""486487feature_space.save("myfeaturespace.keras")488489"""490## Preprocessing with `FeatureSpace` as part of the tf.data pipeline491492We will opt to use our component asynchronously by making it part of the tf.data493pipeline, as noted at the494[previous guide](https://keras.io/examples/structured_data/structured_data_classification_with_feature_space/)495This enables asynchronous parallel preprocessing of the data on CPU before it496hits the model. Usually, this is always the right thing to do during training.497498Let's create a training and validation dataset of preprocessed batches:499"""500501preprocessed_train_ds = train_ds.map(502lambda x, y: (feature_space(x), y), num_parallel_calls=tf.data.AUTOTUNE503).prefetch(tf.data.AUTOTUNE)504505preprocessed_valid_ds = valid_ds.map(506lambda x, y: (feature_space(x), y), num_parallel_calls=tf.data.AUTOTUNE507).prefetch(tf.data.AUTOTUNE)508509"""510## Model511512We will take advantage of our `FeatureSpace` component to build the model, as we want the513model to be compatible with our preprocessing function, let's use the the `FeatureSpace`514feature map as the input of our model.515"""516517encoded_features = feature_space.get_encoded_features()518print(encoded_features)519520"""521This model is quite trivial only for demonstration purposes so don't pay too much522attention to the architecture.523"""524525x = keras.layers.Dense(64, activation="relu")(encoded_features)526x = keras.layers.Dropout(0.5)(x)527output = keras.layers.Dense(1, activation="sigmoid")(x)528529model = keras.Model(inputs=encoded_features, outputs=output)530model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"])531532"""533## Training534535Let's train our model for 20 epochs. Note that feature preprocessing is happening as part536of the tf.data pipeline, not as part of the model.537"""538539model.fit(540preprocessed_train_ds, validation_data=preprocessed_valid_ds, epochs=10, verbose=2541)542543"""544## Inference on new data with the end-to-end model545546Now, we can build our inference model (which includes the `FeatureSpace`) to make547predictions based on dicts of raw features values, as follows:548"""549550"""551### Loading the `FeatureSpace`552553First let's load the `FeatureSpace` that we saved a few moment ago, this can be quite554handy if you train a model but want to do inference at different time, possibly using a555different device or environment.556"""557558loaded_feature_space = keras.saving.load_model("myfeaturespace.keras")559560"""561### Building the inference end-to-end model562563To build the inference model we need both the feature input map and the preprocessing564encoded Keras tensors.565"""566567dict_inputs = loaded_feature_space.get_inputs()568encoded_features = loaded_feature_space.get_encoded_features()569print(encoded_features)570571print(dict_inputs)572573outputs = model(encoded_features)574inference_model = keras.Model(inputs=dict_inputs, outputs=outputs)575576sample = {577"age": 30,578"job": "blue-collar",579"marital": "married",580"education": "basic.9y",581"default": "no",582"housing": "yes",583"loan": "no",584"contact": "cellular",585"month": "may",586"day_of_week": "fri",587"campaign": 2,588"pdays": 999,589"previous": 0,590"poutcome": "nonexistent",591"emp.var.rate": -1.8,592"cons.price.idx": 92.893,593"cons.conf.idx": -46.2,594"euribor3m": 1.313,595"nr.employed": 5099.1,596"previously_contacted": 0,597}598599input_dict = {600name: keras.ops.convert_to_tensor([value]) for name, value in sample.items()601}602predictions = inference_model.predict(input_dict)603604print(605f"This particular client has a {100 * predictions[0][0]:.2f}% probability "606"of subscribing a term deposit, as evaluated by our model."607)608609610