Path: blob/master/examples/nlp/neural_machine_translation_with_keras_hub.py
3507 views
"""1Title: English-to-Spanish translation with KerasHub2Author: [Abheesht Sharma](https://github.com/abheesht17/)3Date created: 2022/05/264Last modified: 2024/04/305Description: Use KerasHub to train a sequence-to-sequence Transformer model on the machine translation task.6Accelerator: GPU7"""89"""10## Introduction1112KerasHub provides building blocks for NLP (model layers, tokenizers, metrics, etc.) and13makes it convenient to construct NLP pipelines.1415In this example, we'll use KerasHub layers to build an encoder-decoder Transformer16model, and train it on the English-to-Spanish machine translation task.1718This example is based on the19[English-to-Spanish NMT20example](https://keras.io/examples/nlp/neural_machine_translation_with_transformer/)21by [fchollet](https://twitter.com/fchollet). The original example is more low-level22and implements layers from scratch, whereas this example uses KerasHub to show23some more advanced approaches, such as subword tokenization and using metrics24to compute the quality of generated translations.2526You'll learn how to:2728- Tokenize text using `keras_hub.tokenizers.WordPieceTokenizer`.29- Implement a sequence-to-sequence Transformer model using KerasHub's30`keras_hub.layers.TransformerEncoder`, `keras_hub.layers.TransformerDecoder` and31`keras_hub.layers.TokenAndPositionEmbedding` layers, and train it.32- Use `keras_hub.samplers` to generate translations of unseen input sentences33using the top-p decoding strategy!3435Don't worry if you aren't familiar with KerasHub. This tutorial will start with36the basics. Let's dive right in!37"""3839"""40## Setup4142Before we start implementing the pipeline, let's import all the libraries we need.43"""4445"""shell46pip install -q --upgrade rouge-score47pip install -q --upgrade keras-hub48pip install -q --upgrade keras # Upgrade to Keras 3.49"""5051import keras_hub52import pathlib53import random5455import keras56from keras import ops5758import tensorflow.data as tf_data59from tensorflow_text.tools.wordpiece_vocab import (60bert_vocab_from_dataset as bert_vocab,61)6263"""64Let's also define our parameters/hyperparameters.65"""6667BATCH_SIZE = 6468EPOCHS = 1 # This should be at least 10 for convergence69MAX_SEQUENCE_LENGTH = 4070ENG_VOCAB_SIZE = 1500071SPA_VOCAB_SIZE = 150007273EMBED_DIM = 25674INTERMEDIATE_DIM = 204875NUM_HEADS = 87677"""78## Downloading the data7980We'll be working with an English-to-Spanish translation dataset81provided by [Anki](https://www.manythings.org/anki/). Let's download it:82"""8384text_file = keras.utils.get_file(85fname="spa-eng.zip",86origin="http://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip",87extract=True,88)89text_file = pathlib.Path(text_file).parent / "spa-eng" / "spa.txt"9091"""92## Parsing the data9394Each line contains an English sentence and its corresponding Spanish sentence.95The English sentence is the *source sequence* and Spanish one is the *target sequence*.96Before adding the text to a list, we convert it to lowercase.97"""9899with open(text_file) as f:100lines = f.read().split("\n")[:-1]101text_pairs = []102for line in lines:103eng, spa = line.split("\t")104eng = eng.lower()105spa = spa.lower()106text_pairs.append((eng, spa))107108"""109Here's what our sentence pairs look like:110"""111112for _ in range(5):113print(random.choice(text_pairs))114115"""116Now, let's split the sentence pairs into a training set, a validation set,117and a test set.118"""119120random.shuffle(text_pairs)121num_val_samples = int(0.15 * len(text_pairs))122num_train_samples = len(text_pairs) - 2 * num_val_samples123train_pairs = text_pairs[:num_train_samples]124val_pairs = text_pairs[num_train_samples : num_train_samples + num_val_samples]125test_pairs = text_pairs[num_train_samples + num_val_samples :]126127print(f"{len(text_pairs)} total pairs")128print(f"{len(train_pairs)} training pairs")129print(f"{len(val_pairs)} validation pairs")130print(f"{len(test_pairs)} test pairs")131132133"""134## Tokenizing the data135136We'll define two tokenizers - one for the source language (English), and the other137for the target language (Spanish). We'll be using138`keras_hub.tokenizers.WordPieceTokenizer` to tokenize the text.139`keras_hub.tokenizers.WordPieceTokenizer` takes a WordPiece vocabulary140and has functions for tokenizing the text, and detokenizing sequences of tokens.141142Before we define the two tokenizers, we first need to train them on the dataset143we have. The WordPiece tokenization algorithm is a subword tokenization algorithm;144training it on a corpus gives us a vocabulary of subwords. A subword tokenizer145is a compromise between word tokenizers (word tokenizers need very large146vocabularies for good coverage of input words), and character tokenizers147(characters don't really encode meaning like words do). Luckily, KerasHub148makes it very simple to train WordPiece on a corpus with the149`keras_hub.tokenizers.compute_word_piece_vocabulary` utility.150"""151152153def train_word_piece(text_samples, vocab_size, reserved_tokens):154word_piece_ds = tf_data.Dataset.from_tensor_slices(text_samples)155vocab = keras_hub.tokenizers.compute_word_piece_vocabulary(156word_piece_ds.batch(1000).prefetch(2),157vocabulary_size=vocab_size,158reserved_tokens=reserved_tokens,159)160return vocab161162163"""164Every vocabulary has a few special, reserved tokens. We have four such tokens:165166- `"[PAD]"` - Padding token. Padding tokens are appended to the input sequence167length when the input sequence length is shorter than the maximum sequence length.168- `"[UNK]"` - Unknown token.169- `"[START]"` - Token that marks the start of the input sequence.170- `"[END]"` - Token that marks the end of the input sequence.171"""172173reserved_tokens = ["[PAD]", "[UNK]", "[START]", "[END]"]174175eng_samples = [text_pair[0] for text_pair in train_pairs]176eng_vocab = train_word_piece(eng_samples, ENG_VOCAB_SIZE, reserved_tokens)177178spa_samples = [text_pair[1] for text_pair in train_pairs]179spa_vocab = train_word_piece(spa_samples, SPA_VOCAB_SIZE, reserved_tokens)180181"""182Let's see some tokens!183"""184185print("English Tokens: ", eng_vocab[100:110])186print("Spanish Tokens: ", spa_vocab[100:110])187188"""189Now, let's define the tokenizers. We will configure the tokenizers with the190the vocabularies trained above.191"""192193eng_tokenizer = keras_hub.tokenizers.WordPieceTokenizer(194vocabulary=eng_vocab, lowercase=False195)196spa_tokenizer = keras_hub.tokenizers.WordPieceTokenizer(197vocabulary=spa_vocab, lowercase=False198)199200"""201Let's try and tokenize a sample from our dataset! To verify whether the text has202been tokenized correctly, we can also detokenize the list of tokens back to the203original text.204"""205206eng_input_ex = text_pairs[0][0]207eng_tokens_ex = eng_tokenizer.tokenize(eng_input_ex)208print("English sentence: ", eng_input_ex)209print("Tokens: ", eng_tokens_ex)210print(211"Recovered text after detokenizing: ",212eng_tokenizer.detokenize(eng_tokens_ex),213)214215print()216217spa_input_ex = text_pairs[0][1]218spa_tokens_ex = spa_tokenizer.tokenize(spa_input_ex)219print("Spanish sentence: ", spa_input_ex)220print("Tokens: ", spa_tokens_ex)221print(222"Recovered text after detokenizing: ",223spa_tokenizer.detokenize(spa_tokens_ex),224)225226"""227## Format datasets228229Next, we'll format our datasets.230231At each training step, the model will seek to predict target words N+1 (and beyond)232using the source sentence and the target words 0 to N.233234As such, the training dataset will yield a tuple `(inputs, targets)`, where:235236- `inputs` is a dictionary with the keys `encoder_inputs` and `decoder_inputs`.237`encoder_inputs` is the tokenized source sentence and `decoder_inputs` is the target238sentence "so far",239that is to say, the words 0 to N used to predict word N+1 (and beyond) in the target240sentence.241- `target` is the target sentence offset by one step:242it provides the next words in the target sentence -- what the model will try to predict.243244We will add special tokens, `"[START]"` and `"[END]"`, to the input Spanish245sentence after tokenizing the text. We will also pad the input to a fixed length.246This can be easily done using `keras_hub.layers.StartEndPacker`.247"""248249250def preprocess_batch(eng, spa):251batch_size = ops.shape(spa)[0]252253eng = eng_tokenizer(eng)254spa = spa_tokenizer(spa)255256# Pad `eng` to `MAX_SEQUENCE_LENGTH`.257eng_start_end_packer = keras_hub.layers.StartEndPacker(258sequence_length=MAX_SEQUENCE_LENGTH,259pad_value=eng_tokenizer.token_to_id("[PAD]"),260)261eng = eng_start_end_packer(eng)262263# Add special tokens (`"[START]"` and `"[END]"`) to `spa` and pad it as well.264spa_start_end_packer = keras_hub.layers.StartEndPacker(265sequence_length=MAX_SEQUENCE_LENGTH + 1,266start_value=spa_tokenizer.token_to_id("[START]"),267end_value=spa_tokenizer.token_to_id("[END]"),268pad_value=spa_tokenizer.token_to_id("[PAD]"),269)270spa = spa_start_end_packer(spa)271272return (273{274"encoder_inputs": eng,275"decoder_inputs": spa[:, :-1],276},277spa[:, 1:],278)279280281def make_dataset(pairs):282eng_texts, spa_texts = zip(*pairs)283eng_texts = list(eng_texts)284spa_texts = list(spa_texts)285dataset = tf_data.Dataset.from_tensor_slices((eng_texts, spa_texts))286dataset = dataset.batch(BATCH_SIZE)287dataset = dataset.map(preprocess_batch, num_parallel_calls=tf_data.AUTOTUNE)288return dataset.shuffle(2048).prefetch(16).cache()289290291train_ds = make_dataset(train_pairs)292val_ds = make_dataset(val_pairs)293294"""295Let's take a quick look at the sequence shapes296(we have batches of 64 pairs, and all sequences are 40 steps long):297"""298299for inputs, targets in train_ds.take(1):300print(f'inputs["encoder_inputs"].shape: {inputs["encoder_inputs"].shape}')301print(f'inputs["decoder_inputs"].shape: {inputs["decoder_inputs"].shape}')302print(f"targets.shape: {targets.shape}")303304305"""306## Building the model307308Now, let's move on to the exciting part - defining our model!309We first need an embedding layer, i.e., a vector for every token in our input sequence.310This embedding layer can be initialised randomly. We also need a positional311embedding layer which encodes the word order in the sequence. The convention is312to add these two embeddings. KerasHub has a `keras_hub.layers.TokenAndPositionEmbedding `313layer which does all of the above steps for us.314315Our sequence-to-sequence Transformer consists of a `keras_hub.layers.TransformerEncoder`316layer and a `keras_hub.layers.TransformerDecoder` layer chained together.317318The source sequence will be passed to `keras_hub.layers.TransformerEncoder`, which319will produce a new representation of it. This new representation will then be passed320to the `keras_hub.layers.TransformerDecoder`, together with the target sequence321so far (target words 0 to N). The `keras_hub.layers.TransformerDecoder` will322then seek to predict the next words in the target sequence (N+1 and beyond).323324A key detail that makes this possible is causal masking.325The `keras_hub.layers.TransformerDecoder` sees the entire sequence at once, and326thus we must make sure that it only uses information from target tokens 0 to N327when predicting token N+1 (otherwise, it could use information from the future,328which would result in a model that cannot be used at inference time). Causal masking329is enabled by default in `keras_hub.layers.TransformerDecoder`.330331We also need to mask the padding tokens (`"[PAD]"`). For this, we can set the332`mask_zero` argument of the `keras_hub.layers.TokenAndPositionEmbedding` layer333to True. This will then be propagated to all subsequent layers.334"""335336# Encoder337encoder_inputs = keras.Input(shape=(None,), name="encoder_inputs")338339x = keras_hub.layers.TokenAndPositionEmbedding(340vocabulary_size=ENG_VOCAB_SIZE,341sequence_length=MAX_SEQUENCE_LENGTH,342embedding_dim=EMBED_DIM,343)(encoder_inputs)344345encoder_outputs = keras_hub.layers.TransformerEncoder(346intermediate_dim=INTERMEDIATE_DIM, num_heads=NUM_HEADS347)(inputs=x)348encoder = keras.Model(encoder_inputs, encoder_outputs)349350351# Decoder352decoder_inputs = keras.Input(shape=(None,), name="decoder_inputs")353encoded_seq_inputs = keras.Input(shape=(None, EMBED_DIM), name="decoder_state_inputs")354355x = keras_hub.layers.TokenAndPositionEmbedding(356vocabulary_size=SPA_VOCAB_SIZE,357sequence_length=MAX_SEQUENCE_LENGTH,358embedding_dim=EMBED_DIM,359)(decoder_inputs)360361x = keras_hub.layers.TransformerDecoder(362intermediate_dim=INTERMEDIATE_DIM, num_heads=NUM_HEADS363)(decoder_sequence=x, encoder_sequence=encoded_seq_inputs)364x = keras.layers.Dropout(0.5)(x)365decoder_outputs = keras.layers.Dense(SPA_VOCAB_SIZE, activation="softmax")(x)366decoder = keras.Model(367[368decoder_inputs,369encoded_seq_inputs,370],371decoder_outputs,372)373decoder_outputs = decoder([decoder_inputs, encoder_outputs])374375transformer = keras.Model(376[encoder_inputs, decoder_inputs],377decoder_outputs,378name="transformer",379)380381"""382## Training our model383384We'll use accuracy as a quick way to monitor training progress on the validation data.385Note that machine translation typically uses BLEU scores as well as other metrics,386rather than accuracy. However, in order to use metrics like ROUGE, BLEU, etc. we387will have decode the probabilities and generate the text. Text generation is388computationally expensive, and performing this during training is not recommended.389390Here we only train for 1 epoch, but to get the model to actually converge391you should train for at least 10 epochs.392"""393394transformer.summary()395transformer.compile(396"rmsprop", loss="sparse_categorical_crossentropy", metrics=["accuracy"]397)398transformer.fit(train_ds, epochs=EPOCHS, validation_data=val_ds)399400"""401## Decoding test sentences (qualitative analysis)402403Finally, let's demonstrate how to translate brand new English sentences.404We simply feed into the model the tokenized English sentence405as well as the target token `"[START]"`. The model outputs probabilities of the406next token. We then we repeatedly generated the next token conditioned on the407tokens generated so far, until we hit the token `"[END]"`.408409For decoding, we will use the `keras_hub.samplers` module from410KerasHub. Greedy Decoding is a text decoding method which outputs the most411likely next token at each time step, i.e., the token with the highest probability.412"""413414415def decode_sequences(input_sentences):416batch_size = 1417418# Tokenize the encoder input.419encoder_input_tokens = ops.convert_to_tensor(eng_tokenizer(input_sentences))420if len(encoder_input_tokens[0]) < MAX_SEQUENCE_LENGTH:421pads = ops.full((1, MAX_SEQUENCE_LENGTH - len(encoder_input_tokens[0])), 0)422encoder_input_tokens = ops.concatenate(423[encoder_input_tokens.to_tensor(), pads], 1424)425426# Define a function that outputs the next token's probability given the427# input sequence.428def next(prompt, cache, index):429logits = transformer([encoder_input_tokens, prompt])[:, index - 1, :]430# Ignore hidden states for now; only needed for contrastive search.431hidden_states = None432return logits, hidden_states, cache433434# Build a prompt of length 40 with a start token and padding tokens.435length = 40436start = ops.full((batch_size, 1), spa_tokenizer.token_to_id("[START]"))437pad = ops.full((batch_size, length - 1), spa_tokenizer.token_to_id("[PAD]"))438prompt = ops.concatenate((start, pad), axis=-1)439440generated_tokens = keras_hub.samplers.GreedySampler()(441next,442prompt,443stop_token_ids=[spa_tokenizer.token_to_id("[END]")],444index=1, # Start sampling after start token.445)446generated_sentences = spa_tokenizer.detokenize(generated_tokens)447return generated_sentences448449450test_eng_texts = [pair[0] for pair in test_pairs]451for i in range(2):452input_sentence = random.choice(test_eng_texts)453translated = decode_sequences([input_sentence])454translated = translated.numpy()[0].decode("utf-8")455translated = (456translated.replace("[PAD]", "")457.replace("[START]", "")458.replace("[END]", "")459.strip()460)461print(f"** Example {i} **")462print(input_sentence)463print(translated)464print()465466"""467## Evaluating our model (quantitative analysis)468469There are many metrics which are used for text generation tasks. Here, to470evaluate translations generated by our model, let's compute the ROUGE-1 and471ROUGE-2 scores. Essentially, ROUGE-N is a score based on the number of common472n-grams between the reference text and the generated text. ROUGE-1 and ROUGE-2473use the number of common unigrams and bigrams, respectively.474475We will calculate the score over 30 test samples (since decoding is an476expensive process).477"""478479rouge_1 = keras_hub.metrics.RougeN(order=1)480rouge_2 = keras_hub.metrics.RougeN(order=2)481482for test_pair in test_pairs[:30]:483input_sentence = test_pair[0]484reference_sentence = test_pair[1]485486translated_sentence = decode_sequences([input_sentence])487translated_sentence = translated_sentence.numpy()[0].decode("utf-8")488translated_sentence = (489translated_sentence.replace("[PAD]", "")490.replace("[START]", "")491.replace("[END]", "")492.strip()493)494495rouge_1(reference_sentence, translated_sentence)496rouge_2(reference_sentence, translated_sentence)497498print("ROUGE-1 Score: ", rouge_1.result())499print("ROUGE-2 Score: ", rouge_2.result())500501"""502After 10 epochs, the scores are as follows:503504| | **ROUGE-1** | **ROUGE-2** |505|:-------------:|:-----------:|:-----------:|506| **Precision** | 0.568 | 0.374 |507| **Recall** | 0.615 | 0.394 |508| **F1 Score** | 0.579 | 0.381 |509"""510511512