Path: blob/master/examples/nlp/multi_label_classification.py
3507 views
"""1Title: Large-scale multi-label text classification2Author: [Sayak Paul](https://twitter.com/RisingSayak), [Soumik Rakshit](https://github.com/soumik12345)3Date created: 2020/09/254Last modified: 2025/02/275Description: Implementing a large-scale multi-label text classification model.6Accelerator: GPU7Converted to keras 3 and made backend-agnostic by: [Humbulani Ndou](https://github.com/Humbulani1234)8"""910"""11## Introduction1213In this example, we will build a multi-label text classifier to predict the subject areas14of arXiv papers from their abstract bodies. This type of classifier can be useful for15conference submission portals like [OpenReview](https://openreview.net/). Given a paper16abstract, the portal could provide suggestions for which areas the paper would17best belong to.1819The dataset was collected using the20[`arXiv` Python library](https://github.com/lukasschwab/arxiv.py)21that provides a wrapper around the22[original arXiv API](http://arxiv.org/help/api/index).23To learn more about the data collection process, please refer to24[this notebook](https://github.com/soumik12345/multi-label-text-classification/blob/master/arxiv_scrape.ipynb).25Additionally, you can also find the dataset on26[Kaggle](https://www.kaggle.com/spsayakpaul/arxiv-paper-abstracts).27"""2829"""30## Imports31"""3233import os3435os.environ["KERAS_BACKEND"] = "jax" # or tensorflow, or torch3637import keras38from keras import layers, ops3940from sklearn.model_selection import train_test_split4142from ast import literal_eval43import matplotlib.pyplot as plt44import pandas as pd45import numpy as np4647"""48## Perform exploratory data analysis4950In this section, we first load the dataset into a `pandas` dataframe and then perform51some basic exploratory data analysis (EDA).52"""5354arxiv_data = pd.read_csv(55"https://github.com/soumik12345/multi-label-text-classification/releases/download/v0.2/arxiv_data.csv"56)57arxiv_data.head()5859"""60Our text features are present in the `summaries` column and their corresponding labels61are in `terms`. As you can notice, there are multiple categories associated with a62particular entry.63"""6465print(f"There are {len(arxiv_data)} rows in the dataset.")6667"""68Real-world data is noisy. One of the most commonly observed source of noise is data69duplication. Here we notice that our initial dataset has got about 13k duplicate entries.70"""7172total_duplicate_titles = sum(arxiv_data["titles"].duplicated())73print(f"There are {total_duplicate_titles} duplicate titles.")7475"""76Before proceeding further, we drop these entries.77"""7879arxiv_data = arxiv_data[~arxiv_data["titles"].duplicated()]80print(f"There are {len(arxiv_data)} rows in the deduplicated dataset.")8182# There are some terms with occurrence as low as 1.83print(sum(arxiv_data["terms"].value_counts() == 1))8485# How many unique terms?86print(arxiv_data["terms"].nunique())8788"""89As observed above, out of 3,157 unique combinations of `terms`, 2,321 entries have the90lowest occurrence. To prepare our train, validation, and test sets with91[stratification](https://en.wikipedia.org/wiki/Stratified_sampling), we need to drop92these terms.93"""9495# Filtering the rare terms.96arxiv_data_filtered = arxiv_data.groupby("terms").filter(lambda x: len(x) > 1)97arxiv_data_filtered.shape9899"""100## Convert the string labels to lists of strings101102The initial labels are represented as raw strings. Here we make them `List[str]` for a103more compact representation.104"""105106arxiv_data_filtered["terms"] = arxiv_data_filtered["terms"].apply(107lambda x: literal_eval(x)108)109arxiv_data_filtered["terms"].values[:5]110111"""112## Use stratified splits because of class imbalance113114The dataset has a115[class imbalance problem](https://developers.google.com/machine-learning/glossary/#class-imbalanced-dataset).116So, to have a fair evaluation result, we need to ensure the datasets are sampled with117stratification. To know more about different strategies to deal with the class imbalance118problem, you can follow119[this tutorial](https://www.tensorflow.org/tutorials/structured_data/imbalanced_data).120For an end-to-end demonstration of classification with imbablanced data, refer to121[Imbalanced classification: credit card fraud detection](https://keras.io/examples/structured_data/imbalanced_classification/).122"""123124test_split = 0.1125126# Initial train and test split.127train_df, test_df = train_test_split(128arxiv_data_filtered,129test_size=test_split,130stratify=arxiv_data_filtered["terms"].values,131)132133# Splitting the test set further into validation134# and new test sets.135val_df = test_df.sample(frac=0.5)136test_df.drop(val_df.index, inplace=True)137138print(f"Number of rows in training set: {len(train_df)}")139print(f"Number of rows in validation set: {len(val_df)}")140print(f"Number of rows in test set: {len(test_df)}")141142"""143## Multi-label binarization144145Now we preprocess our labels using the146[`StringLookup`](https://keras.io/api/layers/preprocessing_layers/categorical/string_lookup)147layer.148"""149150# For RaggedTensor151import tensorflow as tf152153terms = tf.ragged.constant(train_df["terms"].values)154lookup = layers.StringLookup(output_mode="multi_hot")155lookup.adapt(terms)156vocab = lookup.get_vocabulary()157158159def invert_multi_hot(encoded_labels):160"""Reverse a single multi-hot encoded label to a tuple of vocab terms."""161hot_indices = np.argwhere(encoded_labels == 1.0)[..., 0]162return np.take(vocab, hot_indices)163164165print("Vocabulary:\n")166print(vocab)167168169"""170Here we are separating the individual unique classes available from the label171pool and then using this information to represent a given label set with 0's and 1's.172Below is an example.173"""174175sample_label = train_df["terms"].iloc[0]176print(f"Original label: {sample_label}")177178label_binarized = lookup([sample_label])179print(f"Label-binarized representation: {label_binarized}")180181"""182## Data preprocessing and `tf.data.Dataset` objects183184We first get percentile estimates of the sequence lengths. The purpose will be clear in a185moment.186"""187188train_df["summaries"].apply(lambda x: len(x.split(" "))).describe()189190"""191Notice that 50% of the abstracts have a length of 154 (you may get a different number192based on the split). So, any number close to that value is a good enough approximate for the193maximum sequence length.194195Now, we implement utilities to prepare our datasets.196"""197198max_seqlen = 150199batch_size = 128200padding_token = "<pad>"201auto = tf.data.AUTOTUNE202203204def make_dataset(dataframe, is_train=True):205labels = tf.ragged.constant(dataframe["terms"].values)206label_binarized = lookup(labels).numpy()207dataset = tf.data.Dataset.from_tensor_slices(208(dataframe["summaries"].values, label_binarized)209)210dataset = dataset.shuffle(batch_size * 10) if is_train else dataset211return dataset.batch(batch_size)212213214"""215Now we can prepare the `tf.data.Dataset` objects.216"""217218train_dataset = make_dataset(train_df, is_train=True)219validation_dataset = make_dataset(val_df, is_train=False)220test_dataset = make_dataset(test_df, is_train=False)221222"""223## Dataset preview224"""225226text_batch, label_batch = next(iter(train_dataset))227228for i, text in enumerate(text_batch[:5]):229label = label_batch[i].numpy()[None, ...]230print(f"Abstract: {text}")231print(f"Label(s): {invert_multi_hot(label[0])}")232print(" ")233234"""235## Vectorization236237Before we feed the data to our model, we need to vectorize it (represent it in a numerical form).238For that purpose, we will use the239[`TextVectorization` layer](https://keras.io/api/layers/preprocessing_layers/text/text_vectorization).240It can operate as a part of your main model so that the model is excluded from the core241preprocessing logic. This greatly reduces the chances of training / serving skew during inference.242243We first calculate the number of unique words present in the abstracts.244"""245246# Source: https://stackoverflow.com/a/18937309/7636462247vocabulary = set()248train_df["summaries"].str.lower().str.split().apply(vocabulary.update)249vocabulary_size = len(vocabulary)250print(vocabulary_size)251252253"""254We now create our vectorization layer and `map()` to the `tf.data.Dataset`s created255earlier.256"""257258text_vectorizer = layers.TextVectorization(259max_tokens=vocabulary_size, ngrams=2, output_mode="tf_idf"260)261262# `TextVectorization` layer needs to be adapted as per the vocabulary from our263# training set.264with tf.device("/CPU:0"):265text_vectorizer.adapt(train_dataset.map(lambda text, label: text))266267train_dataset = train_dataset.map(268lambda text, label: (text_vectorizer(text), label), num_parallel_calls=auto269).prefetch(auto)270validation_dataset = validation_dataset.map(271lambda text, label: (text_vectorizer(text), label), num_parallel_calls=auto272).prefetch(auto)273test_dataset = test_dataset.map(274lambda text, label: (text_vectorizer(text), label), num_parallel_calls=auto275).prefetch(auto)276277278"""279A batch of raw text will first go through the `TextVectorization` layer and it will280generate their integer representations. Internally, the `TextVectorization` layer will281first create bi-grams out of the sequences and then represent them using282[TF-IDF](https://wikipedia.org/wiki/Tf%E2%80%93idf). The output representations will then283be passed to the shallow model responsible for text classification.284285To learn more about other possible configurations with `TextVectorizer`, please consult286the287[official documentation](https://keras.io/api/layers/preprocessing_layers/text/text_vectorization).288289**Note**: Setting the `max_tokens` argument to a pre-calculated vocabulary size is290not a requirement.291"""292293"""294## Create a text classification model295296We will keep our model simple -- it will be a small stack of fully-connected layers with297ReLU as the non-linearity.298299"""300301302def make_model():303shallow_mlp_model = keras.Sequential(304[305layers.Dense(512, activation="relu"),306layers.Dense(256, activation="relu"),307layers.Dense(lookup.vocabulary_size(), activation="sigmoid"),308] # More on why "sigmoid" has been used here in a moment.309)310return shallow_mlp_model311312313"""314## Train the model315316We will train our model using the binary crossentropy loss. This is because the labels317are not disjoint. For a given abstract, we may have multiple categories. So, we will318divide the prediction task into a series of multiple binary classification problems. This319is also why we kept the activation function of the classification layer in our model to320sigmoid. Researchers have used other combinations of loss function and activation321function as well. For example, in [Exploring the Limits of Weakly Supervised Pretraining](https://arxiv.org/abs/1805.00932),322Mahajan et al. used the softmax activation function and cross-entropy loss to train323their models.324325There are several options of metrics that can be used in multi-label classification.326To keep this code example narrow we decided to use the327[binary accuracy metric](https://keras.io/api/metrics/accuracy_metrics/#binaryaccuracy-class).328To see the explanation why this metric is used we refer to this329[pull-request](https://github.com/keras-team/keras-io/pull/1133#issuecomment-1322736860).330There are also other suitable metrics for multi-label classification, like331[F1 Score](https://www.tensorflow.org/addons/api_docs/python/tfa/metrics/F1Score) or332[Hamming loss](https://www.tensorflow.org/addons/api_docs/python/tfa/metrics/HammingLoss).333"""334335epochs = 20336337shallow_mlp_model = make_model()338shallow_mlp_model.compile(339loss="binary_crossentropy", optimizer="adam", metrics=["binary_accuracy"]340)341342history = shallow_mlp_model.fit(343train_dataset, validation_data=validation_dataset, epochs=epochs344)345346347def plot_result(item):348plt.plot(history.history[item], label=item)349plt.plot(history.history["val_" + item], label="val_" + item)350plt.xlabel("Epochs")351plt.ylabel(item)352plt.title("Train and Validation {} Over Epochs".format(item), fontsize=14)353plt.legend()354plt.grid()355plt.show()356357358plot_result("loss")359plot_result("binary_accuracy")360361"""362While training, we notice an initial sharp fall in the loss followed by a gradual decay.363"""364365"""366### Evaluate the model367"""368369_, binary_acc = shallow_mlp_model.evaluate(test_dataset)370print(f"Categorical accuracy on the test set: {round(binary_acc * 100, 2)}%.")371372"""373The trained model gives us an evaluation accuracy of ~99%.374"""375376"""377## Inference378379An important feature of the380[preprocessing layers provided by Keras](https://keras.io/api/layers/preprocessing_layers/)381is that they can be included inside a `tf.keras.Model`. We will export an inference model382by including the `text_vectorization` layer on top of `shallow_mlp_model`. This will383allow our inference model to directly operate on raw strings.384385**Note** that during training it is always preferable to use these preprocessing386layers as a part of the data input pipeline rather than the model to avoid387surfacing bottlenecks for the hardware accelerators. This also allows for388asynchronous data processing.389"""390391392# We create a custom Model to override the predict method so393# that it first vectorizes text data394class ModelEndtoEnd(keras.Model):395396def predict(self, inputs):397indices = text_vectorizer(inputs)398return super().predict(indices)399400401def get_inference_model(model):402inputs = shallow_mlp_model.inputs403outputs = shallow_mlp_model.outputs404end_to_end_model = ModelEndtoEnd(inputs, outputs, name="end_to_end_model")405end_to_end_model.compile(406optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"]407)408return end_to_end_model409410411model_for_inference = get_inference_model(shallow_mlp_model)412413# Create a small dataset just for demonstrating inference.414inference_dataset = make_dataset(test_df.sample(2), is_train=False)415text_batch, label_batch = next(iter(inference_dataset))416predicted_probabilities = model_for_inference.predict(text_batch)417418419# Perform inference.420for i, text in enumerate(text_batch[:5]):421label = label_batch[i].numpy()[None, ...]422print(f"Abstract: {text}")423print(f"Label(s): {invert_multi_hot(label[0])}")424predicted_proba = [proba for proba in predicted_probabilities[i]]425top_3_labels = [426x427for _, x in sorted(428zip(predicted_probabilities[i], lookup.get_vocabulary()),429key=lambda pair: pair[0],430reverse=True,431)432][:3]433print(f"Predicted Label(s): ({', '.join([label for label in top_3_labels])})")434print(" ")435436"""437The prediction results are not that great but not below the par for a simple model like438ours. We can improve this performance with models that consider word order like LSTM or439even those that use Transformers ([Vaswani et al.](https://arxiv.org/abs/1706.03762)).440"""441442"""443## Acknowledgements444445We would like to thank [Matt Watson](https://github.com/mattdangerw) for helping us446tackle the multi-label binarization part and inverse-transforming the processed labels447to the original form.448449Thanks to [Cingis Kratochvil](https://github.com/cumbalik) for suggesting and extending this code example by introducing binary accuracy as the evaluation metric.450"""451452453