Path: blob/master/examples/nlp/ipynb/text_classification_with_switch_transformer.ipynb
3508 views
Text classification with Switch Transformer
Author: Khalid Salama
Date created: 2020/05/10
Last modified: 2021/02/15
Description: Implement a Switch Transformer for text classification.
Introduction
This example demonstrates the implementation of the Switch Transformer model for text classification.
The Switch Transformer replaces the feedforward network (FFN) layer in the standard Transformer with a Mixture of Expert (MoE) routing layer, where each expert operates independently on the tokens in the sequence. This allows increasing the model size without increasing the computation needed to process each example.
Note that, for training the Switch Transformer efficiently, data and model parallelism need to be applied, so that expert modules can run simultaneously, each on its own accelerator. While the implementation described in the paper uses the TensorFlow Mesh framework for distributed training, this example presents a simple, non-distributed implementation of the Switch Transformer model for demonstration purposes.
Setup
Download and prepare dataset
Define hyperparameters
Implement token & position embedding layer
It consists of two separate embedding layers, one for tokens, one for token index (positions).
Implement the feedforward network
This is used as the Mixture of Experts in the Switch Transformer.
Implement the load-balanced loss
This is an auxiliary loss to encourage a balanced load across experts.
Implement the router as a layer
Implement a Switch layer
Implement a Transformer block layer
Implement the classifier
The TransformerBlock
layer outputs one vector for each time step of our input sequence. Here, we take the mean across all time steps and use a feedforward network on top of it to classify text.
Train and evaluate the model
Conclusion
Compared to the standard Transformer architecture, the Switch Transformer can have a much larger number of parameters, leading to increased model capacity, while maintaining a reasonable computational cost.