Path: blob/master/examples/nlp/multiple_choice_task_with_transfer_learning.py
3507 views
"""1Title: MultipleChoice Task with Transfer Learning2Author: Md Awsafur Rahman3Date created: 2023/09/144Last modified: 2025/06/165Description: Use pre-trained nlp models for multiplechoice task.6Accelerator: GPU7"""89"""10## Introduction1112In this example, we will demonstrate how to perform the **MultipleChoice** task by13finetuning pre-trained DebertaV3 model. In this task, several candidate answers are14provided along with a context and the model is trained to select the correct answer15unlike question answering. We will use SWAG dataset to demonstrate this example.16"""1718"""19## Setup20"""2122"""shell23"""2425import keras_hub26import keras27import tensorflow as tf # For tf.data only.2829import numpy as np30import pandas as pd3132import matplotlib.pyplot as plt3334"""35## Dataset36In this example we'll use **SWAG** dataset for multiplechoice task.37"""3839"""shell40wget "https://github.com/rowanz/swagaf/archive/refs/heads/master.zip" -O swag.zip41unzip -q swag.zip42"""4344"""shell45ls swagaf-master/data46"""4748"""49## Configuration50"""515253class CFG:54preset = "deberta_v3_extra_small_en" # Name of pretrained models55sequence_length = 200 # Input sequence length56seed = 42 # Random seed57epochs = 5 # Training epochs58batch_size = 8 # Batch size59augment = True # Augmentation (Shuffle Options)606162"""63## Reproducibility64Sets value for random seed to produce similar result in each run.65"""6667keras.utils.set_random_seed(CFG.seed)686970"""71## Meta Data72* **train.csv** - will be used for training.73* `sent1` and `sent2`: these fields show how a sentence starts, and if you put the two74together, you get the `startphrase` field.75* `ending_<i>`: suggests a possible ending for how a sentence can end, but only one of76them is correct.77* `label`: identifies the correct sentence ending.7879* **val.csv** - similar to `train.csv` but will be used for validation.80"""8182# Train data83train_df = pd.read_csv(84"swagaf-master/data/train.csv", index_col=085) # Read CSV file into a DataFrame86train_df = train_df.sample(frac=0.02)87print("# Train Data: {:,}".format(len(train_df)))8889# Valid data90valid_df = pd.read_csv(91"swagaf-master/data/val.csv", index_col=092) # Read CSV file into a DataFrame93valid_df = valid_df.sample(frac=0.02)94print("# Valid Data: {:,}".format(len(valid_df)))9596"""97## Contextualize Options9899Our approach entails furnishing the model with question and answer pairs, as opposed to100employing a single question for all five options. In practice, this signifies that for101the five options, we will supply the model with the same set of five questions combined102with each respective answer choice (e.g., `(Q + A)`, `(Q + B)`, and so on). This analogy103draws parallels to the practice of revisiting a question multiple times during an exam to104promote a deeper understanding of the problem at hand.105106> Notably, in the context of SWAG dataset, question is the start of a sentence and107options are possible ending of that sentence.108"""109110111# Define a function to create options based on the prompt and choices112def make_options(row):113row["options"] = [114f"{row.startphrase}\n{row.ending0}", # Option 0115f"{row.startphrase}\n{row.ending1}", # Option 1116f"{row.startphrase}\n{row.ending2}", # Option 2117f"{row.startphrase}\n{row.ending3}",118] # Option 3119return row120121122"""123Apply the `make_options` function to each row of the dataframe124"""125126train_df = train_df.apply(make_options, axis=1)127valid_df = valid_df.apply(make_options, axis=1)128129"""130## Preprocessing131132**What it does:** The preprocessor takes input strings and transforms them into a133dictionary (`token_ids`, `padding_mask`) containing preprocessed tensors. This process134starts with tokenization, where input strings are converted into sequences of token IDs.135136**Why it's important:** Initially, raw text data is complex and challenging for modeling137due to its high dimensionality. By converting text into a compact set of tokens, such as138transforming `"The quick brown fox"` into `["the", "qu", "##ick", "br", "##own", "fox"]`,139we simplify the data. Many models rely on special tokens and additional tensors to140understand input. These tokens help divide input and identify padding, among other tasks.141Making all sequences the same length through padding boosts computational efficiency,142making subsequent steps smoother.143144Explore the following pages to access the available preprocessing and tokenizer layers in145**KerasHub**:146- [Preprocessing](https://keras.io/api/keras_hub/preprocessing_layers/)147- [Tokenizers](https://keras.io/api/keras_hub/tokenizers/)148"""149150preprocessor = keras_hub.models.DebertaV3Preprocessor.from_preset(151preset=CFG.preset, # Name of the model152sequence_length=CFG.sequence_length, # Max sequence length, will be padded if shorter153)154155"""156Now, let's examine what the output shape of the preprocessing layer looks like. The157output shape of the layer can be represented as $(num\_choices, sequence\_length)$.158"""159160outs = preprocessor(train_df.options.iloc[0]) # Process options for the first row161162# Display the shape of each processed output163for k, v in outs.items():164print(k, ":", v.shape)165166"""167We'll use the `preprocessing_fn` function to transform each text option using the168`dataset.map(preprocessing_fn)` method.169"""170171172def preprocess_fn(text, label=None):173text = preprocessor(text) # Preprocess text174return (175(text, label) if label is not None else text176) # Return processed text and label if available177178179"""180## Augmentation181182In this notebook, we'll experiment with an interesting augmentation technique,183`option_shuffle`. Since we're providing the model with one option at a time, we can184introduce a shuffle to the order of options. For instance, options `[A, C, E, D, B]`185would be rearranged as `[D, B, A, E, C]`. This practice will help the model focus on the186content of the options themselves, rather than being influenced by their positions.187188**Note:** Even though `option_shuffle` function is written in pure189tensorflow, it can be used with any backend (e.g. JAX, PyTorch) as it is only used190in `tf.data.Dataset` pipeline which is compatible with Keras 3 routines.191"""192193194def option_shuffle(options, labels, prob=0.50, seed=None):195if tf.random.uniform([]) > prob: # Shuffle probability check196return options, labels197# Shuffle indices of options and labels in the same order198indices = tf.random.shuffle(tf.range(tf.shape(options)[0]), seed=seed)199# Shuffle options and labels200options = tf.gather(options, indices)201labels = tf.gather(labels, indices)202return options, labels203204205"""206In the following function, we'll merge all augmentation functions to apply to the text.207These augmentations will be applied to the data using the `dataset.map(augment_fn)`208approach.209"""210211212def augment_fn(text, label=None):213text, label = option_shuffle(text, label, prob=0.5) # Shuffle the options214return (text, label) if label is not None else text215216217"""218## DataLoader219220The code below sets up a robust data flow pipeline using `tf.data.Dataset` for data221processing. Notable aspects of `tf.data` include its ability to simplify pipeline222construction and represent components in sequences.223224To learn more about `tf.data`, refer to this225[documentation](https://www.tensorflow.org/guide/data).226"""227228229def build_dataset(230texts,231labels=None,232batch_size=32,233cache=False,234augment=False,235repeat=False,236shuffle=1024,237):238AUTO = tf.data.AUTOTUNE # AUTOTUNE option239slices = (240(texts,)241if labels is None242else (texts, keras.utils.to_categorical(labels, num_classes=4))243) # Create slices244ds = tf.data.Dataset.from_tensor_slices(slices) # Create dataset from slices245ds = ds.cache() if cache else ds # Cache dataset if enabled246if augment: # Apply augmentation if enabled247ds = ds.map(augment_fn, num_parallel_calls=AUTO)248ds = ds.map(preprocess_fn, num_parallel_calls=AUTO) # Map preprocessing function249ds = ds.repeat() if repeat else ds # Repeat dataset if enabled250opt = tf.data.Options() # Create dataset options251if shuffle:252ds = ds.shuffle(shuffle, seed=CFG.seed) # Shuffle dataset if enabled253opt.experimental_deterministic = False254ds = ds.with_options(opt) # Set dataset options255ds = ds.batch(batch_size, drop_remainder=True) # Batch dataset256ds = ds.prefetch(AUTO) # Prefetch next batch257return ds # Return the built dataset258259260"""261Now let's create train and valid dataloader using above function.262"""263264# Build train dataloader265train_texts = train_df.options.tolist() # Extract training texts266train_labels = train_df.label.tolist() # Extract training labels267train_ds = build_dataset(268train_texts,269train_labels,270batch_size=CFG.batch_size,271cache=True,272shuffle=True,273repeat=True,274augment=CFG.augment,275)276277# Build valid dataloader278valid_texts = valid_df.options.tolist() # Extract validation texts279valid_labels = valid_df.label.tolist() # Extract validation labels280valid_ds = build_dataset(281valid_texts,282valid_labels,283batch_size=CFG.batch_size,284cache=True,285shuffle=False,286repeat=False,287augment=False,288)289290291"""292## LR Schedule293294Implementing a learning rate scheduler is crucial for transfer learning. The learning295rate initiates at `lr_start` and gradually tapers down to `lr_min` using **cosine**296curve.297298**Importance:** A well-structured learning rate schedule is essential for efficient model299training, ensuring optimal convergence and avoiding issues such as overshooting or300stagnation.301"""302303import math304305306def get_lr_callback(batch_size=8, mode="cos", epochs=10, plot=False):307lr_start, lr_max, lr_min = 1.0e-6, 0.6e-6 * batch_size, 1e-6308lr_ramp_ep, lr_sus_ep = 2, 0309310def lrfn(epoch): # Learning rate update function311if epoch < lr_ramp_ep:312lr = (lr_max - lr_start) / lr_ramp_ep * epoch + lr_start313elif epoch < lr_ramp_ep + lr_sus_ep:314lr = lr_max315else:316decay_total_epochs, decay_epoch_index = (317epochs - lr_ramp_ep - lr_sus_ep + 3,318epoch - lr_ramp_ep - lr_sus_ep,319)320phase = math.pi * decay_epoch_index / decay_total_epochs321lr = (lr_max - lr_min) * 0.5 * (1 + math.cos(phase)) + lr_min322return lr323324if plot: # Plot lr curve if plot is True325plt.figure(figsize=(10, 5))326plt.plot(327np.arange(epochs),328[lrfn(epoch) for epoch in np.arange(epochs)],329marker="o",330)331plt.xlabel("epoch")332plt.ylabel("lr")333plt.title("LR Scheduler")334plt.show()335336return keras.callbacks.LearningRateScheduler(337lrfn, verbose=False338) # Create lr callback339340341_ = get_lr_callback(CFG.batch_size, plot=True)342343"""344## Callbacks345346The function below will gather all the training callbacks, such as `lr_scheduler`,347`model_checkpoint`.348"""349350351def get_callbacks():352callbacks = []353lr_cb = get_lr_callback(CFG.batch_size) # Get lr callback354ckpt_cb = keras.callbacks.ModelCheckpoint(355f"best.keras",356monitor="val_accuracy",357save_best_only=True,358save_weights_only=False,359mode="max",360) # Get Model checkpoint callback361callbacks.extend([lr_cb, ckpt_cb]) # Add lr and checkpoint callbacks362return callbacks # Return the list of callbacks363364365callbacks = get_callbacks()366367"""368## MultipleChoice Model369370371372373374"""375376"""377378### Pre-trained Models379380The `KerasHub` library provides comprehensive, ready-to-use implementations of popular381NLP model architectures. It features a variety of pre-trained models including `Bert`,382`Roberta`, `DebertaV3`, and more. In this notebook, we'll showcase the usage of383`DistillBert`. However, feel free to explore all available models in the [KerasHub384documentation](https://keras.io/api/keras_hub/models/). Also for a deeper understanding385of `KerasHub`, refer to the informative [getting started386guide](https://keras.io/guides/keras_hub/getting_started/).387388Our approach involves using `keras_hub.models.XXClassifier` to process each question and389option pari (e.g. (Q+A), (Q+B), etc.), generating logits. These logits are then combined390and passed through a softmax function to produce the final output.391"""392393"""394395### Classifier for Multiple-Choice Tasks396397When dealing with multiple-choice questions, instead of giving the model the question and398all options together `(Q + A + B + C ...)`, we provide the model with one option at a399time along with the question. For instance, `(Q + A)`, `(Q + B)`, and so on. Once we have400the prediction scores (logits) for all options, we combine them using the `Softmax`401function to get the ultimate result. If we had given all options at once to the model,402the text's length would increase, making it harder for the model to handle. The picture403below illustrates this idea:404405406407<div align="center"><b> Picture Credit: </b> <a href="https://twitter.com/johnowhitaker">408@johnowhitaker </a> </div><br>409410From a coding perspective, remember that we use the same model for all five options, with411shared weights. Despite the figure suggesting five separate models, they are, in fact,412one model with shared weights. Another point to consider is the the input shapes of413Classifier and MultipleChoice.414415* Input shape for **Multiple Choice**: $(batch\_size, num\_choices, seq\_length)$416* Input shape for **Classifier**: $(batch\_size, seq\_length)$417418Certainly, it's clear that we can't directly give the data for the multiple-choice task419to the model because the input shapes don't match. To handle this, we'll use **slicing**.420This means we'll separate the features of each option, like $feature_{(Q + A)}$ and421$feature_{(Q + B)}$, and give them one by one to the NLP classifier. After we get the422prediction scores $logits_{(Q + A)}$ and $logits_{(Q + B)}$ for all the options, we'll423use the Softmax function, like $\operatorname{Softmax}([logits_{(Q + A)}, logits_{(Q +424B)}])$, to combine them. This final step helps us make the ultimate decision or choice.425426> Note that in the classifier, we set `num_classes=1` instead of `5`. This is because the427classifier produces a single output for each option. When dealing with five options,428these individual outputs are joined together and then processed through a softmax429function to generate the final result, which has a dimension of `5`.430"""431432433# Selects one option from five434class SelectOption(keras.layers.Layer):435def __init__(self, index, **kwargs):436super().__init__(**kwargs)437self.index = index438439def call(self, inputs):440# Selects a specific slice from the inputs tensor441return inputs[:, self.index, :]442443def get_config(self):444# For serialize the model445base_config = super().get_config()446config = {447"index": self.index,448}449return {**base_config, **config}450451452def build_model():453# Define input layers454inputs = {455"token_ids": keras.Input(shape=(4, None), dtype="int32", name="token_ids"),456"padding_mask": keras.Input(457shape=(4, None), dtype="int32", name="padding_mask"458),459}460# Create a DebertaV3Classifier model461classifier = keras_hub.models.DebertaV3Classifier.from_preset(462CFG.preset,463preprocessor=None,464num_classes=1, # one output per one option, for five options total 5 outputs465)466logits = []467# Loop through each option (Q+A), (Q+B) etc and compute associated logits468for option_idx in range(4):469option = {470k: SelectOption(option_idx, name=f"{k}_{option_idx}")(v)471for k, v in inputs.items()472}473logit = classifier(option)474logits.append(logit)475476# Compute final output477logits = keras.layers.Concatenate(axis=-1)(logits)478outputs = keras.layers.Softmax(axis=-1)(logits)479model = keras.Model(inputs, outputs)480481# Compile the model with optimizer, loss, and metrics482model.compile(483optimizer=keras.optimizers.AdamW(5e-6),484loss=keras.losses.CategoricalCrossentropy(label_smoothing=0.02),485metrics=[486keras.metrics.CategoricalAccuracy(name="accuracy"),487],488jit_compile=True,489)490return model491492493# Build the Build494model = build_model()495496"""497Let's checkout the model summary to have a better insight on the model.498"""499500model.summary()501502"""503Finally, let's check the model structure visually if everything is in place.504"""505506keras.utils.plot_model(model, show_shapes=True)507508"""509## Training510"""511512# Start training the model513history = model.fit(514train_ds,515epochs=CFG.epochs,516validation_data=valid_ds,517callbacks=callbacks,518steps_per_epoch=int(len(train_df) / CFG.batch_size),519verbose=1,520)521522"""523## Inference524"""525526# Make predictions using the trained model on last validation data527predictions = model.predict(528valid_ds,529batch_size=CFG.batch_size, # max batch size = valid size530verbose=1,531)532533# Format predictions and true answers534pred_answers = np.arange(4)[np.argsort(-predictions)][:, 0]535true_answers = valid_df.label.values536537# Check 5 Predictions538print("# Predictions\n")539for i in range(0, 50, 10):540row = valid_df.iloc[i]541question = row.startphrase542pred_answer = f"ending{pred_answers[i]}"543true_answer = f"ending{true_answers[i]}"544print(f"❓ Sentence {i+1}:\n{question}\n")545print(f"✅ True Ending: {true_answer}\n >> {row[true_answer]}\n")546print(f"🤖 Predicted Ending: {pred_answer}\n >> {row[pred_answer]}\n")547print("-" * 90, "\n")548549"""550## Reference551* [Multiple Choice with552HF](https://twitter.com/johnowhitaker/status/1689790373454041089?s=20)553* [Keras NLP](https://keras.io/api/keras_hub/)554* [BirdCLEF23: Pretraining is All you Need555[Train]](https://www.kaggle.com/code/awsaf49/birdclef23-pretraining-is-all-you-need-train)556[Train]](https://www.kaggle.com/code/awsaf49/birdclef23-pretraining-is-all-you-need-train)557* [Triple Stratified KFold with558TFRecords](https://www.kaggle.com/code/cdeotte/triple-stratified-kfold-with-tfrecords)559"""560561562