Path: blob/master/examples/generative/ipynb/text_generation_gpt.ipynb
3508 views
GPT text generation from scratch with KerasHub
Author: Jesse Chan
Date created: 2022/07/25
Last modified: 2022/07/25
Description: Using KerasHub to train a mini-GPT model for text generation.
Introduction
In this example, we will use KerasHub to build a scaled down Generative Pre-Trained (GPT) model. GPT is a Transformer-based model that allows you to generate sophisticated text from a prompt.
We will train the model on the simplebooks-92 corpus, which is a dataset made from several novels. It is a good dataset for this example since it has a small vocabulary and high word frequency, which is beneficial when training a model with few parameters.
This example combines concepts from Text generation with a miniature GPT with KerasHub abstractions. We will demonstrate how KerasHub tokenization, layers and metrics simplify the training process, and then show how to generate output text using the KerasHub sampling utilities.
Note: If you are running this example on a Colab, make sure to enable GPU runtime for faster training.
This example requires KerasHub. You can install it via the following command: pip install keras-hub
Setup
Settings & hyperparameters
Load the data
Now, let's download the dataset! The SimpleBooks dataset consists of 1,573 Gutenberg books, and has one of the smallest vocabulary size to word-level tokens ratio. It has a vocabulary size of ~98k, a third of WikiText-103's, with around the same number of tokens (~100M). This makes it easy to fit a small model.
Train the tokenizer
We train the tokenizer from the training dataset for a vocabulary size of VOCAB_SIZE
, which is a tuned hyperparameter. We want to limit the vocabulary as much as possible, as we will see later on that it has a large effect on the number of model parameters. We also don't want to include too few vocabulary terms, or there would be too many out-of-vocabulary (OOV) sub-words. In addition, three tokens are reserved in the vocabulary:
"[PAD]"
for padding sequences toSEQ_LEN
. This token has index 0 in bothreserved_tokens
andvocab
, sinceWordPieceTokenizer
(and other layers) consider0
/vocab[0]
as the default padding."[UNK]"
for OOV sub-words, which should match the defaultoov_token="[UNK]"
inWordPieceTokenizer
."[BOS]"
stands for beginning of sentence, but here technically it is a token representing the beginning of each line of training data.
Load tokenizer
We use the vocabulary data to initialize keras_hub.tokenizers.WordPieceTokenizer
. WordPieceTokenizer is an efficient implementation of the WordPiece algorithm used by BERT and other models. It will strip, lower-case and do other irreversible preprocessing operations.
Tokenize data
We preprocess the dataset by tokenizing and splitting it into features
and labels
.
Build the model
We create our scaled down GPT model with the following layers:
One
keras_hub.layers.TokenAndPositionEmbedding
layer, which combines the embedding for the token and its position.Multiple
keras_hub.layers.TransformerDecoder
layers, with the default causal masking. The layer has no cross-attention when run with decoder sequence only.One final dense linear layer
Let's take a look at our model summary - a large majority of the parameters are in the token_and_position_embedding
and the output dense
layer! This means that the vocabulary size (VOCAB_SIZE
) has a large effect on the size of the model, while the number of Transformer decoder layers (NUM_LAYERS
) doesn't affect it as much.
Training
Now that we have our model, let's train it with the fit()
method.
Inference
With our trained model, we can test it out to gauge its performance. To do this we can seed our model with an input sequence starting with the "[BOS]"
token, and progressively sample the model by making predictions for each subsequent token in a loop.
To start lets build a prompt with the same shape as our model inputs, containing only the "[BOS]"
token.
We will use the keras_hub.samplers
module for inference, which requires a callback function wrapping the model we just trained. This wrapper calls the model and returns the logit predictions for the current token we are generating.
Note: There are two pieces of more advanced functionality available when defining your callback. The first is the ability to take in a cache
of states computed in previous generation steps, which can be used to speed up generation. The second is the ability to output the final dense "hidden state" of each generated token. This is used by keras_hub.samplers.ContrastiveSampler
, which avoids repetition by penalizing repeated hidden states. Both are optional, and we will ignore them for now.
Creating the wrapper function is the most complex part of using these functions. Now that it's done, let's test out the different utilities, starting with greedy search.
Greedy search
We greedily pick the most probable token at each timestep. In other words, we get the argmax of the model output.
As you can see, greedy search starts out making some sense, but quickly starts repeating itself. This is a common problem with text generation that can be fixed by some of the probabilistic text generation utilities shown later on!
Beam search
At a high-level, beam search keeps track of the num_beams
most probable sequences at each timestep, and predicts the best next token from all sequences. It is an improvement over greedy search since it stores more possibilities. However, it is less efficient than greedy search since it has to compute and store multiple potential sequences.
Note: beam search with num_beams=1
is identical to greedy search.
Similar to greedy search, beam search quickly starts repeating itself, since it is still a deterministic method.
Random search
Random search is our first probabilistic method. At each time step, it samples the next token using the softmax probabilities provided by the model.
VoilĂ , no repetitions! However, with random search, we may see some nonsensical words appearing since any word in the vocabulary has a chance of appearing with this sampling method. This is fixed by our next search utility, top-k search.
Top-K search
Similar to random search, we sample the next token from the probability distribution provided by the model. The only difference is that here, we select out the top k
most probable tokens, and distribute the probability mass over them before sampling. This way, we won't be sampling from low probability tokens, and hence we would have less nonsensical words!
Top-P search
Even with the top-k search, there is something to improve upon. With top-k search, the number k
is fixed, which means it selects the same number of tokens for any probability distribution. Consider two scenarios, one where the probability mass is concentrated over 2 words and another where the probability mass is evenly concentrated across 10. Should we choose k=2
or k=10
? There is no one size that fits all k
here.
This is where top-p search comes in! Instead of choosing a k
, we choose a probability p
that we want the probabilities of the top tokens to sum up to. This way, we can dynamically adjust the k
based on the probability distribution. By setting p=0.9
, if 90% of the probability mass is concentrated on the top 2 tokens, we can filter out the top 2 tokens to sample from. If instead the 90% is distributed over 10 tokens, it will similarly filter out the top 10 tokens to sample from.
Using callbacks for text generation
We can also wrap the utilities in a callback, which allows you to print out a prediction sequence for every epoch of the model! Here is an example of a callback for top-k search:
Conclusion
To recap, in this example, we use KerasHub layers to train a sub-word vocabulary, tokenize training data, create a miniature GPT model, and perform inference with the text generation library.
If you would like to understand how Transformers work, or learn more about training the full GPT model, here are some further readings:
Attention Is All You Need Vaswani et al., 2017
GPT-3 Paper Brown et al., 2020