Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/generative/text_generation_with_miniature_gpt.py
3507 views
1
"""
2
Title: Text generation with a miniature GPT
3
Author: [Apoorv Nandan](https://twitter.com/NandanApoorv)
4
Date created: 2020/05/29
5
Last modified: 2020/05/29
6
Description: Implement a miniature version of GPT and train it to generate text.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
This example demonstrates how to implement an autoregressive language model
14
using a miniature version of the GPT model.
15
The model consists of a single Transformer block with causal masking
16
in its attention layer.
17
We use the text from the IMDB sentiment classification dataset for training
18
and generate new movie reviews for a given prompt.
19
When using this script with your own dataset, make sure it has at least
20
1 million words.
21
22
This example should be run with `tf-nightly>=2.3.0-dev20200531` or
23
with TensorFlow 2.3 or higher.
24
25
**References:**
26
27
- [GPT](https://www.semanticscholar.org/paper/Improving-Language-Understanding-by-Generative-Radford/cd18800a0fe0b668a1cc19f2ec95b5003d0a5035)
28
- [GPT-2](https://www.semanticscholar.org/paper/Language-Models-are-Unsupervised-Multitask-Learners-Radford-Wu/9405cc0d6169988371b2755e573cc28650d14dfe)
29
- [GPT-3](https://arxiv.org/abs/2005.14165)
30
"""
31
"""
32
## Setup
33
"""
34
# We set the backend to TensorFlow. The code works with
35
# both `tensorflow` and `torch`. It does not work with JAX
36
# due to the behavior of `jax.numpy.tile` in a jit scope
37
# (used in `causal_attention_mask()`: `tile` in JAX does
38
# not support a dynamic `reps` argument.
39
# You can make the code work in JAX by wrapping the
40
# inside of the `causal_attention_mask` function in
41
# a decorator to prevent jit compilation:
42
# `with jax.ensure_compile_time_eval():`.
43
import os
44
45
os.environ["KERAS_BACKEND"] = "tensorflow"
46
47
import keras
48
from keras import layers
49
from keras import ops
50
from keras.layers import TextVectorization
51
import numpy as np
52
import os
53
import string
54
import random
55
import tensorflow
56
import tensorflow.data as tf_data
57
import tensorflow.strings as tf_strings
58
59
60
"""
61
## Implement a Transformer block as a layer
62
"""
63
64
65
def causal_attention_mask(batch_size, n_dest, n_src, dtype):
66
"""
67
Mask the upper half of the dot product matrix in self attention.
68
This prevents flow of information from future tokens to current token.
69
1's in the lower triangle, counting from the lower right corner.
70
"""
71
i = ops.arange(n_dest)[:, None]
72
j = ops.arange(n_src)
73
m = i >= j - n_src + n_dest
74
mask = ops.cast(m, dtype)
75
mask = ops.reshape(mask, [1, n_dest, n_src])
76
mult = ops.concatenate(
77
[ops.expand_dims(batch_size, -1), ops.convert_to_tensor([1, 1])], 0
78
)
79
return ops.tile(mask, mult)
80
81
82
class TransformerBlock(layers.Layer):
83
def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):
84
super().__init__()
85
self.att = layers.MultiHeadAttention(num_heads, embed_dim)
86
self.ffn = keras.Sequential(
87
[
88
layers.Dense(ff_dim, activation="relu"),
89
layers.Dense(embed_dim),
90
]
91
)
92
self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
93
self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
94
self.dropout1 = layers.Dropout(rate)
95
self.dropout2 = layers.Dropout(rate)
96
97
def call(self, inputs):
98
input_shape = ops.shape(inputs)
99
batch_size = input_shape[0]
100
seq_len = input_shape[1]
101
causal_mask = causal_attention_mask(batch_size, seq_len, seq_len, "bool")
102
attention_output = self.att(inputs, inputs, attention_mask=causal_mask)
103
attention_output = self.dropout1(attention_output)
104
out1 = self.layernorm1(inputs + attention_output)
105
ffn_output = self.ffn(out1)
106
ffn_output = self.dropout2(ffn_output)
107
return self.layernorm2(out1 + ffn_output)
108
109
110
"""
111
## Implement an embedding layer
112
113
Create two separate embedding layers: one for tokens and one for token index
114
(positions).
115
"""
116
117
118
class TokenAndPositionEmbedding(layers.Layer):
119
def __init__(self, maxlen, vocab_size, embed_dim):
120
super().__init__()
121
self.token_emb = layers.Embedding(input_dim=vocab_size, output_dim=embed_dim)
122
self.pos_emb = layers.Embedding(input_dim=maxlen, output_dim=embed_dim)
123
124
def call(self, x):
125
maxlen = ops.shape(x)[-1]
126
positions = ops.arange(0, maxlen, 1)
127
positions = self.pos_emb(positions)
128
x = self.token_emb(x)
129
return x + positions
130
131
132
"""
133
## Implement the miniature GPT model
134
"""
135
vocab_size = 20000 # Only consider the top 20k words
136
maxlen = 80 # Max sequence size
137
embed_dim = 256 # Embedding size for each token
138
num_heads = 2 # Number of attention heads
139
feed_forward_dim = 256 # Hidden layer size in feed forward network inside transformer
140
141
142
def create_model():
143
inputs = layers.Input(shape=(maxlen,), dtype="int32")
144
embedding_layer = TokenAndPositionEmbedding(maxlen, vocab_size, embed_dim)
145
x = embedding_layer(inputs)
146
transformer_block = TransformerBlock(embed_dim, num_heads, feed_forward_dim)
147
x = transformer_block(x)
148
outputs = layers.Dense(vocab_size)(x)
149
model = keras.Model(inputs=inputs, outputs=[outputs, x])
150
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
151
model.compile(
152
"adam",
153
loss=[loss_fn, None],
154
) # No loss and optimization based on word embeddings from transformer block
155
return model
156
157
158
"""
159
## Prepare the data for word-level language modelling
160
161
Download the IMDB dataset and combine training and validation sets for a text
162
generation task.
163
"""
164
165
"""shell
166
curl -O https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz
167
tar -xf aclImdb_v1.tar.gz
168
"""
169
170
171
batch_size = 128
172
173
# The dataset contains each review in a separate text file
174
# The text files are present in four different folders
175
# Create a list all files
176
filenames = []
177
directories = [
178
"aclImdb/train/pos",
179
"aclImdb/train/neg",
180
"aclImdb/test/pos",
181
"aclImdb/test/neg",
182
]
183
for dir in directories:
184
for f in os.listdir(dir):
185
filenames.append(os.path.join(dir, f))
186
187
print(f"{len(filenames)} files")
188
189
# Create a dataset from text files
190
random.shuffle(filenames)
191
text_ds = tf_data.TextLineDataset(filenames)
192
text_ds = text_ds.shuffle(buffer_size=256)
193
text_ds = text_ds.batch(batch_size)
194
195
196
def custom_standardization(input_string):
197
"""Remove html line-break tags and handle punctuation"""
198
lowercased = tf_strings.lower(input_string)
199
stripped_html = tf_strings.regex_replace(lowercased, "<br />", " ")
200
return tf_strings.regex_replace(stripped_html, f"([{string.punctuation}])", r" \1")
201
202
203
# Create a vectorization layer and adapt it to the text
204
vectorize_layer = TextVectorization(
205
standardize=custom_standardization,
206
max_tokens=vocab_size - 1,
207
output_mode="int",
208
output_sequence_length=maxlen + 1,
209
)
210
vectorize_layer.adapt(text_ds)
211
vocab = vectorize_layer.get_vocabulary() # To get words back from token indices
212
213
214
def prepare_lm_inputs_labels(text):
215
"""
216
Shift word sequences by 1 position so that the target for position (i) is
217
word at position (i+1). The model will use all words up till position (i)
218
to predict the next word.
219
"""
220
text = tensorflow.expand_dims(text, -1)
221
tokenized_sentences = vectorize_layer(text)
222
x = tokenized_sentences[:, :-1]
223
y = tokenized_sentences[:, 1:]
224
return x, y
225
226
227
text_ds = text_ds.map(prepare_lm_inputs_labels, num_parallel_calls=tf_data.AUTOTUNE)
228
text_ds = text_ds.prefetch(tf_data.AUTOTUNE)
229
230
231
"""
232
## Implement a Keras callback for generating text
233
"""
234
235
236
class TextGenerator(keras.callbacks.Callback):
237
"""A callback to generate text from a trained model.
238
1. Feed some starting prompt to the model
239
2. Predict probabilities for the next token
240
3. Sample the next token and add it to the next input
241
242
Arguments:
243
max_tokens: Integer, the number of tokens to be generated after prompt.
244
start_tokens: List of integers, the token indices for the starting prompt.
245
index_to_word: List of strings, obtained from the TextVectorization layer.
246
top_k: Integer, sample from the `top_k` token predictions.
247
print_every: Integer, print after this many epochs.
248
"""
249
250
def __init__(
251
self, max_tokens, start_tokens, index_to_word, top_k=10, print_every=1
252
):
253
self.max_tokens = max_tokens
254
self.start_tokens = start_tokens
255
self.index_to_word = index_to_word
256
self.print_every = print_every
257
self.k = top_k
258
259
def sample_from(self, logits):
260
logits, indices = ops.top_k(logits, k=self.k, sorted=True)
261
indices = np.asarray(indices).astype("int32")
262
preds = keras.activations.softmax(ops.expand_dims(logits, 0))[0]
263
preds = np.asarray(preds).astype("float32")
264
return np.random.choice(indices, p=preds)
265
266
def detokenize(self, number):
267
return self.index_to_word[number]
268
269
def on_epoch_end(self, epoch, logs=None):
270
start_tokens = [_ for _ in self.start_tokens]
271
if (epoch + 1) % self.print_every != 0:
272
return
273
num_tokens_generated = 0
274
tokens_generated = []
275
while num_tokens_generated <= self.max_tokens:
276
pad_len = maxlen - len(start_tokens)
277
sample_index = len(start_tokens) - 1
278
if pad_len < 0:
279
x = start_tokens[:maxlen]
280
sample_index = maxlen - 1
281
elif pad_len > 0:
282
x = start_tokens + [0] * pad_len
283
else:
284
x = start_tokens
285
x = np.array([x])
286
y, _ = self.model.predict(x, verbose=0)
287
sample_token = self.sample_from(y[0][sample_index])
288
tokens_generated.append(sample_token)
289
start_tokens.append(sample_token)
290
num_tokens_generated = len(tokens_generated)
291
txt = " ".join(
292
[self.detokenize(_) for _ in self.start_tokens + tokens_generated]
293
)
294
print(f"generated text:\n{txt}\n")
295
296
297
# Tokenize starting prompt
298
word_to_index = {}
299
for index, word in enumerate(vocab):
300
word_to_index[word] = index
301
302
start_prompt = "this movie is"
303
start_tokens = [word_to_index.get(_, 1) for _ in start_prompt.split()]
304
num_tokens_generated = 40
305
text_gen_callback = TextGenerator(num_tokens_generated, start_tokens, vocab)
306
307
308
"""
309
## Train the model
310
311
Note: This code should preferably be run on GPU.
312
"""
313
314
model = create_model()
315
316
model.fit(text_ds, verbose=2, epochs=25, callbacks=[text_gen_callback])
317
318