Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Path: blob/main/beginner_source/basics/data_tutorial.py
Views: 713
"""1`Learn the Basics <intro.html>`_ ||2`Quickstart <quickstart_tutorial.html>`_ ||3`Tensors <tensorqs_tutorial.html>`_ ||4**Datasets & DataLoaders** ||5`Transforms <transforms_tutorial.html>`_ ||6`Build Model <buildmodel_tutorial.html>`_ ||7`Autograd <autogradqs_tutorial.html>`_ ||8`Optimization <optimization_tutorial.html>`_ ||9`Save & Load Model <saveloadrun_tutorial.html>`_1011Datasets & DataLoaders12======================1314"""1516#################################################################17# Code for processing data samples can get messy and hard to maintain; we ideally want our dataset code18# to be decoupled from our model training code for better readability and modularity.19# PyTorch provides two data primitives: ``torch.utils.data.DataLoader`` and ``torch.utils.data.Dataset``20# that allow you to use pre-loaded datasets as well as your own data.21# ``Dataset`` stores the samples and their corresponding labels, and ``DataLoader`` wraps an iterable around22# the ``Dataset`` to enable easy access to the samples.23#24# PyTorch domain libraries provide a number of pre-loaded datasets (such as FashionMNIST) that25# subclass ``torch.utils.data.Dataset`` and implement functions specific to the particular data.26# They can be used to prototype and benchmark your model. You can find them27# here: `Image Datasets <https://pytorch.org/vision/stable/datasets.html>`_,28# `Text Datasets <https://pytorch.org/text/stable/datasets.html>`_, and29# `Audio Datasets <https://pytorch.org/audio/stable/datasets.html>`_30#3132############################################################33# Loading a Dataset34# -------------------35#36# Here is an example of how to load the `Fashion-MNIST <https://research.zalando.com/project/fashion_mnist/fashion_mnist/>`_ dataset from TorchVision.37# Fashion-MNIST is a dataset of Zalando’s article images consisting of 60,000 training examples and 10,000 test examples.38# Each example comprises a 28×28 grayscale image and an associated label from one of 10 classes.39#40# We load the `FashionMNIST Dataset <https://pytorch.org/vision/stable/datasets.html#fashion-mnist>`_ with the following parameters:41# - ``root`` is the path where the train/test data is stored,42# - ``train`` specifies training or test dataset,43# - ``download=True`` downloads the data from the internet if it's not available at ``root``.44# - ``transform`` and ``target_transform`` specify the feature and label transformations454647import torch48from torch.utils.data import Dataset49from torchvision import datasets50from torchvision.transforms import ToTensor51import matplotlib.pyplot as plt525354training_data = datasets.FashionMNIST(55root="data",56train=True,57download=True,58transform=ToTensor()59)6061test_data = datasets.FashionMNIST(62root="data",63train=False,64download=True,65transform=ToTensor()66)676869#################################################################70# Iterating and Visualizing the Dataset71# -------------------------------------72#73# We can index ``Datasets`` manually like a list: ``training_data[index]``.74# We use ``matplotlib`` to visualize some samples in our training data.7576labels_map = {770: "T-Shirt",781: "Trouser",792: "Pullover",803: "Dress",814: "Coat",825: "Sandal",836: "Shirt",847: "Sneaker",858: "Bag",869: "Ankle Boot",87}88figure = plt.figure(figsize=(8, 8))89cols, rows = 3, 390for i in range(1, cols * rows + 1):91sample_idx = torch.randint(len(training_data), size=(1,)).item()92img, label = training_data[sample_idx]93figure.add_subplot(rows, cols, i)94plt.title(labels_map[label])95plt.axis("off")96plt.imshow(img.squeeze(), cmap="gray")97plt.show()9899#################################################################100# ..101# .. figure:: /_static/img/basics/fashion_mnist.png102# :alt: fashion_mnist103104105######################################################################106# --------------107#108109#################################################################110# Creating a Custom Dataset for your files111# ---------------------------------------------------112#113# A custom Dataset class must implement three functions: `__init__`, `__len__`, and `__getitem__`.114# Take a look at this implementation; the FashionMNIST images are stored115# in a directory ``img_dir``, and their labels are stored separately in a CSV file ``annotations_file``.116#117# In the next sections, we'll break down what's happening in each of these functions.118119120import os121import pandas as pd122from torchvision.io import read_image123124class CustomImageDataset(Dataset):125def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):126self.img_labels = pd.read_csv(annotations_file)127self.img_dir = img_dir128self.transform = transform129self.target_transform = target_transform130131def __len__(self):132return len(self.img_labels)133134def __getitem__(self, idx):135img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])136image = read_image(img_path)137label = self.img_labels.iloc[idx, 1]138if self.transform:139image = self.transform(image)140if self.target_transform:141label = self.target_transform(label)142return image, label143144145#################################################################146# ``__init__``147# ^^^^^^^^^^^^^^^^^^^^148#149# The __init__ function is run once when instantiating the Dataset object. We initialize150# the directory containing the images, the annotations file, and both transforms (covered151# in more detail in the next section).152#153# The labels.csv file looks like: ::154#155# tshirt1.jpg, 0156# tshirt2.jpg, 0157# ......158# ankleboot999.jpg, 9159160161def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):162self.img_labels = pd.read_csv(annotations_file)163self.img_dir = img_dir164self.transform = transform165self.target_transform = target_transform166167168#################################################################169# ``__len__``170# ^^^^^^^^^^^^^^^^^^^^171#172# The __len__ function returns the number of samples in our dataset.173#174# Example:175176177def __len__(self):178return len(self.img_labels)179180181#################################################################182# ``__getitem__``183# ^^^^^^^^^^^^^^^^^^^^184#185# The __getitem__ function loads and returns a sample from the dataset at the given index ``idx``.186# Based on the index, it identifies the image's location on disk, converts that to a tensor using ``read_image``, retrieves the187# corresponding label from the csv data in ``self.img_labels``, calls the transform functions on them (if applicable), and returns the188# tensor image and corresponding label in a tuple.189190def __getitem__(self, idx):191img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])192image = read_image(img_path)193label = self.img_labels.iloc[idx, 1]194if self.transform:195image = self.transform(image)196if self.target_transform:197label = self.target_transform(label)198return image, label199200201######################################################################202# --------------203#204205206#################################################################207# Preparing your data for training with DataLoaders208# -------------------------------------------------209# The ``Dataset`` retrieves our dataset's features and labels one sample at a time. While training a model, we typically want to210# pass samples in "minibatches", reshuffle the data at every epoch to reduce model overfitting, and use Python's ``multiprocessing`` to211# speed up data retrieval.212#213# ``DataLoader`` is an iterable that abstracts this complexity for us in an easy API.214215from torch.utils.data import DataLoader216217train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)218test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)219220###########################221# Iterate through the DataLoader222# -------------------------------223#224# We have loaded that dataset into the ``DataLoader`` and can iterate through the dataset as needed.225# Each iteration below returns a batch of ``train_features`` and ``train_labels`` (containing ``batch_size=64`` features and labels respectively).226# Because we specified ``shuffle=True``, after we iterate over all batches the data is shuffled (for finer-grained control over227# the data loading order, take a look at `Samplers <https://pytorch.org/docs/stable/data.html#data-loading-order-and-sampler>`_).228229# Display image and label.230train_features, train_labels = next(iter(train_dataloader))231print(f"Feature batch shape: {train_features.size()}")232print(f"Labels batch shape: {train_labels.size()}")233img = train_features[0].squeeze()234label = train_labels[0]235plt.imshow(img, cmap="gray")236plt.show()237print(f"Label: {label}")238239######################################################################240# --------------241#242243#################################################################244# Further Reading245# ----------------246# - `torch.utils.data API <https://pytorch.org/docs/stable/data.html>`_247248249