CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.

GitHub Repository: pytorch/tutorials
Path: blob/main/beginner_source/data_loading_tutorial.py
Views: 712
1
# -*- coding: utf-8 -*-
2
"""
3
Writing Custom Datasets, DataLoaders and Transforms
4
===================================================
5
**Author**: `Sasank Chilamkurthy <https://chsasank.github.io>`_
6
7
A lot of effort in solving any machine learning problem goes into
8
preparing the data. PyTorch provides many tools to make data loading
9
easy and hopefully, to make your code more readable. In this tutorial,
10
we will see how to load and preprocess/augment data from a non trivial
11
dataset.
12
13
To run this tutorial, please make sure the following packages are
14
installed:
15
16
- ``scikit-image``: For image io and transforms
17
- ``pandas``: For easier csv parsing
18
19
"""
20
21
import os
22
import torch
23
import pandas as pd
24
from skimage import io, transform
25
import numpy as np
26
import matplotlib.pyplot as plt
27
from torch.utils.data import Dataset, DataLoader
28
from torchvision import transforms, utils
29
30
# Ignore warnings
31
import warnings
32
warnings.filterwarnings("ignore")
33
34
plt.ion() # interactive mode
35
36
######################################################################
37
# The dataset we are going to deal with is that of facial pose.
38
# This means that a face is annotated like this:
39
#
40
# .. figure:: /_static/img/landmarked_face2.png
41
# :width: 400
42
#
43
# Over all, 68 different landmark points are annotated for each face.
44
#
45
# .. note::
46
# Download the dataset from `here <https://download.pytorch.org/tutorial/faces.zip>`_
47
# so that the images are in a directory named 'data/faces/'.
48
# This dataset was actually
49
# generated by applying excellent `dlib's pose
50
# estimation <https://blog.dlib.net/2014/08/real-time-face-pose-estimation.html>`__
51
# on a few images from imagenet tagged as 'face'.
52
#
53
# Dataset comes with a ``.csv`` file with annotations which looks like this:
54
#
55
# .. code-block:: sh
56
#
57
# image_name,part_0_x,part_0_y,part_1_x,part_1_y,part_2_x, ... ,part_67_x,part_67_y
58
# 0805personali01.jpg,27,83,27,98, ... 84,134
59
# 1084239450_e76e00b7e7.jpg,70,236,71,257, ... ,128,312
60
#
61
# Let's take a single image name and its annotations from the CSV, in this case row index number 65
62
# for person-7.jpg just as an example. Read it, store the image name in ``img_name`` and store its
63
# annotations in an (L, 2) array ``landmarks`` where L is the number of landmarks in that row.
64
#
65
66
landmarks_frame = pd.read_csv('data/faces/face_landmarks.csv')
67
68
n = 65
69
img_name = landmarks_frame.iloc[n, 0]
70
landmarks = landmarks_frame.iloc[n, 1:]
71
landmarks = np.asarray(landmarks, dtype=float).reshape(-1, 2)
72
73
print('Image name: {}'.format(img_name))
74
print('Landmarks shape: {}'.format(landmarks.shape))
75
print('First 4 Landmarks: {}'.format(landmarks[:4]))
76
77
78
######################################################################
79
# Let's write a simple helper function to show an image and its landmarks
80
# and use it to show a sample.
81
#
82
83
def show_landmarks(image, landmarks):
84
"""Show image with landmarks"""
85
plt.imshow(image)
86
plt.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker='.', c='r')
87
plt.pause(0.001) # pause a bit so that plots are updated
88
89
plt.figure()
90
show_landmarks(io.imread(os.path.join('data/faces/', img_name)),
91
landmarks)
92
plt.show()
93
94
95
######################################################################
96
# Dataset class
97
# -------------
98
#
99
# ``torch.utils.data.Dataset`` is an abstract class representing a
100
# dataset.
101
# Your custom dataset should inherit ``Dataset`` and override the following
102
# methods:
103
#
104
# - ``__len__`` so that ``len(dataset)`` returns the size of the dataset.
105
# - ``__getitem__`` to support the indexing such that ``dataset[i]`` can
106
# be used to get :math:`i`\ th sample.
107
#
108
# Let's create a dataset class for our face landmarks dataset. We will
109
# read the csv in ``__init__`` but leave the reading of images to
110
# ``__getitem__``. This is memory efficient because all the images are not
111
# stored in the memory at once but read as required.
112
#
113
# Sample of our dataset will be a dict
114
# ``{'image': image, 'landmarks': landmarks}``. Our dataset will take an
115
# optional argument ``transform`` so that any required processing can be
116
# applied on the sample. We will see the usefulness of ``transform`` in the
117
# next section.
118
#
119
120
class FaceLandmarksDataset(Dataset):
121
"""Face Landmarks dataset."""
122
123
def __init__(self, csv_file, root_dir, transform=None):
124
"""
125
Arguments:
126
csv_file (string): Path to the csv file with annotations.
127
root_dir (string): Directory with all the images.
128
transform (callable, optional): Optional transform to be applied
129
on a sample.
130
"""
131
self.landmarks_frame = pd.read_csv(csv_file)
132
self.root_dir = root_dir
133
self.transform = transform
134
135
def __len__(self):
136
return len(self.landmarks_frame)
137
138
def __getitem__(self, idx):
139
if torch.is_tensor(idx):
140
idx = idx.tolist()
141
142
img_name = os.path.join(self.root_dir,
143
self.landmarks_frame.iloc[idx, 0])
144
image = io.imread(img_name)
145
landmarks = self.landmarks_frame.iloc[idx, 1:]
146
landmarks = np.array([landmarks], dtype=float).reshape(-1, 2)
147
sample = {'image': image, 'landmarks': landmarks}
148
149
if self.transform:
150
sample = self.transform(sample)
151
152
return sample
153
154
155
######################################################################
156
# Let's instantiate this class and iterate through the data samples. We
157
# will print the sizes of first 4 samples and show their landmarks.
158
#
159
160
face_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',
161
root_dir='data/faces/')
162
163
fig = plt.figure()
164
165
for i, sample in enumerate(face_dataset):
166
print(i, sample['image'].shape, sample['landmarks'].shape)
167
168
ax = plt.subplot(1, 4, i + 1)
169
plt.tight_layout()
170
ax.set_title('Sample #{}'.format(i))
171
ax.axis('off')
172
show_landmarks(**sample)
173
174
if i == 3:
175
plt.show()
176
break
177
178
179
######################################################################
180
# Transforms
181
# ----------
182
#
183
# One issue we can see from the above is that the samples are not of the
184
# same size. Most neural networks expect the images of a fixed size.
185
# Therefore, we will need to write some preprocessing code.
186
# Let's create three transforms:
187
#
188
# - ``Rescale``: to scale the image
189
# - ``RandomCrop``: to crop from image randomly. This is data
190
# augmentation.
191
# - ``ToTensor``: to convert the numpy images to torch images (we need to
192
# swap axes).
193
#
194
# We will write them as callable classes instead of simple functions so
195
# that parameters of the transform need not be passed every time it's
196
# called. For this, we just need to implement ``__call__`` method and
197
# if required, ``__init__`` method. We can then use a transform like this:
198
#
199
# .. code-block:: python
200
#
201
# tsfm = Transform(params)
202
# transformed_sample = tsfm(sample)
203
#
204
# Observe below how these transforms had to be applied both on the image and
205
# landmarks.
206
#
207
208
class Rescale(object):
209
"""Rescale the image in a sample to a given size.
210
211
Args:
212
output_size (tuple or int): Desired output size. If tuple, output is
213
matched to output_size. If int, smaller of image edges is matched
214
to output_size keeping aspect ratio the same.
215
"""
216
217
def __init__(self, output_size):
218
assert isinstance(output_size, (int, tuple))
219
self.output_size = output_size
220
221
def __call__(self, sample):
222
image, landmarks = sample['image'], sample['landmarks']
223
224
h, w = image.shape[:2]
225
if isinstance(self.output_size, int):
226
if h > w:
227
new_h, new_w = self.output_size * h / w, self.output_size
228
else:
229
new_h, new_w = self.output_size, self.output_size * w / h
230
else:
231
new_h, new_w = self.output_size
232
233
new_h, new_w = int(new_h), int(new_w)
234
235
img = transform.resize(image, (new_h, new_w))
236
237
# h and w are swapped for landmarks because for images,
238
# x and y axes are axis 1 and 0 respectively
239
landmarks = landmarks * [new_w / w, new_h / h]
240
241
return {'image': img, 'landmarks': landmarks}
242
243
244
class RandomCrop(object):
245
"""Crop randomly the image in a sample.
246
247
Args:
248
output_size (tuple or int): Desired output size. If int, square crop
249
is made.
250
"""
251
252
def __init__(self, output_size):
253
assert isinstance(output_size, (int, tuple))
254
if isinstance(output_size, int):
255
self.output_size = (output_size, output_size)
256
else:
257
assert len(output_size) == 2
258
self.output_size = output_size
259
260
def __call__(self, sample):
261
image, landmarks = sample['image'], sample['landmarks']
262
263
h, w = image.shape[:2]
264
new_h, new_w = self.output_size
265
266
top = np.random.randint(0, h - new_h + 1)
267
left = np.random.randint(0, w - new_w + 1)
268
269
image = image[top: top + new_h,
270
left: left + new_w]
271
272
landmarks = landmarks - [left, top]
273
274
return {'image': image, 'landmarks': landmarks}
275
276
277
class ToTensor(object):
278
"""Convert ndarrays in sample to Tensors."""
279
280
def __call__(self, sample):
281
image, landmarks = sample['image'], sample['landmarks']
282
283
# swap color axis because
284
# numpy image: H x W x C
285
# torch image: C x H x W
286
image = image.transpose((2, 0, 1))
287
return {'image': torch.from_numpy(image),
288
'landmarks': torch.from_numpy(landmarks)}
289
290
######################################################################
291
# .. note::
292
# In the example above, `RandomCrop` uses an external library's random number generator
293
# (in this case, Numpy's `np.random.int`). This can result in unexpected behavior with `DataLoader`
294
# (see `here <https://pytorch.org/docs/stable/notes/faq.html#my-data-loader-workers-return-identical-random-numbers>`_).
295
# In practice, it is safer to stick to PyTorch's random number generator, e.g. by using `torch.randint` instead.
296
297
######################################################################
298
# Compose transforms
299
# ~~~~~~~~~~~~~~~~~~
300
#
301
# Now, we apply the transforms on a sample.
302
#
303
# Let's say we want to rescale the shorter side of the image to 256 and
304
# then randomly crop a square of size 224 from it. i.e, we want to compose
305
# ``Rescale`` and ``RandomCrop`` transforms.
306
# ``torchvision.transforms.Compose`` is a simple callable class which allows us
307
# to do this.
308
#
309
310
scale = Rescale(256)
311
crop = RandomCrop(128)
312
composed = transforms.Compose([Rescale(256),
313
RandomCrop(224)])
314
315
# Apply each of the above transforms on sample.
316
fig = plt.figure()
317
sample = face_dataset[65]
318
for i, tsfrm in enumerate([scale, crop, composed]):
319
transformed_sample = tsfrm(sample)
320
321
ax = plt.subplot(1, 3, i + 1)
322
plt.tight_layout()
323
ax.set_title(type(tsfrm).__name__)
324
show_landmarks(**transformed_sample)
325
326
plt.show()
327
328
329
######################################################################
330
# Iterating through the dataset
331
# -----------------------------
332
#
333
# Let's put this all together to create a dataset with composed
334
# transforms.
335
# To summarize, every time this dataset is sampled:
336
#
337
# - An image is read from the file on the fly
338
# - Transforms are applied on the read image
339
# - Since one of the transforms is random, data is augmented on
340
# sampling
341
#
342
# We can iterate over the created dataset with a ``for i in range``
343
# loop as before.
344
#
345
346
transformed_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',
347
root_dir='data/faces/',
348
transform=transforms.Compose([
349
Rescale(256),
350
RandomCrop(224),
351
ToTensor()
352
]))
353
354
for i, sample in enumerate(transformed_dataset):
355
print(i, sample['image'].size(), sample['landmarks'].size())
356
357
if i == 3:
358
break
359
360
361
######################################################################
362
# However, we are losing a lot of features by using a simple ``for`` loop to
363
# iterate over the data. In particular, we are missing out on:
364
#
365
# - Batching the data
366
# - Shuffling the data
367
# - Load the data in parallel using ``multiprocessing`` workers.
368
#
369
# ``torch.utils.data.DataLoader`` is an iterator which provides all these
370
# features. Parameters used below should be clear. One parameter of
371
# interest is ``collate_fn``. You can specify how exactly the samples need
372
# to be batched using ``collate_fn``. However, default collate should work
373
# fine for most use cases.
374
#
375
376
dataloader = DataLoader(transformed_dataset, batch_size=4,
377
shuffle=True, num_workers=0)
378
379
380
# Helper function to show a batch
381
def show_landmarks_batch(sample_batched):
382
"""Show image with landmarks for a batch of samples."""
383
images_batch, landmarks_batch = \
384
sample_batched['image'], sample_batched['landmarks']
385
batch_size = len(images_batch)
386
im_size = images_batch.size(2)
387
grid_border_size = 2
388
389
grid = utils.make_grid(images_batch)
390
plt.imshow(grid.numpy().transpose((1, 2, 0)))
391
392
for i in range(batch_size):
393
plt.scatter(landmarks_batch[i, :, 0].numpy() + i * im_size + (i + 1) * grid_border_size,
394
landmarks_batch[i, :, 1].numpy() + grid_border_size,
395
s=10, marker='.', c='r')
396
397
plt.title('Batch from dataloader')
398
399
# if you are using Windows, uncomment the next line and indent the for loop.
400
# you might need to go back and change ``num_workers`` to 0.
401
402
# if __name__ == '__main__':
403
for i_batch, sample_batched in enumerate(dataloader):
404
print(i_batch, sample_batched['image'].size(),
405
sample_batched['landmarks'].size())
406
407
# observe 4th batch and stop.
408
if i_batch == 3:
409
plt.figure()
410
show_landmarks_batch(sample_batched)
411
plt.axis('off')
412
plt.ioff()
413
plt.show()
414
break
415
416
######################################################################
417
# Afterword: torchvision
418
# ----------------------
419
#
420
# In this tutorial, we have seen how to write and use datasets, transforms
421
# and dataloader. ``torchvision`` package provides some common datasets and
422
# transforms. You might not even have to write custom classes. One of the
423
# more generic datasets available in torchvision is ``ImageFolder``.
424
# It assumes that images are organized in the following way:
425
#
426
# .. code-block:: sh
427
#
428
# root/ants/xxx.png
429
# root/ants/xxy.jpeg
430
# root/ants/xxz.png
431
# .
432
# .
433
# .
434
# root/bees/123.jpg
435
# root/bees/nsdf3.png
436
# root/bees/asd932_.png
437
#
438
# where 'ants', 'bees' etc. are class labels. Similarly generic transforms
439
# which operate on ``PIL.Image`` like ``RandomHorizontalFlip``, ``Scale``,
440
# are also available. You can use these to write a dataloader like this:
441
#
442
# .. code-block:: pytorch
443
#
444
# import torch
445
# from torchvision import transforms, datasets
446
#
447
# data_transform = transforms.Compose([
448
# transforms.RandomSizedCrop(224),
449
# transforms.RandomHorizontalFlip(),
450
# transforms.ToTensor(),
451
# transforms.Normalize(mean=[0.485, 0.456, 0.406],
452
# std=[0.229, 0.224, 0.225])
453
# ])
454
# hymenoptera_dataset = datasets.ImageFolder(root='hymenoptera_data/train',
455
# transform=data_transform)
456
# dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset,
457
# batch_size=4, shuffle=True,
458
# num_workers=4)
459
#
460
# For an example with training code, please see
461
# :doc:`transfer_learning_tutorial`.
462
463