Path: blob/master/examples/nlp/sentence_embeddings_with_sbert.py
3507 views
"""1Title: Sentence embeddings using Siamese RoBERTa-networks2Author: [Mohammed Abu El-Nasr](https://github.com/abuelnasr0)3Date created: 2023/07/144Last modified: 2023/07/145Description: Fine-tune a RoBERTa model to generate sentence embeddings using KerasHub.6Accelerator: GPU7"""89"""10## Introduction1112BERT and RoBERTa can be used for semantic textual similarity tasks, where two sentences13are passed to the model and the network predicts whether they are similar or not. But14what if we have a large collection of sentences and want to find the most similar pairs15in that collection? That will take n*(n-1)/2 inference computations, where n is the16number of sentences in the collection. For example, if n = 10000, the required time will17be 65 hours on a V100 GPU.1819A common method to overcome the time overhead issue is to pass one sentence to the model,20then average the output of the model, or take the first token (the [CLS] token) and use21them as a [sentence embedding](https://en.wikipedia.org/wiki/Sentence_embedding), then22use a vector similarity measure like cosine similarity or Manhatten / Euclidean distance23to find close sentences (semantically similar sentences). That will reduce the time to24find the most similar pairs in a collection of 10,000 sentences from 65 hours to 525seconds!2627If we use RoBERTa directly, that will yield rather bad sentence embeddings. But if we28fine-tune RoBERTa using a Siamese network, that will generate semantically meaningful29sentence embeddings. This will enable RoBERTa to be used for new tasks. These tasks30include:3132- Large-scale semantic similarity comparison.33- Clustering.34- Information retrieval via semantic search.3536In this example, we will show how to fine-tune a RoBERTa model using a Siamese network37such that it will be able to produce semantically meaningful sentence embeddings and use38them in a semantic search and clustering example.39This method of fine-tuning was introduced in40[Sentence-BERT](https://arxiv.org/abs/1908.10084)41"""4243"""44## Setup4546Let's install and import the libraries we need. We'll be using the KerasHub library in47this example.4849We will also enable [mixed precision](https://www.tensorflow.org/guide/mixed_precision)50training. This will help us reduce the training time.51"""5253"""shell54pip install -q --upgrade keras-hub55pip install -q --upgrade keras # Upgrade to Keras 3.56"""5758import os5960os.environ["KERAS_BACKEND"] = "tensorflow"6162import keras63import keras_hub64import tensorflow as tf65import tensorflow_datasets as tfds66import sklearn.cluster as cluster6768keras.mixed_precision.set_global_policy("mixed_float16")6970"""71## Fine-tune the model using siamese networks7273[Siamese network](https://en.wikipedia.org/wiki/Siamese_neural_network) is a neural74network architecture that contains two or more subnetworks. The subnetworks share the75same weights. It is used to generate feature vectors for each input and then compare them76for similarity.7778For our example, the subnetwork will be a RoBERTa model that has a pooling layer on top79of it to produce the embeddings of the input sentences. These embeddings will then be80compared to each other to learn to produce semantically meaningful embeddings.8182The pooling strategies used are mean, max, and CLS pooling. Mean pooling produces the83best results. We will use it in our examples.84"""8586"""87### Fine-tune using the regression objective function8889For building the siamese network with the regression objective function, the siamese90network is asked to predict the cosine similarity between the embeddings of the two input91sentences.9293Cosine similarity indicates the angle between the sentence embeddings. If the cosine94similarity is high, that means there is a small angle between the embeddings; hence, they95are semantically similar.96"""9798"""99#### Load the dataset100101We will use the STSB dataset to fine-tune the model for the regression objective. STSB102consists of a collection of sentence pairs that are labelled in the range [0, 5]. 0103indicates the least semantic similarity between the two sentences, and 5 indicates the104most semantic similarity between the two sentences.105106The range of the cosine similarity is [-1, 1] and it's the output of the siamese network,107but the range of the labels in the dataset is [0, 5]. We need to unify the range between108the cosine similarity and the dataset labels, so while preparing the dataset, we will109divide the labels by 2.5 and subtract 1.110"""111112TRAIN_BATCH_SIZE = 6113VALIDATION_BATCH_SIZE = 8114115TRAIN_NUM_BATCHES = 300116VALIDATION_NUM_BATCHES = 40117118AUTOTUNE = tf.data.experimental.AUTOTUNE119120121def change_range(x):122return (x / 2.5) - 1123124125def prepare_dataset(dataset, num_batches, batch_size):126dataset = dataset.map(127lambda z: (128[z["sentence1"], z["sentence2"]],129[tf.cast(change_range(z["label"]), tf.float32)],130),131num_parallel_calls=AUTOTUNE,132)133dataset = dataset.batch(batch_size)134dataset = dataset.take(num_batches)135dataset = dataset.prefetch(AUTOTUNE)136return dataset137138139stsb_ds = tfds.load(140"glue/stsb",141)142stsb_train, stsb_valid = stsb_ds["train"], stsb_ds["validation"]143144stsb_train = prepare_dataset(stsb_train, TRAIN_NUM_BATCHES, TRAIN_BATCH_SIZE)145stsb_valid = prepare_dataset(stsb_valid, VALIDATION_NUM_BATCHES, VALIDATION_BATCH_SIZE)146147"""148Let's see examples from the dataset of two sentenses and their similarity.149"""150151for x, y in stsb_train:152for i, example in enumerate(x):153print(f"sentence 1 : {example[0]} ")154print(f"sentence 2 : {example[1]} ")155print(f"similarity : {y[i]} \n")156break157158"""159#### Build the encoder model.160161Now, we'll build the encoder model that will produce the sentence embeddings. It consists162of:163164- A preprocessor layer to tokenize and generate padding masks for the sentences.165- A backbone model that will generate the contextual representation of each token in the166sentence.167- A mean pooling layer to produce the embeddings. We will use `keras.layers.GlobalAveragePooling1D`168to apply the mean pooling to the backbone outputs. We will pass the padding mask to the169layer to exclude padded tokens from being averaged.170- A normalization layer to normalize the embeddings as we are using the cosine similarity.171"""172173preprocessor = keras_hub.models.RobertaPreprocessor.from_preset("roberta_base_en")174backbone = keras_hub.models.RobertaBackbone.from_preset("roberta_base_en")175inputs = keras.Input(shape=(1,), dtype="string", name="sentence")176x = preprocessor(inputs)177h = backbone(x)178embedding = keras.layers.GlobalAveragePooling1D(name="pooling_layer")(179h, x["padding_mask"]180)181n_embedding = keras.layers.UnitNormalization(axis=1)(embedding)182roberta_normal_encoder = keras.Model(inputs=inputs, outputs=n_embedding)183184roberta_normal_encoder.summary()185186"""187#### Build the Siamese network with the regression objective function.188189It's described above that the Siamese network has two or more subnetworks, and for this190Siamese model, we need two encoders. But we don't have two encoders; we have only one191encoder, but we will pass the two sentences through it. That way, we can have two paths192to get the embeddings and also shared weights between the two paths.193194After passing the two sentences to the model and getting the normalized embeddings, we195will multiply the two normalized embeddings to get the cosine similarity between the two196sentences.197"""198199200class RegressionSiamese(keras.Model):201def __init__(self, encoder, **kwargs):202inputs = keras.Input(shape=(2,), dtype="string", name="sentences")203sen1, sen2 = keras.ops.split(inputs, 2, axis=1)204u = encoder(sen1)205v = encoder(sen2)206cosine_similarity_scores = keras.ops.matmul(u, keras.ops.transpose(v))207208super().__init__(209inputs=inputs,210outputs=cosine_similarity_scores,211**kwargs,212)213214self.encoder = encoder215216def get_encoder(self):217return self.encoder218219220"""221#### Fit the model222223Let's try this example before training and compare it to the output after training.224"""225226sentences = [227"Today is a very sunny day.",228"I am hungry, I will get my meal.",229"The dog is eating his food.",230]231query = ["The dog is enjoying his meal."]232233encoder = roberta_normal_encoder234235sentence_embeddings = encoder(tf.constant(sentences))236query_embedding = encoder(tf.constant(query))237238cosine_similarity_scores = tf.matmul(query_embedding, tf.transpose(sentence_embeddings))239for i, sim in enumerate(cosine_similarity_scores[0]):240print(f"cosine similarity score between sentence {i+1} and the query = {sim} ")241242"""243For the training we will use `MeanSquaredError()` as loss function, and `Adam()`244optimizer with learning rate = 2e-5.245"""246247roberta_regression_siamese = RegressionSiamese(roberta_normal_encoder)248249roberta_regression_siamese.compile(250loss=keras.losses.MeanSquaredError(),251optimizer=keras.optimizers.Adam(2e-5),252jit_compile=False,253)254255roberta_regression_siamese.fit(stsb_train, validation_data=stsb_valid, epochs=1)256257"""258Let's try the model after training, we will notice a huge difference in the output. That259means that the model after fine-tuning is capable of producing semantically meaningful260embeddings. where the semantically similar sentences have a small angle between them. and261semantically dissimilar sentences have a large angle between them.262"""263264sentences = [265"Today is a very sunny day.",266"I am hungry, I will get my meal.",267"The dog is eating his food.",268]269query = ["The dog is enjoying his food."]270271encoder = roberta_regression_siamese.get_encoder()272273sentence_embeddings = encoder(tf.constant(sentences))274query_embedding = encoder(tf.constant(query))275276cosine_simalarities = tf.matmul(query_embedding, tf.transpose(sentence_embeddings))277for i, sim in enumerate(cosine_simalarities[0]):278print(f"cosine similarity between sentence {i+1} and the query = {sim} ")279280"""281### Fine-tune Using the triplet Objective Function282283For the Siamese network with the triplet objective function, three sentences are passed284to the Siamese network *anchor*, *positive*, and *negative* sentences. *anchor* and285*positive* sentences are semantically similar, and *anchor* and *negative* sentences are286semantically dissimilar. The objective is to minimize the distance between the *anchor*287sentence and the *positive* sentence, and to maximize the distance between the *anchor*288sentence and the *negative* sentence.289"""290291"""292#### Load the dataset293294We will use the Wikipedia-sections-triplets dataset for fine-tuning. This data set295consists of sentences derived from the Wikipedia website. It has a collection of 3296sentences *anchor*, *positive*, *negative*. *anchor* and *positive* are derived from the297same section. *anchor* and *negative* are derived from different sections.298299This dataset has 1.8 million training triplets and 220,000 test triplets. In this300example, we will only use 1200 triplets for training and 300 for testing.301"""302303"""shell304wget https://sbert.net/datasets/wikipedia-sections-triplets.zip -q305unzip wikipedia-sections-triplets.zip -d wikipedia-sections-triplets306"""307308NUM_TRAIN_BATCHES = 200309NUM_TEST_BATCHES = 75310AUTOTUNE = tf.data.experimental.AUTOTUNE311312313def prepare_wiki_data(dataset, num_batches):314dataset = dataset.map(315lambda z: ((z["Sentence1"], z["Sentence2"], z["Sentence3"]), 0)316)317dataset = dataset.batch(6)318dataset = dataset.take(num_batches)319dataset = dataset.prefetch(AUTOTUNE)320return dataset321322323wiki_train = tf.data.experimental.make_csv_dataset(324"wikipedia-sections-triplets/train.csv",325batch_size=1,326num_epochs=1,327)328wiki_test = tf.data.experimental.make_csv_dataset(329"wikipedia-sections-triplets/test.csv",330batch_size=1,331num_epochs=1,332)333334wiki_train = prepare_wiki_data(wiki_train, NUM_TRAIN_BATCHES)335wiki_test = prepare_wiki_data(wiki_test, NUM_TEST_BATCHES)336337"""338#### Build the encoder model339340For this encoder model, we will use RoBERTa with mean pooling and we will not normalize341the output embeddings. The encoder model consists of:342343- A preprocessor layer to tokenize and generate padding masks for the sentences.344- A backbone model that will generate the contextual representation of each token in the345sentence.346- A mean pooling layer to produce the embeddings.347"""348349preprocessor = keras_hub.models.RobertaPreprocessor.from_preset("roberta_base_en")350backbone = keras_hub.models.RobertaBackbone.from_preset("roberta_base_en")351input = keras.Input(shape=(1,), dtype="string", name="sentence")352353x = preprocessor(input)354h = backbone(x)355embedding = keras.layers.GlobalAveragePooling1D(name="pooling_layer")(356h, x["padding_mask"]357)358359roberta_encoder = keras.Model(inputs=input, outputs=embedding)360361362roberta_encoder.summary()363364"""365#### Build the Siamese network with the triplet objective function366367For the Siamese network with the triplet objective function, we will build the model with368an encoder, and we will pass the three sentences through that encoder. We will get an369embedding for each sentence, and we will calculate the `positive_dist` and370`negative_dist` that will be passed to the loss function described below.371"""372373374class TripletSiamese(keras.Model):375def __init__(self, encoder, **kwargs):376anchor = keras.Input(shape=(1,), dtype="string")377positive = keras.Input(shape=(1,), dtype="string")378negative = keras.Input(shape=(1,), dtype="string")379380ea = encoder(anchor)381ep = encoder(positive)382en = encoder(negative)383384positive_dist = keras.ops.sum(keras.ops.square(ea - ep), axis=1)385negative_dist = keras.ops.sum(keras.ops.square(ea - en), axis=1)386387positive_dist = keras.ops.sqrt(positive_dist)388negative_dist = keras.ops.sqrt(negative_dist)389390output = keras.ops.stack([positive_dist, negative_dist], axis=0)391392super().__init__(inputs=[anchor, positive, negative], outputs=output, **kwargs)393394self.encoder = encoder395396def get_encoder(self):397return self.encoder398399400"""401We will use a custom loss function for the triplet objective. The loss function will402receive the distance between the *anchor* and the *positive* embeddings `positive_dist`,403and the distance between the *anchor* and the *negative* embeddings `negative_dist`,404where they are stacked together in `y_pred`.405406We will use `positive_dist` and `negative_dist` to compute the loss such that407`negative_dist` is larger than `positive_dist` at least by a specific margin.408Mathematically, we will minimize this loss function: `max( positive_dist - negative_dist409+ margin, 0)`.410411There is no `y_true` used in this loss function. Note that we set the labels in the412dataset to zero, but they will not be used.413"""414415416class TripletLoss(keras.losses.Loss):417def __init__(self, margin=1, **kwargs):418super().__init__(**kwargs)419self.margin = margin420421def call(self, y_true, y_pred):422positive_dist, negative_dist = tf.unstack(y_pred, axis=0)423424losses = keras.ops.relu(positive_dist - negative_dist + self.margin)425return keras.ops.mean(losses, axis=0)426427428"""429#### Fit the model430431For the training, we will use the custom `TripletLoss()` loss function, and `Adam()`432optimizer with a learning rate = 2e-5.433"""434435roberta_triplet_siamese = TripletSiamese(roberta_encoder)436437roberta_triplet_siamese.compile(438loss=TripletLoss(),439optimizer=keras.optimizers.Adam(2e-5),440jit_compile=False,441)442443roberta_triplet_siamese.fit(wiki_train, validation_data=wiki_test, epochs=1)444445"""446Let's try this model in a clustering example. Here are 6 questions. first 3 questions447about learning English, and the last 3 questions about working online. Let's see if the448embeddings produced by our encoder will cluster them correctly.449"""450451questions = [452"What should I do to improve my English writting?",453"How to be good at speaking English?",454"How can I improve my English?",455"How to earn money online?",456"How do I earn money online?",457"How to work and earn money through internet?",458]459460encoder = roberta_triplet_siamese.get_encoder()461embeddings = encoder(tf.constant(questions))462kmeans = cluster.KMeans(n_clusters=2, random_state=0, n_init="auto").fit(embeddings)463464for i, label in enumerate(kmeans.labels_):465print(f"sentence ({questions[i]}) belongs to cluster {label}")466467468