Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/nlp/fnet_classification_with_keras_hub.py
3507 views
1
"""
2
Title: Text Classification using FNet
3
Author: [Abheesht Sharma](https://github.com/abheesht17/)
4
Date created: 2022/06/01
5
Last modified: 2022/12/21
6
Description: Text Classification on the IMDb Dataset using `keras_hub.layers.FNetEncoder` layer.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
In this example, we will demonstrate the ability of FNet to achieve comparable
14
results with a vanilla Transformer model on the text classification task.
15
We will be using the IMDb dataset, which is a
16
collection of movie reviews labelled either positive or negative (sentiment
17
analysis).
18
19
To build the tokenizer, model, etc., we will use components from
20
[KerasHub](https://github.com/keras-team/keras-hub). KerasHub makes life easier
21
for people who want to build NLP pipelines! :)
22
23
### Model
24
25
Transformer-based language models (LMs) such as BERT, RoBERTa, XLNet, etc. have
26
demonstrated the effectiveness of the self-attention mechanism for computing
27
rich embeddings for input text. However, the self-attention mechanism is an
28
expensive operation, with a time complexity of `O(n^2)`, where `n` is the number
29
of tokens in the input. Hence, there has been an effort to reduce the time
30
complexity of the self-attention mechanism and improve performance without
31
sacrificing the quality of results.
32
33
In 2020, a paper titled
34
[FNet: Mixing Tokens with Fourier Transforms](https://arxiv.org/abs/2105.03824)
35
replaced the self-attention layer in BERT with a simple Fourier Transform layer
36
for "token mixing". This resulted in comparable accuracy and a speed-up during
37
training. In particular, a couple of points from the paper stand out:
38
39
* The authors claim that FNet is 80% faster than BERT on GPUs and 70% faster on
40
TPUs. The reason for this speed-up is two-fold: a) the Fourier Transform layer
41
is unparametrized, it does not have any parameters, and b) the authors use Fast
42
Fourier Transform (FFT); this reduces the time complexity from `O(n^2)`
43
(in the case of self-attention) to `O(n log n)`.
44
* FNet manages to achieve 92-97% of the accuracy of BERT on the GLUE benchmark.
45
"""
46
47
"""
48
## Setup
49
50
Before we start with the implementation, let's import all the necessary packages.
51
"""
52
53
"""shell
54
pip install -q --upgrade keras-hub
55
pip install -q --upgrade keras # Upgrade to Keras 3.
56
"""
57
58
import keras_hub
59
import keras
60
import tensorflow as tf
61
import os
62
63
keras.utils.set_random_seed(42)
64
65
"""
66
Let's also define our hyperparameters.
67
"""
68
BATCH_SIZE = 64
69
EPOCHS = 3
70
MAX_SEQUENCE_LENGTH = 512
71
VOCAB_SIZE = 15000
72
73
EMBED_DIM = 128
74
INTERMEDIATE_DIM = 512
75
76
"""
77
## Loading the dataset
78
79
First, let's download the IMDB dataset and extract it.
80
"""
81
82
"""shell
83
wget http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz
84
tar -xzf aclImdb_v1.tar.gz
85
"""
86
87
"""
88
Samples are present in the form of text files. Let's inspect the structure of
89
the directory.
90
"""
91
92
print(os.listdir("./aclImdb"))
93
print(os.listdir("./aclImdb/train"))
94
print(os.listdir("./aclImdb/test"))
95
96
"""
97
The directory contains two sub-directories: `train` and `test`. Each subdirectory
98
in turn contains two folders: `pos` and `neg` for positive and negative reviews,
99
respectively. Before we load the dataset, let's delete the `./aclImdb/train/unsup`
100
folder since it has unlabelled samples.
101
"""
102
103
"""shell
104
rm -rf aclImdb/train/unsup
105
"""
106
107
"""
108
We'll use the `keras.utils.text_dataset_from_directory` utility to generate
109
our labelled `tf.data.Dataset` dataset from text files.
110
"""
111
112
train_ds = keras.utils.text_dataset_from_directory(
113
"aclImdb/train",
114
batch_size=BATCH_SIZE,
115
validation_split=0.2,
116
subset="training",
117
seed=42,
118
)
119
val_ds = keras.utils.text_dataset_from_directory(
120
"aclImdb/train",
121
batch_size=BATCH_SIZE,
122
validation_split=0.2,
123
subset="validation",
124
seed=42,
125
)
126
test_ds = keras.utils.text_dataset_from_directory("aclImdb/test", batch_size=BATCH_SIZE)
127
128
"""
129
We will now convert the text to lowercase.
130
"""
131
train_ds = train_ds.map(lambda x, y: (tf.strings.lower(x), y))
132
val_ds = val_ds.map(lambda x, y: (tf.strings.lower(x), y))
133
test_ds = test_ds.map(lambda x, y: (tf.strings.lower(x), y))
134
135
"""
136
Let's print a few samples.
137
"""
138
for text_batch, label_batch in train_ds.take(1):
139
for i in range(3):
140
print(text_batch.numpy()[i])
141
print(label_batch.numpy()[i])
142
143
144
"""
145
### Tokenizing the data
146
147
We'll be using the `keras_hub.tokenizers.WordPieceTokenizer` layer to tokenize
148
the text. `keras_hub.tokenizers.WordPieceTokenizer` takes a WordPiece vocabulary
149
and has functions for tokenizing the text, and detokenizing sequences of tokens.
150
151
Before we define the tokenizer, we first need to train it on the dataset
152
we have. The WordPiece tokenization algorithm is a subword tokenization algorithm;
153
training it on a corpus gives us a vocabulary of subwords. A subword tokenizer
154
is a compromise between word tokenizers (word tokenizers need very large
155
vocabularies for good coverage of input words), and character tokenizers
156
(characters don't really encode meaning like words do). Luckily, KerasHub
157
makes it very simple to train WordPiece on a corpus with the
158
`keras_hub.tokenizers.compute_word_piece_vocabulary` utility.
159
160
Note: The official implementation of FNet uses the SentencePiece Tokenizer.
161
"""
162
163
164
def train_word_piece(ds, vocab_size, reserved_tokens):
165
word_piece_ds = ds.unbatch().map(lambda x, y: x)
166
vocab = keras_hub.tokenizers.compute_word_piece_vocabulary(
167
word_piece_ds.batch(1000).prefetch(2),
168
vocabulary_size=vocab_size,
169
reserved_tokens=reserved_tokens,
170
)
171
return vocab
172
173
174
"""
175
Every vocabulary has a few special, reserved tokens. We have two such tokens:
176
177
- `"[PAD]"` - Padding token. Padding tokens are appended to the input sequence length
178
when the input sequence length is shorter than the maximum sequence length.
179
- `"[UNK]"` - Unknown token.
180
"""
181
reserved_tokens = ["[PAD]", "[UNK]"]
182
train_sentences = [element[0] for element in train_ds]
183
vocab = train_word_piece(train_ds, VOCAB_SIZE, reserved_tokens)
184
185
"""
186
Let's see some tokens!
187
"""
188
print("Tokens: ", vocab[100:110])
189
190
"""
191
Now, let's define the tokenizer. We will configure the tokenizer with the
192
the vocabularies trained above. We will define a maximum sequence length so that
193
all sequences are padded to the same length, if the length of the sequence is
194
less than the specified sequence length. Otherwise, the sequence is truncated.
195
"""
196
tokenizer = keras_hub.tokenizers.WordPieceTokenizer(
197
vocabulary=vocab,
198
lowercase=False,
199
sequence_length=MAX_SEQUENCE_LENGTH,
200
)
201
202
"""
203
Let's try and tokenize a sample from our dataset! To verify whether the text has
204
been tokenized correctly, we can also detokenize the list of tokens back to the
205
original text.
206
"""
207
input_sentence_ex = train_ds.take(1).get_single_element()[0][0]
208
input_tokens_ex = tokenizer(input_sentence_ex)
209
210
print("Sentence: ", input_sentence_ex)
211
print("Tokens: ", input_tokens_ex)
212
print("Recovered text after detokenizing: ", tokenizer.detokenize(input_tokens_ex))
213
214
215
"""
216
## Formatting the dataset
217
218
Next, we'll format our datasets in the form that will be fed to the models. We
219
need to tokenize the text.
220
"""
221
222
223
def format_dataset(sentence, label):
224
sentence = tokenizer(sentence)
225
return ({"input_ids": sentence}, label)
226
227
228
def make_dataset(dataset):
229
dataset = dataset.map(format_dataset, num_parallel_calls=tf.data.AUTOTUNE)
230
return dataset.shuffle(512).prefetch(16).cache()
231
232
233
train_ds = make_dataset(train_ds)
234
val_ds = make_dataset(val_ds)
235
test_ds = make_dataset(test_ds)
236
237
"""
238
## Building the model
239
240
Now, let's move on to the exciting part - defining our model!
241
We first need an embedding layer, i.e., a layer that maps every token in the input
242
sequence to a vector. This embedding layer can be initialised randomly. We also
243
need a positional embedding layer which encodes the word order in the sequence.
244
The convention is to add, i.e., sum, these two embeddings. KerasHub has a
245
`keras_hub.layers.TokenAndPositionEmbedding ` layer which does all of the above
246
steps for us.
247
248
Our FNet classification model consists of three `keras_hub.layers.FNetEncoder`
249
layers with a `keras.layers.Dense` layer on top.
250
251
Note: For FNet, masking the padding tokens has a minimal effect on results. In the
252
official implementation, the padding tokens are not masked.
253
"""
254
255
input_ids = keras.Input(shape=(None,), dtype="int64", name="input_ids")
256
257
x = keras_hub.layers.TokenAndPositionEmbedding(
258
vocabulary_size=VOCAB_SIZE,
259
sequence_length=MAX_SEQUENCE_LENGTH,
260
embedding_dim=EMBED_DIM,
261
mask_zero=True,
262
)(input_ids)
263
264
x = keras_hub.layers.FNetEncoder(intermediate_dim=INTERMEDIATE_DIM)(inputs=x)
265
x = keras_hub.layers.FNetEncoder(intermediate_dim=INTERMEDIATE_DIM)(inputs=x)
266
x = keras_hub.layers.FNetEncoder(intermediate_dim=INTERMEDIATE_DIM)(inputs=x)
267
268
269
x = keras.layers.GlobalAveragePooling1D()(x)
270
x = keras.layers.Dropout(0.1)(x)
271
outputs = keras.layers.Dense(1, activation="sigmoid")(x)
272
273
fnet_classifier = keras.Model(input_ids, outputs, name="fnet_classifier")
274
275
"""
276
## Training our model
277
278
We'll use accuracy to monitor training progress on the validation data. Let's
279
train our model for 3 epochs.
280
"""
281
fnet_classifier.summary()
282
fnet_classifier.compile(
283
optimizer=keras.optimizers.Adam(learning_rate=0.001),
284
loss="binary_crossentropy",
285
metrics=["accuracy"],
286
)
287
fnet_classifier.fit(train_ds, epochs=EPOCHS, validation_data=val_ds)
288
289
"""
290
We obtain a train accuracy of around 92% and a validation accuracy of around
291
85%. Moreover, for 3 epochs, it takes around 86 seconds to train the model
292
(on Colab with a 16 GB Tesla T4 GPU).
293
294
Let's calculate the test accuracy.
295
"""
296
fnet_classifier.evaluate(test_ds, batch_size=BATCH_SIZE)
297
298
299
"""
300
## Comparison with Transformer model
301
302
Let's compare our FNet Classifier model with a Transformer Classifier model. We
303
keep all the parameters/hyperparameters the same. For example, we use three
304
`TransformerEncoder` layers.
305
306
We set the number of heads to 2.
307
"""
308
NUM_HEADS = 2
309
input_ids = keras.Input(shape=(None,), dtype="int64", name="input_ids")
310
311
312
x = keras_hub.layers.TokenAndPositionEmbedding(
313
vocabulary_size=VOCAB_SIZE,
314
sequence_length=MAX_SEQUENCE_LENGTH,
315
embedding_dim=EMBED_DIM,
316
mask_zero=True,
317
)(input_ids)
318
319
x = keras_hub.layers.TransformerEncoder(
320
intermediate_dim=INTERMEDIATE_DIM, num_heads=NUM_HEADS
321
)(inputs=x)
322
x = keras_hub.layers.TransformerEncoder(
323
intermediate_dim=INTERMEDIATE_DIM, num_heads=NUM_HEADS
324
)(inputs=x)
325
x = keras_hub.layers.TransformerEncoder(
326
intermediate_dim=INTERMEDIATE_DIM, num_heads=NUM_HEADS
327
)(inputs=x)
328
329
330
x = keras.layers.GlobalAveragePooling1D()(x)
331
x = keras.layers.Dropout(0.1)(x)
332
outputs = keras.layers.Dense(1, activation="sigmoid")(x)
333
334
transformer_classifier = keras.Model(input_ids, outputs, name="transformer_classifier")
335
336
337
transformer_classifier.summary()
338
transformer_classifier.compile(
339
optimizer=keras.optimizers.Adam(learning_rate=0.001),
340
loss="binary_crossentropy",
341
metrics=["accuracy"],
342
)
343
transformer_classifier.fit(train_ds, epochs=EPOCHS, validation_data=val_ds)
344
345
"""
346
We obtain a train accuracy of around 94% and a validation accuracy of around
347
86.5%. It takes around 146 seconds to train the model (on Colab with a 16 GB Tesla
348
T4 GPU).
349
350
Let's calculate the test accuracy.
351
"""
352
transformer_classifier.evaluate(test_ds, batch_size=BATCH_SIZE)
353
354
"""
355
Let's make a table and compare the two models. We can see that FNet
356
significantly speeds up our run time (1.7x), with only a small sacrifice in
357
overall accuracy (drop of 0.75%).
358
359
| | **FNet Classifier** | **Transformer Classifier** |
360
|:-----------------------:|:-------------------:|:--------------------------:|
361
| **Training Time** | 86 seconds | 146 seconds |
362
| **Train Accuracy** | 92.34% | 93.85% |
363
| **Validation Accuracy** | 85.21% | 86.42% |
364
| **Test Accuracy** | 83.94% | 84.69% |
365
| **#Params** | 2,321,921 | 2,520,065 |
366
"""
367
368