Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/nlp/sentence_embeddings_with_sbert.py
3507 views
1
"""
2
Title: Sentence embeddings using Siamese RoBERTa-networks
3
Author: [Mohammed Abu El-Nasr](https://github.com/abuelnasr0)
4
Date created: 2023/07/14
5
Last modified: 2023/07/14
6
Description: Fine-tune a RoBERTa model to generate sentence embeddings using KerasHub.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
BERT and RoBERTa can be used for semantic textual similarity tasks, where two sentences
14
are passed to the model and the network predicts whether they are similar or not. But
15
what if we have a large collection of sentences and want to find the most similar pairs
16
in that collection? That will take n*(n-1)/2 inference computations, where n is the
17
number of sentences in the collection. For example, if n = 10000, the required time will
18
be 65 hours on a V100 GPU.
19
20
A common method to overcome the time overhead issue is to pass one sentence to the model,
21
then average the output of the model, or take the first token (the [CLS] token) and use
22
them as a [sentence embedding](https://en.wikipedia.org/wiki/Sentence_embedding), then
23
use a vector similarity measure like cosine similarity or Manhatten / Euclidean distance
24
to find close sentences (semantically similar sentences). That will reduce the time to
25
find the most similar pairs in a collection of 10,000 sentences from 65 hours to 5
26
seconds!
27
28
If we use RoBERTa directly, that will yield rather bad sentence embeddings. But if we
29
fine-tune RoBERTa using a Siamese network, that will generate semantically meaningful
30
sentence embeddings. This will enable RoBERTa to be used for new tasks. These tasks
31
include:
32
33
- Large-scale semantic similarity comparison.
34
- Clustering.
35
- Information retrieval via semantic search.
36
37
In this example, we will show how to fine-tune a RoBERTa model using a Siamese network
38
such that it will be able to produce semantically meaningful sentence embeddings and use
39
them in a semantic search and clustering example.
40
This method of fine-tuning was introduced in
41
[Sentence-BERT](https://arxiv.org/abs/1908.10084)
42
"""
43
44
"""
45
## Setup
46
47
Let's install and import the libraries we need. We'll be using the KerasHub library in
48
this example.
49
50
We will also enable [mixed precision](https://www.tensorflow.org/guide/mixed_precision)
51
training. This will help us reduce the training time.
52
"""
53
54
"""shell
55
pip install -q --upgrade keras-hub
56
pip install -q --upgrade keras # Upgrade to Keras 3.
57
"""
58
59
import os
60
61
os.environ["KERAS_BACKEND"] = "tensorflow"
62
63
import keras
64
import keras_hub
65
import tensorflow as tf
66
import tensorflow_datasets as tfds
67
import sklearn.cluster as cluster
68
69
keras.mixed_precision.set_global_policy("mixed_float16")
70
71
"""
72
## Fine-tune the model using siamese networks
73
74
[Siamese network](https://en.wikipedia.org/wiki/Siamese_neural_network) is a neural
75
network architecture that contains two or more subnetworks. The subnetworks share the
76
same weights. It is used to generate feature vectors for each input and then compare them
77
for similarity.
78
79
For our example, the subnetwork will be a RoBERTa model that has a pooling layer on top
80
of it to produce the embeddings of the input sentences. These embeddings will then be
81
compared to each other to learn to produce semantically meaningful embeddings.
82
83
The pooling strategies used are mean, max, and CLS pooling. Mean pooling produces the
84
best results. We will use it in our examples.
85
"""
86
87
"""
88
### Fine-tune using the regression objective function
89
90
For building the siamese network with the regression objective function, the siamese
91
network is asked to predict the cosine similarity between the embeddings of the two input
92
sentences.
93
94
Cosine similarity indicates the angle between the sentence embeddings. If the cosine
95
similarity is high, that means there is a small angle between the embeddings; hence, they
96
are semantically similar.
97
"""
98
99
"""
100
#### Load the dataset
101
102
We will use the STSB dataset to fine-tune the model for the regression objective. STSB
103
consists of a collection of sentence pairs that are labelled in the range [0, 5]. 0
104
indicates the least semantic similarity between the two sentences, and 5 indicates the
105
most semantic similarity between the two sentences.
106
107
The range of the cosine similarity is [-1, 1] and it's the output of the siamese network,
108
but the range of the labels in the dataset is [0, 5]. We need to unify the range between
109
the cosine similarity and the dataset labels, so while preparing the dataset, we will
110
divide the labels by 2.5 and subtract 1.
111
"""
112
113
TRAIN_BATCH_SIZE = 6
114
VALIDATION_BATCH_SIZE = 8
115
116
TRAIN_NUM_BATCHES = 300
117
VALIDATION_NUM_BATCHES = 40
118
119
AUTOTUNE = tf.data.experimental.AUTOTUNE
120
121
122
def change_range(x):
123
return (x / 2.5) - 1
124
125
126
def prepare_dataset(dataset, num_batches, batch_size):
127
dataset = dataset.map(
128
lambda z: (
129
[z["sentence1"], z["sentence2"]],
130
[tf.cast(change_range(z["label"]), tf.float32)],
131
),
132
num_parallel_calls=AUTOTUNE,
133
)
134
dataset = dataset.batch(batch_size)
135
dataset = dataset.take(num_batches)
136
dataset = dataset.prefetch(AUTOTUNE)
137
return dataset
138
139
140
stsb_ds = tfds.load(
141
"glue/stsb",
142
)
143
stsb_train, stsb_valid = stsb_ds["train"], stsb_ds["validation"]
144
145
stsb_train = prepare_dataset(stsb_train, TRAIN_NUM_BATCHES, TRAIN_BATCH_SIZE)
146
stsb_valid = prepare_dataset(stsb_valid, VALIDATION_NUM_BATCHES, VALIDATION_BATCH_SIZE)
147
148
"""
149
Let's see examples from the dataset of two sentenses and their similarity.
150
"""
151
152
for x, y in stsb_train:
153
for i, example in enumerate(x):
154
print(f"sentence 1 : {example[0]} ")
155
print(f"sentence 2 : {example[1]} ")
156
print(f"similarity : {y[i]} \n")
157
break
158
159
"""
160
#### Build the encoder model.
161
162
Now, we'll build the encoder model that will produce the sentence embeddings. It consists
163
of:
164
165
- A preprocessor layer to tokenize and generate padding masks for the sentences.
166
- A backbone model that will generate the contextual representation of each token in the
167
sentence.
168
- A mean pooling layer to produce the embeddings. We will use `keras.layers.GlobalAveragePooling1D`
169
to apply the mean pooling to the backbone outputs. We will pass the padding mask to the
170
layer to exclude padded tokens from being averaged.
171
- A normalization layer to normalize the embeddings as we are using the cosine similarity.
172
"""
173
174
preprocessor = keras_hub.models.RobertaPreprocessor.from_preset("roberta_base_en")
175
backbone = keras_hub.models.RobertaBackbone.from_preset("roberta_base_en")
176
inputs = keras.Input(shape=(1,), dtype="string", name="sentence")
177
x = preprocessor(inputs)
178
h = backbone(x)
179
embedding = keras.layers.GlobalAveragePooling1D(name="pooling_layer")(
180
h, x["padding_mask"]
181
)
182
n_embedding = keras.layers.UnitNormalization(axis=1)(embedding)
183
roberta_normal_encoder = keras.Model(inputs=inputs, outputs=n_embedding)
184
185
roberta_normal_encoder.summary()
186
187
"""
188
#### Build the Siamese network with the regression objective function.
189
190
It's described above that the Siamese network has two or more subnetworks, and for this
191
Siamese model, we need two encoders. But we don't have two encoders; we have only one
192
encoder, but we will pass the two sentences through it. That way, we can have two paths
193
to get the embeddings and also shared weights between the two paths.
194
195
After passing the two sentences to the model and getting the normalized embeddings, we
196
will multiply the two normalized embeddings to get the cosine similarity between the two
197
sentences.
198
"""
199
200
201
class RegressionSiamese(keras.Model):
202
def __init__(self, encoder, **kwargs):
203
inputs = keras.Input(shape=(2,), dtype="string", name="sentences")
204
sen1, sen2 = keras.ops.split(inputs, 2, axis=1)
205
u = encoder(sen1)
206
v = encoder(sen2)
207
cosine_similarity_scores = keras.ops.matmul(u, keras.ops.transpose(v))
208
209
super().__init__(
210
inputs=inputs,
211
outputs=cosine_similarity_scores,
212
**kwargs,
213
)
214
215
self.encoder = encoder
216
217
def get_encoder(self):
218
return self.encoder
219
220
221
"""
222
#### Fit the model
223
224
Let's try this example before training and compare it to the output after training.
225
"""
226
227
sentences = [
228
"Today is a very sunny day.",
229
"I am hungry, I will get my meal.",
230
"The dog is eating his food.",
231
]
232
query = ["The dog is enjoying his meal."]
233
234
encoder = roberta_normal_encoder
235
236
sentence_embeddings = encoder(tf.constant(sentences))
237
query_embedding = encoder(tf.constant(query))
238
239
cosine_similarity_scores = tf.matmul(query_embedding, tf.transpose(sentence_embeddings))
240
for i, sim in enumerate(cosine_similarity_scores[0]):
241
print(f"cosine similarity score between sentence {i+1} and the query = {sim} ")
242
243
"""
244
For the training we will use `MeanSquaredError()` as loss function, and `Adam()`
245
optimizer with learning rate = 2e-5.
246
"""
247
248
roberta_regression_siamese = RegressionSiamese(roberta_normal_encoder)
249
250
roberta_regression_siamese.compile(
251
loss=keras.losses.MeanSquaredError(),
252
optimizer=keras.optimizers.Adam(2e-5),
253
jit_compile=False,
254
)
255
256
roberta_regression_siamese.fit(stsb_train, validation_data=stsb_valid, epochs=1)
257
258
"""
259
Let's try the model after training, we will notice a huge difference in the output. That
260
means that the model after fine-tuning is capable of producing semantically meaningful
261
embeddings. where the semantically similar sentences have a small angle between them. and
262
semantically dissimilar sentences have a large angle between them.
263
"""
264
265
sentences = [
266
"Today is a very sunny day.",
267
"I am hungry, I will get my meal.",
268
"The dog is eating his food.",
269
]
270
query = ["The dog is enjoying his food."]
271
272
encoder = roberta_regression_siamese.get_encoder()
273
274
sentence_embeddings = encoder(tf.constant(sentences))
275
query_embedding = encoder(tf.constant(query))
276
277
cosine_simalarities = tf.matmul(query_embedding, tf.transpose(sentence_embeddings))
278
for i, sim in enumerate(cosine_simalarities[0]):
279
print(f"cosine similarity between sentence {i+1} and the query = {sim} ")
280
281
"""
282
### Fine-tune Using the triplet Objective Function
283
284
For the Siamese network with the triplet objective function, three sentences are passed
285
to the Siamese network *anchor*, *positive*, and *negative* sentences. *anchor* and
286
*positive* sentences are semantically similar, and *anchor* and *negative* sentences are
287
semantically dissimilar. The objective is to minimize the distance between the *anchor*
288
sentence and the *positive* sentence, and to maximize the distance between the *anchor*
289
sentence and the *negative* sentence.
290
"""
291
292
"""
293
#### Load the dataset
294
295
We will use the Wikipedia-sections-triplets dataset for fine-tuning. This data set
296
consists of sentences derived from the Wikipedia website. It has a collection of 3
297
sentences *anchor*, *positive*, *negative*. *anchor* and *positive* are derived from the
298
same section. *anchor* and *negative* are derived from different sections.
299
300
This dataset has 1.8 million training triplets and 220,000 test triplets. In this
301
example, we will only use 1200 triplets for training and 300 for testing.
302
"""
303
304
"""shell
305
wget https://sbert.net/datasets/wikipedia-sections-triplets.zip -q
306
unzip wikipedia-sections-triplets.zip -d wikipedia-sections-triplets
307
"""
308
309
NUM_TRAIN_BATCHES = 200
310
NUM_TEST_BATCHES = 75
311
AUTOTUNE = tf.data.experimental.AUTOTUNE
312
313
314
def prepare_wiki_data(dataset, num_batches):
315
dataset = dataset.map(
316
lambda z: ((z["Sentence1"], z["Sentence2"], z["Sentence3"]), 0)
317
)
318
dataset = dataset.batch(6)
319
dataset = dataset.take(num_batches)
320
dataset = dataset.prefetch(AUTOTUNE)
321
return dataset
322
323
324
wiki_train = tf.data.experimental.make_csv_dataset(
325
"wikipedia-sections-triplets/train.csv",
326
batch_size=1,
327
num_epochs=1,
328
)
329
wiki_test = tf.data.experimental.make_csv_dataset(
330
"wikipedia-sections-triplets/test.csv",
331
batch_size=1,
332
num_epochs=1,
333
)
334
335
wiki_train = prepare_wiki_data(wiki_train, NUM_TRAIN_BATCHES)
336
wiki_test = prepare_wiki_data(wiki_test, NUM_TEST_BATCHES)
337
338
"""
339
#### Build the encoder model
340
341
For this encoder model, we will use RoBERTa with mean pooling and we will not normalize
342
the output embeddings. The encoder model consists of:
343
344
- A preprocessor layer to tokenize and generate padding masks for the sentences.
345
- A backbone model that will generate the contextual representation of each token in the
346
sentence.
347
- A mean pooling layer to produce the embeddings.
348
"""
349
350
preprocessor = keras_hub.models.RobertaPreprocessor.from_preset("roberta_base_en")
351
backbone = keras_hub.models.RobertaBackbone.from_preset("roberta_base_en")
352
input = keras.Input(shape=(1,), dtype="string", name="sentence")
353
354
x = preprocessor(input)
355
h = backbone(x)
356
embedding = keras.layers.GlobalAveragePooling1D(name="pooling_layer")(
357
h, x["padding_mask"]
358
)
359
360
roberta_encoder = keras.Model(inputs=input, outputs=embedding)
361
362
363
roberta_encoder.summary()
364
365
"""
366
#### Build the Siamese network with the triplet objective function
367
368
For the Siamese network with the triplet objective function, we will build the model with
369
an encoder, and we will pass the three sentences through that encoder. We will get an
370
embedding for each sentence, and we will calculate the `positive_dist` and
371
`negative_dist` that will be passed to the loss function described below.
372
"""
373
374
375
class TripletSiamese(keras.Model):
376
def __init__(self, encoder, **kwargs):
377
anchor = keras.Input(shape=(1,), dtype="string")
378
positive = keras.Input(shape=(1,), dtype="string")
379
negative = keras.Input(shape=(1,), dtype="string")
380
381
ea = encoder(anchor)
382
ep = encoder(positive)
383
en = encoder(negative)
384
385
positive_dist = keras.ops.sum(keras.ops.square(ea - ep), axis=1)
386
negative_dist = keras.ops.sum(keras.ops.square(ea - en), axis=1)
387
388
positive_dist = keras.ops.sqrt(positive_dist)
389
negative_dist = keras.ops.sqrt(negative_dist)
390
391
output = keras.ops.stack([positive_dist, negative_dist], axis=0)
392
393
super().__init__(inputs=[anchor, positive, negative], outputs=output, **kwargs)
394
395
self.encoder = encoder
396
397
def get_encoder(self):
398
return self.encoder
399
400
401
"""
402
We will use a custom loss function for the triplet objective. The loss function will
403
receive the distance between the *anchor* and the *positive* embeddings `positive_dist`,
404
and the distance between the *anchor* and the *negative* embeddings `negative_dist`,
405
where they are stacked together in `y_pred`.
406
407
We will use `positive_dist` and `negative_dist` to compute the loss such that
408
`negative_dist` is larger than `positive_dist` at least by a specific margin.
409
Mathematically, we will minimize this loss function: `max( positive_dist - negative_dist
410
+ margin, 0)`.
411
412
There is no `y_true` used in this loss function. Note that we set the labels in the
413
dataset to zero, but they will not be used.
414
"""
415
416
417
class TripletLoss(keras.losses.Loss):
418
def __init__(self, margin=1, **kwargs):
419
super().__init__(**kwargs)
420
self.margin = margin
421
422
def call(self, y_true, y_pred):
423
positive_dist, negative_dist = tf.unstack(y_pred, axis=0)
424
425
losses = keras.ops.relu(positive_dist - negative_dist + self.margin)
426
return keras.ops.mean(losses, axis=0)
427
428
429
"""
430
#### Fit the model
431
432
For the training, we will use the custom `TripletLoss()` loss function, and `Adam()`
433
optimizer with a learning rate = 2e-5.
434
"""
435
436
roberta_triplet_siamese = TripletSiamese(roberta_encoder)
437
438
roberta_triplet_siamese.compile(
439
loss=TripletLoss(),
440
optimizer=keras.optimizers.Adam(2e-5),
441
jit_compile=False,
442
)
443
444
roberta_triplet_siamese.fit(wiki_train, validation_data=wiki_test, epochs=1)
445
446
"""
447
Let's try this model in a clustering example. Here are 6 questions. first 3 questions
448
about learning English, and the last 3 questions about working online. Let's see if the
449
embeddings produced by our encoder will cluster them correctly.
450
"""
451
452
questions = [
453
"What should I do to improve my English writting?",
454
"How to be good at speaking English?",
455
"How can I improve my English?",
456
"How to earn money online?",
457
"How do I earn money online?",
458
"How to work and earn money through internet?",
459
]
460
461
encoder = roberta_triplet_siamese.get_encoder()
462
embeddings = encoder(tf.constant(questions))
463
kmeans = cluster.KMeans(n_clusters=2, random_state=0, n_init="auto").fit(embeddings)
464
465
for i, label in enumerate(kmeans.labels_):
466
print(f"sentence ({questions[i]}) belongs to cluster {label}")
467
468