Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/nlp/ipynb/multi_label_classification.ipynb
3508 views
Kernel: Python 3

Large-scale multi-label text classification

Author: Sayak Paul, Soumik Rakshit
Date created: 2020/09/25
Last modified: 2025/02/27
Description: Implementing a large-scale multi-label text classification model.

Introduction

In this example, we will build a multi-label text classifier to predict the subject areas of arXiv papers from their abstract bodies. This type of classifier can be useful for conference submission portals like OpenReview. Given a paper abstract, the portal could provide suggestions for which areas the paper would best belong to.

The dataset was collected using the arXiv Python library that provides a wrapper around the original arXiv API. To learn more about the data collection process, please refer to this notebook. Additionally, you can also find the dataset on Kaggle.

Imports

import os os.environ["KERAS_BACKEND"] = "jax" # or tensorflow, or torch import keras from keras import layers, ops from sklearn.model_selection import train_test_split from ast import literal_eval import matplotlib.pyplot as plt import pandas as pd import numpy as np

Perform exploratory data analysis

In this section, we first load the dataset into a pandas dataframe and then perform some basic exploratory data analysis (EDA).

arxiv_data = pd.read_csv( "https://github.com/soumik12345/multi-label-text-classification/releases/download/v0.2/arxiv_data.csv" ) arxiv_data.head()

Our text features are present in the summaries column and their corresponding labels are in terms. As you can notice, there are multiple categories associated with a particular entry.

print(f"There are {len(arxiv_data)} rows in the dataset.")

Real-world data is noisy. One of the most commonly observed source of noise is data duplication. Here we notice that our initial dataset has got about 13k duplicate entries.

total_duplicate_titles = sum(arxiv_data["titles"].duplicated()) print(f"There are {total_duplicate_titles} duplicate titles.")

Before proceeding further, we drop these entries.

arxiv_data = arxiv_data[~arxiv_data["titles"].duplicated()] print(f"There are {len(arxiv_data)} rows in the deduplicated dataset.") # There are some terms with occurrence as low as 1. print(sum(arxiv_data["terms"].value_counts() == 1)) # How many unique terms? print(arxiv_data["terms"].nunique())

As observed above, out of 3,157 unique combinations of terms, 2,321 entries have the lowest occurrence. To prepare our train, validation, and test sets with stratification, we need to drop these terms.

# Filtering the rare terms. arxiv_data_filtered = arxiv_data.groupby("terms").filter(lambda x: len(x) > 1) arxiv_data_filtered.shape

Convert the string labels to lists of strings

The initial labels are represented as raw strings. Here we make them List[str] for a more compact representation.

arxiv_data_filtered["terms"] = arxiv_data_filtered["terms"].apply( lambda x: literal_eval(x) ) arxiv_data_filtered["terms"].values[:5]

Use stratified splits because of class imbalance

The dataset has a class imbalance problem. So, to have a fair evaluation result, we need to ensure the datasets are sampled with stratification. To know more about different strategies to deal with the class imbalance problem, you can follow this tutorial. For an end-to-end demonstration of classification with imbablanced data, refer to Imbalanced classification: credit card fraud detection.

test_split = 0.1 # Initial train and test split. train_df, test_df = train_test_split( arxiv_data_filtered, test_size=test_split, stratify=arxiv_data_filtered["terms"].values, ) # Splitting the test set further into validation # and new test sets. val_df = test_df.sample(frac=0.5) test_df.drop(val_df.index, inplace=True) print(f"Number of rows in training set: {len(train_df)}") print(f"Number of rows in validation set: {len(val_df)}") print(f"Number of rows in test set: {len(test_df)}")

Multi-label binarization

Now we preprocess our labels using the StringLookup layer.

# For RaggedTensor import tensorflow as tf terms = tf.ragged.constant(train_df["terms"].values) lookup = layers.StringLookup(output_mode="multi_hot") lookup.adapt(terms) vocab = lookup.get_vocabulary() def invert_multi_hot(encoded_labels): """Reverse a single multi-hot encoded label to a tuple of vocab terms.""" hot_indices = np.argwhere(encoded_labels == 1.0)[..., 0] return np.take(vocab, hot_indices) print("Vocabulary:\n") print(vocab)

