Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/vision/bit.py
3507 views
1
"""
2
Title: Image Classification using BigTransfer (BiT)
3
Author: [Sayan Nath](https://twitter.com/sayannath2350)
4
Date created: 2021/09/24
5
Last modified: 2024/01/03
6
Description: BigTransfer (BiT) State-of-the-art transfer learning for image classification.
7
Accelerator: GPU
8
Converted to Keras 3 by: [Sitam Meur](https://github.com/sitamgithub-MSIT)
9
"""
10
11
"""
12
## Introduction
13
14
BigTransfer (also known as BiT) is a state-of-the-art transfer learning method for image
15
classification. Transfer of pre-trained representations improves sample efficiency and
16
simplifies hyperparameter tuning when training deep neural networks for vision. BiT
17
revisit the paradigm of pre-training on large supervised datasets and fine-tuning the
18
model on a target task. The importance of appropriately choosing normalization layers and
19
scaling the architecture capacity as the amount of pre-training data increases.
20
21
BigTransfer(BiT) is trained on public datasets, along with code in
22
[TF2, Jax and Pytorch](https://github.com/google-research/big_transfer). This will help anyone to reach
23
state of the art performance on their task of interest, even with just a handful of
24
labeled images per class.
25
26
You can find BiT models pre-trained on
27
[ImageNet](https://image-net.org/challenges/LSVRC/2012/index) and ImageNet-21k in
28
[TFHub](https://tfhub.dev/google/collections/bit/1) as TensorFlow2 SavedModels that you
29
can use easily as Keras Layers. There are a variety of sizes ranging from a standard
30
ResNet50 to a ResNet152x4 (152 layers deep, 4x wider than a typical ResNet50) for users
31
with larger computational and memory budgets but higher accuracy requirements.
32
33
![](https://i.imgur.com/XeWVfe7.jpeg)
34
Figure: The x-axis shows the number of images used per class, ranging from 1 to the full
35
dataset. On the plots on the left, the curve in blue above is our BiT-L model, whereas
36
the curve below is a ResNet-50 pre-trained on ImageNet (ILSVRC-2012).
37
"""
38
39
"""
40
## Setup
41
"""
42
43
import os
44
45
os.environ["KERAS_BACKEND"] = "tensorflow"
46
import numpy as np
47
import pandas as pd
48
import matplotlib.pyplot as plt
49
50
import keras
51
from keras import ops
52
import tensorflow as tf
53
import tensorflow_hub as hub
54
import tensorflow_datasets as tfds
55
56
tfds.disable_progress_bar()
57
58
SEEDS = 42
59
60
keras.utils.set_random_seed(SEEDS)
61
62
"""
63
## Gather Flower Dataset
64
"""
65
66
train_ds, validation_ds = tfds.load(
67
"tf_flowers",
68
split=["train[:85%]", "train[85%:]"],
69
as_supervised=True,
70
)
71
72
"""
73
## Visualise the dataset
74
"""
75
76
plt.figure(figsize=(10, 10))
77
for i, (image, label) in enumerate(train_ds.take(9)):
78
ax = plt.subplot(3, 3, i + 1)
79
plt.imshow(image)
80
plt.title(int(label))
81
plt.axis("off")
82
83
"""
84
## Define hyperparameters
85
"""
86
87
RESIZE_TO = 384
88
CROP_TO = 224
89
BATCH_SIZE = 64
90
STEPS_PER_EPOCH = 10
91
AUTO = tf.data.AUTOTUNE # optimise the pipeline performance
92
NUM_CLASSES = 5 # number of classes
93
SCHEDULE_LENGTH = (
94
500 # we will train on lower resolution images and will still attain good results
95
)
96
SCHEDULE_BOUNDARIES = [
97
200,
98
300,
99
400,
100
] # more the dataset size the schedule length increase
101
102
"""
103
The hyperparamteres like `SCHEDULE_LENGTH` and `SCHEDULE_BOUNDARIES` are determined based
104
on empirical results. The method has been explained in the [original
105
paper](https://arxiv.org/abs/1912.11370) and in their [Google AI Blog
106
Post](https://ai.googleblog.com/2020/05/open-sourcing-bit-exploring-large-scale.html).
107
108
The `SCHEDULE_LENGTH` is aslo determined whether to use [MixUp
109
Augmentation](https://arxiv.org/abs/1710.09412) or not. You can also find an easy MixUp
110
Implementation in [Keras Coding Examples](https://keras.io/examples/vision/mixup/).
111
112
![](https://i.imgur.com/oSaIBYZ.jpeg)
113
"""
114
115
"""
116
## Define preprocessing helper functions
117
"""
118
119
SCHEDULE_LENGTH = SCHEDULE_LENGTH * 512 / BATCH_SIZE
120
121
random_flip = keras.layers.RandomFlip("horizontal")
122
random_crop = keras.layers.RandomCrop(CROP_TO, CROP_TO)
123
124
125
def preprocess_train(image, label):
126
image = random_flip(image)
127
image = ops.image.resize(image, (RESIZE_TO, RESIZE_TO))
128
image = random_crop(image)
129
image = image / 255.0
130
return (image, label)
131
132
133
def preprocess_test(image, label):
134
image = ops.image.resize(image, (RESIZE_TO, RESIZE_TO))
135
image = ops.cast(image, dtype="float32")
136
image = image / 255.0
137
return (image, label)
138
139
140
DATASET_NUM_TRAIN_EXAMPLES = train_ds.cardinality().numpy()
141
142
repeat_count = int(
143
SCHEDULE_LENGTH * BATCH_SIZE / DATASET_NUM_TRAIN_EXAMPLES * STEPS_PER_EPOCH
144
)
145
repeat_count += 50 + 1 # To ensure at least there are 50 epochs of training
146
147
"""
148
## Define the data pipeline
149
"""
150
151
# Training pipeline
152
pipeline_train = (
153
train_ds.shuffle(10000)
154
.repeat(repeat_count) # Repeat dataset_size / num_steps
155
.map(preprocess_train, num_parallel_calls=AUTO)
156
.batch(BATCH_SIZE)
157
.prefetch(AUTO)
158
)
159
160
# Validation pipeline
161
pipeline_validation = (
162
validation_ds.map(preprocess_test, num_parallel_calls=AUTO)
163
.batch(BATCH_SIZE)
164
.prefetch(AUTO)
165
)
166
167
"""
168
## Visualise the training samples
169
"""
170
171
image_batch, label_batch = next(iter(pipeline_train))
172
173
plt.figure(figsize=(10, 10))
174
for n in range(25):
175
ax = plt.subplot(5, 5, n + 1)
176
plt.imshow(image_batch[n])
177
plt.title(label_batch[n].numpy())
178
plt.axis("off")
179
180
"""
181
## Load pretrained TF-Hub model into a `KerasLayer`
182
"""
183
184
bit_model_url = "https://tfhub.dev/google/bit/m-r50x1/1"
185
bit_module = hub.load(bit_model_url)
186
187
"""
188
## Create BigTransfer (BiT) model
189
190
To create the new model, we:
191
192
1. Cut off the BiT model’s original head. This leaves us with the “pre-logits” output.
193
We do not have to do this if we use the ‘feature extractor’ models (i.e. all those in
194
subdirectories titled `feature_vectors`), since for those models the head has already
195
been cut off.
196
197
2. Add a new head with the number of outputs equal to the number of classes of our new
198
task. Note that it is important that we initialise the head to all zeroes.
199
"""
200
201
202
class MyBiTModel(keras.Model):
203
def __init__(self, num_classes, module, **kwargs):
204
super().__init__(**kwargs)
205
206
self.num_classes = num_classes
207
self.head = keras.layers.Dense(num_classes, kernel_initializer="zeros")
208
self.bit_model = module
209
210
def call(self, images):
211
bit_embedding = self.bit_model(images)
212
return self.head(bit_embedding)
213
214
215
model = MyBiTModel(num_classes=NUM_CLASSES, module=bit_module)
216
217
"""
218
## Define optimizer and loss
219
"""
220
221
learning_rate = 0.003 * BATCH_SIZE / 512
222
223
# Decay learning rate by a factor of 10 at SCHEDULE_BOUNDARIES.
224
lr_schedule = keras.optimizers.schedules.PiecewiseConstantDecay(
225
boundaries=SCHEDULE_BOUNDARIES,
226
values=[
227
learning_rate,
228
learning_rate * 0.1,
229
learning_rate * 0.01,
230
learning_rate * 0.001,
231
],
232
)
233
optimizer = keras.optimizers.SGD(learning_rate=lr_schedule, momentum=0.9)
234
235
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
236
237
"""
238
## Compile the model
239
"""
240
241
model.compile(optimizer=optimizer, loss=loss_fn, metrics=["accuracy"])
242
243
"""
244
## Set up callbacks
245
"""
246
247
train_callbacks = [
248
keras.callbacks.EarlyStopping(
249
monitor="val_accuracy", patience=2, restore_best_weights=True
250
)
251
]
252
253
"""
254
## Train the model
255
"""
256
257
history = model.fit(
258
pipeline_train,
259
batch_size=BATCH_SIZE,
260
epochs=int(SCHEDULE_LENGTH / STEPS_PER_EPOCH),
261
steps_per_epoch=STEPS_PER_EPOCH,
262
validation_data=pipeline_validation,
263
callbacks=train_callbacks,
264
)
265
266
"""
267
## Plot the training and validation metrics
268
"""
269
270
271
def plot_hist(hist):
272
plt.plot(hist.history["accuracy"])
273
plt.plot(hist.history["val_accuracy"])
274
plt.plot(hist.history["loss"])
275
plt.plot(hist.history["val_loss"])
276
plt.title("Training Progress")
277
plt.ylabel("Accuracy/Loss")
278
plt.xlabel("Epochs")
279
plt.legend(["train_acc", "val_acc", "train_loss", "val_loss"], loc="upper left")
280
plt.show()
281
282
283
plot_hist(history)
284
285
"""
286
## Evaluate the model
287
"""
288
289
accuracy = model.evaluate(pipeline_validation)[1] * 100
290
print("Accuracy: {:.2f}%".format(accuracy))
291
292
"""
293
## Conclusion
294
295
BiT performs well across a surprisingly wide range of data regimes
296
-- from 1 example per class to 1M total examples. BiT achieves 87.5% top-1 accuracy on
297
ILSVRC-2012, 99.4% on CIFAR-10, and 76.3% on the 19 task Visual Task Adaptation Benchmark
298
(VTAB). On small datasets, BiT attains 76.8% on ILSVRC-2012 with 10 examples per class,
299
and 97.0% on CIFAR-10 with 10 examples per class.
300
301
![](https://i.imgur.com/b1Lw5fz.png)
302
303
You can experiment further with the BigTransfer Method by following the
304
[original paper](https://arxiv.org/abs/1912.11370).
305
306
307
**Example available on HuggingFace**
308
| Trained Model | Demo |
309
| :--: | :--: |
310
| [![Generic badge](https://img.shields.io/badge/%F0%9F%A4%97%20Model-bit-black.svg)](https://huggingface.co/keras-io/bit) | [![Generic badge](https://img.shields.io/badge/%F0%9F%A4%97%20Spaces-bit-black.svg)](https://huggingface.co/spaces/keras-io/siamese-contrastive) |
311
"""
312
313