Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/nlp/abstractive_summarization_with_bart.py
3507 views
1
"""
2
Title: Abstractive Text Summarization with BART
3
Author: [Abheesht Sharma](https://github.com/abheesht17/)
4
Date created: 2023/07/08
5
Last modified: 2024/03/20
6
Description: Use KerasHub to fine-tune BART on the abstractive summarization task.
7
Accelerator: GPU
8
Converted to Keras 3 by: [Sitam Meur](https://github.com/sitamgithub-MSIT)
9
"""
10
11
"""
12
## Introduction
13
14
In the era of information overload, it has become crucial to extract the crux
15
of a long document or a conversation and express it in a few sentences. Owing
16
to the fact that summarization has widespread applications in different domains,
17
it has become a key, well-studied NLP task in recent years.
18
19
[Bidirectional Autoregressive Transformer (BART)](https://arxiv.org/abs/1910.13461)
20
is a Transformer-based encoder-decoder model, often used for
21
sequence-to-sequence tasks like summarization and neural machine translation.
22
BART is pre-trained in a self-supervised fashion on a large text corpus. During
23
pre-training, the text is corrupted and BART is trained to reconstruct the
24
original text (hence called a "denoising autoencoder"). Some pre-training tasks
25
include token masking, token deletion, sentence permutation (shuffle sentences
26
and train BART to fix the order), etc.
27
28
In this example, we will demonstrate how to fine-tune BART on the abstractive
29
summarization task (on conversations!) using KerasHub, and generate summaries
30
using the fine-tuned model.
31
"""
32
33
"""
34
## Setup
35
36
Before we start implementing the pipeline, let's install and import all the
37
libraries we need. We'll be using the KerasHub library. We will also need a
38
couple of utility libraries.
39
"""
40
41
"""shell
42
pip install git+https://github.com/keras-team/keras-hub.git py7zr -q
43
"""
44
45
"""
46
This examples uses [Keras 3](https://keras.io/keras_3/) to work in any of
47
`"tensorflow"`, `"jax"` or `"torch"`. Support for Keras 3 is baked into
48
KerasHub, simply change the `"KERAS_BACKEND"` environment variable to select
49
the backend of your choice. We select the JAX backend below.
50
"""
51
52
import os
53
54
os.environ["KERAS_BACKEND"] = "jax"
55
56
"""
57
Import all necessary libraries.
58
"""
59
60
import py7zr
61
import time
62
63
import keras_hub
64
import keras
65
import tensorflow as tf
66
import tensorflow_datasets as tfds
67
68
"""
69
Let's also define our hyperparameters.
70
"""
71
72
BATCH_SIZE = 8
73
NUM_BATCHES = 600
74
EPOCHS = 1 # Can be set to a higher value for better results
75
MAX_ENCODER_SEQUENCE_LENGTH = 512
76
MAX_DECODER_SEQUENCE_LENGTH = 128
77
MAX_GENERATION_LENGTH = 40
78
79
"""
80
## Dataset
81
82
Let's load the [SAMSum dataset](https://arxiv.org/abs/1911.12237). This dataset
83
contains around 15,000 pairs of conversations/dialogues and summaries.
84
"""
85
86
# Download the dataset.
87
filename = keras.utils.get_file(
88
"corpus.7z",
89
origin="https://huggingface.co/datasets/samsum/resolve/main/data/corpus.7z",
90
)
91
92
# Extract the `.7z` file.
93
with py7zr.SevenZipFile(filename, mode="r") as z:
94
z.extractall(path="/root/tensorflow_datasets/downloads/manual")
95
96
# Load data using TFDS.
97
samsum_ds = tfds.load("samsum", split="train", as_supervised=True)
98
99
"""
100
The dataset has two fields: `dialogue` and `summary`. Let's see a sample.
101
"""
102
for dialogue, summary in samsum_ds:
103
print(dialogue.numpy())
104
print(summary.numpy())
105
break
106
107
"""
108
We'll now batch the dataset and retain only a subset of the dataset for the
109
purpose of this example. The dialogue is fed to the encoder, and the
110
corresponding summary serves as input to the decoder. We will, therefore, change
111
the format of the dataset to a dictionary having two keys: `"encoder_text"` and
112
`"decoder_text"`.This is how `keras_hub.models.BartSeq2SeqLMPreprocessor`
113
expects the input format to be.
114
"""
115
116
train_ds = (
117
samsum_ds.map(
118
lambda dialogue, summary: {"encoder_text": dialogue, "decoder_text": summary}
119
)
120
.batch(BATCH_SIZE)
121
.cache()
122
)
123
train_ds = train_ds.take(NUM_BATCHES)
124
125
"""
126
## Fine-tune BART
127
128
Let's load the model and preprocessor first. We use sequence lengths of 512
129
and 128 for the encoder and decoder, respectively, instead of 1024 (which is the
130
default sequence length). This will allow us to run this example quickly
131
on Colab.
132
133
If you observe carefully, the preprocessor is attached to the model. What this
134
means is that we don't have to worry about preprocessing the text inputs;
135
everything will be done internally. The preprocessor tokenizes the encoder text
136
and the decoder text, adds special tokens and pads them. To generate labels
137
for auto-regressive training, the preprocessor shifts the decoder text one
138
position to the right. This is done because at every timestep, the model is
139
trained to predict the next token.
140
"""
141
142
preprocessor = keras_hub.models.BartSeq2SeqLMPreprocessor.from_preset(
143
"bart_base_en",
144
encoder_sequence_length=MAX_ENCODER_SEQUENCE_LENGTH,
145
decoder_sequence_length=MAX_DECODER_SEQUENCE_LENGTH,
146
)
147
bart_lm = keras_hub.models.BartSeq2SeqLM.from_preset(
148
"bart_base_en", preprocessor=preprocessor
149
)
150
151
bart_lm.summary()
152
153
"""
154
Define the optimizer and loss. We use the Adam optimizer with a linearly
155
decaying learning rate. Compile the model.
156
"""
157
158
optimizer = keras.optimizers.AdamW(
159
learning_rate=5e-5,
160
weight_decay=0.01,
161
epsilon=1e-6,
162
global_clipnorm=1.0, # Gradient clipping.
163
)
164
# Exclude layernorm and bias terms from weight decay.
165
optimizer.exclude_from_weight_decay(var_names=["bias"])
166
optimizer.exclude_from_weight_decay(var_names=["gamma"])
167
optimizer.exclude_from_weight_decay(var_names=["beta"])
168
169
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
170
171
bart_lm.compile(
172
optimizer=optimizer,
173
loss=loss,
174
weighted_metrics=["accuracy"],
175
)
176
177
"""
178
Let's train the model!
179
"""
180
181
bart_lm.fit(train_ds, epochs=EPOCHS)
182
183
"""
184
## Generate summaries and evaluate them!
185
186
Now that the model has been trained, let's get to the fun part - actually
187
generating summaries! Let's pick the first 100 samples from the validation set
188
and generate summaries for them. We will use the default decoding strategy, i.e.,
189
greedy search.
190
191
Generation in KerasHub is highly optimized. It is backed by the power of XLA.
192
Secondly, key/value tensors in the self-attention layer and cross-attention layer
193
in the decoder are cached to avoid recomputation at every timestep.
194
"""
195
196
197
def generate_text(model, input_text, max_length=200, print_time_taken=False):
198
start = time.time()
199
output = model.generate(input_text, max_length=max_length)
200
end = time.time()
201
print(f"Total Time Elapsed: {end - start:.2f}s")
202
return output
203
204
205
# Load the dataset.
206
val_ds = tfds.load("samsum", split="validation", as_supervised=True)
207
val_ds = val_ds.take(100)
208
209
dialogues = []
210
ground_truth_summaries = []
211
for dialogue, summary in val_ds:
212
dialogues.append(dialogue.numpy())
213
ground_truth_summaries.append(summary.numpy())
214
215
# Let's make a dummy call - the first call to XLA generally takes a bit longer.
216
_ = generate_text(bart_lm, "sample text", max_length=MAX_GENERATION_LENGTH)
217
218
# Generate summaries.
219
generated_summaries = generate_text(
220
bart_lm,
221
val_ds.map(lambda dialogue, _: dialogue).batch(8),
222
max_length=MAX_GENERATION_LENGTH,
223
print_time_taken=True,
224
)
225
226
"""
227
Let's see some of the summaries.
228
"""
229
for dialogue, generated_summary, ground_truth_summary in zip(
230
dialogues[:5], generated_summaries[:5], ground_truth_summaries[:5]
231
):
232
print("Dialogue:", dialogue)
233
print("Generated Summary:", generated_summary)
234
print("Ground Truth Summary:", ground_truth_summary)
235
print("=============================")
236
237
"""
238
The generated summaries look awesome! Not bad for a model trained only for 1
239
epoch and on 5000 examples :)
240
"""
241
242