Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/nlp/semantic_similarity_with_bert.py
3507 views
1
"""
2
Title: Semantic Similarity with BERT
3
Author: [Mohamad Merchant](https://twitter.com/mohmadmerchant1)
4
Date created: 2020/08/15
5
Last modified: 2020/08/29
6
Description: Natural Language Inference by fine-tuning BERT model on SNLI Corpus.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
Semantic Similarity is the task of determining how similar
14
two sentences are, in terms of what they mean.
15
This example demonstrates the use of SNLI (Stanford Natural Language Inference) Corpus
16
to predict sentence semantic similarity with Transformers.
17
We will fine-tune a BERT model that takes two sentences as inputs
18
and that outputs a similarity score for these two sentences.
19
20
### References
21
22
* [BERT](https://arxiv.org/pdf/1810.04805.pdf)
23
* [SNLI](https://nlp.stanford.edu/projects/snli/)
24
"""
25
26
"""
27
## Setup
28
29
Note: install HuggingFace `transformers` via `pip install transformers` (version >= 2.11.0).
30
"""
31
import numpy as np
32
import pandas as pd
33
import tensorflow as tf
34
import transformers
35
36
"""
37
## Configuration
38
"""
39
40
max_length = 128 # Maximum length of input sentence to the model.
41
batch_size = 32
42
epochs = 2
43
44
# Labels in our dataset.
45
labels = ["contradiction", "entailment", "neutral"]
46
47
"""
48
## Load the Data
49
"""
50
51
"""shell
52
curl -LO https://raw.githubusercontent.com/MohamadMerchant/SNLI/master/data.tar.gz
53
tar -xvzf data.tar.gz
54
"""
55
# There are more than 550k samples in total; we will use 100k for this example.
56
train_df = pd.read_csv("SNLI_Corpus/snli_1.0_train.csv", nrows=100000)
57
valid_df = pd.read_csv("SNLI_Corpus/snli_1.0_dev.csv")
58
test_df = pd.read_csv("SNLI_Corpus/snli_1.0_test.csv")
59
60
# Shape of the data
61
print(f"Total train samples : {train_df.shape[0]}")
62
print(f"Total validation samples: {valid_df.shape[0]}")
63
print(f"Total test samples: {valid_df.shape[0]}")
64
65
"""
66
Dataset Overview:
67
68
- sentence1: The premise caption that was supplied to the author of the pair.
69
- sentence2: The hypothesis caption that was written by the author of the pair.
70
- similarity: This is the label chosen by the majority of annotators.
71
Where no majority exists, the label "-" is used (we will skip such samples here).
72
73
Here are the "similarity" label values in our dataset:
74
75
- Contradiction: The sentences share no similarity.
76
- Entailment: The sentences have similar meaning.
77
- Neutral: The sentences are neutral.
78
"""
79
80
"""
81
Let's look at one sample from the dataset:
82
"""
83
print(f"Sentence1: {train_df.loc[1, 'sentence1']}")
84
print(f"Sentence2: {train_df.loc[1, 'sentence2']}")
85
print(f"Similarity: {train_df.loc[1, 'similarity']}")
86
87
"""
88
## Preprocessing
89
"""
90
91
# We have some NaN entries in our train data, we will simply drop them.
92
print("Number of missing values")
93
print(train_df.isnull().sum())
94
train_df.dropna(axis=0, inplace=True)
95
96
"""
97
Distribution of our training targets.
98
"""
99
print("Train Target Distribution")
100
print(train_df.similarity.value_counts())
101
102
"""
103
Distribution of our validation targets.
104
"""
105
print("Validation Target Distribution")
106
print(valid_df.similarity.value_counts())
107
108
"""
109
The value "-" appears as part of our training and validation targets.
110
We will skip these samples.
111
"""
112
train_df = (
113
train_df[train_df.similarity != "-"]
114
.sample(frac=1.0, random_state=42)
115
.reset_index(drop=True)
116
)
117
valid_df = (
118
valid_df[valid_df.similarity != "-"]
119
.sample(frac=1.0, random_state=42)
120
.reset_index(drop=True)
121
)
122
123
"""
124
One-hot encode training, validation, and test labels.
125
"""
126
train_df["label"] = train_df["similarity"].apply(
127
lambda x: 0 if x == "contradiction" else 1 if x == "entailment" else 2
128
)
129
y_train = tf.keras.utils.to_categorical(train_df.label, num_classes=3)
130
131
valid_df["label"] = valid_df["similarity"].apply(
132
lambda x: 0 if x == "contradiction" else 1 if x == "entailment" else 2
133
)
134
y_val = tf.keras.utils.to_categorical(valid_df.label, num_classes=3)
135
136
test_df["label"] = test_df["similarity"].apply(
137
lambda x: 0 if x == "contradiction" else 1 if x == "entailment" else 2
138
)
139
y_test = tf.keras.utils.to_categorical(test_df.label, num_classes=3)
140
141
"""
142
## Create a custom data generator
143
"""
144
145
146
class BertSemanticDataGenerator(tf.keras.utils.Sequence):
147
"""Generates batches of data.
148
149
Args:
150
sentence_pairs: Array of premise and hypothesis input sentences.
151
labels: Array of labels.
152
batch_size: Integer batch size.
153
shuffle: boolean, whether to shuffle the data.
154
include_targets: boolean, whether to include the labels.
155
156
Returns:
157
Tuples `([input_ids, attention_mask, `token_type_ids], labels)`
158
(or just `[input_ids, attention_mask, `token_type_ids]`
159
if `include_targets=False`)
160
"""
161
162
def __init__(
163
self,
164
sentence_pairs,
165
labels,
166
batch_size=batch_size,
167
shuffle=True,
168
include_targets=True,
169
):
170
self.sentence_pairs = sentence_pairs
171
self.labels = labels
172
self.shuffle = shuffle
173
self.batch_size = batch_size
174
self.include_targets = include_targets
175
# Load our BERT Tokenizer to encode the text.
176
# We will use base-base-uncased pretrained model.
177
self.tokenizer = transformers.BertTokenizer.from_pretrained(
178
"bert-base-uncased", do_lower_case=True
179
)
180
self.indexes = np.arange(len(self.sentence_pairs))
181
self.on_epoch_end()
182
183
def __len__(self):
184
# Denotes the number of batches per epoch.
185
return len(self.sentence_pairs) // self.batch_size
186
187
def __getitem__(self, idx):
188
# Retrieves the batch of index.
189
indexes = self.indexes[idx * self.batch_size : (idx + 1) * self.batch_size]
190
sentence_pairs = self.sentence_pairs[indexes]
191
192
# With BERT tokenizer's batch_encode_plus batch of both the sentences are
193
# encoded together and separated by [SEP] token.
194
encoded = self.tokenizer.batch_encode_plus(
195
sentence_pairs.tolist(),
196
add_special_tokens=True,
197
max_length=max_length,
198
return_attention_mask=True,
199
return_token_type_ids=True,
200
pad_to_max_length=True,
201
return_tensors="tf",
202
)
203
204
# Convert batch of encoded features to numpy array.
205
input_ids = np.array(encoded["input_ids"], dtype="int32")
206
attention_masks = np.array(encoded["attention_mask"], dtype="int32")
207
token_type_ids = np.array(encoded["token_type_ids"], dtype="int32")
208
209
# Set to true if data generator is used for training/validation.
210
if self.include_targets:
211
labels = np.array(self.labels[indexes], dtype="int32")
212
return [input_ids, attention_masks, token_type_ids], labels
213
else:
214
return [input_ids, attention_masks, token_type_ids]
215
216
def on_epoch_end(self):
217
# Shuffle indexes after each epoch if shuffle is set to True.
218
if self.shuffle:
219
np.random.RandomState(42).shuffle(self.indexes)
220
221
222
"""
223
## Build the model
224
"""
225
# Create the model under a distribution strategy scope.
226
strategy = tf.distribute.MirroredStrategy()
227
228
with strategy.scope():
229
# Encoded token ids from BERT tokenizer.
230
input_ids = tf.keras.layers.Input(
231
shape=(max_length,), dtype=tf.int32, name="input_ids"
232
)
233
# Attention masks indicates to the model which tokens should be attended to.
234
attention_masks = tf.keras.layers.Input(
235
shape=(max_length,), dtype=tf.int32, name="attention_masks"
236
)
237
# Token type ids are binary masks identifying different sequences in the model.
238
token_type_ids = tf.keras.layers.Input(
239
shape=(max_length,), dtype=tf.int32, name="token_type_ids"
240
)
241
# Loading pretrained BERT model.
242
bert_model = transformers.TFBertModel.from_pretrained("bert-base-uncased")
243
# Freeze the BERT model to reuse the pretrained features without modifying them.
244
bert_model.trainable = False
245
246
bert_output = bert_model.bert(
247
input_ids, attention_mask=attention_masks, token_type_ids=token_type_ids
248
)
249
sequence_output = bert_output.last_hidden_state
250
pooled_output = bert_output.pooler_output
251
252
# Add trainable layers on top of frozen layers to adapt the pretrained features on the new data.
253
bi_lstm = tf.keras.layers.Bidirectional(
254
tf.keras.layers.LSTM(64, return_sequences=True)
255
)(sequence_output)
256
# Applying hybrid pooling approach to bi_lstm sequence output.
257
avg_pool = tf.keras.layers.GlobalAveragePooling1D()(bi_lstm)
258
max_pool = tf.keras.layers.GlobalMaxPooling1D()(bi_lstm)
259
concat = tf.keras.layers.concatenate([avg_pool, max_pool])
260
dropout = tf.keras.layers.Dropout(0.3)(concat)
261
output = tf.keras.layers.Dense(3, activation="softmax")(dropout)
262
model = tf.keras.models.Model(
263
inputs=[input_ids, attention_masks, token_type_ids], outputs=output
264
)
265
266
model.compile(
267
optimizer=tf.keras.optimizers.Adam(),
268
loss="categorical_crossentropy",
269
metrics=["acc"],
270
)
271
272
273
print(f"Strategy: {strategy}")
274
model.summary()
275
276
"""
277
Create train and validation data generators
278
"""
279
train_data = BertSemanticDataGenerator(
280
train_df[["sentence1", "sentence2"]].values.astype("str"),
281
y_train,
282
batch_size=batch_size,
283
shuffle=True,
284
)
285
valid_data = BertSemanticDataGenerator(
286
valid_df[["sentence1", "sentence2"]].values.astype("str"),
287
y_val,
288
batch_size=batch_size,
289
shuffle=False,
290
)
291
292
"""
293
## Train the Model
294
295
Training is done only for the top layers to perform "feature extraction",
296
which will allow the model to use the representations of the pretrained model.
297
"""
298
history = model.fit(
299
train_data,
300
validation_data=valid_data,
301
epochs=epochs,
302
use_multiprocessing=True,
303
workers=-1,
304
)
305
306
"""
307
## Fine-tuning
308
309
This step must only be performed after the feature extraction model has
310
been trained to convergence on the new data.
311
312
This is an optional last step where `bert_model` is unfreezed and retrained
313
with a very low learning rate. This can deliver meaningful improvement by
314
incrementally adapting the pretrained features to the new data.
315
"""
316
317
# Unfreeze the bert_model.
318
bert_model.trainable = True
319
# Recompile the model to make the change effective.
320
model.compile(
321
optimizer=tf.keras.optimizers.Adam(1e-5),
322
loss="categorical_crossentropy",
323
metrics=["accuracy"],
324
)
325
model.summary()
326
327
"""
328
## Train the entire model end-to-end
329
"""
330
history = model.fit(
331
train_data,
332
validation_data=valid_data,
333
epochs=epochs,
334
use_multiprocessing=True,
335
workers=-1,
336
)
337
338
"""
339
## Evaluate model on the test set
340
"""
341
test_data = BertSemanticDataGenerator(
342
test_df[["sentence1", "sentence2"]].values.astype("str"),
343
y_test,
344
batch_size=batch_size,
345
shuffle=False,
346
)
347
model.evaluate(test_data, verbose=1)
348
349
"""
350
## Inference on custom sentences
351
"""
352
353
354
def check_similarity(sentence1, sentence2):
355
sentence_pairs = np.array([[str(sentence1), str(sentence2)]])
356
test_data = BertSemanticDataGenerator(
357
sentence_pairs,
358
labels=None,
359
batch_size=1,
360
shuffle=False,
361
include_targets=False,
362
)
363
364
proba = model.predict(test_data[0])[0]
365
idx = np.argmax(proba)
366
proba = f"{proba[idx]: .2f}%"
367
pred = labels[idx]
368
return pred, proba
369
370
371
"""
372
Check results on some example sentence pairs.
373
"""
374
sentence1 = "Two women are observing something together."
375
sentence2 = "Two women are standing with their eyes closed."
376
check_similarity(sentence1, sentence2)
377
"""
378
Check results on some example sentence pairs.
379
"""
380
sentence1 = "A smiling costumed woman is holding an umbrella"
381
sentence2 = "A happy woman in a fairy costume holds an umbrella"
382
check_similarity(sentence1, sentence2)
383
384
"""
385
Check results on some example sentence pairs
386
"""
387
sentence1 = "A soccer game with multiple males playing"
388
sentence2 = "Some men are playing a sport"
389
check_similarity(sentence1, sentence2)
390
391
"""
392
Example available on HuggingFace
393
394
| Trained Model | Demo |
395
| :--: | :--: |
396
| [![Generic badge](https://img.shields.io/badge/%F0%9F%A4%97%20Model-semantic%20similarity%20with%20bert-black.svg)](https://huggingface.co/keras-io/bert-semantic-similarity) | [![Generic badge](https://img.shields.io/badge/%F0%9F%A4%97%20Spaces-semantic%20similarity%20with%20bert-black.svg)](https://huggingface.co/spaces/keras-io/bert-semantic-similarity) |
397
"""
398
399