Path: blob/master/examples/nlp/ipynb/sentence_embeddings_with_sbert.ipynb
3508 views
Sentence embeddings using Siamese RoBERTa-networks
Author: Mohammed Abu El-Nasr
Date created: 2023/07/14
Last modified: 2023/07/14
Description: Fine-tune a RoBERTa model to generate sentence embeddings using KerasHub.
Introduction
BERT and RoBERTa can be used for semantic textual similarity tasks, where two sentences are passed to the model and the network predicts whether they are similar or not. But what if we have a large collection of sentences and want to find the most similar pairs in that collection? That will take n*(n-1)/2 inference computations, where n is the number of sentences in the collection. For example, if n = 10000, the required time will be 65 hours on a V100 GPU.
A common method to overcome the time overhead issue is to pass one sentence to the model, then average the output of the model, or take the first token (the [CLS] token) and use them as a sentence embedding, then use a vector similarity measure like cosine similarity or Manhatten / Euclidean distance to find close sentences (semantically similar sentences). That will reduce the time to find the most similar pairs in a collection of 10,000 sentences from 65 hours to 5 seconds!
If we use RoBERTa directly, that will yield rather bad sentence embeddings. But if we fine-tune RoBERTa using a Siamese network, that will generate semantically meaningful sentence embeddings. This will enable RoBERTa to be used for new tasks. These tasks include:
Large-scale semantic similarity comparison.
Clustering.
Information retrieval via semantic search.
In this example, we will show how to fine-tune a RoBERTa model using a Siamese network such that it will be able to produce semantically meaningful sentence embeddings and use them in a semantic search and clustering example. This method of fine-tuning was introduced in Sentence-BERT
Setup
Let's install and import the libraries we need. We'll be using the KerasHub library in this example.
We will also enable mixed precision training. This will help us reduce the training time.
Fine-tune the model using siamese networks
Siamese network is a neural network architecture that contains two or more subnetworks. The subnetworks share the same weights. It is used to generate feature vectors for each input and then compare them for similarity.
For our example, the subnetwork will be a RoBERTa model that has a pooling layer on top of it to produce the embeddings of the input sentences. These embeddings will then be compared to each other to learn to produce semantically meaningful embeddings.
The pooling strategies used are mean, max, and CLS pooling. Mean pooling produces the best results. We will use it in our examples.
Fine-tune using the regression objective function
For building the siamese network with the regression objective function, the siamese network is asked to predict the cosine similarity between the embeddings of the two input sentences.
Cosine similarity indicates the angle between the sentence embeddings. If the cosine similarity is high, that means there is a small angle between the embeddings; hence, they are semantically similar.
Load the dataset
We will use the STSB dataset to fine-tune the model for the regression objective. STSB consists of a collection of sentence pairs that are labelled in the range [0, 5]. 0 indicates the least semantic similarity between the two sentences, and 5 indicates the most semantic similarity between the two sentences.
The range of the cosine similarity is [-1, 1] and it's the output of the siamese network, but the range of the labels in the dataset is [0, 5]. We need to unify the range between the cosine similarity and the dataset labels, so while preparing the dataset, we will divide the labels by 2.5 and subtract 1.
Let's see examples from the dataset of two sentenses and their similarity.
Build the encoder model.
Now, we'll build the encoder model that will produce the sentence embeddings. It consists of:
A preprocessor layer to tokenize and generate padding masks for the sentences.
A backbone model that will generate the contextual representation of each token in the sentence.
A mean pooling layer to produce the embeddings. We will use
keras.layers.GlobalAveragePooling1D
to apply the mean pooling to the backbone outputs. We will pass the padding mask to the layer to exclude padded tokens from being averaged.A normalization layer to normalize the embeddings as we are using the cosine similarity.
Build the Siamese network with the regression objective function.
It's described above that the Siamese network has two or more subnetworks, and for this Siamese model, we need two encoders. But we don't have two encoders; we have only one encoder, but we will pass the two sentences through it. That way, we can have two paths to get the embeddings and also shared weights between the two paths.
After passing the two sentences to the model and getting the normalized embeddings, we will multiply the two normalized embeddings to get the cosine similarity between the two sentences.
Fit the model
Let's try this example before training and compare it to the output after training.
For the training we will use MeanSquaredError()
as loss function, and Adam()
optimizer with learning rate = 2e-5.
Let's try the model after training, we will notice a huge difference in the output. That means that the model after fine-tuning is capable of producing semantically meaningful embeddings. where the semantically similar sentences have a small angle between them. and semantically dissimilar sentences have a large angle between them.
Fine-tune Using the triplet Objective Function
For the Siamese network with the triplet objective function, three sentences are passed to the Siamese network anchor, positive, and negative sentences. anchor and positive sentences are semantically similar, and anchor and negative sentences are semantically dissimilar. The objective is to minimize the distance between the anchor sentence and the positive sentence, and to maximize the distance between the anchor sentence and the negative sentence.
Load the dataset
We will use the Wikipedia-sections-triplets dataset for fine-tuning. This data set consists of sentences derived from the Wikipedia website. It has a collection of 3 sentences anchor, positive, negative. anchor and positive are derived from the same section. anchor and negative are derived from different sections.
This dataset has 1.8 million training triplets and 220,000 test triplets. In this example, we will only use 1200 triplets for training and 300 for testing.
Build the encoder model
For this encoder model, we will use RoBERTa with mean pooling and we will not normalize the output embeddings. The encoder model consists of:
A preprocessor layer to tokenize and generate padding masks for the sentences.
A backbone model that will generate the contextual representation of each token in the sentence.
A mean pooling layer to produce the embeddings.
Build the Siamese network with the triplet objective function
For the Siamese network with the triplet objective function, we will build the model with an encoder, and we will pass the three sentences through that encoder. We will get an embedding for each sentence, and we will calculate the positive_dist
and negative_dist
that will be passed to the loss function described below.
We will use a custom loss function for the triplet objective. The loss function will receive the distance between the anchor and the positive embeddings positive_dist
, and the distance between the anchor and the negative embeddings negative_dist
, where they are stacked together in y_pred
.
We will use positive_dist
and negative_dist
to compute the loss such that negative_dist
is larger than positive_dist
at least by a specific margin. Mathematically, we will minimize this loss function: `max( positive_dist - negative_dist
margin, 0)`.
There is no y_true
used in this loss function. Note that we set the labels in the dataset to zero, but they will not be used.
Fit the model
For the training, we will use the custom TripletLoss()
loss function, and Adam()
optimizer with a learning rate = 2e-5.
Let's try this model in a clustering example. Here are 6 questions. first 3 questions about learning English, and the last 3 questions about working online. Let's see if the embeddings produced by our encoder will cluster them correctly.