Path: blob/master/examples/generative/text_generation_with_miniature_gpt.py
3507 views
"""1Title: Text generation with a miniature GPT2Author: [Apoorv Nandan](https://twitter.com/NandanApoorv)3Date created: 2020/05/294Last modified: 2020/05/295Description: Implement a miniature version of GPT and train it to generate text.6Accelerator: GPU7"""89"""10## Introduction1112This example demonstrates how to implement an autoregressive language model13using a miniature version of the GPT model.14The model consists of a single Transformer block with causal masking15in its attention layer.16We use the text from the IMDB sentiment classification dataset for training17and generate new movie reviews for a given prompt.18When using this script with your own dataset, make sure it has at least191 million words.2021This example should be run with `tf-nightly>=2.3.0-dev20200531` or22with TensorFlow 2.3 or higher.2324**References:**2526- [GPT](https://www.semanticscholar.org/paper/Improving-Language-Understanding-by-Generative-Radford/cd18800a0fe0b668a1cc19f2ec95b5003d0a5035)27- [GPT-2](https://www.semanticscholar.org/paper/Language-Models-are-Unsupervised-Multitask-Learners-Radford-Wu/9405cc0d6169988371b2755e573cc28650d14dfe)28- [GPT-3](https://arxiv.org/abs/2005.14165)29"""30"""31## Setup32"""33# We set the backend to TensorFlow. The code works with34# both `tensorflow` and `torch`. It does not work with JAX35# due to the behavior of `jax.numpy.tile` in a jit scope36# (used in `causal_attention_mask()`: `tile` in JAX does37# not support a dynamic `reps` argument.38# You can make the code work in JAX by wrapping the39# inside of the `causal_attention_mask` function in40# a decorator to prevent jit compilation:41# `with jax.ensure_compile_time_eval():`.42import os4344os.environ["KERAS_BACKEND"] = "tensorflow"4546import keras47from keras import layers48from keras import ops49from keras.layers import TextVectorization50import numpy as np51import os52import string53import random54import tensorflow55import tensorflow.data as tf_data56import tensorflow.strings as tf_strings575859"""60## Implement a Transformer block as a layer61"""626364def causal_attention_mask(batch_size, n_dest, n_src, dtype):65"""66Mask the upper half of the dot product matrix in self attention.67This prevents flow of information from future tokens to current token.681's in the lower triangle, counting from the lower right corner.69"""70i = ops.arange(n_dest)[:, None]71j = ops.arange(n_src)72m = i >= j - n_src + n_dest73mask = ops.cast(m, dtype)74mask = ops.reshape(mask, [1, n_dest, n_src])75mult = ops.concatenate(76[ops.expand_dims(batch_size, -1), ops.convert_to_tensor([1, 1])], 077)78return ops.tile(mask, mult)798081class TransformerBlock(layers.Layer):82def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):83super().__init__()84self.att = layers.MultiHeadAttention(num_heads, embed_dim)85self.ffn = keras.Sequential(86[87layers.Dense(ff_dim, activation="relu"),88layers.Dense(embed_dim),89]90)91self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)92self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)93self.dropout1 = layers.Dropout(rate)94self.dropout2 = layers.Dropout(rate)9596def call(self, inputs):97input_shape = ops.shape(inputs)98batch_size = input_shape[0]99seq_len = input_shape[1]100causal_mask = causal_attention_mask(batch_size, seq_len, seq_len, "bool")101attention_output = self.att(inputs, inputs, attention_mask=causal_mask)102attention_output = self.dropout1(attention_output)103out1 = self.layernorm1(inputs + attention_output)104ffn_output = self.ffn(out1)105ffn_output = self.dropout2(ffn_output)106return self.layernorm2(out1 + ffn_output)107108109"""110## Implement an embedding layer111112Create two separate embedding layers: one for tokens and one for token index113(positions).114"""115116117class TokenAndPositionEmbedding(layers.Layer):118def __init__(self, maxlen, vocab_size, embed_dim):119super().__init__()120self.token_emb = layers.Embedding(input_dim=vocab_size, output_dim=embed_dim)121self.pos_emb = layers.Embedding(input_dim=maxlen, output_dim=embed_dim)122123def call(self, x):124maxlen = ops.shape(x)[-1]125positions = ops.arange(0, maxlen, 1)126positions = self.pos_emb(positions)127x = self.token_emb(x)128return x + positions129130131"""132## Implement the miniature GPT model133"""134vocab_size = 20000 # Only consider the top 20k words135maxlen = 80 # Max sequence size136embed_dim = 256 # Embedding size for each token137num_heads = 2 # Number of attention heads138feed_forward_dim = 256 # Hidden layer size in feed forward network inside transformer139140141def create_model():142inputs = layers.Input(shape=(maxlen,), dtype="int32")143embedding_layer = TokenAndPositionEmbedding(maxlen, vocab_size, embed_dim)144x = embedding_layer(inputs)145transformer_block = TransformerBlock(embed_dim, num_heads, feed_forward_dim)146x = transformer_block(x)147outputs = layers.Dense(vocab_size)(x)148model = keras.Model(inputs=inputs, outputs=[outputs, x])149loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)150model.compile(151"adam",152loss=[loss_fn, None],153) # No loss and optimization based on word embeddings from transformer block154return model155156157"""158## Prepare the data for word-level language modelling159160Download the IMDB dataset and combine training and validation sets for a text161generation task.162"""163164"""shell165curl -O https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz166tar -xf aclImdb_v1.tar.gz167"""168169170batch_size = 128171172# The dataset contains each review in a separate text file173# The text files are present in four different folders174# Create a list all files175filenames = []176directories = [177"aclImdb/train/pos",178"aclImdb/train/neg",179"aclImdb/test/pos",180"aclImdb/test/neg",181]182for dir in directories:183for f in os.listdir(dir):184filenames.append(os.path.join(dir, f))185186print(f"{len(filenames)} files")187188# Create a dataset from text files189random.shuffle(filenames)190text_ds = tf_data.TextLineDataset(filenames)191text_ds = text_ds.shuffle(buffer_size=256)192text_ds = text_ds.batch(batch_size)193194195def custom_standardization(input_string):196"""Remove html line-break tags and handle punctuation"""197lowercased = tf_strings.lower(input_string)198stripped_html = tf_strings.regex_replace(lowercased, "<br />", " ")199return tf_strings.regex_replace(stripped_html, f"([{string.punctuation}])", r" \1")200201202# Create a vectorization layer and adapt it to the text203vectorize_layer = TextVectorization(204standardize=custom_standardization,205max_tokens=vocab_size - 1,206output_mode="int",207output_sequence_length=maxlen + 1,208)209vectorize_layer.adapt(text_ds)210vocab = vectorize_layer.get_vocabulary() # To get words back from token indices211212213def prepare_lm_inputs_labels(text):214"""215Shift word sequences by 1 position so that the target for position (i) is216word at position (i+1). The model will use all words up till position (i)217to predict the next word.218"""219text = tensorflow.expand_dims(text, -1)220tokenized_sentences = vectorize_layer(text)221x = tokenized_sentences[:, :-1]222y = tokenized_sentences[:, 1:]223return x, y224225226text_ds = text_ds.map(prepare_lm_inputs_labels, num_parallel_calls=tf_data.AUTOTUNE)227text_ds = text_ds.prefetch(tf_data.AUTOTUNE)228229230"""231## Implement a Keras callback for generating text232"""233234235class TextGenerator(keras.callbacks.Callback):236"""A callback to generate text from a trained model.2371. Feed some starting prompt to the model2382. Predict probabilities for the next token2393. Sample the next token and add it to the next input240241Arguments:242max_tokens: Integer, the number of tokens to be generated after prompt.243start_tokens: List of integers, the token indices for the starting prompt.244index_to_word: List of strings, obtained from the TextVectorization layer.245top_k: Integer, sample from the `top_k` token predictions.246print_every: Integer, print after this many epochs.247"""248249def __init__(250self, max_tokens, start_tokens, index_to_word, top_k=10, print_every=1251):252self.max_tokens = max_tokens253self.start_tokens = start_tokens254self.index_to_word = index_to_word255self.print_every = print_every256self.k = top_k257258def sample_from(self, logits):259logits, indices = ops.top_k(logits, k=self.k, sorted=True)260indices = np.asarray(indices).astype("int32")261preds = keras.activations.softmax(ops.expand_dims(logits, 0))[0]262preds = np.asarray(preds).astype("float32")263return np.random.choice(indices, p=preds)264265def detokenize(self, number):266return self.index_to_word[number]267268def on_epoch_end(self, epoch, logs=None):269start_tokens = [_ for _ in self.start_tokens]270if (epoch + 1) % self.print_every != 0:271return272num_tokens_generated = 0273tokens_generated = []274while num_tokens_generated <= self.max_tokens:275pad_len = maxlen - len(start_tokens)276sample_index = len(start_tokens) - 1277if pad_len < 0:278x = start_tokens[:maxlen]279sample_index = maxlen - 1280elif pad_len > 0:281x = start_tokens + [0] * pad_len282else:283x = start_tokens284x = np.array([x])285y, _ = self.model.predict(x, verbose=0)286sample_token = self.sample_from(y[0][sample_index])287tokens_generated.append(sample_token)288start_tokens.append(sample_token)289num_tokens_generated = len(tokens_generated)290txt = " ".join(291[self.detokenize(_) for _ in self.start_tokens + tokens_generated]292)293print(f"generated text:\n{txt}\n")294295296# Tokenize starting prompt297word_to_index = {}298for index, word in enumerate(vocab):299word_to_index[word] = index300301start_prompt = "this movie is"302start_tokens = [word_to_index.get(_, 1) for _ in start_prompt.split()]303num_tokens_generated = 40304text_gen_callback = TextGenerator(num_tokens_generated, start_tokens, vocab)305306307"""308## Train the model309310Note: This code should preferably be run on GPU.311"""312313model = create_model()314315model.fit(text_ds, verbose=2, epochs=25, callbacks=[text_gen_callback])316317318