Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/generative/text_generation_fnet.py
3507 views
1
"""
2
Title: Text Generation using FNet
3
Author: [Darshan Deshpande](https://twitter.com/getdarshan)
4
Date created: 2021/10/05
5
Last modified: 2021/10/05
6
Description: FNet transformer for text generation in Keras.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
The original transformer implementation (Vaswani et al., 2017) was one of the major
14
breakthroughs in Natural Language Processing, giving rise to important architectures such BERT and GPT.
15
However, the drawback of these architectures is
16
that the self-attention mechanism they use is computationally expensive. The FNet
17
architecture proposes to replace this self-attention attention with a leaner mechanism:
18
a Fourier transformation-based linear mixer for input tokens.
19
20
The FNet model was able to achieve 92-97% of BERT's accuracy while training 80% faster on
21
GPUs and almost 70% faster on TPUs. This type of design provides an efficient and small
22
model size, leading to faster inference times.
23
24
In this example, we will implement and train this architecture on the Cornell Movie
25
Dialog corpus to show the applicability of this model to text generation.
26
"""
27
28
"""
29
## Imports
30
"""
31
32
import tensorflow as tf
33
from tensorflow import keras
34
from tensorflow.keras import layers
35
import os
36
37
# Defining hyperparameters
38
39
VOCAB_SIZE = 8192
40
MAX_SAMPLES = 50000
41
BUFFER_SIZE = 20000
42
MAX_LENGTH = 40
43
EMBED_DIM = 256
44
LATENT_DIM = 512
45
NUM_HEADS = 8
46
BATCH_SIZE = 64
47
48
"""
49
## Loading data
50
51
We will be using the Cornell Dialog Corpus. We will parse the movie conversations into
52
questions and answers sets.
53
"""
54
55
path_to_zip = keras.utils.get_file(
56
"cornell_movie_dialogs.zip",
57
origin="http://www.cs.cornell.edu/~cristian/data/cornell_movie_dialogs_corpus.zip",
58
extract=True,
59
)
60
61
path_to_dataset = os.path.join(
62
os.path.dirname(path_to_zip), "cornell movie-dialogs corpus"
63
)
64
path_to_movie_lines = os.path.join(path_to_dataset, "movie_lines.txt")
65
path_to_movie_conversations = os.path.join(path_to_dataset, "movie_conversations.txt")
66
67
68
def load_conversations():
69
# Helper function for loading the conversation splits
70
id2line = {}
71
with open(path_to_movie_lines, errors="ignore") as file:
72
lines = file.readlines()
73
for line in lines:
74
parts = line.replace("\n", "").split(" +++$+++ ")
75
id2line[parts[0]] = parts[4]
76
77
inputs, outputs = [], []
78
with open(path_to_movie_conversations, "r") as file:
79
lines = file.readlines()
80
for line in lines:
81
parts = line.replace("\n", "").split(" +++$+++ ")
82
# get conversation in a list of line ID
83
conversation = [line[1:-1] for line in parts[3][1:-1].split(", ")]
84
for i in range(len(conversation) - 1):
85
inputs.append(id2line[conversation[i]])
86
outputs.append(id2line[conversation[i + 1]])
87
if len(inputs) >= MAX_SAMPLES:
88
return inputs, outputs
89
return inputs, outputs
90
91
92
questions, answers = load_conversations()
93
94
# Splitting training and validation sets
95
96
train_dataset = tf.data.Dataset.from_tensor_slices((questions[:40000], answers[:40000]))
97
val_dataset = tf.data.Dataset.from_tensor_slices((questions[40000:], answers[40000:]))
98
99
"""
100
### Preprocessing and Tokenization
101
"""
102
103
104
def preprocess_text(sentence):
105
sentence = tf.strings.lower(sentence)
106
# Adding a space between the punctuation and the last word to allow better tokenization
107
sentence = tf.strings.regex_replace(sentence, r"([?.!,])", r" \1 ")
108
# Replacing multiple continuous spaces with a single space
109
sentence = tf.strings.regex_replace(sentence, r"\s\s+", " ")
110
# Replacing non english words with spaces
111
sentence = tf.strings.regex_replace(sentence, r"[^a-z?.!,]+", " ")
112
sentence = tf.strings.strip(sentence)
113
sentence = tf.strings.join(["[start]", sentence, "[end]"], separator=" ")
114
return sentence
115
116
117
vectorizer = layers.TextVectorization(
118
VOCAB_SIZE,
119
standardize=preprocess_text,
120
output_mode="int",
121
output_sequence_length=MAX_LENGTH,
122
)
123
124
# We will adapt the vectorizer to both the questions and answers
125
# This dataset is batched to parallelize and speed up the process
126
vectorizer.adapt(tf.data.Dataset.from_tensor_slices((questions + answers)).batch(128))
127
128
"""
129
### Tokenizing and padding sentences using `TextVectorization`
130
"""
131
132
133
def vectorize_text(inputs, outputs):
134
inputs, outputs = vectorizer(inputs), vectorizer(outputs)
135
# One extra padding token to the right to match the output shape
136
outputs = tf.pad(outputs, [[0, 1]])
137
return (
138
{"encoder_inputs": inputs, "decoder_inputs": outputs[:-1]},
139
{"outputs": outputs[1:]},
140
)
141
142
143
train_dataset = train_dataset.map(vectorize_text, num_parallel_calls=tf.data.AUTOTUNE)
144
val_dataset = val_dataset.map(vectorize_text, num_parallel_calls=tf.data.AUTOTUNE)
145
146
train_dataset = (
147
train_dataset.cache()
148
.shuffle(BUFFER_SIZE)
149
.batch(BATCH_SIZE)
150
.prefetch(tf.data.AUTOTUNE)
151
)
152
val_dataset = val_dataset.cache().batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
153
154
"""
155
## Creating the FNet Encoder
156
157
The FNet paper proposes a replacement for the standard attention mechanism used by the
158
Transformer architecture (Vaswani et al., 2017).
159
160
![Architecture](https://i.imgur.com/rLg47qU.png)
161
162
The outputs of the FFT layer are complex numbers. To avoid dealing with complex layers,
163
only the real part (the magnitude) is extracted.
164
165
The dense layers that follow the Fourier transformation act as convolutions applied on
166
the frequency domain.
167
"""
168
169
170
class FNetEncoder(layers.Layer):
171
def __init__(self, embed_dim, dense_dim, **kwargs):
172
super().__init__(**kwargs)
173
self.embed_dim = embed_dim
174
self.dense_dim = dense_dim
175
self.dense_proj = keras.Sequential(
176
[
177
layers.Dense(dense_dim, activation="relu"),
178
layers.Dense(embed_dim),
179
]
180
)
181
self.layernorm_1 = layers.LayerNormalization()
182
self.layernorm_2 = layers.LayerNormalization()
183
184
def call(self, inputs):
185
# Casting the inputs to complex64
186
inp_complex = tf.cast(inputs, tf.complex64)
187
# Projecting the inputs to the frequency domain using FFT2D and
188
# extracting the real part of the output
189
fft = tf.math.real(tf.signal.fft2d(inp_complex))
190
proj_input = self.layernorm_1(inputs + fft)
191
proj_output = self.dense_proj(proj_input)
192
return self.layernorm_2(proj_input + proj_output)
193
194
195
"""
196
## Creating the Decoder
197
198
The decoder architecture remains the same as the one proposed by (Vaswani et al., 2017)
199
in the original transformer architecture, consisting of an embedding, positional
200
encoding, two masked multi-head attention layers and finally the dense output layers.
201
The architecture that follows is taken from
202
[Deep Learning with Python, second edition, chapter 11](https://www.manning.com/books/deep-learning-with-python-second-edition).
203
204
"""
205
206
207
class PositionalEmbedding(layers.Layer):
208
def __init__(self, sequence_length, vocab_size, embed_dim, **kwargs):
209
super().__init__(**kwargs)
210
self.token_embeddings = layers.Embedding(
211
input_dim=vocab_size, output_dim=embed_dim
212
)
213
self.position_embeddings = layers.Embedding(
214
input_dim=sequence_length, output_dim=embed_dim
215
)
216
self.sequence_length = sequence_length
217
self.vocab_size = vocab_size
218
self.embed_dim = embed_dim
219
220
def call(self, inputs):
221
length = tf.shape(inputs)[-1]
222
positions = tf.range(start=0, limit=length, delta=1)
223
embedded_tokens = self.token_embeddings(inputs)
224
embedded_positions = self.position_embeddings(positions)
225
return embedded_tokens + embedded_positions
226
227
def compute_mask(self, inputs, mask=None):
228
return tf.math.not_equal(inputs, 0)
229
230
231
class FNetDecoder(layers.Layer):
232
def __init__(self, embed_dim, latent_dim, num_heads, **kwargs):
233
super().__init__(**kwargs)
234
self.embed_dim = embed_dim
235
self.latent_dim = latent_dim
236
self.num_heads = num_heads
237
self.attention_1 = layers.MultiHeadAttention(
238
num_heads=num_heads, key_dim=embed_dim
239
)
240
self.attention_2 = layers.MultiHeadAttention(
241
num_heads=num_heads, key_dim=embed_dim
242
)
243
self.dense_proj = keras.Sequential(
244
[
245
layers.Dense(latent_dim, activation="relu"),
246
layers.Dense(embed_dim),
247
]
248
)
249
self.layernorm_1 = layers.LayerNormalization()
250
self.layernorm_2 = layers.LayerNormalization()
251
self.layernorm_3 = layers.LayerNormalization()
252
self.supports_masking = True
253
254
def call(self, inputs, encoder_outputs, mask=None):
255
causal_mask = self.get_causal_attention_mask(inputs)
256
if mask is not None:
257
padding_mask = tf.cast(mask[:, tf.newaxis, :], dtype="int32")
258
padding_mask = tf.minimum(padding_mask, causal_mask)
259
260
attention_output_1 = self.attention_1(
261
query=inputs, value=inputs, key=inputs, attention_mask=causal_mask
262
)
263
out_1 = self.layernorm_1(inputs + attention_output_1)
264
265
attention_output_2 = self.attention_2(
266
query=out_1,
267
value=encoder_outputs,
268
key=encoder_outputs,
269
attention_mask=padding_mask,
270
)
271
out_2 = self.layernorm_2(out_1 + attention_output_2)
272
273
proj_output = self.dense_proj(out_2)
274
return self.layernorm_3(out_2 + proj_output)
275
276
def get_causal_attention_mask(self, inputs):
277
input_shape = tf.shape(inputs)
278
batch_size, sequence_length = input_shape[0], input_shape[1]
279
i = tf.range(sequence_length)[:, tf.newaxis]
280
j = tf.range(sequence_length)
281
mask = tf.cast(i >= j, dtype="int32")
282
mask = tf.reshape(mask, (1, input_shape[1], input_shape[1]))
283
mult = tf.concat(
284
[tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)],
285
axis=0,
286
)
287
return tf.tile(mask, mult)
288
289
290
def create_model():
291
encoder_inputs = keras.Input(shape=(None,), dtype="int32", name="encoder_inputs")
292
x = PositionalEmbedding(MAX_LENGTH, VOCAB_SIZE, EMBED_DIM)(encoder_inputs)
293
encoder_outputs = FNetEncoder(EMBED_DIM, LATENT_DIM)(x)
294
encoder = keras.Model(encoder_inputs, encoder_outputs)
295
decoder_inputs = keras.Input(shape=(None,), dtype="int32", name="decoder_inputs")
296
encoded_seq_inputs = keras.Input(
297
shape=(None, EMBED_DIM), name="decoder_state_inputs"
298
)
299
x = PositionalEmbedding(MAX_LENGTH, VOCAB_SIZE, EMBED_DIM)(decoder_inputs)
300
x = FNetDecoder(EMBED_DIM, LATENT_DIM, NUM_HEADS)(x, encoded_seq_inputs)
301
x = layers.Dropout(0.5)(x)
302
decoder_outputs = layers.Dense(VOCAB_SIZE, activation="softmax")(x)
303
decoder = keras.Model(
304
[decoder_inputs, encoded_seq_inputs], decoder_outputs, name="outputs"
305
)
306
decoder_outputs = decoder([decoder_inputs, encoder_outputs])
307
fnet = keras.Model([encoder_inputs, decoder_inputs], decoder_outputs, name="fnet")
308
return fnet
309
310
311
"""
312
## Creating and Training the model
313
"""
314
315
fnet = create_model()
316
fnet.compile("adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
317
318
"""
319
Here, the `epochs` parameter is set to a single epoch, but in practice the model will take around
320
**20-30 epochs** of training to start outputting comprehensible sentences. Although accuracy
321
is not a good measure for this task, we will use it just to get a hint of the improvement
322
of the network.
323
"""
324
325
fnet.fit(train_dataset, epochs=1, validation_data=val_dataset)
326
327
"""
328
## Performing inference
329
"""
330
331
VOCAB = vectorizer.get_vocabulary()
332
333
334
def decode_sentence(input_sentence):
335
# Mapping the input sentence to tokens and adding start and end tokens
336
tokenized_input_sentence = vectorizer(
337
tf.constant("[start] " + preprocess_text(input_sentence) + " [end]")
338
)
339
# Initializing the initial sentence consisting of only the start token.
340
tokenized_target_sentence = tf.expand_dims(VOCAB.index("[start]"), 0)
341
decoded_sentence = ""
342
343
for i in range(MAX_LENGTH):
344
# Get the predictions
345
predictions = fnet.predict(
346
{
347
"encoder_inputs": tf.expand_dims(tokenized_input_sentence, 0),
348
"decoder_inputs": tf.expand_dims(
349
tf.pad(
350
tokenized_target_sentence,
351
[[0, MAX_LENGTH - tf.shape(tokenized_target_sentence)[0]]],
352
),
353
0,
354
),
355
}
356
)
357
# Calculating the token with maximum probability and getting the corresponding word
358
sampled_token_index = tf.argmax(predictions[0, i, :])
359
sampled_token = VOCAB[sampled_token_index.numpy()]
360
# If sampled token is the end token then stop generating and return the sentence
361
if tf.equal(sampled_token_index, VOCAB.index("[end]")):
362
break
363
decoded_sentence += sampled_token + " "
364
tokenized_target_sentence = tf.concat(
365
[tokenized_target_sentence, [sampled_token_index]], 0
366
)
367
368
return decoded_sentence
369
370
371
decode_sentence("Where have you been all this time?")
372
373
"""
374
## Conclusion
375
376
This example shows how to train and perform inference using the FNet model.
377
For getting insight into the architecture or for further reading, you can refer to:
378
379
1. [FNet: Mixing Tokens with Fourier Transforms](https://arxiv.org/abs/2105.03824v3)
380
(Lee-Thorp et al., 2021)
381
2. [Attention Is All You Need](https://arxiv.org/abs/1706.03762v5) (Vaswani et al.,
382
2017)
383
384
Thanks to François Chollet for his Keras example on
385
[English-to-Spanish translation with a sequence-to-sequence Transformer](https://keras.io/examples/nlp/neural_machine_translation_with_transformer/)
386
from which the decoder implementation was extracted.
387
"""
388
389