CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!

GitHub Repository: pytorch/tutorials
Path: blob/main/beginner_source/t5_tutorial.py
Views: 494
1
"""
2
T5-Base Model for Summarization, Sentiment Classification, and Translation
3
==========================================================================
4
5
**Authors**: `Pendo Abbo <[email protected]>`__, `Joe Cummings <[email protected]>`__
6
7
"""
8
9
######################################################################
10
# Overview
11
# --------
12
#
13
# This tutorial demonstrates how to use a pretrained T5 Model for summarization, sentiment classification, and
14
# translation tasks. We will demonstrate how to use the torchtext library to:
15
#
16
# 1. Build a text preprocessing pipeline for a T5 model
17
# 2. Instantiate a pretrained T5 model with base configuration
18
# 3. Read in the CNNDM, IMDB, and Multi30k datasets and preprocess their texts in preparation for the model
19
# 4. Perform text summarization, sentiment classification, and translation
20
#
21
# .. note::
22
# This tutorial requires PyTorch 2.0.0 or later.
23
#
24
#######################################################################
25
# Data Transformation
26
# -------------------
27
#
28
# The T5 model does not work with raw text. Instead, it requires the text to be transformed into numerical form
29
# in order to perform training and inference. The following transformations are required for the T5 model:
30
#
31
# 1. Tokenize text
32
# 2. Convert tokens into (integer) IDs
33
# 3. Truncate the sequences to a specified maximum length
34
# 4. Add end-of-sequence (EOS) and padding token IDs
35
#
36
# T5 uses a ``SentencePiece`` model for text tokenization. Below, we use a pretrained ``SentencePiece`` model to build
37
# the text preprocessing pipeline using torchtext's T5Transform. Note that the transform supports both
38
# batched and non-batched text input (for example, one can either pass a single sentence or a list of sentences), however the T5 model expects the input to be batched.
39
#
40
41
from torchtext.models import T5Transform
42
43
padding_idx = 0
44
eos_idx = 1
45
max_seq_len = 512
46
t5_sp_model_path = "https://download.pytorch.org/models/text/t5_tokenizer_base.model"
47
48
transform = T5Transform(
49
sp_model_path=t5_sp_model_path,
50
max_seq_len=max_seq_len,
51
eos_idx=eos_idx,
52
padding_idx=padding_idx,
53
)
54
55
#######################################################################
56
# Alternatively, we can also use the transform shipped with the pretrained models that does all of the above out-of-the-box
57
#
58
# .. code-block::
59
#
60
# from torchtext.models import T5_BASE_GENERATION
61
# transform = T5_BASE_GENERATION.transform()
62
#
63
64
65
######################################################################
66
# Model Preparation
67
# -----------------
68
#
69
# torchtext provides SOTA pretrained models that can be used directly for NLP tasks or fine-tuned on downstream tasks. Below
70
# we use the pretrained T5 model with standard base configuration to perform text summarization, sentiment classification, and
71
# translation. For additional details on available pretrained models, see `the torchtext documentation <https://pytorch.org/text/main/models.html>`__
72
#
73
#
74
from torchtext.models import T5_BASE_GENERATION
75
76
77
t5_base = T5_BASE_GENERATION
78
transform = t5_base.transform()
79
model = t5_base.get_model()
80
model.eval()
81
82
83
#######################################################################
84
# Using ``GenerationUtils``
85
# -------------------------
86
#
87
# We can use torchtext's ``GenerationUtils`` to produce an output sequence based on the input sequence provided. This calls on the
88
# model's encoder and decoder, and iteratively expands the decoded sequences until the end-of-sequence token is generated
89
# for all sequences in the batch. The ``generate`` method shown below uses greedy search to generate the sequences. Beam search and
90
# other decoding strategies are also supported.
91
#
92
#
93
from torchtext.prototype.generate import GenerationUtils
94
95
sequence_generator = GenerationUtils(model)
96
97
98
#######################################################################
99
# Datasets
100
# --------
101
# torchtext provides several standard NLP datasets. For a complete list, refer to the documentation
102
# at https://pytorch.org/text/stable/datasets.html. These datasets are built using composable torchdata
103
# datapipes and hence support standard flow-control and mapping/transformation using user defined
104
# functions and transforms.
105
#
106
# Below we demonstrate how to preprocess the CNNDM dataset to include the prefix necessary for the
107
# model to identify the task it is performing. The CNNDM dataset has a train, validation, and test
108
# split. Below we demo on the test split.
109
#
110
# The T5 model uses the prefix "summarize" for text summarization. For more information on task
111
# prefixes, please visit Appendix D of the `T5 Paper <https://arxiv.org/pdf/1910.10683.pdf>`__
112
#
113
# .. note::
114
# Using datapipes is still currently subject to a few caveats. If you wish
115
# to extend this example to include shuffling, multi-processing, or
116
# distributed learning, please see :ref:`this note <datapipes_warnings>`
117
# for further instructions.
118
119
from functools import partial
120
121
from torch.utils.data import DataLoader
122
from torchtext.datasets import CNNDM
123
124
cnndm_batch_size = 5
125
cnndm_datapipe = CNNDM(split="test")
126
task = "summarize"
127
128
129
def apply_prefix(task, x):
130
return f"{task}: " + x[0], x[1]
131
132
133
cnndm_datapipe = cnndm_datapipe.map(partial(apply_prefix, task))
134
cnndm_datapipe = cnndm_datapipe.batch(cnndm_batch_size)
135
cnndm_datapipe = cnndm_datapipe.rows2columnar(["article", "abstract"])
136
cnndm_dataloader = DataLoader(cnndm_datapipe, shuffle=True, batch_size=None)
137
138
#######################################################################
139
# Alternately, we can also use batched API, for example, apply the prefix on the whole batch:
140
#
141
# .. code-block::
142
#
143
# def batch_prefix(task, x):
144
# return {
145
# "article": [f'{task}: ' + y for y in x["article"]],
146
# "abstract": x["abstract"]
147
# }
148
#
149
# cnndm_batch_size = 5
150
# cnndm_datapipe = CNNDM(split="test")
151
# task = 'summarize'
152
#
153
# cnndm_datapipe = cnndm_datapipe.batch(cnndm_batch_size).rows2columnar(["article", "abstract"])
154
# cnndm_datapipe = cnndm_datapipe.map(partial(batch_prefix, task))
155
# cnndm_dataloader = DataLoader(cnndm_datapipe, batch_size=None)
156
#
157
158
#######################################################################
159
# We can also load the IMDB dataset, which will be used to demonstrate sentiment classification using the T5 model.
160
# This dataset has a train and test split. Below we demo on the test split.
161
#
162
# The T5 model was trained on the SST2 dataset (also available in torchtext) for sentiment classification using the
163
# prefix ``sst2 sentence``. Therefore, we will use this prefix to perform sentiment classification on the IMDB dataset.
164
#
165
166
from torchtext.datasets import IMDB
167
168
imdb_batch_size = 3
169
imdb_datapipe = IMDB(split="test")
170
task = "sst2 sentence"
171
labels = {"1": "negative", "2": "positive"}
172
173
174
def process_labels(labels, x):
175
return x[1], labels[str(x[0])]
176
177
178
imdb_datapipe = imdb_datapipe.map(partial(process_labels, labels))
179
imdb_datapipe = imdb_datapipe.map(partial(apply_prefix, task))
180
imdb_datapipe = imdb_datapipe.batch(imdb_batch_size)
181
imdb_datapipe = imdb_datapipe.rows2columnar(["text", "label"])
182
imdb_dataloader = DataLoader(imdb_datapipe, batch_size=None)
183
184
#######################################################################
185
# Finally, we can also load the Multi30k dataset to demonstrate English to German translation using the T5 model.
186
# This dataset has a train, validation, and test split. Below we demo on the test split.
187
#
188
# The T5 model uses the prefix "translate English to German" for this task.
189
190
from torchtext.datasets import Multi30k
191
192
multi_batch_size = 5
193
language_pair = ("en", "de")
194
multi_datapipe = Multi30k(split="test", language_pair=language_pair)
195
task = "translate English to German"
196
197
multi_datapipe = multi_datapipe.map(partial(apply_prefix, task))
198
multi_datapipe = multi_datapipe.batch(multi_batch_size)
199
multi_datapipe = multi_datapipe.rows2columnar(["english", "german"])
200
multi_dataloader = DataLoader(multi_datapipe, batch_size=None)
201
202
#######################################################################
203
# Generate Summaries
204
# ------------------
205
#
206
# We can put all of the components together to generate summaries on the first batch of articles in the CNNDM test set
207
# using a beam size of 1.
208
#
209
210
batch = next(iter(cnndm_dataloader))
211
input_text = batch["article"]
212
target = batch["abstract"]
213
beam_size = 1
214
215
model_input = transform(input_text)
216
model_output = sequence_generator.generate(model_input, eos_idx=eos_idx, num_beams=beam_size)
217
output_text = transform.decode(model_output.tolist())
218
219
for i in range(cnndm_batch_size):
220
print(f"Example {i+1}:\n")
221
print(f"prediction: {output_text[i]}\n")
222
print(f"target: {target[i]}\n\n")
223
224
225
#######################################################################
226
# Summarization Output
227
# --------------------
228
#
229
# Summarization output might vary since we shuffle the dataloader.
230
#
231
# .. code-block::
232
#
233
# Example 1:
234
#
235
# prediction: the 24-year-old has been tattooed for over a decade . he has landed in australia
236
# to start work on a new campaign . he says he is 'taking it in your stride' to be honest .
237
#
238
# target: London-based model Stephen James Hendry famed for his full body tattoo . The supermodel
239
# is in Sydney for a new modelling campaign . Australian fans understood to have already located
240
# him at his hotel . The 24-year-old heartthrob is recently single .
241
#
242
#
243
# Example 2:
244
#
245
# prediction: a stray pooch has used up at least three of her own after being hit by a
246
# car and buried in a field . the dog managed to stagger to a nearby farm, dirt-covered
247
# and emaciated, where she was found . she suffered a dislocated jaw, leg injuries and a
248
# caved-in sinus cavity -- and still requires surgery to help her breathe .
249
#
250
# target: Theia, a bully breed mix, was apparently hit by a car, whacked with a hammer
251
# and buried in a field . "She's a true miracle dog and she deserves a good life," says
252
# Sara Mellado, who is looking for a home for Theia .
253
#
254
#
255
# Example 3:
256
#
257
# prediction: mohammad Javad Zarif arrived in Iran on a sunny friday morning . he has gone
258
# a long way to bring Iran in from the cold and allow it to rejoin the international
259
# community . but there are some facts about him that are less well-known .
260
#
261
# target: Mohammad Javad Zarif has spent more time with John Kerry than any other
262
# foreign minister . He once participated in a takeover of the Iranian Consulate in San
263
# Francisco . The Iranian foreign minister tweets in English .
264
#
265
#
266
# Example 4:
267
#
268
# prediction: five americans were monitored for three weeks after being exposed to Ebola in
269
# west africa . one of the five had a heart-related issue and has been discharged but hasn't
270
# left the area . they are clinicians for Partners in Health, a Boston-based aid group .
271
#
272
# target: 17 Americans were exposed to the Ebola virus while in Sierra Leone in March .
273
# Another person was diagnosed with the disease and taken to hospital in Maryland .
274
# National Institutes of Health says the patient is in fair condition after weeks of
275
# treatment .
276
#
277
#
278
# Example 5:
279
#
280
# prediction: the student was identified during an investigation by campus police and
281
# the office of student affairs . he admitted to placing the noose on the tree early
282
# Wednesday morning . the incident is one of several recent racist events to affect
283
# college students .
284
#
285
# target: Student is no longer on Duke University campus and will face disciplinary
286
# review . School officials identified student during investigation and the person
287
# admitted to hanging the noose, Duke says . The noose, made of rope, was discovered on
288
# campus about 2 a.m.
289
#
290
291
292
#######################################################################
293
# Generate Sentiment Classifications
294
# ----------------------------------
295
#
296
# Similarly, we can use the model to generate sentiment classifications on the first batch of reviews from the IMDB test set
297
# using a beam size of 1.
298
#
299
300
batch = next(iter(imdb_dataloader))
301
input_text = batch["text"]
302
target = batch["label"]
303
beam_size = 1
304
305
model_input = transform(input_text)
306
model_output = sequence_generator.generate(model_input, eos_idx=eos_idx, num_beams=beam_size)
307
output_text = transform.decode(model_output.tolist())
308
309
for i in range(imdb_batch_size):
310
print(f"Example {i+1}:\n")
311
print(f"input_text: {input_text[i]}\n")
312
print(f"prediction: {output_text[i]}\n")
313
print(f"target: {target[i]}\n\n")
314
315
316
#######################################################################
317
# Sentiment Output
318
# ----------------
319
#
320
# .. code-block:: bash
321
#
322
# Example 1:
323
#
324
# input_text: sst2 sentence: I love sci-fi and am willing to put up with a lot. Sci-fi
325
# movies/TV are usually underfunded, under-appreciated and misunderstood. I tried to like
326
# this, I really did, but it is to good TV sci-fi as Babylon 5 is to Star Trek (the original).
327
# Silly prosthetics, cheap cardboard sets, stilted dialogues, CG that doesn't match the
328
# background, and painfully one-dimensional characters cannot be overcome with a 'sci-fi'
329
# setting. (I'm sure there are those of you out there who think Babylon 5 is good sci-fi TV.
330
# It's not. It's clichéd and uninspiring.) While US viewers might like emotion and character
331
# development, sci-fi is a genre that does not take itself seriously (cf. Star Trek). It may
332
# treat important issues, yet not as a serious philosophy. It's really difficult to care about
333
# the characters here as they are not simply foolish, just missing a spark of life. Their
334
# actions and reactions are wooden and predictable, often painful to watch. The makers of Earth
335
# KNOW it's rubbish as they have to always say "Gene Roddenberry's Earth..." otherwise people
336
# would not continue watching. Roddenberry's ashes must be turning in their orbit as this dull,
337
# cheap, poorly edited (watching it without advert breaks really brings this home) trudging
338
# Trabant of a show lumbers into space. Spoiler. So, kill off a main character. And then bring
339
# him back as another actor. Jeeez. Dallas all over again.
340
#
341
# prediction: negative
342
#
343
# target: negative
344
#
345
#
346
# Example 2:
347
#
348
# input_text: sst2 sentence: Worth the entertainment value of a rental, especially if you like
349
# action movies. This one features the usual car chases, fights with the great Van Damme kick
350
# style, shooting battles with the 40 shell load shotgun, and even terrorist style bombs. All
351
# of this is entertaining and competently handled but there is nothing that really blows you
352
# away if you've seen your share before.<br /><br />The plot is made interesting by the
353
# inclusion of a rabbit, which is clever but hardly profound. Many of the characters are
354
# heavily stereotyped -- the angry veterans, the terrified illegal aliens, the crooked cops,
355
# the indifferent feds, the bitchy tough lady station head, the crooked politician, the fat
356
# federale who looks like he was typecast as the Mexican in a Hollywood movie from the 1940s.
357
# All passably acted but again nothing special.<br /><br />I thought the main villains were
358
# pretty well done and fairly well acted. By the end of the movie you certainly knew who the
359
# good guys were and weren't. There was an emotional lift as the really bad ones got their just
360
# deserts. Very simplistic, but then you weren't expecting Hamlet, right? The only thing I found
361
# really annoying was the constant cuts to VDs daughter during the last fight scene.<br /><br />
362
# Not bad. Not good. Passable 4.
363
#
364
# prediction: positive
365
#
366
# target: negative
367
#
368
#
369
# Example 3:
370
#
371
# input_text: sst2 sentence: its a totally average film with a few semi-alright action sequences
372
# that make the plot seem a little better and remind the viewer of the classic van dam films.
373
# parts of the plot don't make sense and seem to be added in to use up time. the end plot is that
374
# of a very basic type that doesn't leave the viewer guessing and any twists are obvious from the
375
# beginning. the end scene with the flask backs don't make sense as they are added in and seem to
376
# have little relevance to the history of van dam's character. not really worth watching again,
377
# bit disappointed in the end production, even though it is apparent it was shot on a low budget
378
# certain shots and sections in the film are of poor directed quality.
379
#
380
# prediction: negative
381
#
382
# target: negative
383
#
384
385
386
#######################################################################
387
# Generate Translations
388
# ---------------------
389
#
390
# Finally, we can also use the model to generate English to German translations on the first batch of examples from the Multi30k
391
# test set.
392
#
393
394
batch = next(iter(multi_dataloader))
395
input_text = batch["english"]
396
target = batch["german"]
397
398
model_input = transform(input_text)
399
model_output = sequence_generator.generate(model_input, eos_idx=eos_idx, num_beams=beam_size)
400
output_text = transform.decode(model_output.tolist())
401
402
for i in range(multi_batch_size):
403
print(f"Example {i+1}:\n")
404
print(f"input_text: {input_text[i]}\n")
405
print(f"prediction: {output_text[i]}\n")
406
print(f"target: {target[i]}\n\n")
407
408
409
#######################################################################
410
# Translation Output
411
# ------------------
412
#
413
# .. code-block:: bash
414
#
415
# Example 1:
416
#
417
# input_text: translate English to German: A man in an orange hat starring at something.
418
#
419
# prediction: Ein Mann in einem orangen Hut, der an etwas schaut.
420
#
421
# target: Ein Mann mit einem orangefarbenen Hut, der etwas anstarrt.
422
#
423
#
424
# Example 2:
425
#
426
# input_text: translate English to German: A Boston Terrier is running on lush green grass in front of a white fence.
427
#
428
# prediction: Ein Boston Terrier läuft auf üppigem grünem Gras vor einem weißen Zaun.
429
#
430
# target: Ein Boston Terrier läuft über saftig-grünes Gras vor einem weißen Zaun.
431
#
432
#
433
# Example 3:
434
#
435
# input_text: translate English to German: A girl in karate uniform breaking a stick with a front kick.
436
#
437
# prediction: Ein Mädchen in Karate-Uniform bricht einen Stöck mit einem Frontkick.
438
#
439
# target: Ein Mädchen in einem Karateanzug bricht ein Brett mit einem Tritt.
440
#
441
#
442
# Example 4:
443
#
444
# input_text: translate English to German: Five people wearing winter jackets and helmets stand in the snow, with snowmobiles in the background.
445
#
446
# prediction: Fünf Menschen mit Winterjacken und Helmen stehen im Schnee, mit Schneemobilen im Hintergrund.
447
#
448
# target: Fünf Leute in Winterjacken und mit Helmen stehen im Schnee mit Schneemobilen im Hintergrund.
449
#
450
#
451
# Example 5:
452
#
453
# input_text: translate English to German: People are fixing the roof of a house.
454
#
455
# prediction: Die Leute fixieren das Dach eines Hauses.
456
#
457
# target: Leute Reparieren das Dach eines Hauses.
458
#
459
460