Path: blob/main/beginner_source/transfer_learning_tutorial.py
1367 views
# -*- coding: utf-8 -*-1"""2Transfer Learning for Computer Vision Tutorial3==============================================4**Author**: `Sasank Chilamkurthy <https://chsasank.github.io>`_56In this tutorial, you will learn how to train a convolutional neural network for7image classification using transfer learning. You can read more about the transfer8learning at `cs231n notes <https://cs231n.github.io/transfer-learning/>`__910Quoting these notes,1112In practice, very few people train an entire Convolutional Network13from scratch (with random initialization), because it is relatively14rare to have a dataset of sufficient size. Instead, it is common to15pretrain a ConvNet on a very large dataset (e.g. ImageNet, which16contains 1.2 million images with 1000 categories), and then use the17ConvNet either as an initialization or a fixed feature extractor for18the task of interest.1920These two major transfer learning scenarios look as follows:2122- **Finetuning the ConvNet**: Instead of random initialization, we23initialize the network with a pretrained network, like the one that is24trained on imagenet 1000 dataset. Rest of the training looks as25usual.26- **ConvNet as fixed feature extractor**: Here, we will freeze the weights27for all of the network except that of the final fully connected28layer. This last fully connected layer is replaced with a new one29with random weights and only this layer is trained.3031"""32# License: BSD33# Author: Sasank Chilamkurthy3435import torch36import torch.nn as nn37import torch.optim as optim38from torch.optim import lr_scheduler39import torch.backends.cudnn as cudnn40import numpy as np41import torchvision42from torchvision import datasets, models, transforms43import matplotlib.pyplot as plt44import time45import os46from PIL import Image47from tempfile import TemporaryDirectory4849cudnn.benchmark = True50plt.ion() # interactive mode5152######################################################################53# Load Data54# ---------55#56# We will use torchvision and torch.utils.data packages for loading the57# data.58#59# The problem we're going to solve today is to train a model to classify60# **ants** and **bees**. We have about 120 training images each for ants and bees.61# There are 75 validation images for each class. Usually, this is a very62# small dataset to generalize upon, if trained from scratch. Since we63# are using transfer learning, we should be able to generalize reasonably64# well.65#66# This dataset is a very small subset of imagenet.67#68# .. Note ::69# Download the data from70# `here <https://download.pytorch.org/tutorial/hymenoptera_data.zip>`_71# and extract it to the current directory.7273# Data augmentation and normalization for training74# Just normalization for validation75data_transforms = {76'train': transforms.Compose([77transforms.RandomResizedCrop(224),78transforms.RandomHorizontalFlip(),79transforms.ToTensor(),80transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])81]),82'val': transforms.Compose([83transforms.Resize(256),84transforms.CenterCrop(224),85transforms.ToTensor(),86transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])87]),88}8990data_dir = 'data/hymenoptera_data'91image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),92data_transforms[x])93for x in ['train', 'val']}94dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,95shuffle=True, num_workers=4)96for x in ['train', 'val']}97dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}98class_names = image_datasets['train'].classes99100# We want to be able to train our model on an `accelerator <https://pytorch.org/docs/stable/torch.html#accelerators>`__101# such as CUDA, MPS, MTIA, or XPU. If the current accelerator is available, we will use it. Otherwise, we use the CPU.102103device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"104print(f"Using {device} device")105106######################################################################107# Visualize a few images108# ^^^^^^^^^^^^^^^^^^^^^^109# Let's visualize a few training images so as to understand the data110# augmentations.111112def imshow(inp, title=None):113"""Display image for Tensor."""114inp = inp.numpy().transpose((1, 2, 0))115mean = np.array([0.485, 0.456, 0.406])116std = np.array([0.229, 0.224, 0.225])117inp = std * inp + mean118inp = np.clip(inp, 0, 1)119plt.imshow(inp)120if title is not None:121plt.title(title)122plt.pause(0.001) # pause a bit so that plots are updated123124125# Get a batch of training data126inputs, classes = next(iter(dataloaders['train']))127128# Make a grid from batch129out = torchvision.utils.make_grid(inputs)130131imshow(out, title=[class_names[x] for x in classes])132133134######################################################################135# Training the model136# ------------------137#138# Now, let's write a general function to train a model. Here, we will139# illustrate:140#141# - Scheduling the learning rate142# - Saving the best model143#144# In the following, parameter ``scheduler`` is an LR scheduler object from145# ``torch.optim.lr_scheduler``.146147148def train_model(model, criterion, optimizer, scheduler, num_epochs=25):149since = time.time()150151# Create a temporary directory to save training checkpoints152with TemporaryDirectory() as tempdir:153best_model_params_path = os.path.join(tempdir, 'best_model_params.pt')154155torch.save(model.state_dict(), best_model_params_path)156best_acc = 0.0157158for epoch in range(num_epochs):159print(f'Epoch {epoch}/{num_epochs - 1}')160print('-' * 10)161162# Each epoch has a training and validation phase163for phase in ['train', 'val']:164if phase == 'train':165model.train() # Set model to training mode166else:167model.eval() # Set model to evaluate mode168169running_loss = 0.0170running_corrects = 0171172# Iterate over data.173for inputs, labels in dataloaders[phase]:174inputs = inputs.to(device)175labels = labels.to(device)176177# zero the parameter gradients178optimizer.zero_grad()179180# forward181# track history if only in train182with torch.set_grad_enabled(phase == 'train'):183outputs = model(inputs)184_, preds = torch.max(outputs, 1)185loss = criterion(outputs, labels)186187# backward + optimize only if in training phase188if phase == 'train':189loss.backward()190optimizer.step()191192# statistics193running_loss += loss.item() * inputs.size(0)194running_corrects += torch.sum(preds == labels.data)195if phase == 'train':196scheduler.step()197198epoch_loss = running_loss / dataset_sizes[phase]199epoch_acc = running_corrects.double() / dataset_sizes[phase]200201print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')202203# deep copy the model204if phase == 'val' and epoch_acc > best_acc:205best_acc = epoch_acc206torch.save(model.state_dict(), best_model_params_path)207208print()209210time_elapsed = time.time() - since211print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')212print(f'Best val Acc: {best_acc:4f}')213214# load best model weights215model.load_state_dict(torch.load(best_model_params_path, weights_only=True))216return model217218219######################################################################220# Visualizing the model predictions221# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^222#223# Generic function to display predictions for a few images224#225226def visualize_model(model, num_images=6):227was_training = model.training228model.eval()229images_so_far = 0230fig = plt.figure()231232with torch.no_grad():233for i, (inputs, labels) in enumerate(dataloaders['val']):234inputs = inputs.to(device)235labels = labels.to(device)236237outputs = model(inputs)238_, preds = torch.max(outputs, 1)239240for j in range(inputs.size()[0]):241images_so_far += 1242ax = plt.subplot(num_images//2, 2, images_so_far)243ax.axis('off')244ax.set_title(f'predicted: {class_names[preds[j]]}')245imshow(inputs.cpu().data[j])246247if images_so_far == num_images:248model.train(mode=was_training)249return250model.train(mode=was_training)251252######################################################################253# Finetuning the ConvNet254# ----------------------255#256# Load a pretrained model and reset final fully connected layer.257#258259model_ft = models.resnet18(weights='IMAGENET1K_V1')260num_ftrs = model_ft.fc.in_features261# Here the size of each output sample is set to 2.262# Alternatively, it can be generalized to ``nn.Linear(num_ftrs, len(class_names))``.263model_ft.fc = nn.Linear(num_ftrs, 2)264265model_ft = model_ft.to(device)266267criterion = nn.CrossEntropyLoss()268269# Observe that all parameters are being optimized270optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)271272# Decay LR by a factor of 0.1 every 7 epochs273exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)274275######################################################################276# Train and evaluate277# ^^^^^^^^^^^^^^^^^^278#279# It should take around 15-25 min on CPU. On GPU though, it takes less than a280# minute.281#282283model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler,284num_epochs=25)285286######################################################################287#288289visualize_model(model_ft)290291292######################################################################293# ConvNet as fixed feature extractor294# ----------------------------------295#296# Here, we need to freeze all the network except the final layer. We need297# to set ``requires_grad = False`` to freeze the parameters so that the298# gradients are not computed in ``backward()``.299#300# You can read more about this in the documentation301# `here <https://pytorch.org/docs/notes/autograd.html#excluding-subgraphs-from-backward>`__.302#303304model_conv = torchvision.models.resnet18(weights='IMAGENET1K_V1')305for param in model_conv.parameters():306param.requires_grad = False307308# Parameters of newly constructed modules have requires_grad=True by default309num_ftrs = model_conv.fc.in_features310model_conv.fc = nn.Linear(num_ftrs, 2)311312model_conv = model_conv.to(device)313314criterion = nn.CrossEntropyLoss()315316# Observe that only parameters of final layer are being optimized as317# opposed to before.318optimizer_conv = optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9)319320# Decay LR by a factor of 0.1 every 7 epochs321exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)322323324######################################################################325# Train and evaluate326# ^^^^^^^^^^^^^^^^^^327#328# On CPU this will take about half the time compared to previous scenario.329# This is expected as gradients don't need to be computed for most of the330# network. However, forward does need to be computed.331#332333model_conv = train_model(model_conv, criterion, optimizer_conv,334exp_lr_scheduler, num_epochs=25)335336######################################################################337#338339visualize_model(model_conv)340341plt.ioff()342plt.show()343344345######################################################################346# Inference on custom images347# --------------------------348#349# Use the trained model to make predictions on custom images and visualize350# the predicted class labels along with the images.351#352353def visualize_model_predictions(model,img_path):354was_training = model.training355model.eval()356357img = Image.open(img_path)358img = data_transforms['val'](img)359img = img.unsqueeze(0)360img = img.to(device)361362with torch.no_grad():363outputs = model(img)364_, preds = torch.max(outputs, 1)365366ax = plt.subplot(2,2,1)367ax.axis('off')368ax.set_title(f'Predicted: {class_names[preds[0]]}')369imshow(img.cpu().data[0])370371model.train(mode=was_training)372373######################################################################374#375376visualize_model_predictions(377model_conv,378img_path='data/hymenoptera_data/val/bees/72100438_73de9f17af.jpg'379)380381plt.ioff()382plt.show()383384385######################################################################386# Further Learning387# -----------------388#389# If you would like to learn more about the applications of transfer learning,390# checkout our `Quantized Transfer Learning for Computer Vision Tutorial <https://pytorch.org/tutorials/intermediate/quantized_transfer_learning_tutorial.html>`_.391#392#393394395396