Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/nlp/multi_label_classification.py
3507 views
1
"""
2
Title: Large-scale multi-label text classification
3
Author: [Sayak Paul](https://twitter.com/RisingSayak), [Soumik Rakshit](https://github.com/soumik12345)
4
Date created: 2020/09/25
5
Last modified: 2025/02/27
6
Description: Implementing a large-scale multi-label text classification model.
7
Accelerator: GPU
8
Converted to keras 3 and made backend-agnostic by: [Humbulani Ndou](https://github.com/Humbulani1234)
9
"""
10
11
"""
12
## Introduction
13
14
In this example, we will build a multi-label text classifier to predict the subject areas
15
of arXiv papers from their abstract bodies. This type of classifier can be useful for
16
conference submission portals like [OpenReview](https://openreview.net/). Given a paper
17
abstract, the portal could provide suggestions for which areas the paper would
18
best belong to.
19
20
The dataset was collected using the
21
[`arXiv` Python library](https://github.com/lukasschwab/arxiv.py)
22
that provides a wrapper around the
23
[original arXiv API](http://arxiv.org/help/api/index).
24
To learn more about the data collection process, please refer to
25
[this notebook](https://github.com/soumik12345/multi-label-text-classification/blob/master/arxiv_scrape.ipynb).
26
Additionally, you can also find the dataset on
27
[Kaggle](https://www.kaggle.com/spsayakpaul/arxiv-paper-abstracts).
28
"""
29
30
"""
31
## Imports
32
"""
33
34
import os
35
36
os.environ["KERAS_BACKEND"] = "jax" # or tensorflow, or torch
37
38
import keras
39
from keras import layers, ops
40
41
from sklearn.model_selection import train_test_split
42
43
from ast import literal_eval
44
import matplotlib.pyplot as plt
45
import pandas as pd
46
import numpy as np
47
48
"""
49
## Perform exploratory data analysis
50
51
In this section, we first load the dataset into a `pandas` dataframe and then perform
52
some basic exploratory data analysis (EDA).
53
"""
54
55
arxiv_data = pd.read_csv(
56
"https://github.com/soumik12345/multi-label-text-classification/releases/download/v0.2/arxiv_data.csv"
57
)
58
arxiv_data.head()
59
60
"""
61
Our text features are present in the `summaries` column and their corresponding labels
62
are in `terms`. As you can notice, there are multiple categories associated with a
63
particular entry.
64
"""
65
66
print(f"There are {len(arxiv_data)} rows in the dataset.")
67
68
"""
69
Real-world data is noisy. One of the most commonly observed source of noise is data
70
duplication. Here we notice that our initial dataset has got about 13k duplicate entries.
71
"""
72
73
total_duplicate_titles = sum(arxiv_data["titles"].duplicated())
74
print(f"There are {total_duplicate_titles} duplicate titles.")
75
76
"""
77
Before proceeding further, we drop these entries.
78
"""
79
80
arxiv_data = arxiv_data[~arxiv_data["titles"].duplicated()]
81
print(f"There are {len(arxiv_data)} rows in the deduplicated dataset.")
82
83
# There are some terms with occurrence as low as 1.
84
print(sum(arxiv_data["terms"].value_counts() == 1))
85
86
# How many unique terms?
87
print(arxiv_data["terms"].nunique())
88
89
"""
90
As observed above, out of 3,157 unique combinations of `terms`, 2,321 entries have the
91
lowest occurrence. To prepare our train, validation, and test sets with
92
[stratification](https://en.wikipedia.org/wiki/Stratified_sampling), we need to drop
93
these terms.
94
"""
95
96
# Filtering the rare terms.
97
arxiv_data_filtered = arxiv_data.groupby("terms").filter(lambda x: len(x) > 1)
98
arxiv_data_filtered.shape
99
100
"""
101
## Convert the string labels to lists of strings
102
103
The initial labels are represented as raw strings. Here we make them `List[str]` for a
104
more compact representation.
105
"""
106
107
arxiv_data_filtered["terms"] = arxiv_data_filtered["terms"].apply(
108
lambda x: literal_eval(x)
109
)
110
arxiv_data_filtered["terms"].values[:5]
111
112
"""
113
## Use stratified splits because of class imbalance
114
115
The dataset has a
116
[class imbalance problem](https://developers.google.com/machine-learning/glossary/#class-imbalanced-dataset).
117
So, to have a fair evaluation result, we need to ensure the datasets are sampled with
118
stratification. To know more about different strategies to deal with the class imbalance
119
problem, you can follow
120
[this tutorial](https://www.tensorflow.org/tutorials/structured_data/imbalanced_data).
121
For an end-to-end demonstration of classification with imbablanced data, refer to
122
[Imbalanced classification: credit card fraud detection](https://keras.io/examples/structured_data/imbalanced_classification/).
123
"""
124
125
test_split = 0.1
126
127
# Initial train and test split.
128
train_df, test_df = train_test_split(
129
arxiv_data_filtered,
130
test_size=test_split,
131
stratify=arxiv_data_filtered["terms"].values,
132
)
133
134
# Splitting the test set further into validation
135
# and new test sets.
136
val_df = test_df.sample(frac=0.5)
137
test_df.drop(val_df.index, inplace=True)
138
139
print(f"Number of rows in training set: {len(train_df)}")
140
print(f"Number of rows in validation set: {len(val_df)}")
141
print(f"Number of rows in test set: {len(test_df)}")
142
143
"""
144
## Multi-label binarization
145
146
Now we preprocess our labels using the
147
[`StringLookup`](https://keras.io/api/layers/preprocessing_layers/categorical/string_lookup)
148
layer.
149
"""
150
151
# For RaggedTensor
152
import tensorflow as tf
153
154
terms = tf.ragged.constant(train_df["terms"].values)
155
lookup = layers.StringLookup(output_mode="multi_hot")
156
lookup.adapt(terms)
157
vocab = lookup.get_vocabulary()
158
159
160
def invert_multi_hot(encoded_labels):
161
"""Reverse a single multi-hot encoded label to a tuple of vocab terms."""
162
hot_indices = np.argwhere(encoded_labels == 1.0)[..., 0]
163
return np.take(vocab, hot_indices)
164
165
166
print("Vocabulary:\n")
167
print(vocab)
168
169
170
"""
171
Here we are separating the individual unique classes available from the label
172
pool and then using this information to represent a given label set with 0's and 1's.
173
Below is an example.
174
"""
175
176
sample_label = train_df["terms"].iloc[0]
177
print(f"Original label: {sample_label}")
178
179
label_binarized = lookup([sample_label])
180
print(f"Label-binarized representation: {label_binarized}")
181
182
"""
183
## Data preprocessing and `tf.data.Dataset` objects
184
185
We first get percentile estimates of the sequence lengths. The purpose will be clear in a
186
moment.
187
"""
188
189
train_df["summaries"].apply(lambda x: len(x.split(" "))).describe()
190
191
"""
192
Notice that 50% of the abstracts have a length of 154 (you may get a different number
193
based on the split). So, any number close to that value is a good enough approximate for the
194
maximum sequence length.
195
196
Now, we implement utilities to prepare our datasets.
197
"""
198
199
max_seqlen = 150
200
batch_size = 128
201
padding_token = "<pad>"
202
auto = tf.data.AUTOTUNE
203
204
205
def make_dataset(dataframe, is_train=True):
206
labels = tf.ragged.constant(dataframe["terms"].values)
207
label_binarized = lookup(labels).numpy()
208
dataset = tf.data.Dataset.from_tensor_slices(
209
(dataframe["summaries"].values, label_binarized)
210
)
211
dataset = dataset.shuffle(batch_size * 10) if is_train else dataset
212
return dataset.batch(batch_size)
213
214
215
"""
216
Now we can prepare the `tf.data.Dataset` objects.
217
"""
218
219
train_dataset = make_dataset(train_df, is_train=True)
220
validation_dataset = make_dataset(val_df, is_train=False)
221
test_dataset = make_dataset(test_df, is_train=False)
222
223
"""
224
## Dataset preview
225
"""
226
227
text_batch, label_batch = next(iter(train_dataset))
228
229
for i, text in enumerate(text_batch[:5]):
230
label = label_batch[i].numpy()[None, ...]
231
print(f"Abstract: {text}")
232
print(f"Label(s): {invert_multi_hot(label[0])}")
233
print(" ")
234
235
"""
236
## Vectorization
237
238
Before we feed the data to our model, we need to vectorize it (represent it in a numerical form).
239
For that purpose, we will use the
240
[`TextVectorization` layer](https://keras.io/api/layers/preprocessing_layers/text/text_vectorization).
241
It can operate as a part of your main model so that the model is excluded from the core
242
preprocessing logic. This greatly reduces the chances of training / serving skew during inference.
243
244
We first calculate the number of unique words present in the abstracts.
245
"""
246
247
# Source: https://stackoverflow.com/a/18937309/7636462
248
vocabulary = set()
249
train_df["summaries"].str.lower().str.split().apply(vocabulary.update)
250
vocabulary_size = len(vocabulary)
251
print(vocabulary_size)
252
253
254
"""
255
We now create our vectorization layer and `map()` to the `tf.data.Dataset`s created
256
earlier.
257
"""
258
259
text_vectorizer = layers.TextVectorization(
260
max_tokens=vocabulary_size, ngrams=2, output_mode="tf_idf"
261
)
262
263
# `TextVectorization` layer needs to be adapted as per the vocabulary from our
264
# training set.
265
with tf.device("/CPU:0"):
266
text_vectorizer.adapt(train_dataset.map(lambda text, label: text))
267
268
train_dataset = train_dataset.map(
269
lambda text, label: (text_vectorizer(text), label), num_parallel_calls=auto
270
).prefetch(auto)
271
validation_dataset = validation_dataset.map(
272
lambda text, label: (text_vectorizer(text), label), num_parallel_calls=auto
273
).prefetch(auto)
274
test_dataset = test_dataset.map(
275
lambda text, label: (text_vectorizer(text), label), num_parallel_calls=auto
276
).prefetch(auto)
277
278
279
"""
280
A batch of raw text will first go through the `TextVectorization` layer and it will
281
generate their integer representations. Internally, the `TextVectorization` layer will
282
first create bi-grams out of the sequences and then represent them using
283
[TF-IDF](https://wikipedia.org/wiki/Tf%E2%80%93idf). The output representations will then
284
be passed to the shallow model responsible for text classification.
285
286
To learn more about other possible configurations with `TextVectorizer`, please consult
287
the
288
[official documentation](https://keras.io/api/layers/preprocessing_layers/text/text_vectorization).
289
290
**Note**: Setting the `max_tokens` argument to a pre-calculated vocabulary size is
291
not a requirement.
292
"""
293
294
"""
295
## Create a text classification model
296
297
We will keep our model simple -- it will be a small stack of fully-connected layers with
298
ReLU as the non-linearity.
299
300
"""
301
302
303
def make_model():
304
shallow_mlp_model = keras.Sequential(
305
[
306
layers.Dense(512, activation="relu"),
307
layers.Dense(256, activation="relu"),
308
layers.Dense(lookup.vocabulary_size(), activation="sigmoid"),
309
] # More on why "sigmoid" has been used here in a moment.
310
)
311
return shallow_mlp_model
312
313
314
"""
315
## Train the model
316
317
We will train our model using the binary crossentropy loss. This is because the labels
318
are not disjoint. For a given abstract, we may have multiple categories. So, we will
319
divide the prediction task into a series of multiple binary classification problems. This
320
is also why we kept the activation function of the classification layer in our model to
321
sigmoid. Researchers have used other combinations of loss function and activation
322
function as well. For example, in [Exploring the Limits of Weakly Supervised Pretraining](https://arxiv.org/abs/1805.00932),
323
Mahajan et al. used the softmax activation function and cross-entropy loss to train
324
their models.
325
326
There are several options of metrics that can be used in multi-label classification.
327
To keep this code example narrow we decided to use the
328
[binary accuracy metric](https://keras.io/api/metrics/accuracy_metrics/#binaryaccuracy-class).
329
To see the explanation why this metric is used we refer to this
330
[pull-request](https://github.com/keras-team/keras-io/pull/1133#issuecomment-1322736860).
331
There are also other suitable metrics for multi-label classification, like
332
[F1 Score](https://www.tensorflow.org/addons/api_docs/python/tfa/metrics/F1Score) or
333
[Hamming loss](https://www.tensorflow.org/addons/api_docs/python/tfa/metrics/HammingLoss).
334
"""
335
336
epochs = 20
337
338
shallow_mlp_model = make_model()
339
shallow_mlp_model.compile(
340
loss="binary_crossentropy", optimizer="adam", metrics=["binary_accuracy"]
341
)
342
343
history = shallow_mlp_model.fit(
344
train_dataset, validation_data=validation_dataset, epochs=epochs
345
)
346
347
348
def plot_result(item):
349
plt.plot(history.history[item], label=item)
350
plt.plot(history.history["val_" + item], label="val_" + item)
351
plt.xlabel("Epochs")
352
plt.ylabel(item)
353
plt.title("Train and Validation {} Over Epochs".format(item), fontsize=14)
354
plt.legend()
355
plt.grid()
356
plt.show()
357
358
359
plot_result("loss")
360
plot_result("binary_accuracy")
361
362
"""
363
While training, we notice an initial sharp fall in the loss followed by a gradual decay.
364
"""
365
366
"""
367
### Evaluate the model
368
"""
369
370
_, binary_acc = shallow_mlp_model.evaluate(test_dataset)
371
print(f"Categorical accuracy on the test set: {round(binary_acc * 100, 2)}%.")
372
373
"""
374
The trained model gives us an evaluation accuracy of ~99%.
375
"""
376
377
"""
378
## Inference
379
380
An important feature of the
381
[preprocessing layers provided by Keras](https://keras.io/api/layers/preprocessing_layers/)
382
is that they can be included inside a `tf.keras.Model`. We will export an inference model
383
by including the `text_vectorization` layer on top of `shallow_mlp_model`. This will
384
allow our inference model to directly operate on raw strings.
385
386
**Note** that during training it is always preferable to use these preprocessing
387
layers as a part of the data input pipeline rather than the model to avoid
388
surfacing bottlenecks for the hardware accelerators. This also allows for
389
asynchronous data processing.
390
"""
391
392
393
# We create a custom Model to override the predict method so
394
# that it first vectorizes text data
395
class ModelEndtoEnd(keras.Model):
396
397
def predict(self, inputs):
398
indices = text_vectorizer(inputs)
399
return super().predict(indices)
400
401
402
def get_inference_model(model):
403
inputs = shallow_mlp_model.inputs
404
outputs = shallow_mlp_model.outputs
405
end_to_end_model = ModelEndtoEnd(inputs, outputs, name="end_to_end_model")
406
end_to_end_model.compile(
407
optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"]
408
)
409
return end_to_end_model
410
411
412
model_for_inference = get_inference_model(shallow_mlp_model)
413
414
# Create a small dataset just for demonstrating inference.
415
inference_dataset = make_dataset(test_df.sample(2), is_train=False)
416
text_batch, label_batch = next(iter(inference_dataset))
417
predicted_probabilities = model_for_inference.predict(text_batch)
418
419
420
# Perform inference.
421
for i, text in enumerate(text_batch[:5]):
422
label = label_batch[i].numpy()[None, ...]
423
print(f"Abstract: {text}")
424
print(f"Label(s): {invert_multi_hot(label[0])}")
425
predicted_proba = [proba for proba in predicted_probabilities[i]]
426
top_3_labels = [
427
x
428
for _, x in sorted(
429
zip(predicted_probabilities[i], lookup.get_vocabulary()),
430
key=lambda pair: pair[0],
431
reverse=True,
432
)
433
][:3]
434
print(f"Predicted Label(s): ({', '.join([label for label in top_3_labels])})")
435
print(" ")
436
437
"""
438
The prediction results are not that great but not below the par for a simple model like
439
ours. We can improve this performance with models that consider word order like LSTM or
440
even those that use Transformers ([Vaswani et al.](https://arxiv.org/abs/1706.03762)).
441
"""
442
443
"""
444
## Acknowledgements
445
446
We would like to thank [Matt Watson](https://github.com/mattdangerw) for helping us
447
tackle the multi-label binarization part and inverse-transforming the processed labels
448
to the original form.
449
450
Thanks to [Cingis Kratochvil](https://github.com/cumbalik) for suggesting and extending this code example by introducing binary accuracy as the evaluation metric.
451
"""
452
453