CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.

GitHub Repository: pytorch/tutorials
Path: blob/main/recipes_source/recipes/zeroing_out_gradients.py
Views: 713
1
"""
2
Zeroing out gradients in PyTorch
3
================================
4
It is beneficial to zero out gradients when building a neural network.
5
This is because by default, gradients are accumulated in buffers (i.e,
6
not overwritten) whenever ``.backward()`` is called.
7
8
Introduction
9
------------
10
When training your neural network, models are able to increase their
11
accuracy through gradient descent. In short, gradient descent is the
12
process of minimizing our loss (or error) by tweaking the weights and
13
biases in our model.
14
15
``torch.Tensor`` is the central class of PyTorch. When you create a
16
tensor, if you set its attribute ``.requires_grad`` as ``True``, the
17
package tracks all operations on it. This happens on subsequent backward
18
passes. The gradient for this tensor will be accumulated into ``.grad``
19
attribute. The accumulation (or sum) of all the gradients is calculated
20
when .backward() is called on the loss tensor.
21
22
There are cases where it may be necessary to zero-out the gradients of a
23
tensor. For example: when you start your training loop, you should zero
24
out the gradients so that you can perform this tracking correctly.
25
In this recipe, we will learn how to zero out gradients using the
26
PyTorch library. We will demonstrate how to do this by training a neural
27
network on the ``CIFAR10`` dataset built into PyTorch.
28
29
Setup
30
-----
31
Since we will be training data in this recipe, if you are in a runnable
32
notebook, it is best to switch the runtime to GPU or TPU.
33
Before we begin, we need to install ``torch`` and ``torchvision`` if
34
they aren’t already available.
35
36
.. code-block:: sh
37
38
pip install torchvision
39
40
41
"""
42
43
44
######################################################################
45
# Steps
46
# -----
47
#
48
# Steps 1 through 4 set up our data and neural network for training. The
49
# process of zeroing out the gradients happens in step 5. If you already
50
# have your data and neural network built, skip to 5.
51
#
52
# 1. Import all necessary libraries for loading our data
53
# 2. Load and normalize the dataset
54
# 3. Build the neural network
55
# 4. Define the loss function
56
# 5. Zero the gradients while training the network
57
#
58
# 1. Import necessary libraries for loading our data
59
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
60
#
61
# For this recipe, we will just be using ``torch`` and ``torchvision`` to
62
# access the dataset.
63
#
64
65
import torch
66
67
import torch.nn as nn
68
import torch.nn.functional as F
69
70
import torch.optim as optim
71
72
import torchvision
73
import torchvision.transforms as transforms
74
75
76
######################################################################
77
# 2. Load and normalize the dataset
78
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
79
#
80
# PyTorch features various built-in datasets (see the Loading Data recipe
81
# for more information).
82
#
83
84
transform = transforms.Compose(
85
[transforms.ToTensor(),
86
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
87
88
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
89
download=True, transform=transform)
90
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
91
shuffle=True, num_workers=2)
92
93
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
94
download=True, transform=transform)
95
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
96
shuffle=False, num_workers=2)
97
98
classes = ('plane', 'car', 'bird', 'cat',
99
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
100
101
102
######################################################################
103
# 3. Build the neural network
104
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
105
#
106
# We will use a convolutional neural network. To learn more see the
107
# Defining a Neural Network recipe.
108
#
109
110
class Net(nn.Module):
111
def __init__(self):
112
super(Net, self).__init__()
113
self.conv1 = nn.Conv2d(3, 6, 5)
114
self.pool = nn.MaxPool2d(2, 2)
115
self.conv2 = nn.Conv2d(6, 16, 5)
116
self.fc1 = nn.Linear(16 * 5 * 5, 120)
117
self.fc2 = nn.Linear(120, 84)
118
self.fc3 = nn.Linear(84, 10)
119
120
def forward(self, x):
121
x = self.pool(F.relu(self.conv1(x)))
122
x = self.pool(F.relu(self.conv2(x)))
123
x = x.view(-1, 16 * 5 * 5)
124
x = F.relu(self.fc1(x))
125
x = F.relu(self.fc2(x))
126
x = self.fc3(x)
127
return x
128
129
130
######################################################################
131
# 4. Define a Loss function and optimizer
132
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
133
#
134
# Let’s use a Classification Cross-Entropy loss and SGD with momentum.
135
#
136
137
net = Net()
138
criterion = nn.CrossEntropyLoss()
139
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
140
141
142
######################################################################
143
# 5. Zero the gradients while training the network
144
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
145
#
146
# This is when things start to get interesting. We simply have to loop
147
# over our data iterator, and feed the inputs to the network and optimize.
148
#
149
# Notice that for each entity of data, we zero out the gradients. This is
150
# to ensure that we aren’t tracking any unnecessary information when we
151
# train our neural network.
152
#
153
154
for epoch in range(2): # loop over the dataset multiple times
155
156
running_loss = 0.0
157
for i, data in enumerate(trainloader, 0):
158
# get the inputs; data is a list of [inputs, labels]
159
inputs, labels = data
160
161
# zero the parameter gradients
162
optimizer.zero_grad()
163
164
# forward + backward + optimize
165
outputs = net(inputs)
166
loss = criterion(outputs, labels)
167
loss.backward()
168
optimizer.step()
169
170
# print statistics
171
running_loss += loss.item()
172
if i % 2000 == 1999: # print every 2000 mini-batches
173
print('[%d, %5d] loss: %.3f' %
174
(epoch + 1, i + 1, running_loss / 2000))
175
running_loss = 0.0
176
177
print('Finished Training')
178
179
180
######################################################################
181
# You can also use ``model.zero_grad()``. This is the same as using
182
# ``optimizer.zero_grad()`` as long as all your model parameters are in
183
# that optimizer. Use your best judgment to decide which one to use.
184
#
185
# Congratulations! You have successfully zeroed out gradients PyTorch.
186
#
187
# Learn More
188
# ----------
189
#
190
# Take a look at these other recipes to continue your learning:
191
#
192
# - `Loading data in PyTorch <https://pytorch.org/tutorials/beginner/basics/data_tutorial.html>`__
193
# - `Saving and loading models across devices in PyTorch <https://pytorch.org/tutorials/recipes/recipes/save_load_across_devices.html>`__
194
195