Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/nlp/semantic_similarity_with_keras_hub.py
3507 views
1
"""
2
Title: Semantic Similarity with KerasHub
3
Author: [Anshuman Mishra](https://github.com/shivance/)
4
Date created: 2023/02/25
5
Last modified: 2023/02/25
6
Description: Use pretrained models from KerasHub for the Semantic Similarity Task.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
Semantic similarity refers to the task of determining the degree of similarity between two
14
sentences in terms of their meaning. We already saw in [this](https://keras.io/examples/nlp/semantic_similarity_with_bert/)
15
example how to use SNLI (Stanford Natural Language Inference) corpus to predict sentence
16
semantic similarity with the HuggingFace Transformers library. In this tutorial we will
17
learn how to use [KerasHub](https://keras.io/keras_hub/), an extension of the core Keras API,
18
for the same task. Furthermore, we will discover how KerasHub effectively reduces boilerplate
19
code and simplifies the process of building and utilizing models. For more information on KerasHub,
20
please refer to [KerasHub's official documentation](https://keras.io/keras_hub/).
21
22
This guide is broken down into the following parts:
23
24
1. *Setup*, task definition, and establishing a baseline.
25
2. *Establishing baseline* with BERT.
26
3. *Saving and Reloading* the model.
27
4. *Performing inference* with the model.
28
5 *Improving accuracy* with RoBERTa
29
30
## Setup
31
32
The following guide uses [Keras Core](https://keras.io/keras_core/) to work in
33
any of `tensorflow`, `jax` or `torch`. Support for Keras Core is baked into
34
KerasHub, simply change the `KERAS_BACKEND` environment variable below to change
35
the backend you would like to use. We select the `jax` backend below, which will
36
give us a particularly fast train step below.
37
"""
38
39
"""shell
40
pip install -q --upgrade keras-hub
41
pip install -q --upgrade keras # Upgrade to Keras 3.
42
"""
43
44
import numpy as np
45
import tensorflow as tf
46
import keras
47
import keras_hub
48
import tensorflow_datasets as tfds
49
50
"""
51
To load the SNLI dataset, we use the tensorflow-datasets library, which
52
contains over 550,000 samples in total. However, to ensure that this example runs
53
quickly, we use only 20% of the training samples.
54
55
## Overview of SNLI Dataset
56
57
Every sample in the dataset contains three components: `hypothesis`, `premise`,
58
and `label`. epresents the original caption provided to the author of the pair,
59
while the hypothesis refers to the hypothesis caption created by the author of
60
the pair. The label is assigned by annotators to indicate the similarity between
61
the two sentences.
62
63
The dataset contains three possible similarity label values: Contradiction, Entailment,
64
and Neutral. Contradiction represents completely dissimilar sentences, while Entailment
65
denotes similar meaning sentences. Lastly, Neutral refers to sentences where no clear
66
similarity or dissimilarity can be established between them.
67
"""
68
69
snli_train = tfds.load("snli", split="train[:20%]")
70
snli_val = tfds.load("snli", split="validation")
71
snli_test = tfds.load("snli", split="test")
72
73
# Here's an example of how our training samples look like, where we randomly select
74
# four samples:
75
sample = snli_test.batch(4).take(1).get_single_element()
76
sample
77
78
"""
79
### Preprocessing
80
81
In our dataset, we have identified that some samples have missing or incorrectly labeled
82
data, which is denoted by a value of -1. To ensure the accuracy and reliability of our model,
83
we simply filter out these samples from our dataset.
84
"""
85
86
87
def filter_labels(sample):
88
return sample["label"] >= 0
89
90
91
"""
92
Here's a utility function that splits the example into an `(x, y)` tuple that is suitable
93
for `model.fit()`. By default, `keras_hub.models.BertClassifier` will tokenize and pack
94
together raw strings using a `"[SEP]"` token during training. Therefore, this label
95
splitting is all the data preparation that we need to perform.
96
"""
97
98
99
def split_labels(sample):
100
x = (sample["hypothesis"], sample["premise"])
101
y = sample["label"]
102
return x, y
103
104
105
train_ds = (
106
snli_train.filter(filter_labels)
107
.map(split_labels, num_parallel_calls=tf.data.AUTOTUNE)
108
.batch(16)
109
)
110
val_ds = (
111
snli_val.filter(filter_labels)
112
.map(split_labels, num_parallel_calls=tf.data.AUTOTUNE)
113
.batch(16)
114
)
115
test_ds = (
116
snli_test.filter(filter_labels)
117
.map(split_labels, num_parallel_calls=tf.data.AUTOTUNE)
118
.batch(16)
119
)
120
121
122
"""
123
## Establishing baseline with BERT.
124
125
We use the BERT model from KerasHub to establish a baseline for our semantic similarity
126
task. The `keras_hub.models.BertClassifier` class attaches a classification head to the BERT
127
Backbone, mapping the backbone outputs to a logit output suitable for a classification task.
128
This significantly reduces the need for custom code.
129
130
KerasHub models have built-in tokenization capabilities that handle tokenization by default
131
based on the selected model. However, users can also use custom preprocessing techniques
132
as per their specific needs. If we pass a tuple as input, the model will tokenize all the
133
strings and concatenate them with a `"[SEP]"` separator.
134
135
We use this model with pretrained weights, and we can use the `from_preset()` method
136
to use our own preprocessor. For the SNLI dataset, we set `num_classes` to 3.
137
"""
138
139
bert_classifier = keras_hub.models.BertClassifier.from_preset(
140
"bert_tiny_en_uncased", num_classes=3
141
)
142
143
"""
144
Please note that the BERT Tiny model has only 4,386,307 trainable parameters.
145
146
KerasHub task models come with compilation defaults. We can now train the model we just
147
instantiated by calling the `fit()` method.
148
"""
149
150
bert_classifier.fit(train_ds, validation_data=val_ds, epochs=1)
151
152
"""
153
Our BERT classifier achieved an accuracy of around 76% on the validation split. Now,
154
let's evaluate its performance on the test split.
155
156
### Evaluate the performance of the trained model on test data.
157
"""
158
159
bert_classifier.evaluate(test_ds)
160
161
"""
162
Our baseline BERT model achieved a similar accuracy of around 76% on the test split.
163
Now, let's try to improve its performance by recompiling the model with a slightly
164
higher learning rate.
165
"""
166
167
bert_classifier = keras_hub.models.BertClassifier.from_preset(
168
"bert_tiny_en_uncased", num_classes=3
169
)
170
bert_classifier.compile(
171
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
172
optimizer=keras.optimizers.Adam(5e-5),
173
metrics=["accuracy"],
174
)
175
176
bert_classifier.fit(train_ds, validation_data=val_ds, epochs=1)
177
bert_classifier.evaluate(test_ds)
178
179
"""
180
Just tweaking the learning rate alone was not enough to boost performance, which
181
stayed right around 76%. Let's try again, but this time with
182
`keras.optimizers.AdamW`, and a learning rate schedule.
183
"""
184
185
186
class TriangularSchedule(keras.optimizers.schedules.LearningRateSchedule):
187
"""Linear ramp up for `warmup` steps, then linear decay to zero at `total` steps."""
188
189
def __init__(self, rate, warmup, total):
190
self.rate = rate
191
self.warmup = warmup
192
self.total = total
193
194
def get_config(self):
195
config = {"rate": self.rate, "warmup": self.warmup, "total": self.total}
196
return config
197
198
def __call__(self, step):
199
step = keras.ops.cast(step, dtype="float32")
200
rate = keras.ops.cast(self.rate, dtype="float32")
201
warmup = keras.ops.cast(self.warmup, dtype="float32")
202
total = keras.ops.cast(self.total, dtype="float32")
203
204
warmup_rate = rate * step / self.warmup
205
cooldown_rate = rate * (total - step) / (total - warmup)
206
triangular_rate = keras.ops.minimum(warmup_rate, cooldown_rate)
207
return keras.ops.maximum(triangular_rate, 0.0)
208
209
210
bert_classifier = keras_hub.models.BertClassifier.from_preset(
211
"bert_tiny_en_uncased", num_classes=3
212
)
213
214
# Get the total count of training batches.
215
# This requires walking the dataset to filter all -1 labels.
216
epochs = 3
217
total_steps = sum(1 for _ in train_ds.as_numpy_iterator()) * epochs
218
warmup_steps = int(total_steps * 0.2)
219
220
bert_classifier.compile(
221
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
222
optimizer=keras.optimizers.AdamW(
223
TriangularSchedule(1e-4, warmup_steps, total_steps)
224
),
225
metrics=["accuracy"],
226
)
227
228
bert_classifier.fit(train_ds, validation_data=val_ds, epochs=epochs)
229
230
"""
231
Success! With the learning rate scheduler and the `AdamW` optimizer, our validation
232
accuracy improved to around 79%.
233
234
Now, let's evaluate our final model on the test set and see how it performs.
235
"""
236
237
bert_classifier.evaluate(test_ds)
238
239
"""
240
Our Tiny BERT model achieved an accuracy of approximately 79% on the test set
241
with the use of a learning rate scheduler. This is a significant improvement over
242
our previous results. Fine-tuning a pretrained BERT
243
model can be a powerful tool in natural language processing tasks, and even a
244
small model like Tiny BERT can achieve impressive results.
245
246
Let's save our model for now
247
and move on to learning how to perform inference with it.
248
249
## Save and Reload the model
250
"""
251
bert_classifier.save("bert_classifier.keras")
252
restored_model = keras.models.load_model("bert_classifier.keras")
253
restored_model.evaluate(test_ds)
254
255
"""
256
## Performing inference with the model.
257
258
Let's see how to perform inference with KerasHub models
259
"""
260
261
# Convert to Hypothesis-Premise pair, for forward pass through model
262
sample = (sample["hypothesis"], sample["premise"])
263
sample
264
265
"""
266
The default preprocessor in KerasHub models handles input tokenization automatically,
267
so we don't need to perform tokenization explicitly.
268
"""
269
predictions = bert_classifier.predict(sample)
270
271
272
def softmax(x):
273
return np.exp(x) / np.exp(x).sum(axis=0)
274
275
276
# Get the class predictions with maximum probabilities
277
predictions = softmax(predictions)
278
279
"""
280
## Improving accuracy with RoBERTa
281
282
Now that we have established a baseline, we can attempt to improve our results
283
by experimenting with different models. Thanks to KerasHub, fine-tuning a RoBERTa
284
checkpoint on the same dataset is easy with just a few lines of code.
285
"""
286
287
# Inittializing a RoBERTa from preset
288
roberta_classifier = keras_hub.models.RobertaClassifier.from_preset(
289
"roberta_base_en", num_classes=3
290
)
291
292
roberta_classifier.fit(train_ds, validation_data=val_ds, epochs=1)
293
294
roberta_classifier.evaluate(test_ds)
295
296
"""
297
The RoBERTa base model has significantly more trainable parameters than the BERT
298
Tiny model, with almost 30 times as many at 124,645,635 parameters. As a result, it took
299
approximately 1.5 hours to train on a P100 GPU. However, the performance
300
improvement was substantial, with accuracy increasing to 88% on both the validation
301
and test splits. With RoBERTa, we were able to fit a maximum batch size of 16 on
302
our P100 GPU.
303
304
Despite using a different model, the steps to perform inference with RoBERTa are
305
the same as with BERT!
306
"""
307
308
predictions = roberta_classifier.predict(sample)
309
print(tf.math.argmax(predictions, axis=1).numpy())
310
311
"""
312
We hope this tutorial has been helpful in demonstrating the ease and effectiveness
313
of using KerasHub and BERT for semantic similarity tasks.
314
315
Throughout this tutorial, we demonstrated how to use a pretrained BERT model to
316
establish a baseline and improve performance by training a larger RoBERTa model
317
using just a few lines of code.
318
319
The KerasHub toolbox provides a range of modular building blocks for preprocessing
320
text, including pretrained state-of-the-art models and low-level Transformer Encoder
321
layers. We believe that this makes experimenting with natural language solutions
322
more accessible and efficient.
323
"""
324
325