Path: blob/master/examples/generative/text_generation_fnet.py
3507 views
"""1Title: Text Generation using FNet2Author: [Darshan Deshpande](https://twitter.com/getdarshan)3Date created: 2021/10/054Last modified: 2021/10/055Description: FNet transformer for text generation in Keras.6Accelerator: GPU7"""89"""10## Introduction1112The original transformer implementation (Vaswani et al., 2017) was one of the major13breakthroughs in Natural Language Processing, giving rise to important architectures such BERT and GPT.14However, the drawback of these architectures is15that the self-attention mechanism they use is computationally expensive. The FNet16architecture proposes to replace this self-attention attention with a leaner mechanism:17a Fourier transformation-based linear mixer for input tokens.1819The FNet model was able to achieve 92-97% of BERT's accuracy while training 80% faster on20GPUs and almost 70% faster on TPUs. This type of design provides an efficient and small21model size, leading to faster inference times.2223In this example, we will implement and train this architecture on the Cornell Movie24Dialog corpus to show the applicability of this model to text generation.25"""2627"""28## Imports29"""3031import tensorflow as tf32from tensorflow import keras33from tensorflow.keras import layers34import os3536# Defining hyperparameters3738VOCAB_SIZE = 819239MAX_SAMPLES = 5000040BUFFER_SIZE = 2000041MAX_LENGTH = 4042EMBED_DIM = 25643LATENT_DIM = 51244NUM_HEADS = 845BATCH_SIZE = 644647"""48## Loading data4950We will be using the Cornell Dialog Corpus. We will parse the movie conversations into51questions and answers sets.52"""5354path_to_zip = keras.utils.get_file(55"cornell_movie_dialogs.zip",56origin="http://www.cs.cornell.edu/~cristian/data/cornell_movie_dialogs_corpus.zip",57extract=True,58)5960path_to_dataset = os.path.join(61os.path.dirname(path_to_zip), "cornell movie-dialogs corpus"62)63path_to_movie_lines = os.path.join(path_to_dataset, "movie_lines.txt")64path_to_movie_conversations = os.path.join(path_to_dataset, "movie_conversations.txt")656667def load_conversations():68# Helper function for loading the conversation splits69id2line = {}70with open(path_to_movie_lines, errors="ignore") as file:71lines = file.readlines()72for line in lines:73parts = line.replace("\n", "").split(" +++$+++ ")74id2line[parts[0]] = parts[4]7576inputs, outputs = [], []77with open(path_to_movie_conversations, "r") as file:78lines = file.readlines()79for line in lines:80parts = line.replace("\n", "").split(" +++$+++ ")81# get conversation in a list of line ID82conversation = [line[1:-1] for line in parts[3][1:-1].split(", ")]83for i in range(len(conversation) - 1):84inputs.append(id2line[conversation[i]])85outputs.append(id2line[conversation[i + 1]])86if len(inputs) >= MAX_SAMPLES:87return inputs, outputs88return inputs, outputs899091questions, answers = load_conversations()9293# Splitting training and validation sets9495train_dataset = tf.data.Dataset.from_tensor_slices((questions[:40000], answers[:40000]))96val_dataset = tf.data.Dataset.from_tensor_slices((questions[40000:], answers[40000:]))9798"""99### Preprocessing and Tokenization100"""101102103def preprocess_text(sentence):104sentence = tf.strings.lower(sentence)105# Adding a space between the punctuation and the last word to allow better tokenization106sentence = tf.strings.regex_replace(sentence, r"([?.!,])", r" \1 ")107# Replacing multiple continuous spaces with a single space108sentence = tf.strings.regex_replace(sentence, r"\s\s+", " ")109# Replacing non english words with spaces110sentence = tf.strings.regex_replace(sentence, r"[^a-z?.!,]+", " ")111sentence = tf.strings.strip(sentence)112sentence = tf.strings.join(["[start]", sentence, "[end]"], separator=" ")113return sentence114115116vectorizer = layers.TextVectorization(117VOCAB_SIZE,118standardize=preprocess_text,119output_mode="int",120output_sequence_length=MAX_LENGTH,121)122123# We will adapt the vectorizer to both the questions and answers124# This dataset is batched to parallelize and speed up the process125vectorizer.adapt(tf.data.Dataset.from_tensor_slices((questions + answers)).batch(128))126127"""128### Tokenizing and padding sentences using `TextVectorization`129"""130131132def vectorize_text(inputs, outputs):133inputs, outputs = vectorizer(inputs), vectorizer(outputs)134# One extra padding token to the right to match the output shape135outputs = tf.pad(outputs, [[0, 1]])136return (137{"encoder_inputs": inputs, "decoder_inputs": outputs[:-1]},138{"outputs": outputs[1:]},139)140141142train_dataset = train_dataset.map(vectorize_text, num_parallel_calls=tf.data.AUTOTUNE)143val_dataset = val_dataset.map(vectorize_text, num_parallel_calls=tf.data.AUTOTUNE)144145train_dataset = (146train_dataset.cache()147.shuffle(BUFFER_SIZE)148.batch(BATCH_SIZE)149.prefetch(tf.data.AUTOTUNE)150)151val_dataset = val_dataset.cache().batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)152153"""154## Creating the FNet Encoder155156The FNet paper proposes a replacement for the standard attention mechanism used by the157Transformer architecture (Vaswani et al., 2017).158159160161The outputs of the FFT layer are complex numbers. To avoid dealing with complex layers,162only the real part (the magnitude) is extracted.163164The dense layers that follow the Fourier transformation act as convolutions applied on165the frequency domain.166"""167168169class FNetEncoder(layers.Layer):170def __init__(self, embed_dim, dense_dim, **kwargs):171super().__init__(**kwargs)172self.embed_dim = embed_dim173self.dense_dim = dense_dim174self.dense_proj = keras.Sequential(175[176layers.Dense(dense_dim, activation="relu"),177layers.Dense(embed_dim),178]179)180self.layernorm_1 = layers.LayerNormalization()181self.layernorm_2 = layers.LayerNormalization()182183def call(self, inputs):184# Casting the inputs to complex64185inp_complex = tf.cast(inputs, tf.complex64)186# Projecting the inputs to the frequency domain using FFT2D and187# extracting the real part of the output188fft = tf.math.real(tf.signal.fft2d(inp_complex))189proj_input = self.layernorm_1(inputs + fft)190proj_output = self.dense_proj(proj_input)191return self.layernorm_2(proj_input + proj_output)192193194"""195## Creating the Decoder196197The decoder architecture remains the same as the one proposed by (Vaswani et al., 2017)198in the original transformer architecture, consisting of an embedding, positional199encoding, two masked multi-head attention layers and finally the dense output layers.200The architecture that follows is taken from201[Deep Learning with Python, second edition, chapter 11](https://www.manning.com/books/deep-learning-with-python-second-edition).202203"""204205206class PositionalEmbedding(layers.Layer):207def __init__(self, sequence_length, vocab_size, embed_dim, **kwargs):208super().__init__(**kwargs)209self.token_embeddings = layers.Embedding(210input_dim=vocab_size, output_dim=embed_dim211)212self.position_embeddings = layers.Embedding(213input_dim=sequence_length, output_dim=embed_dim214)215self.sequence_length = sequence_length216self.vocab_size = vocab_size217self.embed_dim = embed_dim218219def call(self, inputs):220length = tf.shape(inputs)[-1]221positions = tf.range(start=0, limit=length, delta=1)222embedded_tokens = self.token_embeddings(inputs)223embedded_positions = self.position_embeddings(positions)224return embedded_tokens + embedded_positions225226def compute_mask(self, inputs, mask=None):227return tf.math.not_equal(inputs, 0)228229230class FNetDecoder(layers.Layer):231def __init__(self, embed_dim, latent_dim, num_heads, **kwargs):232super().__init__(**kwargs)233self.embed_dim = embed_dim234self.latent_dim = latent_dim235self.num_heads = num_heads236self.attention_1 = layers.MultiHeadAttention(237num_heads=num_heads, key_dim=embed_dim238)239self.attention_2 = layers.MultiHeadAttention(240num_heads=num_heads, key_dim=embed_dim241)242self.dense_proj = keras.Sequential(243[244layers.Dense(latent_dim, activation="relu"),245layers.Dense(embed_dim),246]247)248self.layernorm_1 = layers.LayerNormalization()249self.layernorm_2 = layers.LayerNormalization()250self.layernorm_3 = layers.LayerNormalization()251self.supports_masking = True252253def call(self, inputs, encoder_outputs, mask=None):254causal_mask = self.get_causal_attention_mask(inputs)255if mask is not None:256padding_mask = tf.cast(mask[:, tf.newaxis, :], dtype="int32")257padding_mask = tf.minimum(padding_mask, causal_mask)258259attention_output_1 = self.attention_1(260query=inputs, value=inputs, key=inputs, attention_mask=causal_mask261)262out_1 = self.layernorm_1(inputs + attention_output_1)263264attention_output_2 = self.attention_2(265query=out_1,266value=encoder_outputs,267key=encoder_outputs,268attention_mask=padding_mask,269)270out_2 = self.layernorm_2(out_1 + attention_output_2)271272proj_output = self.dense_proj(out_2)273return self.layernorm_3(out_2 + proj_output)274275def get_causal_attention_mask(self, inputs):276input_shape = tf.shape(inputs)277batch_size, sequence_length = input_shape[0], input_shape[1]278i = tf.range(sequence_length)[:, tf.newaxis]279j = tf.range(sequence_length)280mask = tf.cast(i >= j, dtype="int32")281mask = tf.reshape(mask, (1, input_shape[1], input_shape[1]))282mult = tf.concat(283[tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)],284axis=0,285)286return tf.tile(mask, mult)287288289def create_model():290encoder_inputs = keras.Input(shape=(None,), dtype="int32", name="encoder_inputs")291x = PositionalEmbedding(MAX_LENGTH, VOCAB_SIZE, EMBED_DIM)(encoder_inputs)292encoder_outputs = FNetEncoder(EMBED_DIM, LATENT_DIM)(x)293encoder = keras.Model(encoder_inputs, encoder_outputs)294decoder_inputs = keras.Input(shape=(None,), dtype="int32", name="decoder_inputs")295encoded_seq_inputs = keras.Input(296shape=(None, EMBED_DIM), name="decoder_state_inputs"297)298x = PositionalEmbedding(MAX_LENGTH, VOCAB_SIZE, EMBED_DIM)(decoder_inputs)299x = FNetDecoder(EMBED_DIM, LATENT_DIM, NUM_HEADS)(x, encoded_seq_inputs)300x = layers.Dropout(0.5)(x)301decoder_outputs = layers.Dense(VOCAB_SIZE, activation="softmax")(x)302decoder = keras.Model(303[decoder_inputs, encoded_seq_inputs], decoder_outputs, name="outputs"304)305decoder_outputs = decoder([decoder_inputs, encoder_outputs])306fnet = keras.Model([encoder_inputs, decoder_inputs], decoder_outputs, name="fnet")307return fnet308309310"""311## Creating and Training the model312"""313314fnet = create_model()315fnet.compile("adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])316317"""318Here, the `epochs` parameter is set to a single epoch, but in practice the model will take around319**20-30 epochs** of training to start outputting comprehensible sentences. Although accuracy320is not a good measure for this task, we will use it just to get a hint of the improvement321of the network.322"""323324fnet.fit(train_dataset, epochs=1, validation_data=val_dataset)325326"""327## Performing inference328"""329330VOCAB = vectorizer.get_vocabulary()331332333def decode_sentence(input_sentence):334# Mapping the input sentence to tokens and adding start and end tokens335tokenized_input_sentence = vectorizer(336tf.constant("[start] " + preprocess_text(input_sentence) + " [end]")337)338# Initializing the initial sentence consisting of only the start token.339tokenized_target_sentence = tf.expand_dims(VOCAB.index("[start]"), 0)340decoded_sentence = ""341342for i in range(MAX_LENGTH):343# Get the predictions344predictions = fnet.predict(345{346"encoder_inputs": tf.expand_dims(tokenized_input_sentence, 0),347"decoder_inputs": tf.expand_dims(348tf.pad(349tokenized_target_sentence,350[[0, MAX_LENGTH - tf.shape(tokenized_target_sentence)[0]]],351),3520,353),354}355)356# Calculating the token with maximum probability and getting the corresponding word357sampled_token_index = tf.argmax(predictions[0, i, :])358sampled_token = VOCAB[sampled_token_index.numpy()]359# If sampled token is the end token then stop generating and return the sentence360if tf.equal(sampled_token_index, VOCAB.index("[end]")):361break362decoded_sentence += sampled_token + " "363tokenized_target_sentence = tf.concat(364[tokenized_target_sentence, [sampled_token_index]], 0365)366367return decoded_sentence368369370decode_sentence("Where have you been all this time?")371372"""373## Conclusion374375This example shows how to train and perform inference using the FNet model.376For getting insight into the architecture or for further reading, you can refer to:3773781. [FNet: Mixing Tokens with Fourier Transforms](https://arxiv.org/abs/2105.03824v3)379(Lee-Thorp et al., 2021)3802. [Attention Is All You Need](https://arxiv.org/abs/1706.03762v5) (Vaswani et al.,3812017)382383Thanks to François Chollet for his Keras example on384[English-to-Spanish translation with a sequence-to-sequence Transformer](https://keras.io/examples/nlp/neural_machine_translation_with_transformer/)385from which the decoder implementation was extracted.386"""387388389