Path: blob/master/examples/nlp/masked_language_modeling.py
3507 views
"""1Title: End-to-end Masked Language Modeling with BERT2Author: [Ankur Singh](https://twitter.com/ankur310794)3Date created: 2020/09/184Last modified: 2024/03/155Description: Implement a Masked Language Model (MLM) with BERT and fine-tune it on the IMDB Reviews dataset.6Accelerator: GPU7Converted to Keras 3 by: [Sitam Meur](https://github.com/sitamgithub-MSIT) and made backend-agnostic by: [Humbulani Ndou](https://github.com/Humbulani1234)8"""910"""11## Introduction1213Masked Language Modeling is a fill-in-the-blank task,14where a model uses the context words surrounding a mask token to try to predict what the15masked word should be.1617For an input that contains one or more mask tokens,18the model will generate the most likely substitution for each.1920Example:2122- Input: "I have watched this [MASK] and it was awesome."23- Output: "I have watched this movie and it was awesome."2425Masked language modeling is a great way to train a language26model in a self-supervised setting (without human-annotated labels).27Such a model can then be fine-tuned to accomplish various supervised28NLP tasks.2930This example teaches you how to build a BERT model from scratch,31train it with the masked language modeling task,32and then fine-tune this model on a sentiment classification task.3334We will use the Keras `TextVectorization` and `MultiHeadAttention` layers35to create a BERT Transformer-Encoder network architecture.3637Note: This example should be run with `tf-nightly`.38"""3940"""41## Setup4243Install `tf-nightly` via `pip install tf-nightly`.44"""4546import os4748os.environ["KERAS_BACKEND"] = "torch" # or jax, or tensorflow4950import keras_hub5152import keras53from keras import layers54from keras.layers import TextVectorization5556from dataclasses import dataclass57import pandas as pd58import numpy as np59import glob60import re61from pprint import pprint6263"""64## Set-up Configuration65"""666768@dataclass69class Config:70MAX_LEN = 25671BATCH_SIZE = 3272LR = 0.00173VOCAB_SIZE = 3000074EMBED_DIM = 12875NUM_HEAD = 8 # used in bert model76FF_DIM = 128 # used in bert model77NUM_LAYERS = 1787980config = Config()8182"""83## Load the data8485We will first download the IMDB data and load into a Pandas dataframe.86"""8788"""shell89curl -O https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz90tar -xf aclImdb_v1.tar.gz91"""929394def get_text_list_from_files(files):95text_list = []96for name in files:97with open(name) as f:98for line in f:99text_list.append(line)100return text_list101102103def get_data_from_text_files(folder_name):104pos_files = glob.glob("aclImdb/" + folder_name + "/pos/*.txt")105pos_texts = get_text_list_from_files(pos_files)106neg_files = glob.glob("aclImdb/" + folder_name + "/neg/*.txt")107neg_texts = get_text_list_from_files(neg_files)108df = pd.DataFrame(109{110"review": pos_texts + neg_texts,111"sentiment": [0] * len(pos_texts) + [1] * len(neg_texts),112}113)114df = df.sample(len(df)).reset_index(drop=True)115return df116117118train_df = get_data_from_text_files("train")119test_df = get_data_from_text_files("test")120121all_data = pd.concat([train_df, test_df], ignore_index=True)122123"""124## Dataset preparation125126We will use the `TextVectorization` layer to vectorize the text into integer token ids.127It transforms a batch of strings into either128a sequence of token indices (one sample = 1D array of integer token indices, in order)129or a dense representation (one sample = 1D array of float values encoding an unordered set of tokens).130131Below, we define 3 preprocessing functions.1321331. The `get_vectorize_layer` function builds the `TextVectorization` layer.1342. The `encode` function encodes raw text into integer token ids.1353. The `get_masked_input_and_labels` function will mask input token ids.136It masks 15% of all input tokens in each sequence at random.137"""138139# For data pre-processing and tf.data.Dataset140import tensorflow as tf141142143def custom_standardization(input_data):144lowercase = tf.strings.lower(input_data)145stripped_html = tf.strings.regex_replace(lowercase, "<br />", " ")146return tf.strings.regex_replace(147stripped_html, "[%s]" % re.escape("!#$%&'()*+,-./:;<=>?@\^_`{|}~"), ""148)149150151def get_vectorize_layer(texts, vocab_size, max_seq, special_tokens=["[MASK]"]):152"""Build Text vectorization layer153154Args:155texts (list): List of string i.e input texts156vocab_size (int): vocab size157max_seq (int): Maximum sequence length.158special_tokens (list, optional): List of special tokens. Defaults to ['[MASK]'].159160Returns:161layers.Layer: Return TextVectorization Keras Layer162"""163vectorize_layer = TextVectorization(164max_tokens=vocab_size,165output_mode="int",166standardize=custom_standardization,167output_sequence_length=max_seq,168)169vectorize_layer.adapt(texts)170171# Insert mask token in vocabulary172vocab = vectorize_layer.get_vocabulary()173vocab = vocab[2 : vocab_size - len(special_tokens)] + ["[mask]"]174vectorize_layer.set_vocabulary(vocab)175return vectorize_layer176177178vectorize_layer = get_vectorize_layer(179all_data.review.values.tolist(),180config.VOCAB_SIZE,181config.MAX_LEN,182special_tokens=["[mask]"],183)184185# Get mask token id for masked language model186mask_token_id = vectorize_layer(["[mask]"]).numpy()[0][0]187188189def encode(texts):190encoded_texts = vectorize_layer(texts)191return encoded_texts.numpy()192193194def get_masked_input_and_labels(encoded_texts):195# 15% BERT masking196inp_mask = np.random.rand(*encoded_texts.shape) < 0.15197# Do not mask special tokens198inp_mask[encoded_texts <= 2] = False199# Set targets to -1 by default, it means ignore200labels = -1 * np.ones(encoded_texts.shape, dtype=int)201# Set labels for masked tokens202labels[inp_mask] = encoded_texts[inp_mask]203204# Prepare input205encoded_texts_masked = np.copy(encoded_texts)206# Set input to [MASK] which is the last token for the 90% of tokens207# This means leaving 10% unchanged208inp_mask_2mask = inp_mask & (np.random.rand(*encoded_texts.shape) < 0.90)209encoded_texts_masked[inp_mask_2mask] = (210mask_token_id # mask token is the last in the dict211)212213# Set 10% to a random token214inp_mask_2random = inp_mask_2mask & (np.random.rand(*encoded_texts.shape) < 1 / 9)215encoded_texts_masked[inp_mask_2random] = np.random.randint(2163, mask_token_id, inp_mask_2random.sum()217)218219# Prepare sample_weights to pass to .fit() method220sample_weights = np.ones(labels.shape)221sample_weights[labels == -1] = 0222223# y_labels would be same as encoded_texts i.e input tokens224y_labels = np.copy(encoded_texts)225226return encoded_texts_masked, y_labels, sample_weights227228229# We have 25000 examples for training230x_train = encode(train_df.review.values) # encode reviews with vectorizer231y_train = train_df.sentiment.values232train_classifier_ds = (233tf.data.Dataset.from_tensor_slices((x_train, y_train))234.shuffle(1000)235.batch(config.BATCH_SIZE)236)237238# We have 25000 examples for testing239x_test = encode(test_df.review.values)240y_test = test_df.sentiment.values241test_classifier_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(242config.BATCH_SIZE243)244245# Dataset for end to end model input (will be used at the end)246test_raw_classifier_ds = test_df247248# Prepare data for masked language model249x_all_review = encode(all_data.review.values)250x_masked_train, y_masked_labels, sample_weights = get_masked_input_and_labels(251x_all_review252)253254mlm_ds = tf.data.Dataset.from_tensor_slices(255(x_masked_train, y_masked_labels, sample_weights)256)257mlm_ds = mlm_ds.shuffle(1000).batch(config.BATCH_SIZE)258259"""260## Create BERT model (Pretraining Model) for masked language modeling261262We will create a BERT-like pretraining model architecture263using the `MultiHeadAttention` layer.264It will take token ids as inputs (including masked tokens)265and it will predict the correct ids for the masked input tokens.266"""267268269def bert_module(query, key, value, i):270# Multi headed self-attention271attention_output = layers.MultiHeadAttention(272num_heads=config.NUM_HEAD,273key_dim=config.EMBED_DIM // config.NUM_HEAD,274name="encoder_{}_multiheadattention".format(i),275)(query, key, value)276attention_output = layers.Dropout(0.1, name="encoder_{}_att_dropout".format(i))(277attention_output278)279attention_output = layers.LayerNormalization(280epsilon=1e-6, name="encoder_{}_att_layernormalization".format(i)281)(query + attention_output)282283# Feed-forward layer284ffn = keras.Sequential(285[286layers.Dense(config.FF_DIM, activation="relu"),287layers.Dense(config.EMBED_DIM),288],289name="encoder_{}_ffn".format(i),290)291ffn_output = ffn(attention_output)292ffn_output = layers.Dropout(0.1, name="encoder_{}_ffn_dropout".format(i))(293ffn_output294)295sequence_output = layers.LayerNormalization(296epsilon=1e-6, name="encoder_{}_ffn_layernormalization".format(i)297)(attention_output + ffn_output)298return sequence_output299300301loss_fn = keras.losses.SparseCategoricalCrossentropy(reduction=None)302loss_tracker = keras.metrics.Mean(name="loss")303304305class MaskedLanguageModel(keras.Model):306307def compute_loss(self, x=None, y=None, y_pred=None, sample_weight=None):308309loss = loss_fn(y, y_pred, sample_weight)310loss_tracker.update_state(loss, sample_weight=sample_weight)311return keras.ops.sum(loss)312313def compute_metrics(self, x, y, y_pred, sample_weight):314315# Return a dict mapping metric names to current value316return {"loss": loss_tracker.result()}317318@property319def metrics(self):320# We list our `Metric` objects here so that `reset_states()` can be321# called automatically at the start of each epoch322# or at the start of `evaluate()`.323# If you don't implement this property, you have to call324# `reset_states()` yourself at the time of your choosing.325return [loss_tracker]326327328def create_masked_language_bert_model():329inputs = layers.Input((config.MAX_LEN,), dtype="int64")330331word_embeddings = layers.Embedding(332config.VOCAB_SIZE, config.EMBED_DIM, name="word_embedding"333)(inputs)334position_embeddings = keras_hub.layers.PositionEmbedding(335sequence_length=config.MAX_LEN336)(word_embeddings)337embeddings = word_embeddings + position_embeddings338339encoder_output = embeddings340for i in range(config.NUM_LAYERS):341encoder_output = bert_module(encoder_output, encoder_output, encoder_output, i)342343mlm_output = layers.Dense(config.VOCAB_SIZE, name="mlm_cls", activation="softmax")(344encoder_output345)346mlm_model = MaskedLanguageModel(inputs, mlm_output, name="masked_bert_model")347348optimizer = keras.optimizers.Adam(learning_rate=config.LR)349mlm_model.compile(optimizer=optimizer)350return mlm_model351352353id2token = dict(enumerate(vectorize_layer.get_vocabulary()))354token2id = {y: x for x, y in id2token.items()}355356357class MaskedTextGenerator(keras.callbacks.Callback):358def __init__(self, sample_tokens, top_k=5):359self.sample_tokens = sample_tokens360self.k = top_k361362def decode(self, tokens):363return " ".join([id2token[t] for t in tokens if t != 0])364365def convert_ids_to_tokens(self, id):366return id2token[id]367368def on_epoch_end(self, epoch, logs=None):369prediction = self.model.predict(self.sample_tokens)370371masked_index = np.where(self.sample_tokens == mask_token_id)372masked_index = masked_index[1]373mask_prediction = prediction[0][masked_index]374375top_indices = mask_prediction[0].argsort()[-self.k :][::-1]376values = mask_prediction[0][top_indices]377378for i in range(len(top_indices)):379p = top_indices[i]380v = values[i]381tokens = np.copy(sample_tokens[0])382tokens[masked_index[0]] = p383result = {384"input_text": self.decode(sample_tokens[0].numpy()),385"prediction": self.decode(tokens),386"probability": v,387"predicted mask token": self.convert_ids_to_tokens(p),388}389pprint(result)390391392sample_tokens = vectorize_layer(["I have watched this [mask] and it was awesome"])393generator_callback = MaskedTextGenerator(sample_tokens.numpy())394395bert_masked_model = create_masked_language_bert_model()396bert_masked_model.summary()397398"""399## Train and Save400"""401402bert_masked_model.fit(mlm_ds, epochs=5, callbacks=[generator_callback])403bert_masked_model.save("bert_mlm_imdb.keras")404405"""406## Fine-tune a sentiment classification model407408We will fine-tune our self-supervised model on a downstream task of sentiment classification.409To do this, let's create a classifier by adding a pooling layer and a `Dense` layer on top of the410pretrained BERT features.411412"""413414# Load pretrained bert model415mlm_model = keras.models.load_model(416"bert_mlm_imdb.keras", custom_objects={"MaskedLanguageModel": MaskedLanguageModel}417)418pretrained_bert_model = keras.Model(419mlm_model.input, mlm_model.get_layer("encoder_0_ffn_layernormalization").output420)421422# Freeze it423pretrained_bert_model.trainable = False424425426def create_classifier_bert_model():427inputs = layers.Input((config.MAX_LEN,), dtype="int64")428sequence_output = pretrained_bert_model(inputs)429pooled_output = layers.GlobalMaxPooling1D()(sequence_output)430hidden_layer = layers.Dense(64, activation="relu")(pooled_output)431outputs = layers.Dense(1, activation="sigmoid")(hidden_layer)432classifer_model = keras.Model(inputs, outputs, name="classification")433optimizer = keras.optimizers.Adam()434classifer_model.compile(435optimizer=optimizer, loss="binary_crossentropy", metrics=["accuracy"]436)437return classifer_model438439440classifer_model = create_classifier_bert_model()441classifer_model.summary()442443# Train the classifier with frozen BERT stage444classifer_model.fit(445train_classifier_ds,446epochs=5,447validation_data=test_classifier_ds,448)449450# Unfreeze the BERT model for fine-tuning451pretrained_bert_model.trainable = True452optimizer = keras.optimizers.Adam()453classifer_model.compile(454optimizer=optimizer, loss="binary_crossentropy", metrics=["accuracy"]455)456classifer_model.fit(457train_classifier_ds,458epochs=5,459validation_data=test_classifier_ds,460)461462"""463## Create an end-to-end model and evaluate it464465When you want to deploy a model, it's best if it already includes its preprocessing466pipeline, so that you don't have to reimplement the preprocessing logic in your467production environment. Let's create an end-to-end model that incorporates468the `TextVectorization` layer inside evaluate method, and let's evaluate. We will pass raw strings as input.469"""470471472# We create a custom Model to override the evaluate method so473# that it first pre-process text data474class ModelEndtoEnd(keras.Model):475476def evaluate(self, inputs):477features = encode(inputs.review.values)478labels = inputs.sentiment.values479test_classifier_ds = (480tf.data.Dataset.from_tensor_slices((features, labels))481.shuffle(1000)482.batch(config.BATCH_SIZE)483)484return super().evaluate(test_classifier_ds)485486# Build the model487def build(self, input_shape):488self.built = True489490491def get_end_to_end(model):492inputs = classifer_model.inputs[0]493outputs = classifer_model.outputs494end_to_end_model = ModelEndtoEnd(inputs, outputs, name="end_to_end_model")495optimizer = keras.optimizers.Adam(learning_rate=config.LR)496end_to_end_model.compile(497optimizer=optimizer, loss="binary_crossentropy", metrics=["accuracy"]498)499return end_to_end_model500501502end_to_end_classification_model = get_end_to_end(classifer_model)503# Pass raw text dataframe to the model504end_to_end_classification_model.evaluate(test_raw_classifier_ds)505506507