Path: blob/master/examples/generative/text_generation_gpt.py
3507 views
"""1Title: GPT text generation from scratch with KerasHub2Author: [Jesse Chan](https://github.com/jessechancy)3Date created: 2022/07/254Last modified: 2022/07/255Description: Using KerasHub to train a mini-GPT model for text generation.6Accelerator: GPU7"""89"""10## Introduction1112In this example, we will use KerasHub to build a scaled down Generative13Pre-Trained (GPT) model. GPT is a Transformer-based model that allows you to generate14sophisticated text from a prompt.1516We will train the model on the [simplebooks-92](https://arxiv.org/abs/1911.12391) corpus,17which is a dataset made from several novels. It is a good dataset for this example since18it has a small vocabulary and high word frequency, which is beneficial when training a19model with few parameters.2021This example combines concepts from22[Text generation with a miniature GPT](https://keras.io/examples/generative/text_generation_with_miniature_gpt/)23with KerasHub abstractions. We will demonstrate how KerasHub tokenization, layers and24metrics simplify the training25process, and then show how to generate output text using the KerasHub sampling utilities.2627Note: If you are running this example on a Colab,28make sure to enable GPU runtime for faster training.2930This example requires KerasHub. You can install it via the following command:31`pip install keras-hub`32"""3334"""35## Setup36"""3738"""shell39pip install -q --upgrade keras-hub40pip install -q --upgrade keras # Upgrade to Keras 3.41"""4243import os44import keras_hub45import keras4647import tensorflow.data as tf_data48import tensorflow.strings as tf_strings4950"""51## Settings & hyperparameters52"""5354# Data55BATCH_SIZE = 6456MIN_STRING_LEN = 512 # Strings shorter than this will be discarded57SEQ_LEN = 128 # Length of training sequences, in tokens5859# Model60EMBED_DIM = 25661FEED_FORWARD_DIM = 12862NUM_HEADS = 363NUM_LAYERS = 264VOCAB_SIZE = 5000 # Limits parameters in model.6566# Training67EPOCHS = 56869# Inference70NUM_TOKENS_TO_GENERATE = 807172"""73## Load the data7475Now, let's download the dataset! The SimpleBooks dataset consists of 1,573 Gutenberg books, and has76one of the smallest vocabulary size to word-level tokens ratio. It has a vocabulary size of ~98k,77a third of WikiText-103's, with around the same number of tokens (~100M). This makes it easy to fit a small model.78"""7980keras.utils.get_file(81origin="https://dldata-public.s3.us-east-2.amazonaws.com/simplebooks.zip",82extract=True,83)84dir = os.path.expanduser("~/.keras/datasets/simplebooks/")8586# Load simplebooks-92 train set and filter out short lines.87raw_train_ds = (88tf_data.TextLineDataset(dir + "simplebooks-92-raw/train.txt")89.filter(lambda x: tf_strings.length(x) > MIN_STRING_LEN)90.batch(BATCH_SIZE)91.shuffle(buffer_size=256)92)9394# Load simplebooks-92 validation set and filter out short lines.95raw_val_ds = (96tf_data.TextLineDataset(dir + "simplebooks-92-raw/valid.txt")97.filter(lambda x: tf_strings.length(x) > MIN_STRING_LEN)98.batch(BATCH_SIZE)99)100101"""102## Train the tokenizer103104We train the tokenizer from the training dataset for a vocabulary size of `VOCAB_SIZE`,105which is a tuned hyperparameter. We want to limit the vocabulary as much as possible, as106we will see later on107that it has a large effect on the number of model parameters. We also don't want to include108*too few* vocabulary terms, or there would be too many out-of-vocabulary (OOV) sub-words. In109addition, three tokens are reserved in the vocabulary:110111- `"[PAD]"` for padding sequences to `SEQ_LEN`. This token has index 0 in both112`reserved_tokens` and `vocab`, since `WordPieceTokenizer` (and other layers) consider113`0`/`vocab[0]` as the default padding.114- `"[UNK]"` for OOV sub-words, which should match the default `oov_token="[UNK]"` in115`WordPieceTokenizer`.116- `"[BOS]"` stands for beginning of sentence, but here technically it is a token117representing the beginning of each line of training data.118"""119120# Train tokenizer vocabulary121vocab = keras_hub.tokenizers.compute_word_piece_vocabulary(122raw_train_ds,123vocabulary_size=VOCAB_SIZE,124lowercase=True,125reserved_tokens=["[PAD]", "[UNK]", "[BOS]"],126)127128"""129## Load tokenizer130131We use the vocabulary data to initialize132`keras_hub.tokenizers.WordPieceTokenizer`. WordPieceTokenizer is an efficient133implementation of the WordPiece algorithm used by BERT and other models. It will strip,134lower-case and do other irreversible preprocessing operations.135"""136137tokenizer = keras_hub.tokenizers.WordPieceTokenizer(138vocabulary=vocab,139sequence_length=SEQ_LEN,140lowercase=True,141)142143"""144## Tokenize data145146We preprocess the dataset by tokenizing and splitting it into `features` and `labels`.147"""148149# packer adds a start token150start_packer = keras_hub.layers.StartEndPacker(151sequence_length=SEQ_LEN,152start_value=tokenizer.token_to_id("[BOS]"),153)154155156def preprocess(inputs):157outputs = tokenizer(inputs)158features = start_packer(outputs)159labels = outputs160return features, labels161162163# Tokenize and split into train and label sequences.164train_ds = raw_train_ds.map(preprocess, num_parallel_calls=tf_data.AUTOTUNE).prefetch(165tf_data.AUTOTUNE166)167val_ds = raw_val_ds.map(preprocess, num_parallel_calls=tf_data.AUTOTUNE).prefetch(168tf_data.AUTOTUNE169)170171"""172## Build the model173174We create our scaled down GPT model with the following layers:175176- One `keras_hub.layers.TokenAndPositionEmbedding` layer, which combines the embedding177for the token and its position.178- Multiple `keras_hub.layers.TransformerDecoder` layers, with the default causal masking.179The layer has no cross-attention when run with decoder sequence only.180- One final dense linear layer181"""182183inputs = keras.layers.Input(shape=(None,), dtype="int32")184# Embedding.185embedding_layer = keras_hub.layers.TokenAndPositionEmbedding(186vocabulary_size=VOCAB_SIZE,187sequence_length=SEQ_LEN,188embedding_dim=EMBED_DIM,189mask_zero=True,190)191x = embedding_layer(inputs)192# Transformer decoders.193for _ in range(NUM_LAYERS):194decoder_layer = keras_hub.layers.TransformerDecoder(195num_heads=NUM_HEADS,196intermediate_dim=FEED_FORWARD_DIM,197)198x = decoder_layer(x) # Giving one argument only skips cross-attention.199# Output.200outputs = keras.layers.Dense(VOCAB_SIZE)(x)201model = keras.Model(inputs=inputs, outputs=outputs)202loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)203perplexity = keras_hub.metrics.Perplexity(from_logits=True, mask_token_id=0)204model.compile(optimizer="adam", loss=loss_fn, metrics=[perplexity])205206"""207Let's take a look at our model summary - a large majority of the208parameters are in the `token_and_position_embedding` and the output `dense` layer!209This means that the vocabulary size (`VOCAB_SIZE`) has a large effect on the size of the model,210while the number of Transformer decoder layers (`NUM_LAYERS`) doesn't affect it as much.211"""212213model.summary()214215"""216## Training217218Now that we have our model, let's train it with the `fit()` method.219"""220221model.fit(train_ds, validation_data=val_ds, epochs=EPOCHS)222223"""224## Inference225226With our trained model, we can test it out to gauge its performance. To do this227we can seed our model with an input sequence starting with the `"[BOS]"` token,228and progressively sample the model by making predictions for each subsequent229token in a loop.230231To start lets build a prompt with the same shape as our model inputs, containing232only the `"[BOS]"` token.233"""234235# The "packer" layers adds the [BOS] token for us.236prompt_tokens = start_packer(tokenizer([""]))237prompt_tokens238239"""240We will use the `keras_hub.samplers` module for inference, which requires a241callback function wrapping the model we just trained. This wrapper calls242the model and returns the logit predictions for the current token we are243generating.244245Note: There are two pieces of more advanced functionality available when246defining your callback. The first is the ability to take in a `cache` of states247computed in previous generation steps, which can be used to speed up generation.248The second is the ability to output the final dense "hidden state" of each249generated token. This is used by `keras_hub.samplers.ContrastiveSampler`, which250avoids repetition by penalizing repeated hidden states. Both are optional, and251we will ignore them for now.252"""253254255def next(prompt, cache, index):256logits = model(prompt)[:, index - 1, :]257# Ignore hidden states for now; only needed for contrastive search.258hidden_states = None259return logits, hidden_states, cache260261262"""263Creating the wrapper function is the most complex part of using these functions. Now that264it's done, let's test out the different utilities, starting with greedy search.265"""266267"""268### Greedy search269270We greedily pick the most probable token at each timestep. In other words, we get the271argmax of the model output.272"""273274sampler = keras_hub.samplers.GreedySampler()275output_tokens = sampler(276next=next,277prompt=prompt_tokens,278index=1, # Start sampling immediately after the [BOS] token.279)280txt = tokenizer.detokenize(output_tokens)281print(f"Greedy search generated text: \n{txt}\n")282283"""284As you can see, greedy search starts out making some sense, but quickly starts repeating285itself. This is a common problem with text generation that can be fixed by some of the286probabilistic text generation utilities shown later on!287"""288289"""290### Beam search291292At a high-level, beam search keeps track of the `num_beams` most probable sequences at293each timestep, and predicts the best next token from all sequences. It is an improvement294over greedy search since it stores more possibilities. However, it is less efficient than295greedy search since it has to compute and store multiple potential sequences.296297**Note:** beam search with `num_beams=1` is identical to greedy search.298"""299300sampler = keras_hub.samplers.BeamSampler(num_beams=10)301output_tokens = sampler(302next=next,303prompt=prompt_tokens,304index=1,305)306txt = tokenizer.detokenize(output_tokens)307print(f"Beam search generated text: \n{txt}\n")308309"""310Similar to greedy search, beam search quickly starts repeating itself, since it is still311a deterministic method.312"""313314"""315### Random search316317Random search is our first probabilistic method. At each time step, it samples the next318token using the softmax probabilities provided by the model.319"""320321sampler = keras_hub.samplers.RandomSampler()322output_tokens = sampler(323next=next,324prompt=prompt_tokens,325index=1,326)327txt = tokenizer.detokenize(output_tokens)328print(f"Random search generated text: \n{txt}\n")329330"""331VoilĂ , no repetitions! However, with random search, we may see some nonsensical words332appearing since any word in the vocabulary has a chance of appearing with this sampling333method. This is fixed by our next search utility, top-k search.334"""335336"""337### Top-K search338339Similar to random search, we sample the next token from the probability distribution340provided by the model. The only difference is that here, we select out the top `k` most341probable tokens, and distribute the probability mass over them before sampling. This way,342we won't be sampling from low probability tokens, and hence we would have less343nonsensical words!344"""345346sampler = keras_hub.samplers.TopKSampler(k=10)347output_tokens = sampler(348next=next,349prompt=prompt_tokens,350index=1,351)352txt = tokenizer.detokenize(output_tokens)353print(f"Top-K search generated text: \n{txt}\n")354355"""356### Top-P search357358Even with the top-k search, there is something to improve upon. With top-k search, the359number `k` is fixed, which means it selects the same number of tokens for any probability360distribution. Consider two scenarios, one where the probability mass is concentrated over3612 words and another where the probability mass is evenly concentrated across 10. Should362we choose `k=2` or `k=10`? There is no one size that fits all `k` here.363364This is where top-p search comes in! Instead of choosing a `k`, we choose a probability365`p` that we want the probabilities of the top tokens to sum up to. This way, we can366dynamically adjust the `k` based on the probability distribution. By setting `p=0.9`, if36790% of the probability mass is concentrated on the top 2 tokens, we can filter out the368top 2 tokens to sample from. If instead the 90% is distributed over 10 tokens, it will369similarly filter out the top 10 tokens to sample from.370"""371372sampler = keras_hub.samplers.TopPSampler(p=0.5)373output_tokens = sampler(374next=next,375prompt=prompt_tokens,376index=1,377)378txt = tokenizer.detokenize(output_tokens)379print(f"Top-P search generated text: \n{txt}\n")380381"""382### Using callbacks for text generation383384We can also wrap the utilities in a callback, which allows you to print out a prediction385sequence for every epoch of the model! Here is an example of a callback for top-k search:386"""387388389class TopKTextGenerator(keras.callbacks.Callback):390"""A callback to generate text from a trained model using top-k."""391392def __init__(self, k):393self.sampler = keras_hub.samplers.TopKSampler(k)394395def on_epoch_end(self, epoch, logs=None):396output_tokens = self.sampler(397next=next,398prompt=prompt_tokens,399index=1,400)401txt = tokenizer.detokenize(output_tokens)402print(f"Top-K search generated text: \n{txt}\n")403404405text_generation_callback = TopKTextGenerator(k=10)406# Dummy training loop to demonstrate callback.407model.fit(train_ds.take(1), verbose=2, epochs=2, callbacks=[text_generation_callback])408409"""410## Conclusion411412To recap, in this example, we use KerasHub layers to train a sub-word vocabulary,413tokenize training data, create a miniature GPT model, and perform inference with the414text generation library.415416If you would like to understand how Transformers work, or learn more about training the417full GPT model, here are some further readings:418419- Attention Is All You Need [Vaswani et al., 2017](https://arxiv.org/abs/1706.03762)420- GPT-3 Paper [Brown et al., 2020](https://arxiv.org/abs/2005.14165)421"""422423424