Path: blob/master/examples/nlp/semantic_similarity_with_keras_hub.py
3507 views
"""1Title: Semantic Similarity with KerasHub2Author: [Anshuman Mishra](https://github.com/shivance/)3Date created: 2023/02/254Last modified: 2023/02/255Description: Use pretrained models from KerasHub for the Semantic Similarity Task.6Accelerator: GPU7"""89"""10## Introduction1112Semantic similarity refers to the task of determining the degree of similarity between two13sentences in terms of their meaning. We already saw in [this](https://keras.io/examples/nlp/semantic_similarity_with_bert/)14example how to use SNLI (Stanford Natural Language Inference) corpus to predict sentence15semantic similarity with the HuggingFace Transformers library. In this tutorial we will16learn how to use [KerasHub](https://keras.io/keras_hub/), an extension of the core Keras API,17for the same task. Furthermore, we will discover how KerasHub effectively reduces boilerplate18code and simplifies the process of building and utilizing models. For more information on KerasHub,19please refer to [KerasHub's official documentation](https://keras.io/keras_hub/).2021This guide is broken down into the following parts:22231. *Setup*, task definition, and establishing a baseline.242. *Establishing baseline* with BERT.253. *Saving and Reloading* the model.264. *Performing inference* with the model.275 *Improving accuracy* with RoBERTa2829## Setup3031The following guide uses [Keras Core](https://keras.io/keras_core/) to work in32any of `tensorflow`, `jax` or `torch`. Support for Keras Core is baked into33KerasHub, simply change the `KERAS_BACKEND` environment variable below to change34the backend you would like to use. We select the `jax` backend below, which will35give us a particularly fast train step below.36"""3738"""shell39pip install -q --upgrade keras-hub40pip install -q --upgrade keras # Upgrade to Keras 3.41"""4243import numpy as np44import tensorflow as tf45import keras46import keras_hub47import tensorflow_datasets as tfds4849"""50To load the SNLI dataset, we use the tensorflow-datasets library, which51contains over 550,000 samples in total. However, to ensure that this example runs52quickly, we use only 20% of the training samples.5354## Overview of SNLI Dataset5556Every sample in the dataset contains three components: `hypothesis`, `premise`,57and `label`. epresents the original caption provided to the author of the pair,58while the hypothesis refers to the hypothesis caption created by the author of59the pair. The label is assigned by annotators to indicate the similarity between60the two sentences.6162The dataset contains three possible similarity label values: Contradiction, Entailment,63and Neutral. Contradiction represents completely dissimilar sentences, while Entailment64denotes similar meaning sentences. Lastly, Neutral refers to sentences where no clear65similarity or dissimilarity can be established between them.66"""6768snli_train = tfds.load("snli", split="train[:20%]")69snli_val = tfds.load("snli", split="validation")70snli_test = tfds.load("snli", split="test")7172# Here's an example of how our training samples look like, where we randomly select73# four samples:74sample = snli_test.batch(4).take(1).get_single_element()75sample7677"""78### Preprocessing7980In our dataset, we have identified that some samples have missing or incorrectly labeled81data, which is denoted by a value of -1. To ensure the accuracy and reliability of our model,82we simply filter out these samples from our dataset.83"""848586def filter_labels(sample):87return sample["label"] >= 0888990"""91Here's a utility function that splits the example into an `(x, y)` tuple that is suitable92for `model.fit()`. By default, `keras_hub.models.BertClassifier` will tokenize and pack93together raw strings using a `"[SEP]"` token during training. Therefore, this label94splitting is all the data preparation that we need to perform.95"""969798def split_labels(sample):99x = (sample["hypothesis"], sample["premise"])100y = sample["label"]101return x, y102103104train_ds = (105snli_train.filter(filter_labels)106.map(split_labels, num_parallel_calls=tf.data.AUTOTUNE)107.batch(16)108)109val_ds = (110snli_val.filter(filter_labels)111.map(split_labels, num_parallel_calls=tf.data.AUTOTUNE)112.batch(16)113)114test_ds = (115snli_test.filter(filter_labels)116.map(split_labels, num_parallel_calls=tf.data.AUTOTUNE)117.batch(16)118)119120121"""122## Establishing baseline with BERT.123124We use the BERT model from KerasHub to establish a baseline for our semantic similarity125task. The `keras_hub.models.BertClassifier` class attaches a classification head to the BERT126Backbone, mapping the backbone outputs to a logit output suitable for a classification task.127This significantly reduces the need for custom code.128129KerasHub models have built-in tokenization capabilities that handle tokenization by default130based on the selected model. However, users can also use custom preprocessing techniques131as per their specific needs. If we pass a tuple as input, the model will tokenize all the132strings and concatenate them with a `"[SEP]"` separator.133134We use this model with pretrained weights, and we can use the `from_preset()` method135to use our own preprocessor. For the SNLI dataset, we set `num_classes` to 3.136"""137138bert_classifier = keras_hub.models.BertClassifier.from_preset(139"bert_tiny_en_uncased", num_classes=3140)141142"""143Please note that the BERT Tiny model has only 4,386,307 trainable parameters.144145KerasHub task models come with compilation defaults. We can now train the model we just146instantiated by calling the `fit()` method.147"""148149bert_classifier.fit(train_ds, validation_data=val_ds, epochs=1)150151"""152Our BERT classifier achieved an accuracy of around 76% on the validation split. Now,153let's evaluate its performance on the test split.154155### Evaluate the performance of the trained model on test data.156"""157158bert_classifier.evaluate(test_ds)159160"""161Our baseline BERT model achieved a similar accuracy of around 76% on the test split.162Now, let's try to improve its performance by recompiling the model with a slightly163higher learning rate.164"""165166bert_classifier = keras_hub.models.BertClassifier.from_preset(167"bert_tiny_en_uncased", num_classes=3168)169bert_classifier.compile(170loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),171optimizer=keras.optimizers.Adam(5e-5),172metrics=["accuracy"],173)174175bert_classifier.fit(train_ds, validation_data=val_ds, epochs=1)176bert_classifier.evaluate(test_ds)177178"""179Just tweaking the learning rate alone was not enough to boost performance, which180stayed right around 76%. Let's try again, but this time with181`keras.optimizers.AdamW`, and a learning rate schedule.182"""183184185class TriangularSchedule(keras.optimizers.schedules.LearningRateSchedule):186"""Linear ramp up for `warmup` steps, then linear decay to zero at `total` steps."""187188def __init__(self, rate, warmup, total):189self.rate = rate190self.warmup = warmup191self.total = total192193def get_config(self):194config = {"rate": self.rate, "warmup": self.warmup, "total": self.total}195return config196197def __call__(self, step):198step = keras.ops.cast(step, dtype="float32")199rate = keras.ops.cast(self.rate, dtype="float32")200warmup = keras.ops.cast(self.warmup, dtype="float32")201total = keras.ops.cast(self.total, dtype="float32")202203warmup_rate = rate * step / self.warmup204cooldown_rate = rate * (total - step) / (total - warmup)205triangular_rate = keras.ops.minimum(warmup_rate, cooldown_rate)206return keras.ops.maximum(triangular_rate, 0.0)207208209bert_classifier = keras_hub.models.BertClassifier.from_preset(210"bert_tiny_en_uncased", num_classes=3211)212213# Get the total count of training batches.214# This requires walking the dataset to filter all -1 labels.215epochs = 3216total_steps = sum(1 for _ in train_ds.as_numpy_iterator()) * epochs217warmup_steps = int(total_steps * 0.2)218219bert_classifier.compile(220loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),221optimizer=keras.optimizers.AdamW(222TriangularSchedule(1e-4, warmup_steps, total_steps)223),224metrics=["accuracy"],225)226227bert_classifier.fit(train_ds, validation_data=val_ds, epochs=epochs)228229"""230Success! With the learning rate scheduler and the `AdamW` optimizer, our validation231accuracy improved to around 79%.232233Now, let's evaluate our final model on the test set and see how it performs.234"""235236bert_classifier.evaluate(test_ds)237238"""239Our Tiny BERT model achieved an accuracy of approximately 79% on the test set240with the use of a learning rate scheduler. This is a significant improvement over241our previous results. Fine-tuning a pretrained BERT242model can be a powerful tool in natural language processing tasks, and even a243small model like Tiny BERT can achieve impressive results.244245Let's save our model for now246and move on to learning how to perform inference with it.247248## Save and Reload the model249"""250bert_classifier.save("bert_classifier.keras")251restored_model = keras.models.load_model("bert_classifier.keras")252restored_model.evaluate(test_ds)253254"""255## Performing inference with the model.256257Let's see how to perform inference with KerasHub models258"""259260# Convert to Hypothesis-Premise pair, for forward pass through model261sample = (sample["hypothesis"], sample["premise"])262sample263264"""265The default preprocessor in KerasHub models handles input tokenization automatically,266so we don't need to perform tokenization explicitly.267"""268predictions = bert_classifier.predict(sample)269270271def softmax(x):272return np.exp(x) / np.exp(x).sum(axis=0)273274275# Get the class predictions with maximum probabilities276predictions = softmax(predictions)277278"""279## Improving accuracy with RoBERTa280281Now that we have established a baseline, we can attempt to improve our results282by experimenting with different models. Thanks to KerasHub, fine-tuning a RoBERTa283checkpoint on the same dataset is easy with just a few lines of code.284"""285286# Inittializing a RoBERTa from preset287roberta_classifier = keras_hub.models.RobertaClassifier.from_preset(288"roberta_base_en", num_classes=3289)290291roberta_classifier.fit(train_ds, validation_data=val_ds, epochs=1)292293roberta_classifier.evaluate(test_ds)294295"""296The RoBERTa base model has significantly more trainable parameters than the BERT297Tiny model, with almost 30 times as many at 124,645,635 parameters. As a result, it took298approximately 1.5 hours to train on a P100 GPU. However, the performance299improvement was substantial, with accuracy increasing to 88% on both the validation300and test splits. With RoBERTa, we were able to fit a maximum batch size of 16 on301our P100 GPU.302303Despite using a different model, the steps to perform inference with RoBERTa are304the same as with BERT!305"""306307predictions = roberta_classifier.predict(sample)308print(tf.math.argmax(predictions, axis=1).numpy())309310"""311We hope this tutorial has been helpful in demonstrating the ease and effectiveness312of using KerasHub and BERT for semantic similarity tasks.313314Throughout this tutorial, we demonstrated how to use a pretrained BERT model to315establish a baseline and improve performance by training a larger RoBERTa model316using just a few lines of code.317318The KerasHub toolbox provides a range of modular building blocks for preprocessing319text, including pretrained state-of-the-art models and low-level Transformer Encoder320layers. We believe that this makes experimenting with natural language solutions321more accessible and efficient.322"""323324325