CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.

GitHub Repository: pytorch/tutorials
Path: blob/main/beginner_source/basics/transforms_tutorial.py
Views: 713
1
"""
2
`Learn the Basics <intro.html>`_ ||
3
`Quickstart <quickstart_tutorial.html>`_ ||
4
`Tensors <tensorqs_tutorial.html>`_ ||
5
`Datasets & DataLoaders <data_tutorial.html>`_ ||
6
**Transforms** ||
7
`Build Model <buildmodel_tutorial.html>`_ ||
8
`Autograd <autogradqs_tutorial.html>`_ ||
9
`Optimization <optimization_tutorial.html>`_ ||
10
`Save & Load Model <saveloadrun_tutorial.html>`_
11
12
Transforms
13
===================
14
15
Data does not always come in its final processed form that is required for
16
training machine learning algorithms. We use **transforms** to perform some
17
manipulation of the data and make it suitable for training.
18
19
All TorchVision datasets have two parameters -``transform`` to modify the features and
20
``target_transform`` to modify the labels - that accept callables containing the transformation logic.
21
The `torchvision.transforms <https://pytorch.org/vision/stable/transforms.html>`_ module offers
22
several commonly-used transforms out of the box.
23
24
The FashionMNIST features are in PIL Image format, and the labels are integers.
25
For training, we need the features as normalized tensors, and the labels as one-hot encoded tensors.
26
To make these transformations, we use ``ToTensor`` and ``Lambda``.
27
"""
28
29
import torch
30
from torchvision import datasets
31
from torchvision.transforms import ToTensor, Lambda
32
33
ds = datasets.FashionMNIST(
34
root="data",
35
train=True,
36
download=True,
37
transform=ToTensor(),
38
target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
39
)
40
41
#################################################
42
# ToTensor()
43
# -------------------------------
44
#
45
# `ToTensor <https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.ToTensor>`_
46
# converts a PIL image or NumPy ``ndarray`` into a ``FloatTensor``. and scales
47
# the image's pixel intensity values in the range [0., 1.]
48
#
49
50
##############################################
51
# Lambda Transforms
52
# -------------------------------
53
#
54
# Lambda transforms apply any user-defined lambda function. Here, we define a function
55
# to turn the integer into a one-hot encoded tensor.
56
# It first creates a zero tensor of size 10 (the number of labels in our dataset) and calls
57
# `scatter_ <https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_.html>`_ which assigns a
58
# ``value=1`` on the index as given by the label ``y``.
59
60
target_transform = Lambda(lambda y: torch.zeros(
61
10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))
62
63
######################################################################
64
# --------------
65
#
66
67
#################################################################
68
# Further Reading
69
# ~~~~~~~~~~~~~~~~~
70
# - `torchvision.transforms API <https://pytorch.org/vision/stable/transforms.html>`_
71
72