Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Path: blob/main/beginner_source/data_loading_tutorial.py
Views: 712
# -*- coding: utf-8 -*-1"""2Writing Custom Datasets, DataLoaders and Transforms3===================================================4**Author**: `Sasank Chilamkurthy <https://chsasank.github.io>`_56A lot of effort in solving any machine learning problem goes into7preparing the data. PyTorch provides many tools to make data loading8easy and hopefully, to make your code more readable. In this tutorial,9we will see how to load and preprocess/augment data from a non trivial10dataset.1112To run this tutorial, please make sure the following packages are13installed:1415- ``scikit-image``: For image io and transforms16- ``pandas``: For easier csv parsing1718"""1920import os21import torch22import pandas as pd23from skimage import io, transform24import numpy as np25import matplotlib.pyplot as plt26from torch.utils.data import Dataset, DataLoader27from torchvision import transforms, utils2829# Ignore warnings30import warnings31warnings.filterwarnings("ignore")3233plt.ion() # interactive mode3435######################################################################36# The dataset we are going to deal with is that of facial pose.37# This means that a face is annotated like this:38#39# .. figure:: /_static/img/landmarked_face2.png40# :width: 40041#42# Over all, 68 different landmark points are annotated for each face.43#44# .. note::45# Download the dataset from `here <https://download.pytorch.org/tutorial/faces.zip>`_46# so that the images are in a directory named 'data/faces/'.47# This dataset was actually48# generated by applying excellent `dlib's pose49# estimation <https://blog.dlib.net/2014/08/real-time-face-pose-estimation.html>`__50# on a few images from imagenet tagged as 'face'.51#52# Dataset comes with a ``.csv`` file with annotations which looks like this:53#54# .. code-block:: sh55#56# image_name,part_0_x,part_0_y,part_1_x,part_1_y,part_2_x, ... ,part_67_x,part_67_y57# 0805personali01.jpg,27,83,27,98, ... 84,13458# 1084239450_e76e00b7e7.jpg,70,236,71,257, ... ,128,31259#60# Let's take a single image name and its annotations from the CSV, in this case row index number 6561# for person-7.jpg just as an example. Read it, store the image name in ``img_name`` and store its62# annotations in an (L, 2) array ``landmarks`` where L is the number of landmarks in that row.63#6465landmarks_frame = pd.read_csv('data/faces/face_landmarks.csv')6667n = 6568img_name = landmarks_frame.iloc[n, 0]69landmarks = landmarks_frame.iloc[n, 1:]70landmarks = np.asarray(landmarks, dtype=float).reshape(-1, 2)7172print('Image name: {}'.format(img_name))73print('Landmarks shape: {}'.format(landmarks.shape))74print('First 4 Landmarks: {}'.format(landmarks[:4]))757677######################################################################78# Let's write a simple helper function to show an image and its landmarks79# and use it to show a sample.80#8182def show_landmarks(image, landmarks):83"""Show image with landmarks"""84plt.imshow(image)85plt.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker='.', c='r')86plt.pause(0.001) # pause a bit so that plots are updated8788plt.figure()89show_landmarks(io.imread(os.path.join('data/faces/', img_name)),90landmarks)91plt.show()929394######################################################################95# Dataset class96# -------------97#98# ``torch.utils.data.Dataset`` is an abstract class representing a99# dataset.100# Your custom dataset should inherit ``Dataset`` and override the following101# methods:102#103# - ``__len__`` so that ``len(dataset)`` returns the size of the dataset.104# - ``__getitem__`` to support the indexing such that ``dataset[i]`` can105# be used to get :math:`i`\ th sample.106#107# Let's create a dataset class for our face landmarks dataset. We will108# read the csv in ``__init__`` but leave the reading of images to109# ``__getitem__``. This is memory efficient because all the images are not110# stored in the memory at once but read as required.111#112# Sample of our dataset will be a dict113# ``{'image': image, 'landmarks': landmarks}``. Our dataset will take an114# optional argument ``transform`` so that any required processing can be115# applied on the sample. We will see the usefulness of ``transform`` in the116# next section.117#118119class FaceLandmarksDataset(Dataset):120"""Face Landmarks dataset."""121122def __init__(self, csv_file, root_dir, transform=None):123"""124Arguments:125csv_file (string): Path to the csv file with annotations.126root_dir (string): Directory with all the images.127transform (callable, optional): Optional transform to be applied128on a sample.129"""130self.landmarks_frame = pd.read_csv(csv_file)131self.root_dir = root_dir132self.transform = transform133134def __len__(self):135return len(self.landmarks_frame)136137def __getitem__(self, idx):138if torch.is_tensor(idx):139idx = idx.tolist()140141img_name = os.path.join(self.root_dir,142self.landmarks_frame.iloc[idx, 0])143image = io.imread(img_name)144landmarks = self.landmarks_frame.iloc[idx, 1:]145landmarks = np.array([landmarks], dtype=float).reshape(-1, 2)146sample = {'image': image, 'landmarks': landmarks}147148if self.transform:149sample = self.transform(sample)150151return sample152153154######################################################################155# Let's instantiate this class and iterate through the data samples. We156# will print the sizes of first 4 samples and show their landmarks.157#158159face_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',160root_dir='data/faces/')161162fig = plt.figure()163164for i, sample in enumerate(face_dataset):165print(i, sample['image'].shape, sample['landmarks'].shape)166167ax = plt.subplot(1, 4, i + 1)168plt.tight_layout()169ax.set_title('Sample #{}'.format(i))170ax.axis('off')171show_landmarks(**sample)172173if i == 3:174plt.show()175break176177178######################################################################179# Transforms180# ----------181#182# One issue we can see from the above is that the samples are not of the183# same size. Most neural networks expect the images of a fixed size.184# Therefore, we will need to write some preprocessing code.185# Let's create three transforms:186#187# - ``Rescale``: to scale the image188# - ``RandomCrop``: to crop from image randomly. This is data189# augmentation.190# - ``ToTensor``: to convert the numpy images to torch images (we need to191# swap axes).192#193# We will write them as callable classes instead of simple functions so194# that parameters of the transform need not be passed every time it's195# called. For this, we just need to implement ``__call__`` method and196# if required, ``__init__`` method. We can then use a transform like this:197#198# .. code-block:: python199#200# tsfm = Transform(params)201# transformed_sample = tsfm(sample)202#203# Observe below how these transforms had to be applied both on the image and204# landmarks.205#206207class Rescale(object):208"""Rescale the image in a sample to a given size.209210Args:211output_size (tuple or int): Desired output size. If tuple, output is212matched to output_size. If int, smaller of image edges is matched213to output_size keeping aspect ratio the same.214"""215216def __init__(self, output_size):217assert isinstance(output_size, (int, tuple))218self.output_size = output_size219220def __call__(self, sample):221image, landmarks = sample['image'], sample['landmarks']222223h, w = image.shape[:2]224if isinstance(self.output_size, int):225if h > w:226new_h, new_w = self.output_size * h / w, self.output_size227else:228new_h, new_w = self.output_size, self.output_size * w / h229else:230new_h, new_w = self.output_size231232new_h, new_w = int(new_h), int(new_w)233234img = transform.resize(image, (new_h, new_w))235236# h and w are swapped for landmarks because for images,237# x and y axes are axis 1 and 0 respectively238landmarks = landmarks * [new_w / w, new_h / h]239240return {'image': img, 'landmarks': landmarks}241242243class RandomCrop(object):244"""Crop randomly the image in a sample.245246Args:247output_size (tuple or int): Desired output size. If int, square crop248is made.249"""250251def __init__(self, output_size):252assert isinstance(output_size, (int, tuple))253if isinstance(output_size, int):254self.output_size = (output_size, output_size)255else:256assert len(output_size) == 2257self.output_size = output_size258259def __call__(self, sample):260image, landmarks = sample['image'], sample['landmarks']261262h, w = image.shape[:2]263new_h, new_w = self.output_size264265top = np.random.randint(0, h - new_h + 1)266left = np.random.randint(0, w - new_w + 1)267268image = image[top: top + new_h,269left: left + new_w]270271landmarks = landmarks - [left, top]272273return {'image': image, 'landmarks': landmarks}274275276class ToTensor(object):277"""Convert ndarrays in sample to Tensors."""278279def __call__(self, sample):280image, landmarks = sample['image'], sample['landmarks']281282# swap color axis because283# numpy image: H x W x C284# torch image: C x H x W285image = image.transpose((2, 0, 1))286return {'image': torch.from_numpy(image),287'landmarks': torch.from_numpy(landmarks)}288289######################################################################290# .. note::291# In the example above, `RandomCrop` uses an external library's random number generator292# (in this case, Numpy's `np.random.int`). This can result in unexpected behavior with `DataLoader`293# (see `here <https://pytorch.org/docs/stable/notes/faq.html#my-data-loader-workers-return-identical-random-numbers>`_).294# In practice, it is safer to stick to PyTorch's random number generator, e.g. by using `torch.randint` instead.295296######################################################################297# Compose transforms298# ~~~~~~~~~~~~~~~~~~299#300# Now, we apply the transforms on a sample.301#302# Let's say we want to rescale the shorter side of the image to 256 and303# then randomly crop a square of size 224 from it. i.e, we want to compose304# ``Rescale`` and ``RandomCrop`` transforms.305# ``torchvision.transforms.Compose`` is a simple callable class which allows us306# to do this.307#308309scale = Rescale(256)310crop = RandomCrop(128)311composed = transforms.Compose([Rescale(256),312RandomCrop(224)])313314# Apply each of the above transforms on sample.315fig = plt.figure()316sample = face_dataset[65]317for i, tsfrm in enumerate([scale, crop, composed]):318transformed_sample = tsfrm(sample)319320ax = plt.subplot(1, 3, i + 1)321plt.tight_layout()322ax.set_title(type(tsfrm).__name__)323show_landmarks(**transformed_sample)324325plt.show()326327328######################################################################329# Iterating through the dataset330# -----------------------------331#332# Let's put this all together to create a dataset with composed333# transforms.334# To summarize, every time this dataset is sampled:335#336# - An image is read from the file on the fly337# - Transforms are applied on the read image338# - Since one of the transforms is random, data is augmented on339# sampling340#341# We can iterate over the created dataset with a ``for i in range``342# loop as before.343#344345transformed_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',346root_dir='data/faces/',347transform=transforms.Compose([348Rescale(256),349RandomCrop(224),350ToTensor()351]))352353for i, sample in enumerate(transformed_dataset):354print(i, sample['image'].size(), sample['landmarks'].size())355356if i == 3:357break358359360######################################################################361# However, we are losing a lot of features by using a simple ``for`` loop to362# iterate over the data. In particular, we are missing out on:363#364# - Batching the data365# - Shuffling the data366# - Load the data in parallel using ``multiprocessing`` workers.367#368# ``torch.utils.data.DataLoader`` is an iterator which provides all these369# features. Parameters used below should be clear. One parameter of370# interest is ``collate_fn``. You can specify how exactly the samples need371# to be batched using ``collate_fn``. However, default collate should work372# fine for most use cases.373#374375dataloader = DataLoader(transformed_dataset, batch_size=4,376shuffle=True, num_workers=0)377378379# Helper function to show a batch380def show_landmarks_batch(sample_batched):381"""Show image with landmarks for a batch of samples."""382images_batch, landmarks_batch = \383sample_batched['image'], sample_batched['landmarks']384batch_size = len(images_batch)385im_size = images_batch.size(2)386grid_border_size = 2387388grid = utils.make_grid(images_batch)389plt.imshow(grid.numpy().transpose((1, 2, 0)))390391for i in range(batch_size):392plt.scatter(landmarks_batch[i, :, 0].numpy() + i * im_size + (i + 1) * grid_border_size,393landmarks_batch[i, :, 1].numpy() + grid_border_size,394s=10, marker='.', c='r')395396plt.title('Batch from dataloader')397398# if you are using Windows, uncomment the next line and indent the for loop.399# you might need to go back and change ``num_workers`` to 0.400401# if __name__ == '__main__':402for i_batch, sample_batched in enumerate(dataloader):403print(i_batch, sample_batched['image'].size(),404sample_batched['landmarks'].size())405406# observe 4th batch and stop.407if i_batch == 3:408plt.figure()409show_landmarks_batch(sample_batched)410plt.axis('off')411plt.ioff()412plt.show()413break414415######################################################################416# Afterword: torchvision417# ----------------------418#419# In this tutorial, we have seen how to write and use datasets, transforms420# and dataloader. ``torchvision`` package provides some common datasets and421# transforms. You might not even have to write custom classes. One of the422# more generic datasets available in torchvision is ``ImageFolder``.423# It assumes that images are organized in the following way:424#425# .. code-block:: sh426#427# root/ants/xxx.png428# root/ants/xxy.jpeg429# root/ants/xxz.png430# .431# .432# .433# root/bees/123.jpg434# root/bees/nsdf3.png435# root/bees/asd932_.png436#437# where 'ants', 'bees' etc. are class labels. Similarly generic transforms438# which operate on ``PIL.Image`` like ``RandomHorizontalFlip``, ``Scale``,439# are also available. You can use these to write a dataloader like this:440#441# .. code-block:: pytorch442#443# import torch444# from torchvision import transforms, datasets445#446# data_transform = transforms.Compose([447# transforms.RandomSizedCrop(224),448# transforms.RandomHorizontalFlip(),449# transforms.ToTensor(),450# transforms.Normalize(mean=[0.485, 0.456, 0.406],451# std=[0.229, 0.224, 0.225])452# ])453# hymenoptera_dataset = datasets.ImageFolder(root='hymenoptera_data/train',454# transform=data_transform)455# dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset,456# batch_size=4, shuffle=True,457# num_workers=4)458#459# For an example with training code, please see460# :doc:`transfer_learning_tutorial`.461462463