Path: blob/master/examples/nlp/neural_machine_translation_with_transformer.py
3507 views
"""1Title: English-to-Spanish translation with a sequence-to-sequence Transformer2Author: [fchollet](https://twitter.com/fchollet)3Date created: 2021/05/264Last modified: 2024/11/185Description: Implementing a sequence-to-sequence Transformer and training it on a machine translation task.6Accelerator: GPU7"""89"""10## Introduction1112In this example, we'll build a sequence-to-sequence Transformer model, which13we'll train on an English-to-Spanish machine translation task.1415You'll learn how to:1617- Vectorize text using the Keras `TextVectorization` layer.18- Implement a `TransformerEncoder` layer, a `TransformerDecoder` layer,19and a `PositionalEmbedding` layer.20- Prepare data for training a sequence-to-sequence model.21- Use the trained model to generate translations of never-seen-before22input sentences (sequence-to-sequence inference).2324The code featured here is adapted from the book25[Deep Learning with Python, Second Edition](https://www.manning.com/books/deep-learning-with-python-second-edition)26(chapter 11: Deep learning for text).27The present example is fairly barebones, so for detailed explanations of28how each building block works, as well as the theory behind Transformers,29I recommend reading the book.30"""31"""32## Setup33"""3435# We set the backend to TensorFlow. The code works with36# both `tensorflow` and `torch`. It does not work with JAX37# due to the behavior of `jax.numpy.tile` in a jit scope38# (used in `TransformerDecoder.get_causal_attention_mask()`:39# `tile` in JAX does not support a dynamic `reps` argument.40# You can make the code work in JAX by wrapping the41# inside of the `get_causal_attention_mask` method in42# a decorator to prevent jit compilation:43# `with jax.ensure_compile_time_eval():`.44import os4546os.environ["KERAS_BACKEND"] = "tensorflow"4748import pathlib49import random50import string51import re52import numpy as np5354import tensorflow.data as tf_data55import tensorflow.strings as tf_strings5657import keras58from keras import layers59from keras import ops60from keras.layers import TextVectorization6162"""63## Downloading the data6465We'll be working with an English-to-Spanish translation dataset66provided by [Anki](https://www.manythings.org/anki/). Let's download it:67"""6869text_file = keras.utils.get_file(70fname="spa-eng.zip",71origin="http://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip",72extract=True,73)74text_file = pathlib.Path(text_file).parent / "spa-eng" / "spa.txt"7576"""77## Parsing the data7879Each line contains an English sentence and its corresponding Spanish sentence.80The English sentence is the *source sequence* and Spanish one is the *target sequence*.81We prepend the token `"[start]"` and we append the token `"[end]"` to the Spanish sentence.82"""8384with open(text_file) as f:85lines = f.read().split("\n")[:-1]86text_pairs = []87for line in lines:88eng, spa = line.split("\t")89spa = "[start] " + spa + " [end]"90text_pairs.append((eng, spa))9192"""93Here's what our sentence pairs look like:94"""9596for _ in range(5):97print(random.choice(text_pairs))9899"""100Now, let's split the sentence pairs into a training set, a validation set,101and a test set.102"""103104random.shuffle(text_pairs)105num_val_samples = int(0.15 * len(text_pairs))106num_train_samples = len(text_pairs) - 2 * num_val_samples107train_pairs = text_pairs[:num_train_samples]108val_pairs = text_pairs[num_train_samples : num_train_samples + num_val_samples]109test_pairs = text_pairs[num_train_samples + num_val_samples :]110111print(f"{len(text_pairs)} total pairs")112print(f"{len(train_pairs)} training pairs")113print(f"{len(val_pairs)} validation pairs")114print(f"{len(test_pairs)} test pairs")115116"""117## Vectorizing the text data118119We'll use two instances of the `TextVectorization` layer to vectorize the text120data (one for English and one for Spanish),121that is to say, to turn the original strings into integer sequences122where each integer represents the index of a word in a vocabulary.123124The English layer will use the default string standardization (strip punctuation characters)125and splitting scheme (split on whitespace), while126the Spanish layer will use a custom standardization, where we add the character127`"¿"` to the set of punctuation characters to be stripped.128129Note: in a production-grade machine translation model, I would not recommend130stripping the punctuation characters in either language. Instead, I would recommend turning131each punctuation character into its own token,132which you could achieve by providing a custom `split` function to the `TextVectorization` layer.133"""134135strip_chars = string.punctuation + "¿"136strip_chars = strip_chars.replace("[", "")137strip_chars = strip_chars.replace("]", "")138139vocab_size = 15000140sequence_length = 20141batch_size = 64142143144def custom_standardization(input_string):145lowercase = tf_strings.lower(input_string)146return tf_strings.regex_replace(lowercase, "[%s]" % re.escape(strip_chars), "")147148149eng_vectorization = TextVectorization(150max_tokens=vocab_size,151output_mode="int",152output_sequence_length=sequence_length,153)154spa_vectorization = TextVectorization(155max_tokens=vocab_size,156output_mode="int",157output_sequence_length=sequence_length + 1,158standardize=custom_standardization,159)160train_eng_texts = [pair[0] for pair in train_pairs]161train_spa_texts = [pair[1] for pair in train_pairs]162eng_vectorization.adapt(train_eng_texts)163spa_vectorization.adapt(train_spa_texts)164165"""166Next, we'll format our datasets.167168At each training step, the model will seek to predict target words N+1 (and beyond)169using the source sentence and the target words 0 to N.170171As such, the training dataset will yield a tuple `(inputs, targets)`, where:172173- `inputs` is a dictionary with the keys `encoder_inputs` and `decoder_inputs`.174`encoder_inputs` is the vectorized source sentence and `decoder_inputs` is the target sentence "so far",175that is to say, the words 0 to N used to predict word N+1 (and beyond) in the target sentence.176- `target` is the target sentence offset by one step:177it provides the next words in the target sentence -- what the model will try to predict.178"""179180181def format_dataset(eng, spa):182eng = eng_vectorization(eng)183spa = spa_vectorization(spa)184return (185{186"encoder_inputs": eng,187"decoder_inputs": spa[:, :-1],188},189spa[:, 1:],190)191192193def make_dataset(pairs):194eng_texts, spa_texts = zip(*pairs)195eng_texts = list(eng_texts)196spa_texts = list(spa_texts)197dataset = tf_data.Dataset.from_tensor_slices((eng_texts, spa_texts))198dataset = dataset.batch(batch_size)199dataset = dataset.map(format_dataset)200return dataset.cache().shuffle(2048).prefetch(16)201202203train_ds = make_dataset(train_pairs)204val_ds = make_dataset(val_pairs)205206"""207Let's take a quick look at the sequence shapes208(we have batches of 64 pairs, and all sequences are 20 steps long):209"""210211for inputs, targets in train_ds.take(1):212print(f'inputs["encoder_inputs"].shape: {inputs["encoder_inputs"].shape}')213print(f'inputs["decoder_inputs"].shape: {inputs["decoder_inputs"].shape}')214print(f"targets.shape: {targets.shape}")215216"""217## Building the model218219Our sequence-to-sequence Transformer consists of a `TransformerEncoder`220and a `TransformerDecoder` chained together. To make the model aware of word order,221we also use a `PositionalEmbedding` layer.222223The source sequence will be pass to the `TransformerEncoder`,224which will produce a new representation of it.225This new representation will then be passed226to the `TransformerDecoder`, together with the target sequence so far (target words 0 to N).227The `TransformerDecoder` will then seek to predict the next words in the target sequence (N+1 and beyond).228229A key detail that makes this possible is causal masking230(see method `get_causal_attention_mask()` on the `TransformerDecoder`).231The `TransformerDecoder` sees the entire sequences at once, and thus we must make232sure that it only uses information from target tokens 0 to N when predicting token N+1233(otherwise, it could use information from the future, which would234result in a model that cannot be used at inference time).235"""236import keras.ops as ops237238239class TransformerEncoder(layers.Layer):240def __init__(self, embed_dim, dense_dim, num_heads, **kwargs):241super().__init__(**kwargs)242self.embed_dim = embed_dim243self.dense_dim = dense_dim244self.num_heads = num_heads245self.attention = layers.MultiHeadAttention(246num_heads=num_heads, key_dim=embed_dim247)248self.dense_proj = keras.Sequential(249[250layers.Dense(dense_dim, activation="relu"),251layers.Dense(embed_dim),252]253)254self.layernorm_1 = layers.LayerNormalization()255self.layernorm_2 = layers.LayerNormalization()256self.supports_masking = True257258def call(self, inputs, mask=None):259if mask is not None:260padding_mask = ops.cast(mask[:, None, :], dtype="int32")261else:262padding_mask = None263264attention_output = self.attention(265query=inputs, value=inputs, key=inputs, attention_mask=padding_mask266)267proj_input = self.layernorm_1(inputs + attention_output)268proj_output = self.dense_proj(proj_input)269return self.layernorm_2(proj_input + proj_output)270271def get_config(self):272config = super().get_config()273config.update(274{275"embed_dim": self.embed_dim,276"dense_dim": self.dense_dim,277"num_heads": self.num_heads,278}279)280return config281282283class PositionalEmbedding(layers.Layer):284def __init__(self, sequence_length, vocab_size, embed_dim, **kwargs):285super().__init__(**kwargs)286self.token_embeddings = layers.Embedding(287input_dim=vocab_size, output_dim=embed_dim288)289self.position_embeddings = layers.Embedding(290input_dim=sequence_length, output_dim=embed_dim291)292self.sequence_length = sequence_length293self.vocab_size = vocab_size294self.embed_dim = embed_dim295296def call(self, inputs):297length = ops.shape(inputs)[-1]298positions = ops.arange(0, length, 1)299embedded_tokens = self.token_embeddings(inputs)300embedded_positions = self.position_embeddings(positions)301return embedded_tokens + embedded_positions302303def compute_mask(self, inputs, mask=None):304return ops.not_equal(inputs, 0)305306def get_config(self):307config = super().get_config()308config.update(309{310"sequence_length": self.sequence_length,311"vocab_size": self.vocab_size,312"embed_dim": self.embed_dim,313}314)315return config316317318class TransformerDecoder(layers.Layer):319def __init__(self, embed_dim, latent_dim, num_heads, **kwargs):320super().__init__(**kwargs)321self.embed_dim = embed_dim322self.latent_dim = latent_dim323self.num_heads = num_heads324self.attention_1 = layers.MultiHeadAttention(325num_heads=num_heads, key_dim=embed_dim326)327self.attention_2 = layers.MultiHeadAttention(328num_heads=num_heads, key_dim=embed_dim329)330self.dense_proj = keras.Sequential(331[332layers.Dense(latent_dim, activation="relu"),333layers.Dense(embed_dim),334]335)336self.layernorm_1 = layers.LayerNormalization()337self.layernorm_2 = layers.LayerNormalization()338self.layernorm_3 = layers.LayerNormalization()339self.supports_masking = True340341def call(self, inputs, mask=None):342inputs, encoder_outputs = inputs343causal_mask = self.get_causal_attention_mask(inputs)344345if mask is None:346inputs_padding_mask, encoder_outputs_padding_mask = None, None347else:348inputs_padding_mask, encoder_outputs_padding_mask = mask349350attention_output_1 = self.attention_1(351query=inputs,352value=inputs,353key=inputs,354attention_mask=causal_mask,355query_mask=inputs_padding_mask,356)357out_1 = self.layernorm_1(inputs + attention_output_1)358359attention_output_2 = self.attention_2(360query=out_1,361value=encoder_outputs,362key=encoder_outputs,363query_mask=inputs_padding_mask,364key_mask=encoder_outputs_padding_mask,365)366out_2 = self.layernorm_2(out_1 + attention_output_2)367368proj_output = self.dense_proj(out_2)369return self.layernorm_3(out_2 + proj_output)370371def get_causal_attention_mask(self, inputs):372input_shape = ops.shape(inputs)373batch_size, sequence_length = input_shape[0], input_shape[1]374i = ops.arange(sequence_length)[:, None]375j = ops.arange(sequence_length)376mask = ops.cast(i >= j, dtype="int32")377mask = ops.reshape(mask, (1, input_shape[1], input_shape[1]))378mult = ops.concatenate(379[ops.expand_dims(batch_size, -1), ops.convert_to_tensor([1, 1])],380axis=0,381)382return ops.tile(mask, mult)383384def get_config(self):385config = super().get_config()386config.update(387{388"embed_dim": self.embed_dim,389"latent_dim": self.latent_dim,390"num_heads": self.num_heads,391}392)393return config394395396"""397Next, we assemble the end-to-end model.398"""399400embed_dim = 256401latent_dim = 2048402num_heads = 8403404encoder_inputs = keras.Input(shape=(None,), dtype="int64", name="encoder_inputs")405x = PositionalEmbedding(sequence_length, vocab_size, embed_dim)(encoder_inputs)406encoder_outputs = TransformerEncoder(embed_dim, latent_dim, num_heads)(x)407encoder = keras.Model(encoder_inputs, encoder_outputs)408409decoder_inputs = keras.Input(shape=(None,), dtype="int64", name="decoder_inputs")410encoded_seq_inputs = keras.Input(shape=(None, embed_dim), name="decoder_state_inputs")411x = PositionalEmbedding(sequence_length, vocab_size, embed_dim)(decoder_inputs)412x = TransformerDecoder(embed_dim, latent_dim, num_heads)([x, encoder_outputs])413x = layers.Dropout(0.5)(x)414decoder_outputs = layers.Dense(vocab_size, activation="softmax")(x)415decoder = keras.Model([decoder_inputs, encoded_seq_inputs], decoder_outputs)416417transformer = keras.Model(418{"encoder_inputs": encoder_inputs, "decoder_inputs": decoder_inputs},419decoder_outputs,420name="transformer",421)422423"""424## Training our model425426We'll use accuracy as a quick way to monitor training progress on the validation data.427Note that machine translation typically uses BLEU scores as well as other metrics, rather than accuracy.428429Here we only train for 1 epoch, but to get the model to actually converge430you should train for at least 30 epochs.431"""432433epochs = 1 # This should be at least 30 for convergence434435transformer.summary()436transformer.compile(437"rmsprop",438loss=keras.losses.SparseCategoricalCrossentropy(ignore_class=0),439metrics=["accuracy"],440)441transformer.fit(train_ds, epochs=epochs, validation_data=val_ds)442443"""444## Decoding test sentences445446Finally, let's demonstrate how to translate brand new English sentences.447We simply feed into the model the vectorized English sentence448as well as the target token `"[start]"`, then we repeatedly generated the next token, until449we hit the token `"[end]"`.450"""451452spa_vocab = spa_vectorization.get_vocabulary()453spa_index_lookup = dict(zip(range(len(spa_vocab)), spa_vocab))454max_decoded_sentence_length = 20455456457def decode_sequence(input_sentence):458tokenized_input_sentence = eng_vectorization([input_sentence])459decoded_sentence = "[start]"460for i in range(max_decoded_sentence_length):461tokenized_target_sentence = spa_vectorization([decoded_sentence])[:, :-1]462predictions = transformer(463{464"encoder_inputs": tokenized_input_sentence,465"decoder_inputs": tokenized_target_sentence,466}467)468469# ops.argmax(predictions[0, i, :]) is not a concrete value for jax here470sampled_token_index = ops.convert_to_numpy(471ops.argmax(predictions[0, i, :])472).item(0)473sampled_token = spa_index_lookup[sampled_token_index]474decoded_sentence += " " + sampled_token475476if sampled_token == "[end]":477break478return decoded_sentence479480481test_eng_texts = [pair[0] for pair in test_pairs]482for _ in range(30):483input_sentence = random.choice(test_eng_texts)484translated = decode_sequence(input_sentence)485486"""487After 30 epochs, we get results such as:488489> She handed him the money.490> [start] ella le pasó el dinero [end]491492> Tom has never heard Mary sing.493> [start] tom nunca ha oído cantar a mary [end]494495> Perhaps she will come tomorrow.496> [start] tal vez ella vendrá mañana [end]497498> I love to write.499> [start] me encanta escribir [end]500501> His French is improving little by little.502> [start] su francés va a [UNK] sólo un poco [end]503504> My hotel told me to call you.505> [start] mi hotel me dijo que te [UNK] [end]506"""507508509