Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/generative/gpt2_text_generation_with_keras_hub.py
3507 views
1
"""
2
Title: GPT2 Text Generation with KerasHub
3
Author: Chen Qian
4
Date created: 2023/04/17
5
Last modified: 2024/04/12
6
Description: Use KerasHub GPT2 model and `samplers` to do text generation.
7
Accelerator: GPU
8
"""
9
10
"""
11
In this tutorial, you will learn to use [KerasHub](https://keras.io/keras_hub/) to load a
12
pre-trained Large Language Model (LLM) - [GPT-2 model](https://openai.com/research/better-language-models)
13
(originally invented by OpenAI), finetune it to a specific text style, and
14
generate text based on users' input (also known as prompt). You will also learn
15
how GPT2 adapts quickly to non-English languages, such as Chinese.
16
"""
17
18
"""
19
## Before we begin
20
21
Colab offers different kinds of runtimes. Make sure to go to **Runtime ->
22
Change runtime type** and choose the GPU Hardware Accelerator runtime
23
(which should have >12G host RAM and ~15G GPU RAM) since you will finetune the
24
GPT-2 model. Running this tutorial on CPU runtime will take hours.
25
"""
26
27
"""
28
## Install KerasHub, Choose Backend and Import Dependencies
29
30
This examples uses [Keras 3](https://keras.io/keras_3/) to work in any of
31
`"tensorflow"`, `"jax"` or `"torch"`. Support for Keras 3 is baked into
32
KerasHub, simply change the `"KERAS_BACKEND"` environment variable to select
33
the backend of your choice. We select the JAX backend below.
34
"""
35
36
"""shell
37
pip install git+https://github.com/keras-team/keras-hub.git -q
38
"""
39
40
import os
41
42
os.environ["KERAS_BACKEND"] = "jax" # or "tensorflow" or "torch"
43
44
import keras_hub
45
import keras
46
import tensorflow as tf
47
import time
48
49
keras.mixed_precision.set_global_policy("mixed_float16")
50
51
"""
52
## Introduction to Generative Large Language Models (LLMs)
53
54
Large language models (LLMs) are a type of machine learning models that are
55
trained on a large corpus of text data to generate outputs for various natural
56
language processing (NLP) tasks, such as text generation, question answering,
57
and machine translation.
58
59
Generative LLMs are typically based on deep learning neural networks, such as
60
the [Transformer architecture](https://arxiv.org/abs/1706.03762) invented by
61
Google researchers in 2017, and are trained on massive amounts of text data,
62
often involving billions of words. These models, such as Google [LaMDA](https://blog.google/technology/ai/lamda/)
63
and [PaLM](https://ai.googleblog.com/2022/04/pathways-language-model-palm-scaling-to.html),
64
are trained with a large dataset from various data sources which allows them to
65
generate output for many tasks. The core of Generative LLMs is predicting the
66
next word in a sentence, often referred as **Causal LM Pretraining**. In this
67
way LLMs can generate coherent text based on user prompts. For a more
68
pedagogical discussion on language models, you can refer to the
69
[Stanford CS324 LLM class](https://stanford-cs324.github.io/winter2022/lectures/introduction/).
70
"""
71
72
"""
73
## Introduction to KerasHub
74
75
Large Language Models are complex to build and expensive to train from scratch.
76
Luckily there are pretrained LLMs available for use right away. [KerasHub](https://keras.io/keras_hub/)
77
provides a large number of pre-trained checkpoints that allow you to experiment
78
with SOTA models without needing to train them yourself.
79
80
KerasHub is a natural language processing library that supports users through
81
their entire development cycle. KerasHub offers both pretrained models and
82
modularized building blocks, so developers could easily reuse pretrained models
83
or stack their own LLM.
84
85
In a nutshell, for generative LLM, KerasHub offers:
86
87
- Pretrained models with `generate()` method, e.g.,
88
`keras_hub.models.GPT2CausalLM` and `keras_hub.models.OPTCausalLM`.
89
- Sampler class that implements generation algorithms such as Top-K, Beam and
90
contrastive search. These samplers can be used to generate text with
91
custom models.
92
"""
93
94
"""
95
## Load a pre-trained GPT-2 model and generate some text
96
97
KerasHub provides a number of pre-trained models, such as [Google
98
Bert](https://ai.googleblog.com/2018/11/open-sourcing-bert-state-of-art-pre.html)
99
and [GPT-2](https://openai.com/research/better-language-models). You can see
100
the list of models available in the [KerasHub repository](https://github.com/keras-team/keras-hub/tree/master/keras_hub/models).
101
102
It's very easy to load the GPT-2 model as you can see below:
103
"""
104
105
# To speed up training and generation, we use preprocessor of length 128
106
# instead of full length 1024.
107
preprocessor = keras_hub.models.GPT2CausalLMPreprocessor.from_preset(
108
"gpt2_base_en",
109
sequence_length=128,
110
)
111
gpt2_lm = keras_hub.models.GPT2CausalLM.from_preset(
112
"gpt2_base_en", preprocessor=preprocessor
113
)
114
115
"""
116
Once the model is loaded, you can use it to generate some text right away. Run
117
the cells below to give it a try. It's as simple as calling a single function
118
*generate()*:
119
"""
120
121
start = time.time()
122
123
output = gpt2_lm.generate("My trip to Yosemite was", max_length=200)
124
print("\nGPT-2 output:")
125
print(output)
126
127
end = time.time()
128
print(f"TOTAL TIME ELAPSED: {end - start:.2f}s")
129
130
"""
131
Try another one:
132
"""
133
134
start = time.time()
135
136
output = gpt2_lm.generate("That Italian restaurant is", max_length=200)
137
print("\nGPT-2 output:")
138
print(output)
139
140
end = time.time()
141
print(f"TOTAL TIME ELAPSED: {end - start:.2f}s")
142
143
"""
144
Notice how much faster the second call is. This is because the computational
145
graph is [XLA compiled](https://www.tensorflow.org/xla) in the 1st run and
146
re-used in the 2nd behind the scenes.
147
148
The quality of the generated text looks OK, but we can improve it via
149
fine-tuning.
150
"""
151
152
"""
153
## More on the GPT-2 model from KerasHub
154
155
Next up, we will actually fine-tune the model to update its parameters, but
156
before we do, let's take a look at the full set of tools we have to for working
157
with for GPT2.
158
159
The code of GPT2 can be found
160
[here](https://github.com/keras-team/keras-hub/blob/master/keras_hub/models/gpt2/).
161
Conceptually the `GPT2CausalLM` can be hierarchically broken down into several
162
modules in KerasHub, all of which have a *from_preset()* function that loads a
163
pretrained model:
164
165
- `keras_hub.models.GPT2Tokenizer`: The tokenizer used by GPT2 model, which is a
166
[byte-pair encoder](https://huggingface.co/course/chapter6/5?fw=pt).
167
- `keras_hub.models.GPT2CausalLMPreprocessor`: the preprocessor used by GPT2
168
causal LM training. It does the tokenization along with other preprocessing
169
works such as creating the label and appending the end token.
170
- `keras_hub.models.GPT2Backbone`: the GPT2 model, which is a stack of
171
`keras_hub.layers.TransformerDecoder`. This is usually just referred as
172
`GPT2`.
173
- `keras_hub.models.GPT2CausalLM`: wraps `GPT2Backbone`, it multiplies the
174
output of `GPT2Backbone` by embedding matrix to generate logits over
175
vocab tokens.
176
"""
177
178
"""
179
## Finetune on Reddit dataset
180
181
Now you have the knowledge of the GPT-2 model from KerasHub, you can take one
182
step further to finetune the model so that it generates text in a specific
183
style, short or long, strict or casual. In this tutorial, we will use reddit
184
dataset for example.
185
"""
186
187
import tensorflow_datasets as tfds
188
189
reddit_ds = tfds.load("reddit_tifu", split="train", as_supervised=True)
190
191
"""
192
Let's take a look inside sample data from the reddit TensorFlow Dataset. There
193
are two features:
194
195
- **__document__**: text of the post.
196
- **__title__**: the title.
197
198
"""
199
200
for document, title in reddit_ds:
201
print(document.numpy())
202
print(title.numpy())
203
break
204
205
"""
206
In our case, we are performing next word prediction in a language model, so we
207
only need the 'document' feature.
208
"""
209
210
train_ds = (
211
reddit_ds.map(lambda document, _: document)
212
.batch(32)
213
.cache()
214
.prefetch(tf.data.AUTOTUNE)
215
)
216
217
"""
218
Now you can finetune the model using the familiar *fit()* function. Note that
219
`preprocessor` will be automatically called inside `fit` method since
220
`GPT2CausalLM` is a `keras_hub.models.Task` instance.
221
222
This step takes quite a bit of GPU memory and a long time if we were to train
223
it all the way to a fully trained state. Here we just use part of the dataset
224
for demo purposes.
225
"""
226
227
train_ds = train_ds.take(500)
228
num_epochs = 1
229
230
# Linearly decaying learning rate.
231
learning_rate = keras.optimizers.schedules.PolynomialDecay(
232
5e-5,
233
decay_steps=train_ds.cardinality() * num_epochs,
234
end_learning_rate=0.0,
235
)
236
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
237
gpt2_lm.compile(
238
optimizer=keras.optimizers.Adam(learning_rate),
239
loss=loss,
240
weighted_metrics=["accuracy"],
241
)
242
243
gpt2_lm.fit(train_ds, epochs=num_epochs)
244
245
"""
246
After fine-tuning is finished, you can again generate text using the same
247
*generate()* function. This time, the text will be closer to Reddit writing
248
style, and the generated length will be close to our preset length in the
249
training set.
250
"""
251
252
start = time.time()
253
254
output = gpt2_lm.generate("I like basketball", max_length=200)
255
print("\nGPT-2 output:")
256
print(output)
257
258
end = time.time()
259
print(f"TOTAL TIME ELAPSED: {end - start:.2f}s")
260
261
"""
262
## Into the Sampling Method
263
264
In KerasHub, we offer a few sampling methods, e.g., contrastive search,
265
Top-K and beam sampling. By default, our `GPT2CausalLM` uses Top-k search, but
266
you can choose your own sampling method.
267
268
Much like optimizer and activations, there are two ways to specify your custom
269
sampler:
270
271
- Use a string identifier, such as "greedy", you are using the default
272
configuration via this way.
273
- Pass a `keras_hub.samplers.Sampler` instance, you can use custom configuration
274
via this way.
275
"""
276
277
# Use a string identifier.
278
gpt2_lm.compile(sampler="top_k")
279
output = gpt2_lm.generate("I like basketball", max_length=200)
280
print("\nGPT-2 output:")
281
print(output)
282
283
# Use a `Sampler` instance. `GreedySampler` tends to repeat itself,
284
greedy_sampler = keras_hub.samplers.GreedySampler()
285
gpt2_lm.compile(sampler=greedy_sampler)
286
287
output = gpt2_lm.generate("I like basketball", max_length=200)
288
print("\nGPT-2 output:")
289
print(output)
290
291
"""
292
For more details on KerasHub `Sampler` class, you can check the code
293
[here](https://github.com/keras-team/keras-hub/tree/master/keras_hub/samplers).
294
"""
295
296
"""
297
## Finetune on Chinese Poem Dataset
298
299
We can also finetune GPT2 on non-English datasets. For readers knowing Chinese,
300
this part illustrates how to fine-tune GPT2 on Chinese poem dataset to teach our
301
model to become a poet!
302
303
Because GPT2 uses byte-pair encoder, and the original pretraining dataset
304
contains some Chinese characters, we can use the original vocab to finetune on
305
Chinese dataset.
306
"""
307
308
"""shell
309
# Load chinese poetry dataset.
310
git clone https://github.com/chinese-poetry/chinese-poetry.git
311
"""
312
313
"""
314
Load text from the json file. We only use《全唐诗》for demo purposes.
315
"""
316
317
import os
318
import json
319
320
poem_collection = []
321
for file in os.listdir("chinese-poetry/全唐诗"):
322
if ".json" not in file or "poet" not in file:
323
continue
324
full_filename = "%s/%s" % ("chinese-poetry/全唐诗", file)
325
with open(full_filename, "r") as f:
326
content = json.load(f)
327
poem_collection.extend(content)
328
329
paragraphs = ["".join(data["paragraphs"]) for data in poem_collection]
330
331
"""
332
Let's take a look at sample data.
333
"""
334
335
print(paragraphs[0])
336
337
"""
338
Similar as Reddit example, we convert to TF dataset, and only use partial data
339
to train.
340
"""
341
342
train_ds = (
343
tf.data.Dataset.from_tensor_slices(paragraphs)
344
.batch(16)
345
.cache()
346
.prefetch(tf.data.AUTOTUNE)
347
)
348
349
# Running through the whole dataset takes long, only take `500` and run 1
350
# epochs for demo purposes.
351
train_ds = train_ds.take(500)
352
num_epochs = 1
353
354
learning_rate = keras.optimizers.schedules.PolynomialDecay(
355
5e-4,
356
decay_steps=train_ds.cardinality() * num_epochs,
357
end_learning_rate=0.0,
358
)
359
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
360
gpt2_lm.compile(
361
optimizer=keras.optimizers.Adam(learning_rate),
362
loss=loss,
363
weighted_metrics=["accuracy"],
364
)
365
366
gpt2_lm.fit(train_ds, epochs=num_epochs)
367
368
"""
369
Let's check the result!
370
"""
371
372
output = gpt2_lm.generate("昨夜雨疏风骤", max_length=200)
373
print(output)
374
375
"""
376
Not bad 😀
377
"""
378
379