Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/nlp/multiple_choice_task_with_transfer_learning.py
3507 views
1
"""
2
Title: MultipleChoice Task with Transfer Learning
3
Author: Md Awsafur Rahman
4
Date created: 2023/09/14
5
Last modified: 2025/06/16
6
Description: Use pre-trained nlp models for multiplechoice task.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
In this example, we will demonstrate how to perform the **MultipleChoice** task by
14
finetuning pre-trained DebertaV3 model. In this task, several candidate answers are
15
provided along with a context and the model is trained to select the correct answer
16
unlike question answering. We will use SWAG dataset to demonstrate this example.
17
"""
18
19
"""
20
## Setup
21
"""
22
23
"""shell
24
"""
25
26
import keras_hub
27
import keras
28
import tensorflow as tf # For tf.data only.
29
30
import numpy as np
31
import pandas as pd
32
33
import matplotlib.pyplot as plt
34
35
"""
36
## Dataset
37
In this example we'll use **SWAG** dataset for multiplechoice task.
38
"""
39
40
"""shell
41
wget "https://github.com/rowanz/swagaf/archive/refs/heads/master.zip" -O swag.zip
42
unzip -q swag.zip
43
"""
44
45
"""shell
46
ls swagaf-master/data
47
"""
48
49
"""
50
## Configuration
51
"""
52
53
54
class CFG:
55
preset = "deberta_v3_extra_small_en" # Name of pretrained models
56
sequence_length = 200 # Input sequence length
57
seed = 42 # Random seed
58
epochs = 5 # Training epochs
59
batch_size = 8 # Batch size
60
augment = True # Augmentation (Shuffle Options)
61
62
63
"""
64
## Reproducibility
65
Sets value for random seed to produce similar result in each run.
66
"""
67
68
keras.utils.set_random_seed(CFG.seed)
69
70
71
"""
72
## Meta Data
73
* **train.csv** - will be used for training.
74
* `sent1` and `sent2`: these fields show how a sentence starts, and if you put the two
75
together, you get the `startphrase` field.
76
* `ending_<i>`: suggests a possible ending for how a sentence can end, but only one of
77
them is correct.
78
* `label`: identifies the correct sentence ending.
79
80
* **val.csv** - similar to `train.csv` but will be used for validation.
81
"""
82
83
# Train data
84
train_df = pd.read_csv(
85
"swagaf-master/data/train.csv", index_col=0
86
) # Read CSV file into a DataFrame
87
train_df = train_df.sample(frac=0.02)
88
print("# Train Data: {:,}".format(len(train_df)))
89
90
# Valid data
91
valid_df = pd.read_csv(
92
"swagaf-master/data/val.csv", index_col=0
93
) # Read CSV file into a DataFrame
94
valid_df = valid_df.sample(frac=0.02)
95
print("# Valid Data: {:,}".format(len(valid_df)))
96
97
"""
98
## Contextualize Options
99
100
Our approach entails furnishing the model with question and answer pairs, as opposed to
101
employing a single question for all five options. In practice, this signifies that for
102
the five options, we will supply the model with the same set of five questions combined
103
with each respective answer choice (e.g., `(Q + A)`, `(Q + B)`, and so on). This analogy
104
draws parallels to the practice of revisiting a question multiple times during an exam to
105
promote a deeper understanding of the problem at hand.
106
107
> Notably, in the context of SWAG dataset, question is the start of a sentence and
108
options are possible ending of that sentence.
109
"""
110
111
112
# Define a function to create options based on the prompt and choices
113
def make_options(row):
114
row["options"] = [
115
f"{row.startphrase}\n{row.ending0}", # Option 0
116
f"{row.startphrase}\n{row.ending1}", # Option 1
117
f"{row.startphrase}\n{row.ending2}", # Option 2
118
f"{row.startphrase}\n{row.ending3}",
119
] # Option 3
120
return row
121
122
123
"""
124
Apply the `make_options` function to each row of the dataframe
125
"""
126
127
train_df = train_df.apply(make_options, axis=1)
128
valid_df = valid_df.apply(make_options, axis=1)
129
130
"""
131
## Preprocessing
132
133
**What it does:** The preprocessor takes input strings and transforms them into a
134
dictionary (`token_ids`, `padding_mask`) containing preprocessed tensors. This process
135
starts with tokenization, where input strings are converted into sequences of token IDs.
136
137
**Why it's important:** Initially, raw text data is complex and challenging for modeling
138
due to its high dimensionality. By converting text into a compact set of tokens, such as
139
transforming `"The quick brown fox"` into `["the", "qu", "##ick", "br", "##own", "fox"]`,
140
we simplify the data. Many models rely on special tokens and additional tensors to
141
understand input. These tokens help divide input and identify padding, among other tasks.
142
Making all sequences the same length through padding boosts computational efficiency,
143
making subsequent steps smoother.
144
145
Explore the following pages to access the available preprocessing and tokenizer layers in
146
**KerasHub**:
147
- [Preprocessing](https://keras.io/api/keras_hub/preprocessing_layers/)
148
- [Tokenizers](https://keras.io/api/keras_hub/tokenizers/)
149
"""
150
151
preprocessor = keras_hub.models.DebertaV3Preprocessor.from_preset(
152
preset=CFG.preset, # Name of the model
153
sequence_length=CFG.sequence_length, # Max sequence length, will be padded if shorter
154
)
155
156
"""
157
Now, let's examine what the output shape of the preprocessing layer looks like. The
158
output shape of the layer can be represented as $(num\_choices, sequence\_length)$.
159
"""
160
161
outs = preprocessor(train_df.options.iloc[0]) # Process options for the first row
162
163
# Display the shape of each processed output
164
for k, v in outs.items():
165
print(k, ":", v.shape)
166
167
"""
168
We'll use the `preprocessing_fn` function to transform each text option using the
169
`dataset.map(preprocessing_fn)` method.
170
"""
171
172
173
def preprocess_fn(text, label=None):
174
text = preprocessor(text) # Preprocess text
175
return (
176
(text, label) if label is not None else text
177
) # Return processed text and label if available
178
179
180
"""
181
## Augmentation
182
183
In this notebook, we'll experiment with an interesting augmentation technique,
184
`option_shuffle`. Since we're providing the model with one option at a time, we can
185
introduce a shuffle to the order of options. For instance, options `[A, C, E, D, B]`
186
would be rearranged as `[D, B, A, E, C]`. This practice will help the model focus on the
187
content of the options themselves, rather than being influenced by their positions.
188
189
**Note:** Even though `option_shuffle` function is written in pure
190
tensorflow, it can be used with any backend (e.g. JAX, PyTorch) as it is only used
191
in `tf.data.Dataset` pipeline which is compatible with Keras 3 routines.
192
"""
193
194
195
def option_shuffle(options, labels, prob=0.50, seed=None):
196
if tf.random.uniform([]) > prob: # Shuffle probability check
197
return options, labels
198
# Shuffle indices of options and labels in the same order
199
indices = tf.random.shuffle(tf.range(tf.shape(options)[0]), seed=seed)
200
# Shuffle options and labels
201
options = tf.gather(options, indices)
202
labels = tf.gather(labels, indices)
203
return options, labels
204
205
206
"""
207
In the following function, we'll merge all augmentation functions to apply to the text.
208
These augmentations will be applied to the data using the `dataset.map(augment_fn)`
209
approach.
210
"""
211
212
213
def augment_fn(text, label=None):
214
text, label = option_shuffle(text, label, prob=0.5) # Shuffle the options
215
return (text, label) if label is not None else text
216
217
218
"""
219
## DataLoader
220
221
The code below sets up a robust data flow pipeline using `tf.data.Dataset` for data
222
processing. Notable aspects of `tf.data` include its ability to simplify pipeline
223
construction and represent components in sequences.
224
225
To learn more about `tf.data`, refer to this
226
[documentation](https://www.tensorflow.org/guide/data).
227
"""
228
229
230
def build_dataset(
231
texts,
232
labels=None,
233
batch_size=32,
234
cache=False,
235
augment=False,
236
repeat=False,
237
shuffle=1024,
238
):
239
AUTO = tf.data.AUTOTUNE # AUTOTUNE option
240
slices = (
241
(texts,)
242
if labels is None
243
else (texts, keras.utils.to_categorical(labels, num_classes=4))
244
) # Create slices
245
ds = tf.data.Dataset.from_tensor_slices(slices) # Create dataset from slices
246
ds = ds.cache() if cache else ds # Cache dataset if enabled
247
if augment: # Apply augmentation if enabled
248
ds = ds.map(augment_fn, num_parallel_calls=AUTO)
249
ds = ds.map(preprocess_fn, num_parallel_calls=AUTO) # Map preprocessing function
250
ds = ds.repeat() if repeat else ds # Repeat dataset if enabled
251
opt = tf.data.Options() # Create dataset options
252
if shuffle:
253
ds = ds.shuffle(shuffle, seed=CFG.seed) # Shuffle dataset if enabled
254
opt.experimental_deterministic = False
255
ds = ds.with_options(opt) # Set dataset options
256
ds = ds.batch(batch_size, drop_remainder=True) # Batch dataset
257
ds = ds.prefetch(AUTO) # Prefetch next batch
258
return ds # Return the built dataset
259
260
261
"""
262
Now let's create train and valid dataloader using above function.
263
"""
264
265
# Build train dataloader
266
train_texts = train_df.options.tolist() # Extract training texts
267
train_labels = train_df.label.tolist() # Extract training labels
268
train_ds = build_dataset(
269
train_texts,
270
train_labels,
271
batch_size=CFG.batch_size,
272
cache=True,
273
shuffle=True,
274
repeat=True,
275
augment=CFG.augment,
276
)
277
278
# Build valid dataloader
279
valid_texts = valid_df.options.tolist() # Extract validation texts
280
valid_labels = valid_df.label.tolist() # Extract validation labels
281
valid_ds = build_dataset(
282
valid_texts,
283
valid_labels,
284
batch_size=CFG.batch_size,
285
cache=True,
286
shuffle=False,
287
repeat=False,
288
augment=False,
289
)
290
291
292
"""
293
## LR Schedule
294
295
Implementing a learning rate scheduler is crucial for transfer learning. The learning
296
rate initiates at `lr_start` and gradually tapers down to `lr_min` using **cosine**
297
curve.
298
299
**Importance:** A well-structured learning rate schedule is essential for efficient model
300
training, ensuring optimal convergence and avoiding issues such as overshooting or
301
stagnation.
302
"""
303
304
import math
305
306
307
def get_lr_callback(batch_size=8, mode="cos", epochs=10, plot=False):
308
lr_start, lr_max, lr_min = 1.0e-6, 0.6e-6 * batch_size, 1e-6
309
lr_ramp_ep, lr_sus_ep = 2, 0
310
311
def lrfn(epoch): # Learning rate update function
312
if epoch < lr_ramp_ep:
313
lr = (lr_max - lr_start) / lr_ramp_ep * epoch + lr_start
314
elif epoch < lr_ramp_ep + lr_sus_ep:
315
lr = lr_max
316
else:
317
decay_total_epochs, decay_epoch_index = (
318
epochs - lr_ramp_ep - lr_sus_ep + 3,
319
epoch - lr_ramp_ep - lr_sus_ep,
320
)
321
phase = math.pi * decay_epoch_index / decay_total_epochs
322
lr = (lr_max - lr_min) * 0.5 * (1 + math.cos(phase)) + lr_min
323
return lr
324
325
if plot: # Plot lr curve if plot is True
326
plt.figure(figsize=(10, 5))
327
plt.plot(
328
np.arange(epochs),
329
[lrfn(epoch) for epoch in np.arange(epochs)],
330
marker="o",
331
)
332
plt.xlabel("epoch")
333
plt.ylabel("lr")
334
plt.title("LR Scheduler")
335
plt.show()
336
337
return keras.callbacks.LearningRateScheduler(
338
lrfn, verbose=False
339
) # Create lr callback
340
341
342
_ = get_lr_callback(CFG.batch_size, plot=True)
343
344
"""
345
## Callbacks
346
347
The function below will gather all the training callbacks, such as `lr_scheduler`,
348
`model_checkpoint`.
349
"""
350
351
352
def get_callbacks():
353
callbacks = []
354
lr_cb = get_lr_callback(CFG.batch_size) # Get lr callback
355
ckpt_cb = keras.callbacks.ModelCheckpoint(
356
f"best.keras",
357
monitor="val_accuracy",
358
save_best_only=True,
359
save_weights_only=False,
360
mode="max",
361
) # Get Model checkpoint callback
362
callbacks.extend([lr_cb, ckpt_cb]) # Add lr and checkpoint callbacks
363
return callbacks # Return the list of callbacks
364
365
366
callbacks = get_callbacks()
367
368
"""
369
## MultipleChoice Model
370
371
372
373
374
375
"""
376
377
"""
378
379
### Pre-trained Models
380
381
The `KerasHub` library provides comprehensive, ready-to-use implementations of popular
382
NLP model architectures. It features a variety of pre-trained models including `Bert`,
383
`Roberta`, `DebertaV3`, and more. In this notebook, we'll showcase the usage of
384
`DistillBert`. However, feel free to explore all available models in the [KerasHub
385
documentation](https://keras.io/api/keras_hub/models/). Also for a deeper understanding
386
of `KerasHub`, refer to the informative [getting started
387
guide](https://keras.io/guides/keras_hub/getting_started/).
388
389
Our approach involves using `keras_hub.models.XXClassifier` to process each question and
390
option pari (e.g. (Q+A), (Q+B), etc.), generating logits. These logits are then combined
391
and passed through a softmax function to produce the final output.
392
"""
393
394
"""
395
396
### Classifier for Multiple-Choice Tasks
397
398
When dealing with multiple-choice questions, instead of giving the model the question and
399
all options together `(Q + A + B + C ...)`, we provide the model with one option at a
400
time along with the question. For instance, `(Q + A)`, `(Q + B)`, and so on. Once we have
401
the prediction scores (logits) for all options, we combine them using the `Softmax`
402
function to get the ultimate result. If we had given all options at once to the model,
403
the text's length would increase, making it harder for the model to handle. The picture
404
below illustrates this idea:
405
406
![Model Diagram](https://pbs.twimg.com/media/F3NUju_a8AAS8Fq?format=png&name=large)
407
408
<div align="center"><b> Picture Credit: </b> <a href="https://twitter.com/johnowhitaker">
409
@johnowhitaker </a> </div><br>
410
411
From a coding perspective, remember that we use the same model for all five options, with
412
shared weights. Despite the figure suggesting five separate models, they are, in fact,
413
one model with shared weights. Another point to consider is the the input shapes of
414
Classifier and MultipleChoice.
415
416
* Input shape for **Multiple Choice**: $(batch\_size, num\_choices, seq\_length)$
417
* Input shape for **Classifier**: $(batch\_size, seq\_length)$
418
419
Certainly, it's clear that we can't directly give the data for the multiple-choice task
420
to the model because the input shapes don't match. To handle this, we'll use **slicing**.
421
This means we'll separate the features of each option, like $feature_{(Q + A)}$ and
422
$feature_{(Q + B)}$, and give them one by one to the NLP classifier. After we get the
423
prediction scores $logits_{(Q + A)}$ and $logits_{(Q + B)}$ for all the options, we'll
424
use the Softmax function, like $\operatorname{Softmax}([logits_{(Q + A)}, logits_{(Q +
425
B)}])$, to combine them. This final step helps us make the ultimate decision or choice.
426
427
> Note that in the classifier, we set `num_classes=1` instead of `5`. This is because the
428
classifier produces a single output for each option. When dealing with five options,
429
these individual outputs are joined together and then processed through a softmax
430
function to generate the final result, which has a dimension of `5`.
431
"""
432
433
434
# Selects one option from five
435
class SelectOption(keras.layers.Layer):
436
def __init__(self, index, **kwargs):
437
super().__init__(**kwargs)
438
self.index = index
439
440
def call(self, inputs):
441
# Selects a specific slice from the inputs tensor
442
return inputs[:, self.index, :]
443
444
def get_config(self):
445
# For serialize the model
446
base_config = super().get_config()
447
config = {
448
"index": self.index,
449
}
450
return {**base_config, **config}
451
452
453
def build_model():
454
# Define input layers
455
inputs = {
456
"token_ids": keras.Input(shape=(4, None), dtype="int32", name="token_ids"),
457
"padding_mask": keras.Input(
458
shape=(4, None), dtype="int32", name="padding_mask"
459
),
460
}
461
# Create a DebertaV3Classifier model
462
classifier = keras_hub.models.DebertaV3Classifier.from_preset(
463
CFG.preset,
464
preprocessor=None,
465
num_classes=1, # one output per one option, for five options total 5 outputs
466
)
467
logits = []
468
# Loop through each option (Q+A), (Q+B) etc and compute associated logits
469
for option_idx in range(4):
470
option = {
471
k: SelectOption(option_idx, name=f"{k}_{option_idx}")(v)
472
for k, v in inputs.items()
473
}
474
logit = classifier(option)
475
logits.append(logit)
476
477
# Compute final output
478
logits = keras.layers.Concatenate(axis=-1)(logits)
479
outputs = keras.layers.Softmax(axis=-1)(logits)
480
model = keras.Model(inputs, outputs)
481
482
# Compile the model with optimizer, loss, and metrics
483
model.compile(
484
optimizer=keras.optimizers.AdamW(5e-6),
485
loss=keras.losses.CategoricalCrossentropy(label_smoothing=0.02),
486
metrics=[
487
keras.metrics.CategoricalAccuracy(name="accuracy"),
488
],
489
jit_compile=True,
490
)
491
return model
492
493
494
# Build the Build
495
model = build_model()
496
497
"""
498
Let's checkout the model summary to have a better insight on the model.
499
"""
500
501
model.summary()
502
503
"""
504
Finally, let's check the model structure visually if everything is in place.
505
"""
506
507
keras.utils.plot_model(model, show_shapes=True)
508
509
"""
510
## Training
511
"""
512
513
# Start training the model
514
history = model.fit(
515
train_ds,
516
epochs=CFG.epochs,
517
validation_data=valid_ds,
518
callbacks=callbacks,
519
steps_per_epoch=int(len(train_df) / CFG.batch_size),
520
verbose=1,
521
)
522
523
"""
524
## Inference
525
"""
526
527
# Make predictions using the trained model on last validation data
528
predictions = model.predict(
529
valid_ds,
530
batch_size=CFG.batch_size, # max batch size = valid size
531
verbose=1,
532
)
533
534
# Format predictions and true answers
535
pred_answers = np.arange(4)[np.argsort(-predictions)][:, 0]
536
true_answers = valid_df.label.values
537
538
# Check 5 Predictions
539
print("# Predictions\n")
540
for i in range(0, 50, 10):
541
row = valid_df.iloc[i]
542
question = row.startphrase
543
pred_answer = f"ending{pred_answers[i]}"
544
true_answer = f"ending{true_answers[i]}"
545
print(f"❓ Sentence {i+1}:\n{question}\n")
546
print(f"✅ True Ending: {true_answer}\n >> {row[true_answer]}\n")
547
print(f"🤖 Predicted Ending: {pred_answer}\n >> {row[pred_answer]}\n")
548
print("-" * 90, "\n")
549
550
"""
551
## Reference
552
* [Multiple Choice with
553
HF](https://twitter.com/johnowhitaker/status/1689790373454041089?s=20)
554
* [Keras NLP](https://keras.io/api/keras_hub/)
555
* [BirdCLEF23: Pretraining is All you Need
556
[Train]](https://www.kaggle.com/code/awsaf49/birdclef23-pretraining-is-all-you-need-train)
557
[Train]](https://www.kaggle.com/code/awsaf49/birdclef23-pretraining-is-all-you-need-train)
558
* [Triple Stratified KFold with
559
TFRecords](https://www.kaggle.com/code/cdeotte/triple-stratified-kfold-with-tfrecords)
560
"""
561
562