CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.

GitHub Repository: pytorch/tutorials
Path: blob/main/beginner_source/basics/data_tutorial.py
Views: 713
1
"""
2
`Learn the Basics <intro.html>`_ ||
3
`Quickstart <quickstart_tutorial.html>`_ ||
4
`Tensors <tensorqs_tutorial.html>`_ ||
5
**Datasets & DataLoaders** ||
6
`Transforms <transforms_tutorial.html>`_ ||
7
`Build Model <buildmodel_tutorial.html>`_ ||
8
`Autograd <autogradqs_tutorial.html>`_ ||
9
`Optimization <optimization_tutorial.html>`_ ||
10
`Save & Load Model <saveloadrun_tutorial.html>`_
11
12
Datasets & DataLoaders
13
======================
14
15
"""
16
17
#################################################################
18
# Code for processing data samples can get messy and hard to maintain; we ideally want our dataset code
19
# to be decoupled from our model training code for better readability and modularity.
20
# PyTorch provides two data primitives: ``torch.utils.data.DataLoader`` and ``torch.utils.data.Dataset``
21
# that allow you to use pre-loaded datasets as well as your own data.
22
# ``Dataset`` stores the samples and their corresponding labels, and ``DataLoader`` wraps an iterable around
23
# the ``Dataset`` to enable easy access to the samples.
24
#
25
# PyTorch domain libraries provide a number of pre-loaded datasets (such as FashionMNIST) that
26
# subclass ``torch.utils.data.Dataset`` and implement functions specific to the particular data.
27
# They can be used to prototype and benchmark your model. You can find them
28
# here: `Image Datasets <https://pytorch.org/vision/stable/datasets.html>`_,
29
# `Text Datasets <https://pytorch.org/text/stable/datasets.html>`_, and
30
# `Audio Datasets <https://pytorch.org/audio/stable/datasets.html>`_
31
#
32
33
############################################################
34
# Loading a Dataset
35
# -------------------
36
#
37
# Here is an example of how to load the `Fashion-MNIST <https://research.zalando.com/project/fashion_mnist/fashion_mnist/>`_ dataset from TorchVision.
38
# Fashion-MNIST is a dataset of Zalando’s article images consisting of 60,000 training examples and 10,000 test examples.
39
# Each example comprises a 28×28 grayscale image and an associated label from one of 10 classes.
40
#
41
# We load the `FashionMNIST Dataset <https://pytorch.org/vision/stable/datasets.html#fashion-mnist>`_ with the following parameters:
42
# - ``root`` is the path where the train/test data is stored,
43
# - ``train`` specifies training or test dataset,
44
# - ``download=True`` downloads the data from the internet if it's not available at ``root``.
45
# - ``transform`` and ``target_transform`` specify the feature and label transformations
46
47
48
import torch
49
from torch.utils.data import Dataset
50
from torchvision import datasets
51
from torchvision.transforms import ToTensor
52
import matplotlib.pyplot as plt
53
54
55
training_data = datasets.FashionMNIST(
56
root="data",
57
train=True,
58
download=True,
59
transform=ToTensor()
60
)
61
62
test_data = datasets.FashionMNIST(
63
root="data",
64
train=False,
65
download=True,
66
transform=ToTensor()
67
)
68
69
70
#################################################################
71
# Iterating and Visualizing the Dataset
72
# -------------------------------------
73
#
74
# We can index ``Datasets`` manually like a list: ``training_data[index]``.
75
# We use ``matplotlib`` to visualize some samples in our training data.
76
77
labels_map = {
78
0: "T-Shirt",
79
1: "Trouser",
80
2: "Pullover",
81
3: "Dress",
82
4: "Coat",
83
5: "Sandal",
84
6: "Shirt",
85
7: "Sneaker",
86
8: "Bag",
87
9: "Ankle Boot",
88
}
89
figure = plt.figure(figsize=(8, 8))
90
cols, rows = 3, 3
91
for i in range(1, cols * rows + 1):
92
sample_idx = torch.randint(len(training_data), size=(1,)).item()
93
img, label = training_data[sample_idx]
94
figure.add_subplot(rows, cols, i)
95
plt.title(labels_map[label])
96
plt.axis("off")
97
plt.imshow(img.squeeze(), cmap="gray")
98
plt.show()
99
100
#################################################################
101
# ..
102
# .. figure:: /_static/img/basics/fashion_mnist.png
103
# :alt: fashion_mnist
104
105
106
######################################################################
107
# --------------
108
#
109
110
#################################################################
111
# Creating a Custom Dataset for your files
112
# ---------------------------------------------------
113
#
114
# A custom Dataset class must implement three functions: `__init__`, `__len__`, and `__getitem__`.
115
# Take a look at this implementation; the FashionMNIST images are stored
116
# in a directory ``img_dir``, and their labels are stored separately in a CSV file ``annotations_file``.
117
#
118
# In the next sections, we'll break down what's happening in each of these functions.
119
120
121
import os
122
import pandas as pd
123
from torchvision.io import read_image
124
125
class CustomImageDataset(Dataset):
126
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
127
self.img_labels = pd.read_csv(annotations_file)
128
self.img_dir = img_dir
129
self.transform = transform
130
self.target_transform = target_transform
131
132
def __len__(self):
133
return len(self.img_labels)
134
135
def __getitem__(self, idx):
136
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
137
image = read_image(img_path)
138
label = self.img_labels.iloc[idx, 1]
139
if self.transform:
140
image = self.transform(image)
141
if self.target_transform:
142
label = self.target_transform(label)
143
return image, label
144
145
146
#################################################################
147
# ``__init__``
148
# ^^^^^^^^^^^^^^^^^^^^
149
#
150
# The __init__ function is run once when instantiating the Dataset object. We initialize
151
# the directory containing the images, the annotations file, and both transforms (covered
152
# in more detail in the next section).
153
#
154
# The labels.csv file looks like: ::
155
#
156
# tshirt1.jpg, 0
157
# tshirt2.jpg, 0
158
# ......
159
# ankleboot999.jpg, 9
160
161
162
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
163
self.img_labels = pd.read_csv(annotations_file)
164
self.img_dir = img_dir
165
self.transform = transform
166
self.target_transform = target_transform
167
168
169
#################################################################
170
# ``__len__``
171
# ^^^^^^^^^^^^^^^^^^^^
172
#
173
# The __len__ function returns the number of samples in our dataset.
174
#
175
# Example:
176
177
178
def __len__(self):
179
return len(self.img_labels)
180
181
182
#################################################################
183
# ``__getitem__``
184
# ^^^^^^^^^^^^^^^^^^^^
185
#
186
# The __getitem__ function loads and returns a sample from the dataset at the given index ``idx``.
187
# Based on the index, it identifies the image's location on disk, converts that to a tensor using ``read_image``, retrieves the
188
# corresponding label from the csv data in ``self.img_labels``, calls the transform functions on them (if applicable), and returns the
189
# tensor image and corresponding label in a tuple.
190
191
def __getitem__(self, idx):
192
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
193
image = read_image(img_path)
194
label = self.img_labels.iloc[idx, 1]
195
if self.transform:
196
image = self.transform(image)
197
if self.target_transform:
198
label = self.target_transform(label)
199
return image, label
200
201
202
######################################################################
203
# --------------
204
#
205
206
207
#################################################################
208
# Preparing your data for training with DataLoaders
209
# -------------------------------------------------
210
# The ``Dataset`` retrieves our dataset's features and labels one sample at a time. While training a model, we typically want to
211
# pass samples in "minibatches", reshuffle the data at every epoch to reduce model overfitting, and use Python's ``multiprocessing`` to
212
# speed up data retrieval.
213
#
214
# ``DataLoader`` is an iterable that abstracts this complexity for us in an easy API.
215
216
from torch.utils.data import DataLoader
217
218
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
219
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
220
221
###########################
222
# Iterate through the DataLoader
223
# -------------------------------
224
#
225
# We have loaded that dataset into the ``DataLoader`` and can iterate through the dataset as needed.
226
# Each iteration below returns a batch of ``train_features`` and ``train_labels`` (containing ``batch_size=64`` features and labels respectively).
227
# Because we specified ``shuffle=True``, after we iterate over all batches the data is shuffled (for finer-grained control over
228
# the data loading order, take a look at `Samplers <https://pytorch.org/docs/stable/data.html#data-loading-order-and-sampler>`_).
229
230
# Display image and label.
231
train_features, train_labels = next(iter(train_dataloader))
232
print(f"Feature batch shape: {train_features.size()}")
233
print(f"Labels batch shape: {train_labels.size()}")
234
img = train_features[0].squeeze()
235
label = train_labels[0]
236
plt.imshow(img, cmap="gray")
237
plt.show()
238
print(f"Label: {label}")
239
240
######################################################################
241
# --------------
242
#
243
244
#################################################################
245
# Further Reading
246
# ----------------
247
# - `torch.utils.data API <https://pytorch.org/docs/stable/data.html>`_
248
249