Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/nlp/text_classification_with_switch_transformer.py
3507 views
1
"""
2
Title: Text classification with Switch Transformer
3
Author: [Khalid Salama](https://www.linkedin.com/in/khalid-salama-24403144/)
4
Date created: 2020/05/10
5
Last modified: 2021/02/15
6
Description: Implement a Switch Transformer for text classification.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
This example demonstrates the implementation of the
14
[Switch Transformer](https://arxiv.org/abs/2101.03961) model for text
15
classification.
16
17
The Switch Transformer replaces the feedforward network (FFN) layer in the standard
18
Transformer with a Mixture of Expert (MoE) routing layer, where each expert operates
19
independently on the tokens in the sequence. This allows increasing the model size without
20
increasing the computation needed to process each example.
21
22
Note that, for training the Switch Transformer efficiently, data and model parallelism
23
need to be applied, so that expert modules can run simultaneously, each on its own accelerator.
24
While the implementation described in the paper uses the
25
[TensorFlow Mesh](https://github.com/tensorflow/mesh) framework for distributed training,
26
this example presents a simple, non-distributed implementation of the Switch Transformer
27
model for demonstration purposes.
28
"""
29
30
"""
31
## Setup
32
"""
33
34
import keras
35
from keras import ops
36
from keras import layers
37
38
"""
39
## Download and prepare dataset
40
"""
41
42
vocab_size = 20000 # Only consider the top 20k words
43
num_tokens_per_example = 200 # Only consider the first 200 words of each movie review
44
(x_train, y_train), (x_val, y_val) = keras.datasets.imdb.load_data(num_words=vocab_size)
45
print(len(x_train), "Training sequences")
46
print(len(x_val), "Validation sequences")
47
x_train = keras.utils.pad_sequences(x_train, maxlen=num_tokens_per_example)
48
x_val = keras.utils.pad_sequences(x_val, maxlen=num_tokens_per_example)
49
50
"""
51
## Define hyperparameters
52
"""
53
54
embed_dim = 32 # Embedding size for each token.
55
num_heads = 2 # Number of attention heads
56
ff_dim = 32 # Hidden layer size in feedforward network.
57
num_experts = 10 # Number of experts used in the Switch Transformer.
58
batch_size = 50 # Batch size.
59
learning_rate = 0.001 # Learning rate.
60
dropout_rate = 0.25 # Dropout rate.
61
num_epochs = 3 # Number of epochs.
62
num_tokens_per_batch = (
63
batch_size * num_tokens_per_example
64
) # Total number of tokens per batch.
65
print(f"Number of tokens per batch: {num_tokens_per_batch}")
66
67
"""
68
## Implement token & position embedding layer
69
70
It consists of two separate embedding layers, one for tokens, one for token index (positions).
71
"""
72
73
74
class TokenAndPositionEmbedding(layers.Layer):
75
def __init__(self, maxlen, vocab_size, embed_dim):
76
super().__init__()
77
self.token_emb = layers.Embedding(input_dim=vocab_size, output_dim=embed_dim)
78
self.pos_emb = layers.Embedding(input_dim=maxlen, output_dim=embed_dim)
79
80
def call(self, x):
81
maxlen = ops.shape(x)[-1]
82
positions = ops.arange(start=0, stop=maxlen, step=1)
83
positions = self.pos_emb(positions)
84
x = self.token_emb(x)
85
return x + positions
86
87
88
"""
89
## Implement the feedforward network
90
91
This is used as the Mixture of Experts in the Switch Transformer.
92
"""
93
94
95
def create_feedforward_network(ff_dim, embed_dim, name=None):
96
return keras.Sequential(
97
[layers.Dense(ff_dim, activation="relu"), layers.Dense(embed_dim)], name=name
98
)
99
100
101
"""
102
## Implement the load-balanced loss
103
104
This is an auxiliary loss to encourage a balanced load across experts.
105
"""
106
107
108
def load_balanced_loss(router_probs, expert_mask):
109
# router_probs [tokens_per_batch, num_experts] is the probability assigned for
110
# each expert per token. expert_mask [tokens_per_batch, num_experts] contains
111
# the expert with the highest router probability in one−hot format.
112
113
num_experts = ops.shape(expert_mask)[-1]
114
# Get the fraction of tokens routed to each expert.
115
# density is a vector of length num experts that sums to 1.
116
density = ops.mean(expert_mask, axis=0)
117
# Get fraction of probability mass assigned to each expert from the router
118
# across all tokens. density_proxy is a vector of length num experts that sums to 1.
119
density_proxy = ops.mean(router_probs, axis=0)
120
# Want both vectors to have uniform allocation (1/num experts) across all
121
# num_expert elements. The two vectors will be pushed towards uniform allocation
122
# when the dot product is minimized.
123
loss = ops.mean(density_proxy * density) * ops.cast((num_experts**2), "float32")
124
return loss
125
126
127
"""
128
### Implement the router as a layer
129
"""
130
131
132
class Router(layers.Layer):
133
def __init__(self, num_experts, expert_capacity):
134
self.num_experts = num_experts
135
self.route = layers.Dense(units=num_experts)
136
self.expert_capacity = expert_capacity
137
super().__init__()
138
139
def call(self, inputs, training=False):
140
# inputs shape: [tokens_per_batch, embed_dim]
141
# router_logits shape: [tokens_per_batch, num_experts]
142
router_logits = self.route(inputs)
143
144
if training:
145
# Add noise for exploration across experts.
146
router_logits += keras.random.uniform(
147
shape=router_logits.shape, minval=0.9, maxval=1.1
148
)
149
# Probabilities for each token of what expert it should be sent to.
150
router_probs = keras.activations.softmax(router_logits, axis=-1)
151
# Get the top−1 expert for each token. expert_gate is the top−1 probability
152
# from the router for each token. expert_index is what expert each token
153
# is going to be routed to.
154
expert_gate, expert_index = ops.top_k(router_probs, k=1)
155
# expert_mask shape: [tokens_per_batch, num_experts]
156
expert_mask = ops.one_hot(expert_index, self.num_experts)
157
# Compute load balancing loss.
158
aux_loss = load_balanced_loss(router_probs, expert_mask)
159
self.add_loss(aux_loss)
160
# Experts have a fixed capacity, ensure we do not exceed it. Construct
161
# the batch indices, to each expert, with position in expert make sure that
162
# not more that expert capacity examples can be routed to each expert.
163
position_in_expert = ops.cast(
164
ops.cumsum(expert_mask, axis=0) * expert_mask, "int32"
165
)
166
# Keep only tokens that fit within expert capacity.
167
expert_mask *= ops.cast(
168
ops.less(ops.cast(position_in_expert, "int32"), self.expert_capacity),
169
"float32",
170
)
171
expert_mask_flat = ops.sum(expert_mask, axis=-1)
172
# Mask out the experts that have overflowed the expert capacity.
173
expert_gate *= expert_mask_flat
174
# Combine expert outputs and scaling with router probability.
175
# combine_tensor shape: [tokens_per_batch, num_experts, expert_capacity]
176
combined_tensor = ops.expand_dims(
177
expert_gate
178
* expert_mask_flat
179
* ops.squeeze(ops.one_hot(expert_index, self.num_experts), 1),
180
-1,
181
) * ops.squeeze(ops.one_hot(position_in_expert, self.expert_capacity), 1)
182
# Create binary dispatch_tensor [tokens_per_batch, num_experts, expert_capacity]
183
# that is 1 if the token gets routed to the corresponding expert.
184
dispatch_tensor = ops.cast(combined_tensor, "float32")
185
186
return dispatch_tensor, combined_tensor
187
188
189
"""
190
### Implement a Switch layer
191
"""
192
193
194
class Switch(layers.Layer):
195
def __init__(
196
self, num_experts, embed_dim, ff_dim, num_tokens_per_batch, capacity_factor=1
197
):
198
self.num_experts = num_experts
199
self.embed_dim = embed_dim
200
self.experts = [
201
create_feedforward_network(ff_dim, embed_dim) for _ in range(num_experts)
202
]
203
204
self.expert_capacity = num_tokens_per_batch // self.num_experts
205
self.router = Router(self.num_experts, self.expert_capacity)
206
super().__init__()
207
208
def call(self, inputs):
209
batch_size = ops.shape(inputs)[0]
210
num_tokens_per_example = ops.shape(inputs)[1]
211
212
# inputs shape: [num_tokens_per_batch, embed_dim]
213
inputs = ops.reshape(inputs, [num_tokens_per_batch, self.embed_dim])
214
# dispatch_tensor shape: [expert_capacity, num_experts, tokens_per_batch]
215
# combine_tensor shape: [tokens_per_batch, num_experts, expert_capacity]
216
dispatch_tensor, combine_tensor = self.router(inputs)
217
# expert_inputs shape: [num_experts, expert_capacity, embed_dim]
218
expert_inputs = ops.einsum("ab,acd->cdb", inputs, dispatch_tensor)
219
expert_inputs = ops.reshape(
220
expert_inputs, [self.num_experts, self.expert_capacity, self.embed_dim]
221
)
222
# Dispatch to experts
223
expert_input_list = ops.unstack(expert_inputs, axis=0)
224
expert_output_list = [
225
self.experts[idx](expert_input)
226
for idx, expert_input in enumerate(expert_input_list)
227
]
228
# expert_outputs shape: [expert_capacity, num_experts, embed_dim]
229
expert_outputs = ops.stack(expert_output_list, axis=1)
230
# expert_outputs_combined shape: [tokens_per_batch, embed_dim]
231
expert_outputs_combined = ops.einsum(
232
"abc,xba->xc", expert_outputs, combine_tensor
233
)
234
# output shape: [batch_size, num_tokens_per_example, embed_dim]
235
outputs = ops.reshape(
236
expert_outputs_combined,
237
[batch_size, num_tokens_per_example, self.embed_dim],
238
)
239
return outputs
240
241
242
"""
243
## Implement a Transformer block layer
244
"""
245
246
247
class TransformerBlock(layers.Layer):
248
def __init__(self, embed_dim, num_heads, ffn, dropout_rate=0.1):
249
super().__init__()
250
self.att = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
251
# The ffn can be either a standard feedforward network or a switch
252
# layer with a Mixture of Experts.
253
self.ffn = ffn
254
self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
255
self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
256
self.dropout1 = layers.Dropout(dropout_rate)
257
self.dropout2 = layers.Dropout(dropout_rate)
258
259
def call(self, inputs, training=False):
260
attn_output = self.att(inputs, inputs)
261
attn_output = self.dropout1(attn_output, training=training)
262
out1 = self.layernorm1(inputs + attn_output)
263
ffn_output = self.ffn(out1)
264
ffn_output = self.dropout2(ffn_output, training=training)
265
return self.layernorm2(out1 + ffn_output)
266
267
268
"""
269
## Implement the classifier
270
271
The `TransformerBlock` layer outputs one vector for each time step of our input sequence.
272
Here, we take the mean across all time steps and use a feedforward network on top
273
of it to classify text.
274
"""
275
276
277
def create_classifier():
278
switch = Switch(num_experts, embed_dim, ff_dim, num_tokens_per_batch)
279
transformer_block = TransformerBlock(embed_dim // num_heads, num_heads, switch)
280
281
inputs = layers.Input(shape=(num_tokens_per_example,))
282
embedding_layer = TokenAndPositionEmbedding(
283
num_tokens_per_example, vocab_size, embed_dim
284
)
285
x = embedding_layer(inputs)
286
x = transformer_block(x)
287
x = layers.GlobalAveragePooling1D()(x)
288
x = layers.Dropout(dropout_rate)(x)
289
x = layers.Dense(ff_dim, activation="relu")(x)
290
x = layers.Dropout(dropout_rate)(x)
291
outputs = layers.Dense(2, activation="softmax")(x)
292
293
classifier = keras.Model(inputs=inputs, outputs=outputs)
294
return classifier
295
296
297
"""
298
## Train and evaluate the model
299
"""
300
301
302
def run_experiment(classifier):
303
classifier.compile(
304
optimizer=keras.optimizers.Adam(learning_rate),
305
loss="sparse_categorical_crossentropy",
306
metrics=["accuracy"],
307
)
308
history = classifier.fit(
309
x_train,
310
y_train,
311
batch_size=batch_size,
312
epochs=num_epochs,
313
validation_data=(x_val, y_val),
314
)
315
return history
316
317
318
classifier = create_classifier()
319
run_experiment(classifier)
320
321
322
"""
323
## Conclusion
324
325
Compared to the standard Transformer architecture, the Switch Transformer can have a much
326
larger number of parameters, leading to increased model
327
capacity, while maintaining a reasonable computational cost.
328
"""
329
330