Path: blob/master/examples/nlp/semantic_similarity_with_bert.py
3507 views
"""1Title: Semantic Similarity with BERT2Author: [Mohamad Merchant](https://twitter.com/mohmadmerchant1)3Date created: 2020/08/154Last modified: 2020/08/295Description: Natural Language Inference by fine-tuning BERT model on SNLI Corpus.6Accelerator: GPU7"""89"""10## Introduction1112Semantic Similarity is the task of determining how similar13two sentences are, in terms of what they mean.14This example demonstrates the use of SNLI (Stanford Natural Language Inference) Corpus15to predict sentence semantic similarity with Transformers.16We will fine-tune a BERT model that takes two sentences as inputs17and that outputs a similarity score for these two sentences.1819### References2021* [BERT](https://arxiv.org/pdf/1810.04805.pdf)22* [SNLI](https://nlp.stanford.edu/projects/snli/)23"""2425"""26## Setup2728Note: install HuggingFace `transformers` via `pip install transformers` (version >= 2.11.0).29"""30import numpy as np31import pandas as pd32import tensorflow as tf33import transformers3435"""36## Configuration37"""3839max_length = 128 # Maximum length of input sentence to the model.40batch_size = 3241epochs = 24243# Labels in our dataset.44labels = ["contradiction", "entailment", "neutral"]4546"""47## Load the Data48"""4950"""shell51curl -LO https://raw.githubusercontent.com/MohamadMerchant/SNLI/master/data.tar.gz52tar -xvzf data.tar.gz53"""54# There are more than 550k samples in total; we will use 100k for this example.55train_df = pd.read_csv("SNLI_Corpus/snli_1.0_train.csv", nrows=100000)56valid_df = pd.read_csv("SNLI_Corpus/snli_1.0_dev.csv")57test_df = pd.read_csv("SNLI_Corpus/snli_1.0_test.csv")5859# Shape of the data60print(f"Total train samples : {train_df.shape[0]}")61print(f"Total validation samples: {valid_df.shape[0]}")62print(f"Total test samples: {valid_df.shape[0]}")6364"""65Dataset Overview:6667- sentence1: The premise caption that was supplied to the author of the pair.68- sentence2: The hypothesis caption that was written by the author of the pair.69- similarity: This is the label chosen by the majority of annotators.70Where no majority exists, the label "-" is used (we will skip such samples here).7172Here are the "similarity" label values in our dataset:7374- Contradiction: The sentences share no similarity.75- Entailment: The sentences have similar meaning.76- Neutral: The sentences are neutral.77"""7879"""80Let's look at one sample from the dataset:81"""82print(f"Sentence1: {train_df.loc[1, 'sentence1']}")83print(f"Sentence2: {train_df.loc[1, 'sentence2']}")84print(f"Similarity: {train_df.loc[1, 'similarity']}")8586"""87## Preprocessing88"""8990# We have some NaN entries in our train data, we will simply drop them.91print("Number of missing values")92print(train_df.isnull().sum())93train_df.dropna(axis=0, inplace=True)9495"""96Distribution of our training targets.97"""98print("Train Target Distribution")99print(train_df.similarity.value_counts())100101"""102Distribution of our validation targets.103"""104print("Validation Target Distribution")105print(valid_df.similarity.value_counts())106107"""108The value "-" appears as part of our training and validation targets.109We will skip these samples.110"""111train_df = (112train_df[train_df.similarity != "-"]113.sample(frac=1.0, random_state=42)114.reset_index(drop=True)115)116valid_df = (117valid_df[valid_df.similarity != "-"]118.sample(frac=1.0, random_state=42)119.reset_index(drop=True)120)121122"""123One-hot encode training, validation, and test labels.124"""125train_df["label"] = train_df["similarity"].apply(126lambda x: 0 if x == "contradiction" else 1 if x == "entailment" else 2127)128y_train = tf.keras.utils.to_categorical(train_df.label, num_classes=3)129130valid_df["label"] = valid_df["similarity"].apply(131lambda x: 0 if x == "contradiction" else 1 if x == "entailment" else 2132)133y_val = tf.keras.utils.to_categorical(valid_df.label, num_classes=3)134135test_df["label"] = test_df["similarity"].apply(136lambda x: 0 if x == "contradiction" else 1 if x == "entailment" else 2137)138y_test = tf.keras.utils.to_categorical(test_df.label, num_classes=3)139140"""141## Create a custom data generator142"""143144145class BertSemanticDataGenerator(tf.keras.utils.Sequence):146"""Generates batches of data.147148Args:149sentence_pairs: Array of premise and hypothesis input sentences.150labels: Array of labels.151batch_size: Integer batch size.152shuffle: boolean, whether to shuffle the data.153include_targets: boolean, whether to include the labels.154155Returns:156Tuples `([input_ids, attention_mask, `token_type_ids], labels)`157(or just `[input_ids, attention_mask, `token_type_ids]`158if `include_targets=False`)159"""160161def __init__(162self,163sentence_pairs,164labels,165batch_size=batch_size,166shuffle=True,167include_targets=True,168):169self.sentence_pairs = sentence_pairs170self.labels = labels171self.shuffle = shuffle172self.batch_size = batch_size173self.include_targets = include_targets174# Load our BERT Tokenizer to encode the text.175# We will use base-base-uncased pretrained model.176self.tokenizer = transformers.BertTokenizer.from_pretrained(177"bert-base-uncased", do_lower_case=True178)179self.indexes = np.arange(len(self.sentence_pairs))180self.on_epoch_end()181182def __len__(self):183# Denotes the number of batches per epoch.184return len(self.sentence_pairs) // self.batch_size185186def __getitem__(self, idx):187# Retrieves the batch of index.188indexes = self.indexes[idx * self.batch_size : (idx + 1) * self.batch_size]189sentence_pairs = self.sentence_pairs[indexes]190191# With BERT tokenizer's batch_encode_plus batch of both the sentences are192# encoded together and separated by [SEP] token.193encoded = self.tokenizer.batch_encode_plus(194sentence_pairs.tolist(),195add_special_tokens=True,196max_length=max_length,197return_attention_mask=True,198return_token_type_ids=True,199pad_to_max_length=True,200return_tensors="tf",201)202203# Convert batch of encoded features to numpy array.204input_ids = np.array(encoded["input_ids"], dtype="int32")205attention_masks = np.array(encoded["attention_mask"], dtype="int32")206token_type_ids = np.array(encoded["token_type_ids"], dtype="int32")207208# Set to true if data generator is used for training/validation.209if self.include_targets:210labels = np.array(self.labels[indexes], dtype="int32")211return [input_ids, attention_masks, token_type_ids], labels212else:213return [input_ids, attention_masks, token_type_ids]214215def on_epoch_end(self):216# Shuffle indexes after each epoch if shuffle is set to True.217if self.shuffle:218np.random.RandomState(42).shuffle(self.indexes)219220221"""222## Build the model223"""224# Create the model under a distribution strategy scope.225strategy = tf.distribute.MirroredStrategy()226227with strategy.scope():228# Encoded token ids from BERT tokenizer.229input_ids = tf.keras.layers.Input(230shape=(max_length,), dtype=tf.int32, name="input_ids"231)232# Attention masks indicates to the model which tokens should be attended to.233attention_masks = tf.keras.layers.Input(234shape=(max_length,), dtype=tf.int32, name="attention_masks"235)236# Token type ids are binary masks identifying different sequences in the model.237token_type_ids = tf.keras.layers.Input(238shape=(max_length,), dtype=tf.int32, name="token_type_ids"239)240# Loading pretrained BERT model.241bert_model = transformers.TFBertModel.from_pretrained("bert-base-uncased")242# Freeze the BERT model to reuse the pretrained features without modifying them.243bert_model.trainable = False244245bert_output = bert_model.bert(246input_ids, attention_mask=attention_masks, token_type_ids=token_type_ids247)248sequence_output = bert_output.last_hidden_state249pooled_output = bert_output.pooler_output250251# Add trainable layers on top of frozen layers to adapt the pretrained features on the new data.252bi_lstm = tf.keras.layers.Bidirectional(253tf.keras.layers.LSTM(64, return_sequences=True)254)(sequence_output)255# Applying hybrid pooling approach to bi_lstm sequence output.256avg_pool = tf.keras.layers.GlobalAveragePooling1D()(bi_lstm)257max_pool = tf.keras.layers.GlobalMaxPooling1D()(bi_lstm)258concat = tf.keras.layers.concatenate([avg_pool, max_pool])259dropout = tf.keras.layers.Dropout(0.3)(concat)260output = tf.keras.layers.Dense(3, activation="softmax")(dropout)261model = tf.keras.models.Model(262inputs=[input_ids, attention_masks, token_type_ids], outputs=output263)264265model.compile(266optimizer=tf.keras.optimizers.Adam(),267loss="categorical_crossentropy",268metrics=["acc"],269)270271272print(f"Strategy: {strategy}")273model.summary()274275"""276Create train and validation data generators277"""278train_data = BertSemanticDataGenerator(279train_df[["sentence1", "sentence2"]].values.astype("str"),280y_train,281batch_size=batch_size,282shuffle=True,283)284valid_data = BertSemanticDataGenerator(285valid_df[["sentence1", "sentence2"]].values.astype("str"),286y_val,287batch_size=batch_size,288shuffle=False,289)290291"""292## Train the Model293294Training is done only for the top layers to perform "feature extraction",295which will allow the model to use the representations of the pretrained model.296"""297history = model.fit(298train_data,299validation_data=valid_data,300epochs=epochs,301use_multiprocessing=True,302workers=-1,303)304305"""306## Fine-tuning307308This step must only be performed after the feature extraction model has309been trained to convergence on the new data.310311This is an optional last step where `bert_model` is unfreezed and retrained312with a very low learning rate. This can deliver meaningful improvement by313incrementally adapting the pretrained features to the new data.314"""315316# Unfreeze the bert_model.317bert_model.trainable = True318# Recompile the model to make the change effective.319model.compile(320optimizer=tf.keras.optimizers.Adam(1e-5),321loss="categorical_crossentropy",322metrics=["accuracy"],323)324model.summary()325326"""327## Train the entire model end-to-end328"""329history = model.fit(330train_data,331validation_data=valid_data,332epochs=epochs,333use_multiprocessing=True,334workers=-1,335)336337"""338## Evaluate model on the test set339"""340test_data = BertSemanticDataGenerator(341test_df[["sentence1", "sentence2"]].values.astype("str"),342y_test,343batch_size=batch_size,344shuffle=False,345)346model.evaluate(test_data, verbose=1)347348"""349## Inference on custom sentences350"""351352353def check_similarity(sentence1, sentence2):354sentence_pairs = np.array([[str(sentence1), str(sentence2)]])355test_data = BertSemanticDataGenerator(356sentence_pairs,357labels=None,358batch_size=1,359shuffle=False,360include_targets=False,361)362363proba = model.predict(test_data[0])[0]364idx = np.argmax(proba)365proba = f"{proba[idx]: .2f}%"366pred = labels[idx]367return pred, proba368369370"""371Check results on some example sentence pairs.372"""373sentence1 = "Two women are observing something together."374sentence2 = "Two women are standing with their eyes closed."375check_similarity(sentence1, sentence2)376"""377Check results on some example sentence pairs.378"""379sentence1 = "A smiling costumed woman is holding an umbrella"380sentence2 = "A happy woman in a fairy costume holds an umbrella"381check_similarity(sentence1, sentence2)382383"""384Check results on some example sentence pairs385"""386sentence1 = "A soccer game with multiple males playing"387sentence2 = "Some men are playing a sport"388check_similarity(sentence1, sentence2)389390"""391Example available on HuggingFace392393| Trained Model | Demo |394| :--: | :--: |395| [](https://huggingface.co/keras-io/bert-semantic-similarity) | [](https://huggingface.co/spaces/keras-io/bert-semantic-similarity) |396"""397398399