Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pytorch
GitHub Repository: pytorch/tutorials
Path: blob/main/beginner_source/transfer_learning_tutorial.py
1367 views
1
# -*- coding: utf-8 -*-
2
"""
3
Transfer Learning for Computer Vision Tutorial
4
==============================================
5
**Author**: `Sasank Chilamkurthy <https://chsasank.github.io>`_
6
7
In this tutorial, you will learn how to train a convolutional neural network for
8
image classification using transfer learning. You can read more about the transfer
9
learning at `cs231n notes <https://cs231n.github.io/transfer-learning/>`__
10
11
Quoting these notes,
12
13
In practice, very few people train an entire Convolutional Network
14
from scratch (with random initialization), because it is relatively
15
rare to have a dataset of sufficient size. Instead, it is common to
16
pretrain a ConvNet on a very large dataset (e.g. ImageNet, which
17
contains 1.2 million images with 1000 categories), and then use the
18
ConvNet either as an initialization or a fixed feature extractor for
19
the task of interest.
20
21
These two major transfer learning scenarios look as follows:
22
23
- **Finetuning the ConvNet**: Instead of random initialization, we
24
initialize the network with a pretrained network, like the one that is
25
trained on imagenet 1000 dataset. Rest of the training looks as
26
usual.
27
- **ConvNet as fixed feature extractor**: Here, we will freeze the weights
28
for all of the network except that of the final fully connected
29
layer. This last fully connected layer is replaced with a new one
30
with random weights and only this layer is trained.
31
32
"""
33
# License: BSD
34
# Author: Sasank Chilamkurthy
35
36
import torch
37
import torch.nn as nn
38
import torch.optim as optim
39
from torch.optim import lr_scheduler
40
import torch.backends.cudnn as cudnn
41
import numpy as np
42
import torchvision
43
from torchvision import datasets, models, transforms
44
import matplotlib.pyplot as plt
45
import time
46
import os
47
from PIL import Image
48
from tempfile import TemporaryDirectory
49
50
cudnn.benchmark = True
51
plt.ion() # interactive mode
52
53
######################################################################
54
# Load Data
55
# ---------
56
#
57
# We will use torchvision and torch.utils.data packages for loading the
58
# data.
59
#
60
# The problem we're going to solve today is to train a model to classify
61
# **ants** and **bees**. We have about 120 training images each for ants and bees.
62
# There are 75 validation images for each class. Usually, this is a very
63
# small dataset to generalize upon, if trained from scratch. Since we
64
# are using transfer learning, we should be able to generalize reasonably
65
# well.
66
#
67
# This dataset is a very small subset of imagenet.
68
#
69
# .. Note ::
70
# Download the data from
71
# `here <https://download.pytorch.org/tutorial/hymenoptera_data.zip>`_
72
# and extract it to the current directory.
73
74
# Data augmentation and normalization for training
75
# Just normalization for validation
76
data_transforms = {
77
'train': transforms.Compose([
78
transforms.RandomResizedCrop(224),
79
transforms.RandomHorizontalFlip(),
80
transforms.ToTensor(),
81
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
82
]),
83
'val': transforms.Compose([
84
transforms.Resize(256),
85
transforms.CenterCrop(224),
86
transforms.ToTensor(),
87
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
88
]),
89
}
90
91
data_dir = 'data/hymenoptera_data'
92
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
93
data_transforms[x])
94
for x in ['train', 'val']}
95
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
96
shuffle=True, num_workers=4)
97
for x in ['train', 'val']}
98
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
99
class_names = image_datasets['train'].classes
100
101
# We want to be able to train our model on an `accelerator <https://pytorch.org/docs/stable/torch.html#accelerators>`__
102
# such as CUDA, MPS, MTIA, or XPU. If the current accelerator is available, we will use it. Otherwise, we use the CPU.
103
104
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
105
print(f"Using {device} device")
106
107
######################################################################
108
# Visualize a few images
109
# ^^^^^^^^^^^^^^^^^^^^^^
110
# Let's visualize a few training images so as to understand the data
111
# augmentations.
112
113
def imshow(inp, title=None):
114
"""Display image for Tensor."""
115
inp = inp.numpy().transpose((1, 2, 0))
116
mean = np.array([0.485, 0.456, 0.406])
117
std = np.array([0.229, 0.224, 0.225])
118
inp = std * inp + mean
119
inp = np.clip(inp, 0, 1)
120
plt.imshow(inp)
121
if title is not None:
122
plt.title(title)
123
plt.pause(0.001) # pause a bit so that plots are updated
124
125
126
# Get a batch of training data
127
inputs, classes = next(iter(dataloaders['train']))
128
129
# Make a grid from batch
130
out = torchvision.utils.make_grid(inputs)
131
132
imshow(out, title=[class_names[x] for x in classes])
133
134
135
######################################################################
136
# Training the model
137
# ------------------
138
#
139
# Now, let's write a general function to train a model. Here, we will
140
# illustrate:
141
#
142
# - Scheduling the learning rate
143
# - Saving the best model
144
#
145
# In the following, parameter ``scheduler`` is an LR scheduler object from
146
# ``torch.optim.lr_scheduler``.
147
148
149
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
150
since = time.time()
151
152
# Create a temporary directory to save training checkpoints
153
with TemporaryDirectory() as tempdir:
154
best_model_params_path = os.path.join(tempdir, 'best_model_params.pt')
155
156
torch.save(model.state_dict(), best_model_params_path)
157
best_acc = 0.0
158
159
for epoch in range(num_epochs):
160
print(f'Epoch {epoch}/{num_epochs - 1}')
161
print('-' * 10)
162
163
# Each epoch has a training and validation phase
164
for phase in ['train', 'val']:
165
if phase == 'train':
166
model.train() # Set model to training mode
167
else:
168
model.eval() # Set model to evaluate mode
169
170
running_loss = 0.0
171
running_corrects = 0
172
173
# Iterate over data.
174
for inputs, labels in dataloaders[phase]:
175
inputs = inputs.to(device)
176
labels = labels.to(device)
177
178
# zero the parameter gradients
179
optimizer.zero_grad()
180
181
# forward
182
# track history if only in train
183
with torch.set_grad_enabled(phase == 'train'):
184
outputs = model(inputs)
185
_, preds = torch.max(outputs, 1)
186
loss = criterion(outputs, labels)
187
188
# backward + optimize only if in training phase
189
if phase == 'train':
190
loss.backward()
191
optimizer.step()
192
193
# statistics
194
running_loss += loss.item() * inputs.size(0)
195
running_corrects += torch.sum(preds == labels.data)
196
if phase == 'train':
197
scheduler.step()
198
199
epoch_loss = running_loss / dataset_sizes[phase]
200
epoch_acc = running_corrects.double() / dataset_sizes[phase]
201
202
print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
203
204
# deep copy the model
205
if phase == 'val' and epoch_acc > best_acc:
206
best_acc = epoch_acc
207
torch.save(model.state_dict(), best_model_params_path)
208
209
print()
210
211
time_elapsed = time.time() - since
212
print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
213
print(f'Best val Acc: {best_acc:4f}')
214
215
# load best model weights
216
model.load_state_dict(torch.load(best_model_params_path, weights_only=True))
217
return model
218
219
220
######################################################################
221
# Visualizing the model predictions
222
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
223
#
224
# Generic function to display predictions for a few images
225
#
226
227
def visualize_model(model, num_images=6):
228
was_training = model.training
229
model.eval()
230
images_so_far = 0
231
fig = plt.figure()
232
233
with torch.no_grad():
234
for i, (inputs, labels) in enumerate(dataloaders['val']):
235
inputs = inputs.to(device)
236
labels = labels.to(device)
237
238
outputs = model(inputs)
239
_, preds = torch.max(outputs, 1)
240
241
for j in range(inputs.size()[0]):
242
images_so_far += 1
243
ax = plt.subplot(num_images//2, 2, images_so_far)
244
ax.axis('off')
245
ax.set_title(f'predicted: {class_names[preds[j]]}')
246
imshow(inputs.cpu().data[j])
247
248
if images_so_far == num_images:
249
model.train(mode=was_training)
250
return
251
model.train(mode=was_training)
252
253
######################################################################
254
# Finetuning the ConvNet
255
# ----------------------
256
#
257
# Load a pretrained model and reset final fully connected layer.
258
#
259
260
model_ft = models.resnet18(weights='IMAGENET1K_V1')
261
num_ftrs = model_ft.fc.in_features
262
# Here the size of each output sample is set to 2.
263
# Alternatively, it can be generalized to ``nn.Linear(num_ftrs, len(class_names))``.
264
model_ft.fc = nn.Linear(num_ftrs, 2)
265
266
model_ft = model_ft.to(device)
267
268
criterion = nn.CrossEntropyLoss()
269
270
# Observe that all parameters are being optimized
271
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
272
273
# Decay LR by a factor of 0.1 every 7 epochs
274
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
275
276
######################################################################
277
# Train and evaluate
278
# ^^^^^^^^^^^^^^^^^^
279
#
280
# It should take around 15-25 min on CPU. On GPU though, it takes less than a
281
# minute.
282
#
283
284
model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler,
285
num_epochs=25)
286
287
######################################################################
288
#
289
290
visualize_model(model_ft)
291
292
293
######################################################################
294
# ConvNet as fixed feature extractor
295
# ----------------------------------
296
#
297
# Here, we need to freeze all the network except the final layer. We need
298
# to set ``requires_grad = False`` to freeze the parameters so that the
299
# gradients are not computed in ``backward()``.
300
#
301
# You can read more about this in the documentation
302
# `here <https://pytorch.org/docs/notes/autograd.html#excluding-subgraphs-from-backward>`__.
303
#
304
305
model_conv = torchvision.models.resnet18(weights='IMAGENET1K_V1')
306
for param in model_conv.parameters():
307
param.requires_grad = False
308
309
# Parameters of newly constructed modules have requires_grad=True by default
310
num_ftrs = model_conv.fc.in_features
311
model_conv.fc = nn.Linear(num_ftrs, 2)
312
313
model_conv = model_conv.to(device)
314
315
criterion = nn.CrossEntropyLoss()
316
317
# Observe that only parameters of final layer are being optimized as
318
# opposed to before.
319
optimizer_conv = optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9)
320
321
# Decay LR by a factor of 0.1 every 7 epochs
322
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)
323
324
325
######################################################################
326
# Train and evaluate
327
# ^^^^^^^^^^^^^^^^^^
328
#
329
# On CPU this will take about half the time compared to previous scenario.
330
# This is expected as gradients don't need to be computed for most of the
331
# network. However, forward does need to be computed.
332
#
333
334
model_conv = train_model(model_conv, criterion, optimizer_conv,
335
exp_lr_scheduler, num_epochs=25)
336
337
######################################################################
338
#
339
340
visualize_model(model_conv)
341
342
plt.ioff()
343
plt.show()
344
345
346
######################################################################
347
# Inference on custom images
348
# --------------------------
349
#
350
# Use the trained model to make predictions on custom images and visualize
351
# the predicted class labels along with the images.
352
#
353
354
def visualize_model_predictions(model,img_path):
355
was_training = model.training
356
model.eval()
357
358
img = Image.open(img_path)
359
img = data_transforms['val'](img)
360
img = img.unsqueeze(0)
361
img = img.to(device)
362
363
with torch.no_grad():
364
outputs = model(img)
365
_, preds = torch.max(outputs, 1)
366
367
ax = plt.subplot(2,2,1)
368
ax.axis('off')
369
ax.set_title(f'Predicted: {class_names[preds[0]]}')
370
imshow(img.cpu().data[0])
371
372
model.train(mode=was_training)
373
374
######################################################################
375
#
376
377
visualize_model_predictions(
378
model_conv,
379
img_path='data/hymenoptera_data/val/bees/72100438_73de9f17af.jpg'
380
)
381
382
plt.ioff()
383
plt.show()
384
385
386
######################################################################
387
# Further Learning
388
# -----------------
389
#
390
# If you would like to learn more about the applications of transfer learning,
391
# checkout our `Quantized Transfer Learning for Computer Vision Tutorial <https://pytorch.org/tutorials/intermediate/quantized_transfer_learning_tutorial.html>`_.
392
#
393
#
394
395
396