Path: blob/master/examples/nlp/ipynb/multi_label_classification.ipynb
3508 views
Large-scale multi-label text classification
Author: Sayak Paul, Soumik Rakshit
Date created: 2020/09/25
Last modified: 2025/02/27
Description: Implementing a large-scale multi-label text classification model.
Introduction
In this example, we will build a multi-label text classifier to predict the subject areas of arXiv papers from their abstract bodies. This type of classifier can be useful for conference submission portals like OpenReview. Given a paper abstract, the portal could provide suggestions for which areas the paper would best belong to.
The dataset was collected using the arXiv
Python library that provides a wrapper around the original arXiv API. To learn more about the data collection process, please refer to this notebook. Additionally, you can also find the dataset on Kaggle.
Imports
Perform exploratory data analysis
In this section, we first load the dataset into a pandas
dataframe and then perform some basic exploratory data analysis (EDA).
Our text features are present in the summaries
column and their corresponding labels are in terms
. As you can notice, there are multiple categories associated with a particular entry.
Real-world data is noisy. One of the most commonly observed source of noise is data duplication. Here we notice that our initial dataset has got about 13k duplicate entries.
Before proceeding further, we drop these entries.
As observed above, out of 3,157 unique combinations of terms
, 2,321 entries have the lowest occurrence. To prepare our train, validation, and test sets with stratification, we need to drop these terms.
Convert the string labels to lists of strings
The initial labels are represented as raw strings. Here we make them List[str]
for a more compact representation.
Use stratified splits because of class imbalance
The dataset has a class imbalance problem. So, to have a fair evaluation result, we need to ensure the datasets are sampled with stratification. To know more about different strategies to deal with the class imbalance problem, you can follow this tutorial. For an end-to-end demonstration of classification with imbablanced data, refer to Imbalanced classification: credit card fraud detection.
Multi-label binarization
Now we preprocess our labels using the StringLookup
layer.
Here we are separating the individual unique classes available from the label pool and then using this information to represent a given label set with 0's and 1's. Below is an example.
Data preprocessing and tf.data.Dataset
objects
We first get percentile estimates of the sequence lengths. The purpose will be clear in a moment.
Notice that 50% of the abstracts have a length of 154 (you may get a different number based on the split). So, any number close to that value is a good enough approximate for the maximum sequence length.
Now, we implement utilities to prepare our datasets.
Now we can prepare the tf.data.Dataset
objects.
Dataset preview
Vectorization
Before we feed the data to our model, we need to vectorize it (represent it in a numerical form). For that purpose, we will use the TextVectorization
layer. It can operate as a part of your main model so that the model is excluded from the core preprocessing logic. This greatly reduces the chances of training / serving skew during inference.
We first calculate the number of unique words present in the abstracts.
We now create our vectorization layer and map()
to the tf.data.Dataset
s created earlier.
A batch of raw text will first go through the TextVectorization
layer and it will generate their integer representations. Internally, the TextVectorization
layer will first create bi-grams out of the sequences and then represent them using TF-IDF. The output representations will then be passed to the shallow model responsible for text classification.
To learn more about other possible configurations with TextVectorizer
, please consult the official documentation.
Note: Setting the max_tokens
argument to a pre-calculated vocabulary size is not a requirement.
Create a text classification model
We will keep our model simple -- it will be a small stack of fully-connected layers with ReLU as the non-linearity.
Train the model
We will train our model using the binary crossentropy loss. This is because the labels are not disjoint. For a given abstract, we may have multiple categories. So, we will divide the prediction task into a series of multiple binary classification problems. This is also why we kept the activation function of the classification layer in our model to sigmoid. Researchers have used other combinations of loss function and activation function as well. For example, in Exploring the Limits of Weakly Supervised Pretraining, Mahajan et al. used the softmax activation function and cross-entropy loss to train their models.
There are several options of metrics that can be used in multi-label classification. To keep this code example narrow we decided to use the binary accuracy metric. To see the explanation why this metric is used we refer to this pull-request. There are also other suitable metrics for multi-label classification, like F1 Score or Hamming loss.
While training, we notice an initial sharp fall in the loss followed by a gradual decay.
Evaluate the model
The trained model gives us an evaluation accuracy of ~99%.
Inference
An important feature of the preprocessing layers provided by Keras is that they can be included inside a tf.keras.Model
. We will export an inference model by including the text_vectorization
layer on top of shallow_mlp_model
. This will allow our inference model to directly operate on raw strings.
Note that during training it is always preferable to use these preprocessing layers as a part of the data input pipeline rather than the model to avoid surfacing bottlenecks for the hardware accelerators. This also allows for asynchronous data processing.
The prediction results are not that great but not below the par for a simple model like ours. We can improve this performance with models that consider word order like LSTM or even those that use Transformers (Vaswani et al.).
Acknowledgements
We would like to thank Matt Watson for helping us tackle the multi-label binarization part and inverse-transforming the processed labels to the original form.
Thanks to Cingis Kratochvil for suggesting and extending this code example by introducing binary accuracy as the evaluation metric.