Here we are separating the individual unique classes available from the label pool and then using this information to represent a given label set with 0's and 1's. Below is an example.

sample_label = train_df["terms"].iloc[0] print(f"Original label: {sample_label}") label_binarized = lookup([sample_label]) print(f"Label-binarized representation: {label_binarized}")

Data preprocessing and tf.data.Dataset objects

We first get percentile estimates of the sequence lengths. The purpose will be clear in a moment.

train_df["summaries"].apply(lambda x: len(x.split(" "))).describe()

Notice that 50% of the abstracts have a length of 154 (you may get a different number based on the split). So, any number close to that value is a good enough approximate for the maximum sequence length.

Now, we implement utilities to prepare our datasets.

max_seqlen = 150 batch_size = 128 padding_token = "<pad>" auto = tf.data.AUTOTUNE def make_dataset(dataframe, is_train=True): labels = tf.ragged.constant(dataframe["terms"].values) label_binarized = lookup(labels).numpy() dataset = tf.data.Dataset.from_tensor_slices( (dataframe["summaries"].values, label_binarized) ) dataset = dataset.shuffle(batch_size * 10) if is_train else dataset return dataset.batch(batch_size)

Now we can prepare the tf.data.Dataset objects.

train_dataset = make_dataset(train_df, is_train=True) validation_dataset = make_dataset(val_df, is_train=False) test_dataset = make_dataset(test_df, is_train=False)

Dataset preview

text_batch, label_batch = next(iter(train_dataset)) for i, text in enumerate(text_batch[:5]): label = label_batch[i].numpy()[None, ...] print(f"Abstract: {text}") print(f"Label(s): {invert_multi_hot(label[0])}") print(" ")

Vectorization

Before we feed the data to our model, we need to vectorize it (represent it in a numerical form). For that purpose, we will use the TextVectorization layer. It can operate as a part of your main model so that the model is excluded from the core preprocessing logic. This greatly reduces the chances of training / serving skew during inference.

We first calculate the number of unique words present in the abstracts.

# Source: https://stackoverflow.com/a/18937309/7636462 vocabulary = set() train_df["summaries"].str.lower().str.split().apply(vocabulary.update) vocabulary_size = len(vocabulary) print(vocabulary_size)

We now create our vectorization layer and map() to the tf.data.Datasets created earlier.

text_vectorizer = layers.TextVectorization( max_tokens=vocabulary_size, ngrams=2, output_mode="tf_idf" ) # `TextVectorization` layer needs to be adapted as per the vocabulary from our # training set. with tf.device("/CPU:0"): text_vectorizer.adapt(train_dataset.map(lambda text, label: text)) train_dataset = train_dataset.map( lambda text, label: (text_vectorizer(text), label), num_parallel_calls=auto ).prefetch(auto) validation_dataset = validation_dataset.map( lambda text, label: (text_vectorizer(text), label), num_parallel_calls=auto ).prefetch(auto) test_dataset = test_dataset.map( lambda text, label: (text_vectorizer(text), label), num_parallel_calls=auto ).prefetch(auto)

A batch of raw text will first go through the TextVectorization layer and it will generate their integer representations. Internally, the TextVectorization layer will first create bi-grams out of the sequences and then represent them using TF-IDF. The output representations will then be passed to the shallow model responsible for text classification.

To learn more about other possible configurations with TextVectorizer, please consult the official documentation.

Note: Setting the max_tokens argument to a pre-calculated vocabulary size is not a requirement.

Create a text classification model

We will keep our model simple -- it will be a small stack of fully-connected layers with ReLU as the non-linearity.

