Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/generative/text_generation_gpt.py
3507 views
1
"""
2
Title: GPT text generation from scratch with KerasHub
3
Author: [Jesse Chan](https://github.com/jessechancy)
4
Date created: 2022/07/25
5
Last modified: 2022/07/25
6
Description: Using KerasHub to train a mini-GPT model for text generation.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
In this example, we will use KerasHub to build a scaled down Generative
14
Pre-Trained (GPT) model. GPT is a Transformer-based model that allows you to generate
15
sophisticated text from a prompt.
16
17
We will train the model on the [simplebooks-92](https://arxiv.org/abs/1911.12391) corpus,
18
which is a dataset made from several novels. It is a good dataset for this example since
19
it has a small vocabulary and high word frequency, which is beneficial when training a
20
model with few parameters.
21
22
This example combines concepts from
23
[Text generation with a miniature GPT](https://keras.io/examples/generative/text_generation_with_miniature_gpt/)
24
with KerasHub abstractions. We will demonstrate how KerasHub tokenization, layers and
25
metrics simplify the training
26
process, and then show how to generate output text using the KerasHub sampling utilities.
27
28
Note: If you are running this example on a Colab,
29
make sure to enable GPU runtime for faster training.
30
31
This example requires KerasHub. You can install it via the following command:
32
`pip install keras-hub`
33
"""
34
35
"""
36
## Setup
37
"""
38
39
"""shell
40
pip install -q --upgrade keras-hub
41
pip install -q --upgrade keras # Upgrade to Keras 3.
42
"""
43
44
import os
45
import keras_hub
46
import keras
47
48
import tensorflow.data as tf_data
49
import tensorflow.strings as tf_strings
50
51
"""
52
## Settings & hyperparameters
53
"""
54
55
# Data
56
BATCH_SIZE = 64
57
MIN_STRING_LEN = 512 # Strings shorter than this will be discarded
58
SEQ_LEN = 128 # Length of training sequences, in tokens
59
60
# Model
61
EMBED_DIM = 256
62
FEED_FORWARD_DIM = 128
63
NUM_HEADS = 3
64
NUM_LAYERS = 2
65
VOCAB_SIZE = 5000 # Limits parameters in model.
66
67
# Training
68
EPOCHS = 5
69
70
# Inference
71
NUM_TOKENS_TO_GENERATE = 80
72
73
"""
74
## Load the data
75
76
Now, let's download the dataset! The SimpleBooks dataset consists of 1,573 Gutenberg books, and has
77
one of the smallest vocabulary size to word-level tokens ratio. It has a vocabulary size of ~98k,
78
a third of WikiText-103's, with around the same number of tokens (~100M). This makes it easy to fit a small model.
79
"""
80
81
keras.utils.get_file(
82
origin="https://dldata-public.s3.us-east-2.amazonaws.com/simplebooks.zip",
83
extract=True,
84
)
85
dir = os.path.expanduser("~/.keras/datasets/simplebooks/")
86
87
# Load simplebooks-92 train set and filter out short lines.
88
raw_train_ds = (
89
tf_data.TextLineDataset(dir + "simplebooks-92-raw/train.txt")
90
.filter(lambda x: tf_strings.length(x) > MIN_STRING_LEN)
91
.batch(BATCH_SIZE)
92
.shuffle(buffer_size=256)
93
)
94
95
# Load simplebooks-92 validation set and filter out short lines.
96
raw_val_ds = (
97
tf_data.TextLineDataset(dir + "simplebooks-92-raw/valid.txt")
98
.filter(lambda x: tf_strings.length(x) > MIN_STRING_LEN)
99
.batch(BATCH_SIZE)
100
)
101
102
"""
103
## Train the tokenizer
104
105
We train the tokenizer from the training dataset for a vocabulary size of `VOCAB_SIZE`,
106
which is a tuned hyperparameter. We want to limit the vocabulary as much as possible, as
107
we will see later on
108
that it has a large effect on the number of model parameters. We also don't want to include
109
*too few* vocabulary terms, or there would be too many out-of-vocabulary (OOV) sub-words. In
110
addition, three tokens are reserved in the vocabulary:
111
112
- `"[PAD]"` for padding sequences to `SEQ_LEN`. This token has index 0 in both
113
`reserved_tokens` and `vocab`, since `WordPieceTokenizer` (and other layers) consider
114
`0`/`vocab[0]` as the default padding.
115
- `"[UNK]"` for OOV sub-words, which should match the default `oov_token="[UNK]"` in
116
`WordPieceTokenizer`.
117
- `"[BOS]"` stands for beginning of sentence, but here technically it is a token
118
representing the beginning of each line of training data.
119
"""
120
121
# Train tokenizer vocabulary
122
vocab = keras_hub.tokenizers.compute_word_piece_vocabulary(
123
raw_train_ds,
124
vocabulary_size=VOCAB_SIZE,
125
lowercase=True,
126
reserved_tokens=["[PAD]", "[UNK]", "[BOS]"],
127
)
128
129
"""
130
## Load tokenizer
131
132
We use the vocabulary data to initialize
133
`keras_hub.tokenizers.WordPieceTokenizer`. WordPieceTokenizer is an efficient
134
implementation of the WordPiece algorithm used by BERT and other models. It will strip,
135
lower-case and do other irreversible preprocessing operations.
136
"""
137
138
tokenizer = keras_hub.tokenizers.WordPieceTokenizer(
139
vocabulary=vocab,
140
sequence_length=SEQ_LEN,
141
lowercase=True,
142
)
143
144
"""
145
## Tokenize data
146
147
We preprocess the dataset by tokenizing and splitting it into `features` and `labels`.
148
"""
149
150
# packer adds a start token
151
start_packer = keras_hub.layers.StartEndPacker(
152
sequence_length=SEQ_LEN,
153
start_value=tokenizer.token_to_id("[BOS]"),
154
)
155
156
157
def preprocess(inputs):
158
outputs = tokenizer(inputs)
159
features = start_packer(outputs)
160
labels = outputs
161
return features, labels
162
163
164
# Tokenize and split into train and label sequences.
165
train_ds = raw_train_ds.map(preprocess, num_parallel_calls=tf_data.AUTOTUNE).prefetch(
166
tf_data.AUTOTUNE
167
)
168
val_ds = raw_val_ds.map(preprocess, num_parallel_calls=tf_data.AUTOTUNE).prefetch(
169
tf_data.AUTOTUNE
170
)
171
172
"""
173
## Build the model
174
175
We create our scaled down GPT model with the following layers:
176
177
- One `keras_hub.layers.TokenAndPositionEmbedding` layer, which combines the embedding
178
for the token and its position.
179
- Multiple `keras_hub.layers.TransformerDecoder` layers, with the default causal masking.
180
The layer has no cross-attention when run with decoder sequence only.
181
- One final dense linear layer
182
"""
183
184
inputs = keras.layers.Input(shape=(None,), dtype="int32")
185
# Embedding.
186
embedding_layer = keras_hub.layers.TokenAndPositionEmbedding(
187
vocabulary_size=VOCAB_SIZE,
188
sequence_length=SEQ_LEN,
189
embedding_dim=EMBED_DIM,
190
mask_zero=True,
191
)
192
x = embedding_layer(inputs)
193
# Transformer decoders.
194
for _ in range(NUM_LAYERS):
195
decoder_layer = keras_hub.layers.TransformerDecoder(
196
num_heads=NUM_HEADS,
197
intermediate_dim=FEED_FORWARD_DIM,
198
)
199
x = decoder_layer(x) # Giving one argument only skips cross-attention.
200
# Output.
201
outputs = keras.layers.Dense(VOCAB_SIZE)(x)
202
model = keras.Model(inputs=inputs, outputs=outputs)
203
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
204
perplexity = keras_hub.metrics.Perplexity(from_logits=True, mask_token_id=0)
205
model.compile(optimizer="adam", loss=loss_fn, metrics=[perplexity])
206
207
"""
208
Let's take a look at our model summary - a large majority of the
209
parameters are in the `token_and_position_embedding` and the output `dense` layer!
210
This means that the vocabulary size (`VOCAB_SIZE`) has a large effect on the size of the model,
211
while the number of Transformer decoder layers (`NUM_LAYERS`) doesn't affect it as much.
212
"""
213
214
model.summary()
215
216
"""
217
## Training
218
219
Now that we have our model, let's train it with the `fit()` method.
220
"""
221
222
model.fit(train_ds, validation_data=val_ds, epochs=EPOCHS)
223
224
"""
225
## Inference
226
227
With our trained model, we can test it out to gauge its performance. To do this
228
we can seed our model with an input sequence starting with the `"[BOS]"` token,
229
and progressively sample the model by making predictions for each subsequent
230
token in a loop.
231
232
To start lets build a prompt with the same shape as our model inputs, containing
233
only the `"[BOS]"` token.
234
"""
235
236
# The "packer" layers adds the [BOS] token for us.
237
prompt_tokens = start_packer(tokenizer([""]))
238
prompt_tokens
239
240
"""
241
We will use the `keras_hub.samplers` module for inference, which requires a
242
callback function wrapping the model we just trained. This wrapper calls
243
the model and returns the logit predictions for the current token we are
244
generating.
245
246
Note: There are two pieces of more advanced functionality available when
247
defining your callback. The first is the ability to take in a `cache` of states
248
computed in previous generation steps, which can be used to speed up generation.
249
The second is the ability to output the final dense "hidden state" of each
250
generated token. This is used by `keras_hub.samplers.ContrastiveSampler`, which
251
avoids repetition by penalizing repeated hidden states. Both are optional, and
252
we will ignore them for now.
253
"""
254
255
256
def next(prompt, cache, index):
257
logits = model(prompt)[:, index - 1, :]
258
# Ignore hidden states for now; only needed for contrastive search.
259
hidden_states = None
260
return logits, hidden_states, cache
261
262
263
"""
264
Creating the wrapper function is the most complex part of using these functions. Now that
265
it's done, let's test out the different utilities, starting with greedy search.
266
"""
267
268
"""
269
### Greedy search
270
271
We greedily pick the most probable token at each timestep. In other words, we get the
272
argmax of the model output.
273
"""
274
275
sampler = keras_hub.samplers.GreedySampler()
276
output_tokens = sampler(
277
next=next,
278
prompt=prompt_tokens,
279
index=1, # Start sampling immediately after the [BOS] token.
280
)
281
txt = tokenizer.detokenize(output_tokens)
282
print(f"Greedy search generated text: \n{txt}\n")
283
284
"""
285
As you can see, greedy search starts out making some sense, but quickly starts repeating
286
itself. This is a common problem with text generation that can be fixed by some of the
287
probabilistic text generation utilities shown later on!
288
"""
289
290
"""
291
### Beam search
292
293
At a high-level, beam search keeps track of the `num_beams` most probable sequences at
294
each timestep, and predicts the best next token from all sequences. It is an improvement
295
over greedy search since it stores more possibilities. However, it is less efficient than
296
greedy search since it has to compute and store multiple potential sequences.
297
298
**Note:** beam search with `num_beams=1` is identical to greedy search.
299
"""
300
301
sampler = keras_hub.samplers.BeamSampler(num_beams=10)
302
output_tokens = sampler(
303
next=next,
304
prompt=prompt_tokens,
305
index=1,
306
)
307
txt = tokenizer.detokenize(output_tokens)
308
print(f"Beam search generated text: \n{txt}\n")
309
310
"""
311
Similar to greedy search, beam search quickly starts repeating itself, since it is still
312
a deterministic method.
313
"""
314
315
"""
316
### Random search
317
318
Random search is our first probabilistic method. At each time step, it samples the next
319
token using the softmax probabilities provided by the model.
320
"""
321
322
sampler = keras_hub.samplers.RandomSampler()
323
output_tokens = sampler(
324
next=next,
325
prompt=prompt_tokens,
326
index=1,
327
)
328
txt = tokenizer.detokenize(output_tokens)
329
print(f"Random search generated text: \n{txt}\n")
330
331
"""
332
VoilĂ , no repetitions! However, with random search, we may see some nonsensical words
333
appearing since any word in the vocabulary has a chance of appearing with this sampling
334
method. This is fixed by our next search utility, top-k search.
335
"""
336
337
"""
338
### Top-K search
339
340
Similar to random search, we sample the next token from the probability distribution
341
provided by the model. The only difference is that here, we select out the top `k` most
342
probable tokens, and distribute the probability mass over them before sampling. This way,
343
we won't be sampling from low probability tokens, and hence we would have less
344
nonsensical words!
345
"""
346
347
sampler = keras_hub.samplers.TopKSampler(k=10)
348
output_tokens = sampler(
349
next=next,
350
prompt=prompt_tokens,
351
index=1,
352
)
353
txt = tokenizer.detokenize(output_tokens)
354
print(f"Top-K search generated text: \n{txt}\n")
355
356
"""
357
### Top-P search
358
359
Even with the top-k search, there is something to improve upon. With top-k search, the
360
number `k` is fixed, which means it selects the same number of tokens for any probability
361
distribution. Consider two scenarios, one where the probability mass is concentrated over
362
2 words and another where the probability mass is evenly concentrated across 10. Should
363
we choose `k=2` or `k=10`? There is no one size that fits all `k` here.
364
365
This is where top-p search comes in! Instead of choosing a `k`, we choose a probability
366
`p` that we want the probabilities of the top tokens to sum up to. This way, we can
367
dynamically adjust the `k` based on the probability distribution. By setting `p=0.9`, if
368
90% of the probability mass is concentrated on the top 2 tokens, we can filter out the
369
top 2 tokens to sample from. If instead the 90% is distributed over 10 tokens, it will
370
similarly filter out the top 10 tokens to sample from.
371
"""
372
373
sampler = keras_hub.samplers.TopPSampler(p=0.5)
374
output_tokens = sampler(
375
next=next,
376
prompt=prompt_tokens,
377
index=1,
378
)
379
txt = tokenizer.detokenize(output_tokens)
380
print(f"Top-P search generated text: \n{txt}\n")
381
382
"""
383
### Using callbacks for text generation
384
385
We can also wrap the utilities in a callback, which allows you to print out a prediction
386
sequence for every epoch of the model! Here is an example of a callback for top-k search:
387
"""
388
389
390
class TopKTextGenerator(keras.callbacks.Callback):
391
"""A callback to generate text from a trained model using top-k."""
392
393
def __init__(self, k):
394
self.sampler = keras_hub.samplers.TopKSampler(k)
395
396
def on_epoch_end(self, epoch, logs=None):
397
output_tokens = self.sampler(
398
next=next,
399
prompt=prompt_tokens,
400
index=1,
401
)
402
txt = tokenizer.detokenize(output_tokens)
403
print(f"Top-K search generated text: \n{txt}\n")
404
405
406
text_generation_callback = TopKTextGenerator(k=10)
407
# Dummy training loop to demonstrate callback.
408
model.fit(train_ds.take(1), verbose=2, epochs=2, callbacks=[text_generation_callback])
409
410
"""
411
## Conclusion
412
413
To recap, in this example, we use KerasHub layers to train a sub-word vocabulary,
414
tokenize training data, create a miniature GPT model, and perform inference with the
415
text generation library.
416
417
If you would like to understand how Transformers work, or learn more about training the
418
full GPT model, here are some further readings:
419
420
- Attention Is All You Need [Vaswani et al., 2017](https://arxiv.org/abs/1706.03762)
421
- GPT-3 Paper [Brown et al., 2020](https://arxiv.org/abs/2005.14165)
422
"""
423
424