Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Path: blob/main/recipes_source/recipes/zeroing_out_gradients.py
Views: 713
"""1Zeroing out gradients in PyTorch2================================3It is beneficial to zero out gradients when building a neural network.4This is because by default, gradients are accumulated in buffers (i.e,5not overwritten) whenever ``.backward()`` is called.67Introduction8------------9When training your neural network, models are able to increase their10accuracy through gradient descent. In short, gradient descent is the11process of minimizing our loss (or error) by tweaking the weights and12biases in our model.1314``torch.Tensor`` is the central class of PyTorch. When you create a15tensor, if you set its attribute ``.requires_grad`` as ``True``, the16package tracks all operations on it. This happens on subsequent backward17passes. The gradient for this tensor will be accumulated into ``.grad``18attribute. The accumulation (or sum) of all the gradients is calculated19when .backward() is called on the loss tensor.2021There are cases where it may be necessary to zero-out the gradients of a22tensor. For example: when you start your training loop, you should zero23out the gradients so that you can perform this tracking correctly.24In this recipe, we will learn how to zero out gradients using the25PyTorch library. We will demonstrate how to do this by training a neural26network on the ``CIFAR10`` dataset built into PyTorch.2728Setup29-----30Since we will be training data in this recipe, if you are in a runnable31notebook, it is best to switch the runtime to GPU or TPU.32Before we begin, we need to install ``torch`` and ``torchvision`` if33they aren’t already available.3435.. code-block:: sh3637pip install torchvision383940"""414243######################################################################44# Steps45# -----46#47# Steps 1 through 4 set up our data and neural network for training. The48# process of zeroing out the gradients happens in step 5. If you already49# have your data and neural network built, skip to 5.50#51# 1. Import all necessary libraries for loading our data52# 2. Load and normalize the dataset53# 3. Build the neural network54# 4. Define the loss function55# 5. Zero the gradients while training the network56#57# 1. Import necessary libraries for loading our data58# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~59#60# For this recipe, we will just be using ``torch`` and ``torchvision`` to61# access the dataset.62#6364import torch6566import torch.nn as nn67import torch.nn.functional as F6869import torch.optim as optim7071import torchvision72import torchvision.transforms as transforms737475######################################################################76# 2. Load and normalize the dataset77# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~78#79# PyTorch features various built-in datasets (see the Loading Data recipe80# for more information).81#8283transform = transforms.Compose(84[transforms.ToTensor(),85transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])8687trainset = torchvision.datasets.CIFAR10(root='./data', train=True,88download=True, transform=transform)89trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,90shuffle=True, num_workers=2)9192testset = torchvision.datasets.CIFAR10(root='./data', train=False,93download=True, transform=transform)94testloader = torch.utils.data.DataLoader(testset, batch_size=4,95shuffle=False, num_workers=2)9697classes = ('plane', 'car', 'bird', 'cat',98'deer', 'dog', 'frog', 'horse', 'ship', 'truck')99100101######################################################################102# 3. Build the neural network103# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~104#105# We will use a convolutional neural network. To learn more see the106# Defining a Neural Network recipe.107#108109class Net(nn.Module):110def __init__(self):111super(Net, self).__init__()112self.conv1 = nn.Conv2d(3, 6, 5)113self.pool = nn.MaxPool2d(2, 2)114self.conv2 = nn.Conv2d(6, 16, 5)115self.fc1 = nn.Linear(16 * 5 * 5, 120)116self.fc2 = nn.Linear(120, 84)117self.fc3 = nn.Linear(84, 10)118119def forward(self, x):120x = self.pool(F.relu(self.conv1(x)))121x = self.pool(F.relu(self.conv2(x)))122x = x.view(-1, 16 * 5 * 5)123x = F.relu(self.fc1(x))124x = F.relu(self.fc2(x))125x = self.fc3(x)126return x127128129######################################################################130# 4. Define a Loss function and optimizer131# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~132#133# Let’s use a Classification Cross-Entropy loss and SGD with momentum.134#135136net = Net()137criterion = nn.CrossEntropyLoss()138optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)139140141######################################################################142# 5. Zero the gradients while training the network143# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~144#145# This is when things start to get interesting. We simply have to loop146# over our data iterator, and feed the inputs to the network and optimize.147#148# Notice that for each entity of data, we zero out the gradients. This is149# to ensure that we aren’t tracking any unnecessary information when we150# train our neural network.151#152153for epoch in range(2): # loop over the dataset multiple times154155running_loss = 0.0156for i, data in enumerate(trainloader, 0):157# get the inputs; data is a list of [inputs, labels]158inputs, labels = data159160# zero the parameter gradients161optimizer.zero_grad()162163# forward + backward + optimize164outputs = net(inputs)165loss = criterion(outputs, labels)166loss.backward()167optimizer.step()168169# print statistics170running_loss += loss.item()171if i % 2000 == 1999: # print every 2000 mini-batches172print('[%d, %5d] loss: %.3f' %173(epoch + 1, i + 1, running_loss / 2000))174running_loss = 0.0175176print('Finished Training')177178179######################################################################180# You can also use ``model.zero_grad()``. This is the same as using181# ``optimizer.zero_grad()`` as long as all your model parameters are in182# that optimizer. Use your best judgment to decide which one to use.183#184# Congratulations! You have successfully zeroed out gradients PyTorch.185#186# Learn More187# ----------188#189# Take a look at these other recipes to continue your learning:190#191# - `Loading data in PyTorch <https://pytorch.org/tutorials/beginner/basics/data_tutorial.html>`__192# - `Saving and loading models across devices in PyTorch <https://pytorch.org/tutorials/recipes/recipes/save_load_across_devices.html>`__193194195