Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/nlp/neural_machine_translation_with_transformer.py
3507 views
1
"""
2
Title: English-to-Spanish translation with a sequence-to-sequence Transformer
3
Author: [fchollet](https://twitter.com/fchollet)
4
Date created: 2021/05/26
5
Last modified: 2024/11/18
6
Description: Implementing a sequence-to-sequence Transformer and training it on a machine translation task.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
In this example, we'll build a sequence-to-sequence Transformer model, which
14
we'll train on an English-to-Spanish machine translation task.
15
16
You'll learn how to:
17
18
- Vectorize text using the Keras `TextVectorization` layer.
19
- Implement a `TransformerEncoder` layer, a `TransformerDecoder` layer,
20
and a `PositionalEmbedding` layer.
21
- Prepare data for training a sequence-to-sequence model.
22
- Use the trained model to generate translations of never-seen-before
23
input sentences (sequence-to-sequence inference).
24
25
The code featured here is adapted from the book
26
[Deep Learning with Python, Second Edition](https://www.manning.com/books/deep-learning-with-python-second-edition)
27
(chapter 11: Deep learning for text).
28
The present example is fairly barebones, so for detailed explanations of
29
how each building block works, as well as the theory behind Transformers,
30
I recommend reading the book.
31
"""
32
"""
33
## Setup
34
"""
35
36
# We set the backend to TensorFlow. The code works with
37
# both `tensorflow` and `torch`. It does not work with JAX
38
# due to the behavior of `jax.numpy.tile` in a jit scope
39
# (used in `TransformerDecoder.get_causal_attention_mask()`:
40
# `tile` in JAX does not support a dynamic `reps` argument.
41
# You can make the code work in JAX by wrapping the
42
# inside of the `get_causal_attention_mask` method in
43
# a decorator to prevent jit compilation:
44
# `with jax.ensure_compile_time_eval():`.
45
import os
46
47
os.environ["KERAS_BACKEND"] = "tensorflow"
48
49
import pathlib
50
import random
51
import string
52
import re
53
import numpy as np
54
55
import tensorflow.data as tf_data
56
import tensorflow.strings as tf_strings
57
58
import keras
59
from keras import layers
60
from keras import ops
61
from keras.layers import TextVectorization
62
63
"""
64
## Downloading the data
65
66
We'll be working with an English-to-Spanish translation dataset
67
provided by [Anki](https://www.manythings.org/anki/). Let's download it:
68
"""
69
70
text_file = keras.utils.get_file(
71
fname="spa-eng.zip",
72
origin="http://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip",
73
extract=True,
74
)
75
text_file = pathlib.Path(text_file).parent / "spa-eng" / "spa.txt"
76
77
"""
78
## Parsing the data
79
80
Each line contains an English sentence and its corresponding Spanish sentence.
81
The English sentence is the *source sequence* and Spanish one is the *target sequence*.
82
We prepend the token `"[start]"` and we append the token `"[end]"` to the Spanish sentence.
83
"""
84
85
with open(text_file) as f:
86
lines = f.read().split("\n")[:-1]
87
text_pairs = []
88
for line in lines:
89
eng, spa = line.split("\t")
90
spa = "[start] " + spa + " [end]"
91
text_pairs.append((eng, spa))
92
93
"""
94
Here's what our sentence pairs look like:
95
"""
96
97
for _ in range(5):
98
print(random.choice(text_pairs))
99
100
"""
101
Now, let's split the sentence pairs into a training set, a validation set,
102
and a test set.
103
"""
104
105
random.shuffle(text_pairs)
106
num_val_samples = int(0.15 * len(text_pairs))
107
num_train_samples = len(text_pairs) - 2 * num_val_samples
108
train_pairs = text_pairs[:num_train_samples]
109
val_pairs = text_pairs[num_train_samples : num_train_samples + num_val_samples]
110
test_pairs = text_pairs[num_train_samples + num_val_samples :]
111
112
print(f"{len(text_pairs)} total pairs")
113
print(f"{len(train_pairs)} training pairs")
114
print(f"{len(val_pairs)} validation pairs")
115
print(f"{len(test_pairs)} test pairs")
116
117
"""
118
## Vectorizing the text data
119
120
We'll use two instances of the `TextVectorization` layer to vectorize the text
121
data (one for English and one for Spanish),
122
that is to say, to turn the original strings into integer sequences
123
where each integer represents the index of a word in a vocabulary.
124
125
The English layer will use the default string standardization (strip punctuation characters)
126
and splitting scheme (split on whitespace), while
127
the Spanish layer will use a custom standardization, where we add the character
128
`"¿"` to the set of punctuation characters to be stripped.
129
130
Note: in a production-grade machine translation model, I would not recommend
131
stripping the punctuation characters in either language. Instead, I would recommend turning
132
each punctuation character into its own token,
133
which you could achieve by providing a custom `split` function to the `TextVectorization` layer.
134
"""
135
136
strip_chars = string.punctuation + "¿"
137
strip_chars = strip_chars.replace("[", "")
138
strip_chars = strip_chars.replace("]", "")
139
140
vocab_size = 15000
141
sequence_length = 20
142
batch_size = 64
143
144
145
def custom_standardization(input_string):
146
lowercase = tf_strings.lower(input_string)
147
return tf_strings.regex_replace(lowercase, "[%s]" % re.escape(strip_chars), "")
148
149
150
eng_vectorization = TextVectorization(
151
max_tokens=vocab_size,
152
output_mode="int",
153
output_sequence_length=sequence_length,
154
)
155
spa_vectorization = TextVectorization(
156
max_tokens=vocab_size,
157
output_mode="int",
158
output_sequence_length=sequence_length + 1,
159
standardize=custom_standardization,
160
)
161
train_eng_texts = [pair[0] for pair in train_pairs]
162
train_spa_texts = [pair[1] for pair in train_pairs]
163
eng_vectorization.adapt(train_eng_texts)
164
spa_vectorization.adapt(train_spa_texts)
165
166
"""
167
Next, we'll format our datasets.
168
169
At each training step, the model will seek to predict target words N+1 (and beyond)
170
using the source sentence and the target words 0 to N.
171
172
As such, the training dataset will yield a tuple `(inputs, targets)`, where:
173
174
- `inputs` is a dictionary with the keys `encoder_inputs` and `decoder_inputs`.
175
`encoder_inputs` is the vectorized source sentence and `decoder_inputs` is the target sentence "so far",
176
that is to say, the words 0 to N used to predict word N+1 (and beyond) in the target sentence.
177
- `target` is the target sentence offset by one step:
178
it provides the next words in the target sentence -- what the model will try to predict.
179
"""
180
181
182
def format_dataset(eng, spa):
183
eng = eng_vectorization(eng)
184
spa = spa_vectorization(spa)
185
return (
186
{
187
"encoder_inputs": eng,
188
"decoder_inputs": spa[:, :-1],
189
},
190
spa[:, 1:],
191
)
192
193
194
def make_dataset(pairs):
195
eng_texts, spa_texts = zip(*pairs)
196
eng_texts = list(eng_texts)
197
spa_texts = list(spa_texts)
198
dataset = tf_data.Dataset.from_tensor_slices((eng_texts, spa_texts))
199
dataset = dataset.batch(batch_size)
200
dataset = dataset.map(format_dataset)
201
return dataset.cache().shuffle(2048).prefetch(16)
202
203
204
train_ds = make_dataset(train_pairs)
205
val_ds = make_dataset(val_pairs)
206
207
"""
208
Let's take a quick look at the sequence shapes
209
(we have batches of 64 pairs, and all sequences are 20 steps long):
210
"""
211
212
for inputs, targets in train_ds.take(1):
213
print(f'inputs["encoder_inputs"].shape: {inputs["encoder_inputs"].shape}')
214
print(f'inputs["decoder_inputs"].shape: {inputs["decoder_inputs"].shape}')
215
print(f"targets.shape: {targets.shape}")
216
217
"""
218
## Building the model
219
220
Our sequence-to-sequence Transformer consists of a `TransformerEncoder`
221
and a `TransformerDecoder` chained together. To make the model aware of word order,
222
we also use a `PositionalEmbedding` layer.
223
224
The source sequence will be pass to the `TransformerEncoder`,
225
which will produce a new representation of it.
226
This new representation will then be passed
227
to the `TransformerDecoder`, together with the target sequence so far (target words 0 to N).
228
The `TransformerDecoder` will then seek to predict the next words in the target sequence (N+1 and beyond).
229
230
A key detail that makes this possible is causal masking
231
(see method `get_causal_attention_mask()` on the `TransformerDecoder`).
232
The `TransformerDecoder` sees the entire sequences at once, and thus we must make
233
sure that it only uses information from target tokens 0 to N when predicting token N+1
234
(otherwise, it could use information from the future, which would
235
result in a model that cannot be used at inference time).
236
"""
237
import keras.ops as ops
238
239
240
class TransformerEncoder(layers.Layer):
241
def __init__(self, embed_dim, dense_dim, num_heads, **kwargs):
242
super().__init__(**kwargs)
243
self.embed_dim = embed_dim
244
self.dense_dim = dense_dim
245
self.num_heads = num_heads
246
self.attention = layers.MultiHeadAttention(
247
num_heads=num_heads, key_dim=embed_dim
248
)
249
self.dense_proj = keras.Sequential(
250
[
251
layers.Dense(dense_dim, activation="relu"),
252
layers.Dense(embed_dim),
253
]
254
)
255
self.layernorm_1 = layers.LayerNormalization()
256
self.layernorm_2 = layers.LayerNormalization()
257
self.supports_masking = True
258
259
def call(self, inputs, mask=None):
260
if mask is not None:
261
padding_mask = ops.cast(mask[:, None, :], dtype="int32")
262
else:
263
padding_mask = None
264
265
attention_output = self.attention(
266
query=inputs, value=inputs, key=inputs, attention_mask=padding_mask
267
)
268
proj_input = self.layernorm_1(inputs + attention_output)
269
proj_output = self.dense_proj(proj_input)
270
return self.layernorm_2(proj_input + proj_output)
271
272
def get_config(self):
273
config = super().get_config()
274
config.update(
275
{
276
"embed_dim": self.embed_dim,
277
"dense_dim": self.dense_dim,
278
"num_heads": self.num_heads,
279
}
280
)
281
return config
282
283
284
class PositionalEmbedding(layers.Layer):
285
def __init__(self, sequence_length, vocab_size, embed_dim, **kwargs):
286
super().__init__(**kwargs)
287
self.token_embeddings = layers.Embedding(
288
input_dim=vocab_size, output_dim=embed_dim
289
)
290
self.position_embeddings = layers.Embedding(
291
input_dim=sequence_length, output_dim=embed_dim
292
)
293
self.sequence_length = sequence_length
294
self.vocab_size = vocab_size
295
self.embed_dim = embed_dim
296
297
def call(self, inputs):
298
length = ops.shape(inputs)[-1]
299
positions = ops.arange(0, length, 1)
300
embedded_tokens = self.token_embeddings(inputs)
301
embedded_positions = self.position_embeddings(positions)
302
return embedded_tokens + embedded_positions
303
304
def compute_mask(self, inputs, mask=None):
305
return ops.not_equal(inputs, 0)
306
307
def get_config(self):
308
config = super().get_config()
309
config.update(
310
{
311
"sequence_length": self.sequence_length,
312
"vocab_size": self.vocab_size,
313
"embed_dim": self.embed_dim,
314
}
315
)
316
return config
317
318
319
class TransformerDecoder(layers.Layer):
320
def __init__(self, embed_dim, latent_dim, num_heads, **kwargs):
321
super().__init__(**kwargs)
322
self.embed_dim = embed_dim
323
self.latent_dim = latent_dim
324
self.num_heads = num_heads
325
self.attention_1 = layers.MultiHeadAttention(
326
num_heads=num_heads, key_dim=embed_dim
327
)
328
self.attention_2 = layers.MultiHeadAttention(
329
num_heads=num_heads, key_dim=embed_dim
330
)
331
self.dense_proj = keras.Sequential(
332
[
333
layers.Dense(latent_dim, activation="relu"),
334
layers.Dense(embed_dim),
335
]
336
)
337
self.layernorm_1 = layers.LayerNormalization()
338
self.layernorm_2 = layers.LayerNormalization()
339
self.layernorm_3 = layers.LayerNormalization()
340
self.supports_masking = True
341
342
def call(self, inputs, mask=None):
343
inputs, encoder_outputs = inputs
344
causal_mask = self.get_causal_attention_mask(inputs)
345
346
if mask is None:
347
inputs_padding_mask, encoder_outputs_padding_mask = None, None
348
else:
349
inputs_padding_mask, encoder_outputs_padding_mask = mask
350
351
attention_output_1 = self.attention_1(
352
query=inputs,
353
value=inputs,
354
key=inputs,
355
attention_mask=causal_mask,
356
query_mask=inputs_padding_mask,
357
)
358
out_1 = self.layernorm_1(inputs + attention_output_1)
359
360
attention_output_2 = self.attention_2(
361
query=out_1,
362
value=encoder_outputs,
363
key=encoder_outputs,
364
query_mask=inputs_padding_mask,
365
key_mask=encoder_outputs_padding_mask,
366
)
367
out_2 = self.layernorm_2(out_1 + attention_output_2)
368
369
proj_output = self.dense_proj(out_2)
370
return self.layernorm_3(out_2 + proj_output)
371
372
def get_causal_attention_mask(self, inputs):
373
input_shape = ops.shape(inputs)
374
batch_size, sequence_length = input_shape[0], input_shape[1]
375
i = ops.arange(sequence_length)[:, None]
376
j = ops.arange(sequence_length)
377
mask = ops.cast(i >= j, dtype="int32")
378
mask = ops.reshape(mask, (1, input_shape[1], input_shape[1]))
379
mult = ops.concatenate(
380
[ops.expand_dims(batch_size, -1), ops.convert_to_tensor([1, 1])],
381
axis=0,
382
)
383
return ops.tile(mask, mult)
384
385
def get_config(self):
386
config = super().get_config()
387
config.update(
388
{
389
"embed_dim": self.embed_dim,
390
"latent_dim": self.latent_dim,
391
"num_heads": self.num_heads,
392
}
393
)
394
return config
395
396
397
"""
398
Next, we assemble the end-to-end model.
399
"""
400
401
embed_dim = 256
402
latent_dim = 2048
403
num_heads = 8
404
405
encoder_inputs = keras.Input(shape=(None,), dtype="int64", name="encoder_inputs")
406
x = PositionalEmbedding(sequence_length, vocab_size, embed_dim)(encoder_inputs)
407
encoder_outputs = TransformerEncoder(embed_dim, latent_dim, num_heads)(x)
408
encoder = keras.Model(encoder_inputs, encoder_outputs)
409
410
decoder_inputs = keras.Input(shape=(None,), dtype="int64", name="decoder_inputs")
411
encoded_seq_inputs = keras.Input(shape=(None, embed_dim), name="decoder_state_inputs")
412
x = PositionalEmbedding(sequence_length, vocab_size, embed_dim)(decoder_inputs)
413
x = TransformerDecoder(embed_dim, latent_dim, num_heads)([x, encoder_outputs])
414
x = layers.Dropout(0.5)(x)
415
decoder_outputs = layers.Dense(vocab_size, activation="softmax")(x)
416
decoder = keras.Model([decoder_inputs, encoded_seq_inputs], decoder_outputs)
417
418
transformer = keras.Model(
419
{"encoder_inputs": encoder_inputs, "decoder_inputs": decoder_inputs},
420
decoder_outputs,
421
name="transformer",
422
)
423
424
"""
425
## Training our model
426
427
We'll use accuracy as a quick way to monitor training progress on the validation data.
428
Note that machine translation typically uses BLEU scores as well as other metrics, rather than accuracy.
429
430
Here we only train for 1 epoch, but to get the model to actually converge
431
you should train for at least 30 epochs.
432
"""
433
434
epochs = 1 # This should be at least 30 for convergence
435
436
transformer.summary()
437
transformer.compile(
438
"rmsprop",
439
loss=keras.losses.SparseCategoricalCrossentropy(ignore_class=0),
440
metrics=["accuracy"],
441
)
442
transformer.fit(train_ds, epochs=epochs, validation_data=val_ds)
443
444
"""
445
## Decoding test sentences
446
447
Finally, let's demonstrate how to translate brand new English sentences.
448
We simply feed into the model the vectorized English sentence
449
as well as the target token `"[start]"`, then we repeatedly generated the next token, until
450
we hit the token `"[end]"`.
451
"""
452
453
spa_vocab = spa_vectorization.get_vocabulary()
454
spa_index_lookup = dict(zip(range(len(spa_vocab)), spa_vocab))
455
max_decoded_sentence_length = 20
456
457
458
def decode_sequence(input_sentence):
459
tokenized_input_sentence = eng_vectorization([input_sentence])
460
decoded_sentence = "[start]"
461
for i in range(max_decoded_sentence_length):
462
tokenized_target_sentence = spa_vectorization([decoded_sentence])[:, :-1]
463
predictions = transformer(
464
{
465
"encoder_inputs": tokenized_input_sentence,
466
"decoder_inputs": tokenized_target_sentence,
467
}
468
)
469
470
# ops.argmax(predictions[0, i, :]) is not a concrete value for jax here
471
sampled_token_index = ops.convert_to_numpy(
472
ops.argmax(predictions[0, i, :])
473
).item(0)
474
sampled_token = spa_index_lookup[sampled_token_index]
475
decoded_sentence += " " + sampled_token
476
477
if sampled_token == "[end]":
478
break
479
return decoded_sentence
480
481
482
test_eng_texts = [pair[0] for pair in test_pairs]
483
for _ in range(30):
484
input_sentence = random.choice(test_eng_texts)
485
translated = decode_sequence(input_sentence)
486
487
"""
488
After 30 epochs, we get results such as:
489
490
> She handed him the money.
491
> [start] ella le pasó el dinero [end]
492
493
> Tom has never heard Mary sing.
494
> [start] tom nunca ha oído cantar a mary [end]
495
496
> Perhaps she will come tomorrow.
497
> [start] tal vez ella vendrá mañana [end]
498
499
> I love to write.
500
> [start] me encanta escribir [end]
501
502
> His French is improving little by little.
503
> [start] su francés va a [UNK] sólo un poco [end]
504
505
> My hotel told me to call you.
506
> [start] mi hotel me dijo que te [UNK] [end]
507
"""
508
509