Path: blob/master/examples/nlp/fnet_classification_with_keras_hub.py
3507 views
"""1Title: Text Classification using FNet2Author: [Abheesht Sharma](https://github.com/abheesht17/)3Date created: 2022/06/014Last modified: 2022/12/215Description: Text Classification on the IMDb Dataset using `keras_hub.layers.FNetEncoder` layer.6Accelerator: GPU7"""89"""10## Introduction1112In this example, we will demonstrate the ability of FNet to achieve comparable13results with a vanilla Transformer model on the text classification task.14We will be using the IMDb dataset, which is a15collection of movie reviews labelled either positive or negative (sentiment16analysis).1718To build the tokenizer, model, etc., we will use components from19[KerasHub](https://github.com/keras-team/keras-hub). KerasHub makes life easier20for people who want to build NLP pipelines! :)2122### Model2324Transformer-based language models (LMs) such as BERT, RoBERTa, XLNet, etc. have25demonstrated the effectiveness of the self-attention mechanism for computing26rich embeddings for input text. However, the self-attention mechanism is an27expensive operation, with a time complexity of `O(n^2)`, where `n` is the number28of tokens in the input. Hence, there has been an effort to reduce the time29complexity of the self-attention mechanism and improve performance without30sacrificing the quality of results.3132In 2020, a paper titled33[FNet: Mixing Tokens with Fourier Transforms](https://arxiv.org/abs/2105.03824)34replaced the self-attention layer in BERT with a simple Fourier Transform layer35for "token mixing". This resulted in comparable accuracy and a speed-up during36training. In particular, a couple of points from the paper stand out:3738* The authors claim that FNet is 80% faster than BERT on GPUs and 70% faster on39TPUs. The reason for this speed-up is two-fold: a) the Fourier Transform layer40is unparametrized, it does not have any parameters, and b) the authors use Fast41Fourier Transform (FFT); this reduces the time complexity from `O(n^2)`42(in the case of self-attention) to `O(n log n)`.43* FNet manages to achieve 92-97% of the accuracy of BERT on the GLUE benchmark.44"""4546"""47## Setup4849Before we start with the implementation, let's import all the necessary packages.50"""5152"""shell53pip install -q --upgrade keras-hub54pip install -q --upgrade keras # Upgrade to Keras 3.55"""5657import keras_hub58import keras59import tensorflow as tf60import os6162keras.utils.set_random_seed(42)6364"""65Let's also define our hyperparameters.66"""67BATCH_SIZE = 6468EPOCHS = 369MAX_SEQUENCE_LENGTH = 51270VOCAB_SIZE = 150007172EMBED_DIM = 12873INTERMEDIATE_DIM = 5127475"""76## Loading the dataset7778First, let's download the IMDB dataset and extract it.79"""8081"""shell82wget http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz83tar -xzf aclImdb_v1.tar.gz84"""8586"""87Samples are present in the form of text files. Let's inspect the structure of88the directory.89"""9091print(os.listdir("./aclImdb"))92print(os.listdir("./aclImdb/train"))93print(os.listdir("./aclImdb/test"))9495"""96The directory contains two sub-directories: `train` and `test`. Each subdirectory97in turn contains two folders: `pos` and `neg` for positive and negative reviews,98respectively. Before we load the dataset, let's delete the `./aclImdb/train/unsup`99folder since it has unlabelled samples.100"""101102"""shell103rm -rf aclImdb/train/unsup104"""105106"""107We'll use the `keras.utils.text_dataset_from_directory` utility to generate108our labelled `tf.data.Dataset` dataset from text files.109"""110111train_ds = keras.utils.text_dataset_from_directory(112"aclImdb/train",113batch_size=BATCH_SIZE,114validation_split=0.2,115subset="training",116seed=42,117)118val_ds = keras.utils.text_dataset_from_directory(119"aclImdb/train",120batch_size=BATCH_SIZE,121validation_split=0.2,122subset="validation",123seed=42,124)125test_ds = keras.utils.text_dataset_from_directory("aclImdb/test", batch_size=BATCH_SIZE)126127"""128We will now convert the text to lowercase.129"""130train_ds = train_ds.map(lambda x, y: (tf.strings.lower(x), y))131val_ds = val_ds.map(lambda x, y: (tf.strings.lower(x), y))132test_ds = test_ds.map(lambda x, y: (tf.strings.lower(x), y))133134"""135Let's print a few samples.136"""137for text_batch, label_batch in train_ds.take(1):138for i in range(3):139print(text_batch.numpy()[i])140print(label_batch.numpy()[i])141142143"""144### Tokenizing the data145146We'll be using the `keras_hub.tokenizers.WordPieceTokenizer` layer to tokenize147the text. `keras_hub.tokenizers.WordPieceTokenizer` takes a WordPiece vocabulary148and has functions for tokenizing the text, and detokenizing sequences of tokens.149150Before we define the tokenizer, we first need to train it on the dataset151we have. The WordPiece tokenization algorithm is a subword tokenization algorithm;152training it on a corpus gives us a vocabulary of subwords. A subword tokenizer153is a compromise between word tokenizers (word tokenizers need very large154vocabularies for good coverage of input words), and character tokenizers155(characters don't really encode meaning like words do). Luckily, KerasHub156makes it very simple to train WordPiece on a corpus with the157`keras_hub.tokenizers.compute_word_piece_vocabulary` utility.158159Note: The official implementation of FNet uses the SentencePiece Tokenizer.160"""161162163def train_word_piece(ds, vocab_size, reserved_tokens):164word_piece_ds = ds.unbatch().map(lambda x, y: x)165vocab = keras_hub.tokenizers.compute_word_piece_vocabulary(166word_piece_ds.batch(1000).prefetch(2),167vocabulary_size=vocab_size,168reserved_tokens=reserved_tokens,169)170return vocab171172173"""174Every vocabulary has a few special, reserved tokens. We have two such tokens:175176- `"[PAD]"` - Padding token. Padding tokens are appended to the input sequence length177when the input sequence length is shorter than the maximum sequence length.178- `"[UNK]"` - Unknown token.179"""180reserved_tokens = ["[PAD]", "[UNK]"]181train_sentences = [element[0] for element in train_ds]182vocab = train_word_piece(train_ds, VOCAB_SIZE, reserved_tokens)183184"""185Let's see some tokens!186"""187print("Tokens: ", vocab[100:110])188189"""190Now, let's define the tokenizer. We will configure the tokenizer with the191the vocabularies trained above. We will define a maximum sequence length so that192all sequences are padded to the same length, if the length of the sequence is193less than the specified sequence length. Otherwise, the sequence is truncated.194"""195tokenizer = keras_hub.tokenizers.WordPieceTokenizer(196vocabulary=vocab,197lowercase=False,198sequence_length=MAX_SEQUENCE_LENGTH,199)200201"""202Let's try and tokenize a sample from our dataset! To verify whether the text has203been tokenized correctly, we can also detokenize the list of tokens back to the204original text.205"""206input_sentence_ex = train_ds.take(1).get_single_element()[0][0]207input_tokens_ex = tokenizer(input_sentence_ex)208209print("Sentence: ", input_sentence_ex)210print("Tokens: ", input_tokens_ex)211print("Recovered text after detokenizing: ", tokenizer.detokenize(input_tokens_ex))212213214"""215## Formatting the dataset216217Next, we'll format our datasets in the form that will be fed to the models. We218need to tokenize the text.219"""220221222def format_dataset(sentence, label):223sentence = tokenizer(sentence)224return ({"input_ids": sentence}, label)225226227def make_dataset(dataset):228dataset = dataset.map(format_dataset, num_parallel_calls=tf.data.AUTOTUNE)229return dataset.shuffle(512).prefetch(16).cache()230231232train_ds = make_dataset(train_ds)233val_ds = make_dataset(val_ds)234test_ds = make_dataset(test_ds)235236"""237## Building the model238239Now, let's move on to the exciting part - defining our model!240We first need an embedding layer, i.e., a layer that maps every token in the input241sequence to a vector. This embedding layer can be initialised randomly. We also242need a positional embedding layer which encodes the word order in the sequence.243The convention is to add, i.e., sum, these two embeddings. KerasHub has a244`keras_hub.layers.TokenAndPositionEmbedding ` layer which does all of the above245steps for us.246247Our FNet classification model consists of three `keras_hub.layers.FNetEncoder`248layers with a `keras.layers.Dense` layer on top.249250Note: For FNet, masking the padding tokens has a minimal effect on results. In the251official implementation, the padding tokens are not masked.252"""253254input_ids = keras.Input(shape=(None,), dtype="int64", name="input_ids")255256x = keras_hub.layers.TokenAndPositionEmbedding(257vocabulary_size=VOCAB_SIZE,258sequence_length=MAX_SEQUENCE_LENGTH,259embedding_dim=EMBED_DIM,260mask_zero=True,261)(input_ids)262263x = keras_hub.layers.FNetEncoder(intermediate_dim=INTERMEDIATE_DIM)(inputs=x)264x = keras_hub.layers.FNetEncoder(intermediate_dim=INTERMEDIATE_DIM)(inputs=x)265x = keras_hub.layers.FNetEncoder(intermediate_dim=INTERMEDIATE_DIM)(inputs=x)266267268x = keras.layers.GlobalAveragePooling1D()(x)269x = keras.layers.Dropout(0.1)(x)270outputs = keras.layers.Dense(1, activation="sigmoid")(x)271272fnet_classifier = keras.Model(input_ids, outputs, name="fnet_classifier")273274"""275## Training our model276277We'll use accuracy to monitor training progress on the validation data. Let's278train our model for 3 epochs.279"""280fnet_classifier.summary()281fnet_classifier.compile(282optimizer=keras.optimizers.Adam(learning_rate=0.001),283loss="binary_crossentropy",284metrics=["accuracy"],285)286fnet_classifier.fit(train_ds, epochs=EPOCHS, validation_data=val_ds)287288"""289We obtain a train accuracy of around 92% and a validation accuracy of around29085%. Moreover, for 3 epochs, it takes around 86 seconds to train the model291(on Colab with a 16 GB Tesla T4 GPU).292293Let's calculate the test accuracy.294"""295fnet_classifier.evaluate(test_ds, batch_size=BATCH_SIZE)296297298"""299## Comparison with Transformer model300301Let's compare our FNet Classifier model with a Transformer Classifier model. We302keep all the parameters/hyperparameters the same. For example, we use three303`TransformerEncoder` layers.304305We set the number of heads to 2.306"""307NUM_HEADS = 2308input_ids = keras.Input(shape=(None,), dtype="int64", name="input_ids")309310311x = keras_hub.layers.TokenAndPositionEmbedding(312vocabulary_size=VOCAB_SIZE,313sequence_length=MAX_SEQUENCE_LENGTH,314embedding_dim=EMBED_DIM,315mask_zero=True,316)(input_ids)317318x = keras_hub.layers.TransformerEncoder(319intermediate_dim=INTERMEDIATE_DIM, num_heads=NUM_HEADS320)(inputs=x)321x = keras_hub.layers.TransformerEncoder(322intermediate_dim=INTERMEDIATE_DIM, num_heads=NUM_HEADS323)(inputs=x)324x = keras_hub.layers.TransformerEncoder(325intermediate_dim=INTERMEDIATE_DIM, num_heads=NUM_HEADS326)(inputs=x)327328329x = keras.layers.GlobalAveragePooling1D()(x)330x = keras.layers.Dropout(0.1)(x)331outputs = keras.layers.Dense(1, activation="sigmoid")(x)332333transformer_classifier = keras.Model(input_ids, outputs, name="transformer_classifier")334335336transformer_classifier.summary()337transformer_classifier.compile(338optimizer=keras.optimizers.Adam(learning_rate=0.001),339loss="binary_crossentropy",340metrics=["accuracy"],341)342transformer_classifier.fit(train_ds, epochs=EPOCHS, validation_data=val_ds)343344"""345We obtain a train accuracy of around 94% and a validation accuracy of around34686.5%. It takes around 146 seconds to train the model (on Colab with a 16 GB Tesla347T4 GPU).348349Let's calculate the test accuracy.350"""351transformer_classifier.evaluate(test_ds, batch_size=BATCH_SIZE)352353"""354Let's make a table and compare the two models. We can see that FNet355significantly speeds up our run time (1.7x), with only a small sacrifice in356overall accuracy (drop of 0.75%).357358| | **FNet Classifier** | **Transformer Classifier** |359|:-----------------------:|:-------------------:|:--------------------------:|360| **Training Time** | 86 seconds | 146 seconds |361| **Train Accuracy** | 92.34% | 93.85% |362| **Validation Accuracy** | 85.21% | 86.42% |363| **Test Accuracy** | 83.94% | 84.69% |364| **#Params** | 2,321,921 | 2,520,065 |365"""366367368