Path: blob/master/examples/nlp/text_classification_with_switch_transformer.py
3507 views
"""1Title: Text classification with Switch Transformer2Author: [Khalid Salama](https://www.linkedin.com/in/khalid-salama-24403144/)3Date created: 2020/05/104Last modified: 2021/02/155Description: Implement a Switch Transformer for text classification.6Accelerator: GPU7"""89"""10## Introduction1112This example demonstrates the implementation of the13[Switch Transformer](https://arxiv.org/abs/2101.03961) model for text14classification.1516The Switch Transformer replaces the feedforward network (FFN) layer in the standard17Transformer with a Mixture of Expert (MoE) routing layer, where each expert operates18independently on the tokens in the sequence. This allows increasing the model size without19increasing the computation needed to process each example.2021Note that, for training the Switch Transformer efficiently, data and model parallelism22need to be applied, so that expert modules can run simultaneously, each on its own accelerator.23While the implementation described in the paper uses the24[TensorFlow Mesh](https://github.com/tensorflow/mesh) framework for distributed training,25this example presents a simple, non-distributed implementation of the Switch Transformer26model for demonstration purposes.27"""2829"""30## Setup31"""3233import keras34from keras import ops35from keras import layers3637"""38## Download and prepare dataset39"""4041vocab_size = 20000 # Only consider the top 20k words42num_tokens_per_example = 200 # Only consider the first 200 words of each movie review43(x_train, y_train), (x_val, y_val) = keras.datasets.imdb.load_data(num_words=vocab_size)44print(len(x_train), "Training sequences")45print(len(x_val), "Validation sequences")46x_train = keras.utils.pad_sequences(x_train, maxlen=num_tokens_per_example)47x_val = keras.utils.pad_sequences(x_val, maxlen=num_tokens_per_example)4849"""50## Define hyperparameters51"""5253embed_dim = 32 # Embedding size for each token.54num_heads = 2 # Number of attention heads55ff_dim = 32 # Hidden layer size in feedforward network.56num_experts = 10 # Number of experts used in the Switch Transformer.57batch_size = 50 # Batch size.58learning_rate = 0.001 # Learning rate.59dropout_rate = 0.25 # Dropout rate.60num_epochs = 3 # Number of epochs.61num_tokens_per_batch = (62batch_size * num_tokens_per_example63) # Total number of tokens per batch.64print(f"Number of tokens per batch: {num_tokens_per_batch}")6566"""67## Implement token & position embedding layer6869It consists of two separate embedding layers, one for tokens, one for token index (positions).70"""717273class TokenAndPositionEmbedding(layers.Layer):74def __init__(self, maxlen, vocab_size, embed_dim):75super().__init__()76self.token_emb = layers.Embedding(input_dim=vocab_size, output_dim=embed_dim)77self.pos_emb = layers.Embedding(input_dim=maxlen, output_dim=embed_dim)7879def call(self, x):80maxlen = ops.shape(x)[-1]81positions = ops.arange(start=0, stop=maxlen, step=1)82positions = self.pos_emb(positions)83x = self.token_emb(x)84return x + positions858687"""88## Implement the feedforward network8990This is used as the Mixture of Experts in the Switch Transformer.91"""929394def create_feedforward_network(ff_dim, embed_dim, name=None):95return keras.Sequential(96[layers.Dense(ff_dim, activation="relu"), layers.Dense(embed_dim)], name=name97)9899100"""101## Implement the load-balanced loss102103This is an auxiliary loss to encourage a balanced load across experts.104"""105106107def load_balanced_loss(router_probs, expert_mask):108# router_probs [tokens_per_batch, num_experts] is the probability assigned for109# each expert per token. expert_mask [tokens_per_batch, num_experts] contains110# the expert with the highest router probability in one−hot format.111112num_experts = ops.shape(expert_mask)[-1]113# Get the fraction of tokens routed to each expert.114# density is a vector of length num experts that sums to 1.115density = ops.mean(expert_mask, axis=0)116# Get fraction of probability mass assigned to each expert from the router117# across all tokens. density_proxy is a vector of length num experts that sums to 1.118density_proxy = ops.mean(router_probs, axis=0)119# Want both vectors to have uniform allocation (1/num experts) across all120# num_expert elements. The two vectors will be pushed towards uniform allocation121# when the dot product is minimized.122loss = ops.mean(density_proxy * density) * ops.cast((num_experts**2), "float32")123return loss124125126"""127### Implement the router as a layer128"""129130131class Router(layers.Layer):132def __init__(self, num_experts, expert_capacity):133self.num_experts = num_experts134self.route = layers.Dense(units=num_experts)135self.expert_capacity = expert_capacity136super().__init__()137138def call(self, inputs, training=False):139# inputs shape: [tokens_per_batch, embed_dim]140# router_logits shape: [tokens_per_batch, num_experts]141router_logits = self.route(inputs)142143if training:144# Add noise for exploration across experts.145router_logits += keras.random.uniform(146shape=router_logits.shape, minval=0.9, maxval=1.1147)148# Probabilities for each token of what expert it should be sent to.149router_probs = keras.activations.softmax(router_logits, axis=-1)150# Get the top−1 expert for each token. expert_gate is the top−1 probability151# from the router for each token. expert_index is what expert each token152# is going to be routed to.153expert_gate, expert_index = ops.top_k(router_probs, k=1)154# expert_mask shape: [tokens_per_batch, num_experts]155expert_mask = ops.one_hot(expert_index, self.num_experts)156# Compute load balancing loss.157aux_loss = load_balanced_loss(router_probs, expert_mask)158self.add_loss(aux_loss)159# Experts have a fixed capacity, ensure we do not exceed it. Construct160# the batch indices, to each expert, with position in expert make sure that161# not more that expert capacity examples can be routed to each expert.162position_in_expert = ops.cast(163ops.cumsum(expert_mask, axis=0) * expert_mask, "int32"164)165# Keep only tokens that fit within expert capacity.166expert_mask *= ops.cast(167ops.less(ops.cast(position_in_expert, "int32"), self.expert_capacity),168"float32",169)170expert_mask_flat = ops.sum(expert_mask, axis=-1)171# Mask out the experts that have overflowed the expert capacity.172expert_gate *= expert_mask_flat173# Combine expert outputs and scaling with router probability.174# combine_tensor shape: [tokens_per_batch, num_experts, expert_capacity]175combined_tensor = ops.expand_dims(176expert_gate177* expert_mask_flat178* ops.squeeze(ops.one_hot(expert_index, self.num_experts), 1),179-1,180) * ops.squeeze(ops.one_hot(position_in_expert, self.expert_capacity), 1)181# Create binary dispatch_tensor [tokens_per_batch, num_experts, expert_capacity]182# that is 1 if the token gets routed to the corresponding expert.183dispatch_tensor = ops.cast(combined_tensor, "float32")184185return dispatch_tensor, combined_tensor186187188"""189### Implement a Switch layer190"""191192193class Switch(layers.Layer):194def __init__(195self, num_experts, embed_dim, ff_dim, num_tokens_per_batch, capacity_factor=1196):197self.num_experts = num_experts198self.embed_dim = embed_dim199self.experts = [200create_feedforward_network(ff_dim, embed_dim) for _ in range(num_experts)201]202203self.expert_capacity = num_tokens_per_batch // self.num_experts204self.router = Router(self.num_experts, self.expert_capacity)205super().__init__()206207def call(self, inputs):208batch_size = ops.shape(inputs)[0]209num_tokens_per_example = ops.shape(inputs)[1]210211# inputs shape: [num_tokens_per_batch, embed_dim]212inputs = ops.reshape(inputs, [num_tokens_per_batch, self.embed_dim])213# dispatch_tensor shape: [expert_capacity, num_experts, tokens_per_batch]214# combine_tensor shape: [tokens_per_batch, num_experts, expert_capacity]215dispatch_tensor, combine_tensor = self.router(inputs)216# expert_inputs shape: [num_experts, expert_capacity, embed_dim]217expert_inputs = ops.einsum("ab,acd->cdb", inputs, dispatch_tensor)218expert_inputs = ops.reshape(219expert_inputs, [self.num_experts, self.expert_capacity, self.embed_dim]220)221# Dispatch to experts222expert_input_list = ops.unstack(expert_inputs, axis=0)223expert_output_list = [224self.experts[idx](expert_input)225for idx, expert_input in enumerate(expert_input_list)226]227# expert_outputs shape: [expert_capacity, num_experts, embed_dim]228expert_outputs = ops.stack(expert_output_list, axis=1)229# expert_outputs_combined shape: [tokens_per_batch, embed_dim]230expert_outputs_combined = ops.einsum(231"abc,xba->xc", expert_outputs, combine_tensor232)233# output shape: [batch_size, num_tokens_per_example, embed_dim]234outputs = ops.reshape(235expert_outputs_combined,236[batch_size, num_tokens_per_example, self.embed_dim],237)238return outputs239240241"""242## Implement a Transformer block layer243"""244245246class TransformerBlock(layers.Layer):247def __init__(self, embed_dim, num_heads, ffn, dropout_rate=0.1):248super().__init__()249self.att = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)250# The ffn can be either a standard feedforward network or a switch251# layer with a Mixture of Experts.252self.ffn = ffn253self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)254self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)255self.dropout1 = layers.Dropout(dropout_rate)256self.dropout2 = layers.Dropout(dropout_rate)257258def call(self, inputs, training=False):259attn_output = self.att(inputs, inputs)260attn_output = self.dropout1(attn_output, training=training)261out1 = self.layernorm1(inputs + attn_output)262ffn_output = self.ffn(out1)263ffn_output = self.dropout2(ffn_output, training=training)264return self.layernorm2(out1 + ffn_output)265266267"""268## Implement the classifier269270The `TransformerBlock` layer outputs one vector for each time step of our input sequence.271Here, we take the mean across all time steps and use a feedforward network on top272of it to classify text.273"""274275276def create_classifier():277switch = Switch(num_experts, embed_dim, ff_dim, num_tokens_per_batch)278transformer_block = TransformerBlock(embed_dim // num_heads, num_heads, switch)279280inputs = layers.Input(shape=(num_tokens_per_example,))281embedding_layer = TokenAndPositionEmbedding(282num_tokens_per_example, vocab_size, embed_dim283)284x = embedding_layer(inputs)285x = transformer_block(x)286x = layers.GlobalAveragePooling1D()(x)287x = layers.Dropout(dropout_rate)(x)288x = layers.Dense(ff_dim, activation="relu")(x)289x = layers.Dropout(dropout_rate)(x)290outputs = layers.Dense(2, activation="softmax")(x)291292classifier = keras.Model(inputs=inputs, outputs=outputs)293return classifier294295296"""297## Train and evaluate the model298"""299300301def run_experiment(classifier):302classifier.compile(303optimizer=keras.optimizers.Adam(learning_rate),304loss="sparse_categorical_crossentropy",305metrics=["accuracy"],306)307history = classifier.fit(308x_train,309y_train,310batch_size=batch_size,311epochs=num_epochs,312validation_data=(x_val, y_val),313)314return history315316317classifier = create_classifier()318run_experiment(classifier)319320321"""322## Conclusion323324Compared to the standard Transformer architecture, the Switch Transformer can have a much325larger number of parameters, leading to increased model326capacity, while maintaining a reasonable computational cost.327"""328329330