Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Path: blob/main/beginner_source/introyt/autogradyt_tutorial.py
Views: 713
"""1`Introduction <introyt1_tutorial.html>`_ ||2`Tensors <tensors_deeper_tutorial.html>`_ ||3**Autograd** ||4`Building Models <modelsyt_tutorial.html>`_ ||5`TensorBoard Support <tensorboardyt_tutorial.html>`_ ||6`Training Models <trainingyt.html>`_ ||7`Model Understanding <captumyt.html>`_89The Fundamentals of Autograd10============================1112Follow along with the video below or on `youtube <https://www.youtube.com/watch?v=M0fX15_-xrY>`__.1314.. raw:: html1516<div style="margin-top:10px; margin-bottom:10px;">17<iframe width="560" height="315" src="https://www.youtube.com/embed/M0fX15_-xrY" frameborder="0" allow="accelerometer; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe>18</div>1920PyTorch’s *Autograd* feature is part of what make PyTorch flexible and21fast for building machine learning projects. It allows for the rapid and22easy computation of multiple partial derivatives (also referred to as23*gradients)* over a complex computation. This operation is central to24backpropagation-based neural network learning.2526The power of autograd comes from the fact that it traces your27computation dynamically *at runtime,* meaning that if your model has28decision branches, or loops whose lengths are not known until runtime,29the computation will still be traced correctly, and you’ll get correct30gradients to drive learning. This, combined with the fact that your31models are built in Python, offers far more flexibility than frameworks32that rely on static analysis of a more rigidly-structured model for33computing gradients.3435What Do We Need Autograd For?36-----------------------------3738"""3940###########################################################################41# A machine learning model is a *function*, with inputs and outputs. For42# this discussion, we’ll treat the inputs as an *i*-dimensional vector43# :math:`\vec{x}`, with elements :math:`x_{i}`. We can then express the44# model, *M*, as a vector-valued function of the input: :math:`\vec{y} =45# \vec{M}(\vec{x})`. (We treat the value of M’s output as46# a vector because in general, a model may have any number of outputs.)47#48# Since we’ll mostly be discussing autograd in the context of training,49# our output of interest will be the model’s loss. The *loss function*50# L(:math:`\vec{y}`) = L(:math:`\vec{M}`\ (:math:`\vec{x}`)) is a51# single-valued scalar function of the model’s output. This function52# expresses how far off our model’s prediction was from a particular53# input’s *ideal* output. *Note: After this point, we will often omit the54# vector sign where it should be contextually clear - e.g.,* :math:`y`55# instead of :math:`\vec y`.56#57# In training a model, we want to minimize the loss. In the idealized case58# of a perfect model, that means adjusting its learning weights - that is,59# the adjustable parameters of the function - such that loss is zero for60# all inputs. In the real world, it means an iterative process of nudging61# the learning weights until we see that we get a tolerable loss for a62# wide variety of inputs.63#64# How do we decide how far and in which direction to nudge the weights? We65# want to *minimize* the loss, which means making its first derivative66# with respect to the input equal to 0:67# :math:`\frac{\partial L}{\partial x} = 0`.68#69# Recall, though, that the loss is not *directly* derived from the input,70# but a function of the model’s output (which is a function of the input71# directly), :math:`\frac{\partial L}{\partial x}` =72# :math:`\frac{\partial {L({\vec y})}}{\partial x}`. By the chain rule of73# differential calculus, we have74# :math:`\frac{\partial {L({\vec y})}}{\partial x}` =75# :math:`\frac{\partial L}{\partial y}\frac{\partial y}{\partial x}` =76# :math:`\frac{\partial L}{\partial y}\frac{\partial M(x)}{\partial x}`.77#78# :math:`\frac{\partial M(x)}{\partial x}` is where things get complex.79# The partial derivatives of the model’s outputs with respect to its80# inputs, if we were to expand the expression using the chain rule again,81# would involve many local partial derivatives over every multiplied82# learning weight, every activation function, and every other mathematical83# transformation in the model. The full expression for each such partial84# derivative is the sum of the products of the local gradient of *every85# possible path* through the computation graph that ends with the variable86# whose gradient we are trying to measure.87#88# In particular, the gradients over the learning weights are of interest89# to us - they tell us *what direction to change each weight* to get the90# loss function closer to zero.91#92# Since the number of such local derivatives (each corresponding to a93# separate path through the model’s computation graph) will tend to go up94# exponentially with the depth of a neural network, so does the complexity95# in computing them. This is where autograd comes in: It tracks the96# history of every computation. Every computed tensor in your PyTorch97# model carries a history of its input tensors and the function used to98# create it. Combined with the fact that PyTorch functions meant to act on99# tensors each have a built-in implementation for computing their own100# derivatives, this greatly speeds the computation of the local101# derivatives needed for learning.102#103# A Simple Example104# ----------------105#106# That was a lot of theory - but what does it look like to use autograd in107# practice?108#109# Let’s start with a straightforward example. First, we’ll do some imports110# to let us graph our results:111#112113# %matplotlib inline114115import torch116117import matplotlib.pyplot as plt118import matplotlib.ticker as ticker119import math120121122#########################################################################123# Next, we’ll create an input tensor full of evenly spaced values on the124# interval :math:`[0, 2{\pi}]`, and specify ``requires_grad=True``. (Like125# most functions that create tensors, ``torch.linspace()`` accepts an126# optional ``requires_grad`` option.) Setting this flag means that in127# every computation that follows, autograd will be accumulating the128# history of the computation in the output tensors of that computation.129#130131a = torch.linspace(0., 2. * math.pi, steps=25, requires_grad=True)132print(a)133134135########################################################################136# Next, we’ll perform a computation, and plot its output in terms of its137# inputs:138#139140b = torch.sin(a)141plt.plot(a.detach(), b.detach())142143144########################################################################145# Let’s have a closer look at the tensor ``b``. When we print it, we see146# an indicator that it is tracking its computation history:147#148149print(b)150151152#######################################################################153# This ``grad_fn`` gives us a hint that when we execute the154# backpropagation step and compute gradients, we’ll need to compute the155# derivative of :math:`\sin(x)` for all this tensor’s inputs.156#157# Let’s perform some more computations:158#159160c = 2 * b161print(c)162163d = c + 1164print(d)165166167##########################################################################168# Finally, let’s compute a single-element output. When you call169# ``.backward()`` on a tensor with no arguments, it expects the calling170# tensor to contain only a single element, as is the case when computing a171# loss function.172#173174out = d.sum()175print(out)176177178##########################################################################179# Each ``grad_fn`` stored with our tensors allows you to walk the180# computation all the way back to its inputs with its ``next_functions``181# property. We can see below that drilling down on this property on ``d``182# shows us the gradient functions for all the prior tensors. Note that183# ``a.grad_fn`` is reported as ``None``, indicating that this was an input184# to the function with no history of its own.185#186187print('d:')188print(d.grad_fn)189print(d.grad_fn.next_functions)190print(d.grad_fn.next_functions[0][0].next_functions)191print(d.grad_fn.next_functions[0][0].next_functions[0][0].next_functions)192print(d.grad_fn.next_functions[0][0].next_functions[0][0].next_functions[0][0].next_functions)193print('\nc:')194print(c.grad_fn)195print('\nb:')196print(b.grad_fn)197print('\na:')198print(a.grad_fn)199200201######################################################################202# With all this machinery in place, how do we get derivatives out? You203# call the ``backward()`` method on the output, and check the input’s204# ``grad`` property to inspect the gradients:205#206207out.backward()208print(a.grad)209plt.plot(a.detach(), a.grad.detach())210211212#########################################################################213# Recall the computation steps we took to get here:214#215# .. code-block:: python216#217# a = torch.linspace(0., 2. * math.pi, steps=25, requires_grad=True)218# b = torch.sin(a)219# c = 2 * b220# d = c + 1221# out = d.sum()222#223# Adding a constant, as we did to compute ``d``, does not change the224# derivative. That leaves :math:`c = 2 * b = 2 * \sin(a)`, the derivative225# of which should be :math:`2 * \cos(a)`. Looking at the graph above,226# that’s just what we see.227#228# Be aware that only *leaf nodes* of the computation have their gradients229# computed. If you tried, for example, ``print(c.grad)`` you’d get back230# ``None``. In this simple example, only the input is a leaf node, so only231# it has gradients computed.232#233# Autograd in Training234# --------------------235#236# We’ve had a brief look at how autograd works, but how does it look when237# it’s used for its intended purpose? Let’s define a small model and238# examine how it changes after a single training batch. First, define a239# few constants, our model, and some stand-ins for inputs and outputs:240#241242BATCH_SIZE = 16243DIM_IN = 1000244HIDDEN_SIZE = 100245DIM_OUT = 10246247class TinyModel(torch.nn.Module):248249def __init__(self):250super(TinyModel, self).__init__()251252self.layer1 = torch.nn.Linear(DIM_IN, HIDDEN_SIZE)253self.relu = torch.nn.ReLU()254self.layer2 = torch.nn.Linear(HIDDEN_SIZE, DIM_OUT)255256def forward(self, x):257x = self.layer1(x)258x = self.relu(x)259x = self.layer2(x)260return x261262some_input = torch.randn(BATCH_SIZE, DIM_IN, requires_grad=False)263ideal_output = torch.randn(BATCH_SIZE, DIM_OUT, requires_grad=False)264265model = TinyModel()266267268##########################################################################269# One thing you might notice is that we never specify270# ``requires_grad=True`` for the model’s layers. Within a subclass of271# ``torch.nn.Module``, it’s assumed that we want to track gradients on the272# layers’ weights for learning.273#274# If we look at the layers of the model, we can examine the values of the275# weights, and verify that no gradients have been computed yet:276#277278print(model.layer2.weight[0][0:10]) # just a small slice279print(model.layer2.weight.grad)280281282##########################################################################283# Let’s see how this changes when we run through one training batch. For a284# loss function, we’ll just use the square of the Euclidean distance285# between our ``prediction`` and the ``ideal_output``, and we’ll use a286# basic stochastic gradient descent optimizer.287#288289optimizer = torch.optim.SGD(model.parameters(), lr=0.001)290291prediction = model(some_input)292293loss = (ideal_output - prediction).pow(2).sum()294print(loss)295296297######################################################################298# Now, let’s call ``loss.backward()`` and see what happens:299#300301loss.backward()302print(model.layer2.weight[0][0:10])303print(model.layer2.weight.grad[0][0:10])304305306########################################################################307# We can see that the gradients have been computed for each learning308# weight, but the weights remain unchanged, because we haven’t run the309# optimizer yet. The optimizer is responsible for updating model weights310# based on the computed gradients.311#312313optimizer.step()314print(model.layer2.weight[0][0:10])315print(model.layer2.weight.grad[0][0:10])316317318######################################################################319# You should see that ``layer2``\ ’s weights have changed.320#321# One important thing about the process: After calling322# ``optimizer.step()``, you need to call ``optimizer.zero_grad()``, or323# else every time you run ``loss.backward()``, the gradients on the324# learning weights will accumulate:325#326327print(model.layer2.weight.grad[0][0:10])328329for i in range(0, 5):330prediction = model(some_input)331loss = (ideal_output - prediction).pow(2).sum()332loss.backward()333334print(model.layer2.weight.grad[0][0:10])335336optimizer.zero_grad(set_to_none=False)337338print(model.layer2.weight.grad[0][0:10])339340341#########################################################################342# After running the cell above, you should see that after running343# ``loss.backward()`` multiple times, the magnitudes of most of the344# gradients will be much larger. Failing to zero the gradients before345# running your next training batch will cause the gradients to blow up in346# this manner, causing incorrect and unpredictable learning results.347#348# Turning Autograd Off and On349# ---------------------------350#351# There are situations where you will need fine-grained control over352# whether autograd is enabled. There are multiple ways to do this,353# depending on the situation.354#355# The simplest is to change the ``requires_grad`` flag on a tensor356# directly:357#358359a = torch.ones(2, 3, requires_grad=True)360print(a)361362b1 = 2 * a363print(b1)364365a.requires_grad = False366b2 = 2 * a367print(b2)368369370##########################################################################371# In the cell above, we see that ``b1`` has a ``grad_fn`` (i.e., a traced372# computation history), which is what we expect, since it was derived from373# a tensor, ``a``, that had autograd turned on. When we turn off autograd374# explicitly with ``a.requires_grad = False``, computation history is no375# longer tracked, as we see when we compute ``b2``.376#377# If you only need autograd turned off temporarily, a better way is to use378# the ``torch.no_grad()``:379#380381a = torch.ones(2, 3, requires_grad=True) * 2382b = torch.ones(2, 3, requires_grad=True) * 3383384c1 = a + b385print(c1)386387with torch.no_grad():388c2 = a + b389390print(c2)391392c3 = a * b393print(c3)394395396##########################################################################397# ``torch.no_grad()`` can also be used as a function or method decorator:398#399400def add_tensors1(x, y):401return x + y402403@torch.no_grad()404def add_tensors2(x, y):405return x + y406407408a = torch.ones(2, 3, requires_grad=True) * 2409b = torch.ones(2, 3, requires_grad=True) * 3410411c1 = add_tensors1(a, b)412print(c1)413414c2 = add_tensors2(a, b)415print(c2)416417418##########################################################################419# There’s a corresponding context manager, ``torch.enable_grad()``, for420# turning autograd on when it isn’t already. It may also be used as a421# decorator.422#423# Finally, you may have a tensor that requires gradient tracking, but you424# want a copy that does not. For this we have the ``Tensor`` object’s425# ``detach()`` method - it creates a copy of the tensor that is *detached*426# from the computation history:427#428429x = torch.rand(5, requires_grad=True)430y = x.detach()431432print(x)433print(y)434435436#########################################################################437# We did this above when we wanted to graph some of our tensors. This is438# because ``matplotlib`` expects a NumPy array as input, and the implicit439# conversion from a PyTorch tensor to a NumPy array is not enabled for440# tensors with requires_grad=True. Making a detached copy lets us move441# forward.442#443# Autograd and In-place Operations444# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~445#446# In every example in this notebook so far, we’ve used variables to447# capture the intermediate values of a computation. Autograd needs these448# intermediate values to perform gradient computations. *For this reason,449# you must be careful about using in-place operations when using450# autograd.* Doing so can destroy information you need to compute451# derivatives in the ``backward()`` call. PyTorch will even stop you if452# you attempt an in-place operation on leaf variable that requires453# autograd, as shown below.454#455# .. note::456# The following code cell throws a runtime error. This is expected.457#458# .. code-block:: python459#460# a = torch.linspace(0., 2. * math.pi, steps=25, requires_grad=True)461# torch.sin_(a)462#463464#########################################################################465# Autograd Profiler466# -----------------467#468# Autograd tracks every step of your computation in detail. Such a469# computation history, combined with timing information, would make a470# handy profiler - and autograd has that feature baked in. Here’s a quick471# example usage:472#473474device = torch.device('cpu')475run_on_gpu = False476if torch.cuda.is_available():477device = torch.device('cuda')478run_on_gpu = True479480x = torch.randn(2, 3, requires_grad=True)481y = torch.rand(2, 3, requires_grad=True)482z = torch.ones(2, 3, requires_grad=True)483484with torch.autograd.profiler.profile(use_cuda=run_on_gpu) as prf:485for _ in range(1000):486z = (z / x) * y487488print(prf.key_averages().table(sort_by='self_cpu_time_total'))489490491##########################################################################492# The profiler can also label individual sub-blocks of code, break out the493# data by input tensor shape, and export data as a Chrome tracing tools494# file. For full details of the API, see the495# `documentation <https://pytorch.org/docs/stable/autograd.html#profiler>`__.496#497# Advanced Topic: More Autograd Detail and the High-Level API498# -----------------------------------------------------------499#500# If you have a function with an n-dimensional input and m-dimensional501# output, :math:`\vec{y}=f(\vec{x})`, the complete gradient is a matrix of502# the derivative of every output with respect to every input, called the503# *Jacobian:*504#505# .. math::506#507# J508# =509# \left(\begin{array}{ccc}510# \frac{\partial y_{1}}{\partial x_{1}} & \cdots & \frac{\partial y_{1}}{\partial x_{n}}\\511# \vdots & \ddots & \vdots\\512# \frac{\partial y_{m}}{\partial x_{1}} & \cdots & \frac{\partial y_{m}}{\partial x_{n}}513# \end{array}\right)514#515# If you have a second function, :math:`l=g\left(\vec{y}\right)` that516# takes m-dimensional input (that is, the same dimensionality as the517# output above), and returns a scalar output, you can express its518# gradients with respect to :math:`\vec{y}` as a column vector,519# :math:`v=\left(\begin{array}{ccc}\frac{\partial l}{\partial y_{1}} & \cdots & \frac{\partial l}{\partial y_{m}}\end{array}\right)^{T}`520# - which is really just a one-column Jacobian.521#522# More concretely, imagine the first function as your PyTorch model (with523# potentially many inputs and many outputs) and the second function as a524# loss function (with the model’s output as input, and the loss value as525# the scalar output).526#527# If we multiply the first function’s Jacobian by the gradient of the528# second function, and apply the chain rule, we get:529#530# .. math::531#532# J^{T}\cdot v=\left(\begin{array}{ccc}533# \frac{\partial y_{1}}{\partial x_{1}} & \cdots & \frac{\partial y_{m}}{\partial x_{1}}\\534# \vdots & \ddots & \vdots\\535# \frac{\partial y_{1}}{\partial x_{n}} & \cdots & \frac{\partial y_{m}}{\partial x_{n}}536# \end{array}\right)\left(\begin{array}{c}537# \frac{\partial l}{\partial y_{1}}\\538# \vdots\\539# \frac{\partial l}{\partial y_{m}}540# \end{array}\right)=\left(\begin{array}{c}541# \frac{\partial l}{\partial x_{1}}\\542# \vdots\\543# \frac{\partial l}{\partial x_{n}}544# \end{array}\right)545#546# Note: You could also use the equivalent operation :math:`v^{T}\cdot J`,547# and get back a row vector.548#549# The resulting column vector is the *gradient of the second function with550# respect to the inputs of the first* - or in the case of our model and551# loss function, the gradient of the loss with respect to the model552# inputs.553#554# **``torch.autograd`` is an engine for computing these products.** This555# is how we accumulate the gradients over the learning weights during the556# backward pass.557#558# For this reason, the ``backward()`` call can *also* take an optional559# vector input. This vector represents a set of gradients over the tensor,560# which are multiplied by the Jacobian of the autograd-traced tensor that561# precedes it. Let’s try a specific example with a small vector:562#563564x = torch.randn(3, requires_grad=True)565566y = x * 2567while y.data.norm() < 1000:568y = y * 2569570print(y)571572573##########################################################################574# If we tried to call ``y.backward()`` now, we’d get a runtime error and a575# message that gradients can only be *implicitly* computed for scalar576# outputs. For a multi-dimensional output, autograd expects us to provide577# gradients for those three outputs that it can multiply into the578# Jacobian:579#580581v = torch.tensor([0.1, 1.0, 0.0001], dtype=torch.float) # stand-in for gradients582y.backward(v)583584print(x.grad)585586587##########################################################################588# (Note that the output gradients are all related to powers of two - which589# we’d expect from a repeated doubling operation.)590#591# The High-Level API592# ~~~~~~~~~~~~~~~~~~593#594# There is an API on autograd that gives you direct access to important595# differential matrix and vector operations. In particular, it allows you596# to calculate the Jacobian and the *Hessian* matrices of a particular597# function for particular inputs. (The Hessian is like the Jacobian, but598# expresses all partial *second* derivatives.) It also provides methods599# for taking vector products with these matrices.600#601# Let’s take the Jacobian of a simple function, evaluated for a 2602# single-element inputs:603#604605def exp_adder(x, y):606return 2 * x.exp() + 3 * y607608inputs = (torch.rand(1), torch.rand(1)) # arguments for the function609print(inputs)610torch.autograd.functional.jacobian(exp_adder, inputs)611612613########################################################################614# If you look closely, the first output should equal :math:`2e^x` (since615# the derivative of :math:`e^x` is :math:`e^x`), and the second value616# should be 3.617#618# You can, of course, do this with higher-order tensors:619#620621inputs = (torch.rand(3), torch.rand(3)) # arguments for the function622print(inputs)623torch.autograd.functional.jacobian(exp_adder, inputs)624625626#########################################################################627# The ``torch.autograd.functional.hessian()`` method works identically628# (assuming your function is twice differentiable), but returns a matrix629# of all second derivatives.630#631# There is also a function to directly compute the vector-Jacobian632# product, if you provide the vector:633#634635def do_some_doubling(x):636y = x * 2637while y.data.norm() < 1000:638y = y * 2639return y640641inputs = torch.randn(3)642my_gradients = torch.tensor([0.1, 1.0, 0.0001])643torch.autograd.functional.vjp(do_some_doubling, inputs, v=my_gradients)644645646##############################################################################647# The ``torch.autograd.functional.jvp()`` method performs the same matrix648# multiplication as ``vjp()`` with the operands reversed. The ``vhp()``649# and ``hvp()`` methods do the same for a vector-Hessian product.650#651# For more information, including performance notes on the `docs for the652# functional653# API <https://pytorch.org/docs/stable/autograd.html#functional-higher-level-api>`__654#655656657