def make_model(): shallow_mlp_model = keras.Sequential( [ layers.Dense(512, activation="relu"), layers.Dense(256, activation="relu"), layers.Dense(lookup.vocabulary_size(), activation="sigmoid"), ] # More on why "sigmoid" has been used here in a moment. ) return shallow_mlp_model

Train the model

We will train our model using the binary crossentropy loss. This is because the labels are not disjoint. For a given abstract, we may have multiple categories. So, we will divide the prediction task into a series of multiple binary classification problems. This is also why we kept the activation function of the classification layer in our model to sigmoid. Researchers have used other combinations of loss function and activation function as well. For example, in Exploring the Limits of Weakly Supervised Pretraining, Mahajan et al. used the softmax activation function and cross-entropy loss to train their models.

There are several options of metrics that can be used in multi-label classification. To keep this code example narrow we decided to use the binary accuracy metric. To see the explanation why this metric is used we refer to this pull-request. There are also other suitable metrics for multi-label classification, like F1 Score or Hamming loss.

epochs = 20 shallow_mlp_model = make_model() shallow_mlp_model.compile( loss="binary_crossentropy", optimizer="adam", metrics=["binary_accuracy"] ) history = shallow_mlp_model.fit( train_dataset, validation_data=validation_dataset, epochs=epochs ) def plot_result(item): plt.plot(history.history[item], label=item) plt.plot(history.history["val_" + item], label="val_" + item) plt.xlabel("Epochs") plt.ylabel(item) plt.title("Train and Validation {} Over Epochs".format(item), fontsize=14) plt.legend() plt.grid() plt.show() plot_result("loss") plot_result("binary_accuracy")

While training, we notice an initial sharp fall in the loss followed by a gradual decay.

Evaluate the model

_, binary_acc = shallow_mlp_model.evaluate(test_dataset) print(f"Categorical accuracy on the test set: {round(binary_acc * 100, 2)}%.")

The trained model gives us an evaluation accuracy of ~99%.

Inference

An important feature of the preprocessing layers provided by Keras is that they can be included inside a tf.keras.Model. We will export an inference model by including the text_vectorization layer on top of shallow_mlp_model. This will allow our inference model to directly operate on raw strings.

Note that during training it is always preferable to use these preprocessing layers as a part of the data input pipeline rather than the model to avoid surfacing bottlenecks for the hardware accelerators. This also allows for asynchronous data processing.

# We create a custom Model to override the predict method so # that it first vectorizes text data class ModelEndtoEnd(keras.Model): def predict(self, inputs): indices = text_vectorizer(inputs) return super().predict(indices) def get_inference_model(model): inputs = shallow_mlp_model.inputs outputs = shallow_mlp_model.outputs end_to_end_model = ModelEndtoEnd(inputs, outputs, name="end_to_end_model") end_to_end_model.compile( optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"] ) return end_to_end_model model_for_inference = get_inference_model(shallow_mlp_model) # Create a small dataset just for demonstrating inference. inference_dataset = make_dataset(test_df.sample(2), is_train=False) text_batch, label_batch = next(iter(inference_dataset)) predicted_probabilities = model_for_inference.predict(text_batch) # Perform inference. for i, text in enumerate(text_batch[:5]): label = label_batch[i].numpy()[None, ...] print(f"Abstract: {text}") print(f"Label(s): {invert_multi_hot(label[0])}") predicted_proba = [proba for proba in predicted_probabilities[i]] top_3_labels = [ x for _, x in sorted( zip(predicted_probabilities[i], lookup.get_vocabulary()), key=lambda pair: pair[0], reverse=True, ) ][:3] print(f"Predicted Label(s): ({', '.join([label for label in top_3_labels])})") print(" ")

The prediction results are not that great but not below the par for a simple model like ours. We can improve this performance with models that consider word order like LSTM or even those that use Transformers (Vaswani et al.).

Acknowledgements

We would like to thank Matt Watson for helping us tackle the multi-label binarization part and inverse-transforming the processed labels to the original form.

Thanks to Cingis Kratochvil for suggesting and extending this code example by introducing binary accuracy as the evaluation metric.