CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!

GitHub Repository: pytorch/tutorials
Path: blob/main/beginner_source/introyt/autogradyt_tutorial.py
Views: 494
1
"""
2
`Introduction <introyt1_tutorial.html>`_ ||
3
`Tensors <tensors_deeper_tutorial.html>`_ ||
4
**Autograd** ||
5
`Building Models <modelsyt_tutorial.html>`_ ||
6
`TensorBoard Support <tensorboardyt_tutorial.html>`_ ||
7
`Training Models <trainingyt.html>`_ ||
8
`Model Understanding <captumyt.html>`_
9
10
The Fundamentals of Autograd
11
============================
12
13
Follow along with the video below or on `youtube <https://www.youtube.com/watch?v=M0fX15_-xrY>`__.
14
15
.. raw:: html
16
17
<div style="margin-top:10px; margin-bottom:10px;">
18
<iframe width="560" height="315" src="https://www.youtube.com/embed/M0fX15_-xrY" frameborder="0" allow="accelerometer; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe>
19
</div>
20
21
PyTorch’s *Autograd* feature is part of what make PyTorch flexible and
22
fast for building machine learning projects. It allows for the rapid and
23
easy computation of multiple partial derivatives (also referred to as
24
*gradients)* over a complex computation. This operation is central to
25
backpropagation-based neural network learning.
26
27
The power of autograd comes from the fact that it traces your
28
computation dynamically *at runtime,* meaning that if your model has
29
decision branches, or loops whose lengths are not known until runtime,
30
the computation will still be traced correctly, and you’ll get correct
31
gradients to drive learning. This, combined with the fact that your
32
models are built in Python, offers far more flexibility than frameworks
33
that rely on static analysis of a more rigidly-structured model for
34
computing gradients.
35
36
What Do We Need Autograd For?
37
-----------------------------
38
39
"""
40
41
###########################################################################
42
# A machine learning model is a *function*, with inputs and outputs. For
43
# this discussion, we’ll treat the inputs as an *i*-dimensional vector
44
# :math:`\vec{x}`, with elements :math:`x_{i}`. We can then express the
45
# model, *M*, as a vector-valued function of the input: :math:`\vec{y} =
46
# \vec{M}(\vec{x})`. (We treat the value of M’s output as
47
# a vector because in general, a model may have any number of outputs.)
48
#
49
# Since we’ll mostly be discussing autograd in the context of training,
50
# our output of interest will be the model’s loss. The *loss function*
51
# L(:math:`\vec{y}`) = L(:math:`\vec{M}`\ (:math:`\vec{x}`)) is a
52
# single-valued scalar function of the model’s output. This function
53
# expresses how far off our model’s prediction was from a particular
54
# input’s *ideal* output. *Note: After this point, we will often omit the
55
# vector sign where it should be contextually clear - e.g.,* :math:`y`
56
# instead of :math:`\vec y`.
57
#
58
# In training a model, we want to minimize the loss. In the idealized case
59
# of a perfect model, that means adjusting its learning weights - that is,
60
# the adjustable parameters of the function - such that loss is zero for
61
# all inputs. In the real world, it means an iterative process of nudging
62
# the learning weights until we see that we get a tolerable loss for a
63
# wide variety of inputs.
64
#
65
# How do we decide how far and in which direction to nudge the weights? We
66
# want to *minimize* the loss, which means making its first derivative
67
# with respect to the input equal to 0:
68
# :math:`\frac{\partial L}{\partial x} = 0`.
69
#
70
# Recall, though, that the loss is not *directly* derived from the input,
71
# but a function of the model’s output (which is a function of the input
72
# directly), :math:`\frac{\partial L}{\partial x}` =
73
# :math:`\frac{\partial {L({\vec y})}}{\partial x}`. By the chain rule of
74
# differential calculus, we have
75
# :math:`\frac{\partial {L({\vec y})}}{\partial x}` =
76
# :math:`\frac{\partial L}{\partial y}\frac{\partial y}{\partial x}` =
77
# :math:`\frac{\partial L}{\partial y}\frac{\partial M(x)}{\partial x}`.
78
#
79
# :math:`\frac{\partial M(x)}{\partial x}` is where things get complex.
80
# The partial derivatives of the model’s outputs with respect to its
81
# inputs, if we were to expand the expression using the chain rule again,
82
# would involve many local partial derivatives over every multiplied
83
# learning weight, every activation function, and every other mathematical
84
# transformation in the model. The full expression for each such partial
85
# derivative is the sum of the products of the local gradient of *every
86
# possible path* through the computation graph that ends with the variable
87
# whose gradient we are trying to measure.
88
#
89
# In particular, the gradients over the learning weights are of interest
90
# to us - they tell us *what direction to change each weight* to get the
91
# loss function closer to zero.
92
#
93
# Since the number of such local derivatives (each corresponding to a
94
# separate path through the model’s computation graph) will tend to go up
95
# exponentially with the depth of a neural network, so does the complexity
96
# in computing them. This is where autograd comes in: It tracks the
97
# history of every computation. Every computed tensor in your PyTorch
98
# model carries a history of its input tensors and the function used to
99
# create it. Combined with the fact that PyTorch functions meant to act on
100
# tensors each have a built-in implementation for computing their own
101
# derivatives, this greatly speeds the computation of the local
102
# derivatives needed for learning.
103
#
104
# A Simple Example
105
# ----------------
106
#
107
# That was a lot of theory - but what does it look like to use autograd in
108
# practice?
109
#
110
# Let’s start with a straightforward example. First, we’ll do some imports
111
# to let us graph our results:
112
#
113
114
# %matplotlib inline
115
116
import torch
117
118
import matplotlib.pyplot as plt
119
import matplotlib.ticker as ticker
120
import math
121
122
123
#########################################################################
124
# Next, we’ll create an input tensor full of evenly spaced values on the
125
# interval :math:`[0, 2{\pi}]`, and specify ``requires_grad=True``. (Like
126
# most functions that create tensors, ``torch.linspace()`` accepts an
127
# optional ``requires_grad`` option.) Setting this flag means that in
128
# every computation that follows, autograd will be accumulating the
129
# history of the computation in the output tensors of that computation.
130
#
131
132
a = torch.linspace(0., 2. * math.pi, steps=25, requires_grad=True)
133
print(a)
134
135
136
########################################################################
137
# Next, we’ll perform a computation, and plot its output in terms of its
138
# inputs:
139
#
140
141
b = torch.sin(a)
142
plt.plot(a.detach(), b.detach())
143
144
145
########################################################################
146
# Let’s have a closer look at the tensor ``b``. When we print it, we see
147
# an indicator that it is tracking its computation history:
148
#
149
150
print(b)
151
152
153
#######################################################################
154
# This ``grad_fn`` gives us a hint that when we execute the
155
# backpropagation step and compute gradients, we’ll need to compute the
156
# derivative of :math:`\sin(x)` for all this tensor’s inputs.
157
#
158
# Let’s perform some more computations:
159
#
160
161
c = 2 * b
162
print(c)
163
164
d = c + 1
165
print(d)
166
167
168
##########################################################################
169
# Finally, let’s compute a single-element output. When you call
170
# ``.backward()`` on a tensor with no arguments, it expects the calling
171
# tensor to contain only a single element, as is the case when computing a
172
# loss function.
173
#
174
175
out = d.sum()
176
print(out)
177
178
179
##########################################################################
180
# Each ``grad_fn`` stored with our tensors allows you to walk the
181
# computation all the way back to its inputs with its ``next_functions``
182
# property. We can see below that drilling down on this property on ``d``
183
# shows us the gradient functions for all the prior tensors. Note that
184
# ``a.grad_fn`` is reported as ``None``, indicating that this was an input
185
# to the function with no history of its own.
186
#
187
188
print('d:')
189
print(d.grad_fn)
190
print(d.grad_fn.next_functions)
191
print(d.grad_fn.next_functions[0][0].next_functions)
192
print(d.grad_fn.next_functions[0][0].next_functions[0][0].next_functions)
193
print(d.grad_fn.next_functions[0][0].next_functions[0][0].next_functions[0][0].next_functions)
194
print('\nc:')
195
print(c.grad_fn)
196
print('\nb:')
197
print(b.grad_fn)
198
print('\na:')
199
print(a.grad_fn)
200
201
202
######################################################################
203
# With all this machinery in place, how do we get derivatives out? You
204
# call the ``backward()`` method on the output, and check the input’s
205
# ``grad`` property to inspect the gradients:
206
#
207
208
out.backward()
209
print(a.grad)
210
plt.plot(a.detach(), a.grad.detach())
211
212
213
#########################################################################
214
# Recall the computation steps we took to get here:
215
#
216
# .. code-block:: python
217
#
218
# a = torch.linspace(0., 2. * math.pi, steps=25, requires_grad=True)
219
# b = torch.sin(a)
220
# c = 2 * b
221
# d = c + 1
222
# out = d.sum()
223
#
224
# Adding a constant, as we did to compute ``d``, does not change the
225
# derivative. That leaves :math:`c = 2 * b = 2 * \sin(a)`, the derivative
226
# of which should be :math:`2 * \cos(a)`. Looking at the graph above,
227
# that’s just what we see.
228
#
229
# Be aware that only *leaf nodes* of the computation have their gradients
230
# computed. If you tried, for example, ``print(c.grad)`` you’d get back
231
# ``None``. In this simple example, only the input is a leaf node, so only
232
# it has gradients computed.
233
#
234
# Autograd in Training
235
# --------------------
236
#
237
# We’ve had a brief look at how autograd works, but how does it look when
238
# it’s used for its intended purpose? Let’s define a small model and
239
# examine how it changes after a single training batch. First, define a
240
# few constants, our model, and some stand-ins for inputs and outputs:
241
#
242
243
BATCH_SIZE = 16
244
DIM_IN = 1000
245
HIDDEN_SIZE = 100
246
DIM_OUT = 10
247
248
class TinyModel(torch.nn.Module):
249
250
def __init__(self):
251
super(TinyModel, self).__init__()
252
253
self.layer1 = torch.nn.Linear(DIM_IN, HIDDEN_SIZE)
254
self.relu = torch.nn.ReLU()
255
self.layer2 = torch.nn.Linear(HIDDEN_SIZE, DIM_OUT)
256
257
def forward(self, x):
258
x = self.layer1(x)
259
x = self.relu(x)
260
x = self.layer2(x)
261
return x
262
263
some_input = torch.randn(BATCH_SIZE, DIM_IN, requires_grad=False)
264
ideal_output = torch.randn(BATCH_SIZE, DIM_OUT, requires_grad=False)
265
266
model = TinyModel()
267
268
269
##########################################################################
270
# One thing you might notice is that we never specify
271
# ``requires_grad=True`` for the model’s layers. Within a subclass of
272
# ``torch.nn.Module``, it’s assumed that we want to track gradients on the
273
# layers’ weights for learning.
274
#
275
# If we look at the layers of the model, we can examine the values of the
276
# weights, and verify that no gradients have been computed yet:
277
#
278
279
print(model.layer2.weight[0][0:10]) # just a small slice
280
print(model.layer2.weight.grad)
281
282
283
##########################################################################
284
# Let’s see how this changes when we run through one training batch. For a
285
# loss function, we’ll just use the square of the Euclidean distance
286
# between our ``prediction`` and the ``ideal_output``, and we’ll use a
287
# basic stochastic gradient descent optimizer.
288
#
289
290
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
291
292
prediction = model(some_input)
293
294
loss = (ideal_output - prediction).pow(2).sum()
295
print(loss)
296
297
298
######################################################################
299
# Now, let’s call ``loss.backward()`` and see what happens:
300
#
301
302
loss.backward()
303
print(model.layer2.weight[0][0:10])
304
print(model.layer2.weight.grad[0][0:10])
305
306
307
########################################################################
308
# We can see that the gradients have been computed for each learning
309
# weight, but the weights remain unchanged, because we haven’t run the
310
# optimizer yet. The optimizer is responsible for updating model weights
311
# based on the computed gradients.
312
#
313
314
optimizer.step()
315
print(model.layer2.weight[0][0:10])
316
print(model.layer2.weight.grad[0][0:10])
317
318
319
######################################################################
320
# You should see that ``layer2``\ ’s weights have changed.
321
#
322
# One important thing about the process: After calling
323
# ``optimizer.step()``, you need to call ``optimizer.zero_grad()``, or
324
# else every time you run ``loss.backward()``, the gradients on the
325
# learning weights will accumulate:
326
#
327
328
print(model.layer2.weight.grad[0][0:10])
329
330
for i in range(0, 5):
331
prediction = model(some_input)
332
loss = (ideal_output - prediction).pow(2).sum()
333
loss.backward()
334
335
print(model.layer2.weight.grad[0][0:10])
336
337
optimizer.zero_grad(set_to_none=False)
338
339
print(model.layer2.weight.grad[0][0:10])
340
341
342
#########################################################################
343
# After running the cell above, you should see that after running
344
# ``loss.backward()`` multiple times, the magnitudes of most of the
345
# gradients will be much larger. Failing to zero the gradients before
346
# running your next training batch will cause the gradients to blow up in
347
# this manner, causing incorrect and unpredictable learning results.
348
#
349
# Turning Autograd Off and On
350
# ---------------------------
351
#
352
# There are situations where you will need fine-grained control over
353
# whether autograd is enabled. There are multiple ways to do this,
354
# depending on the situation.
355
#
356
# The simplest is to change the ``requires_grad`` flag on a tensor
357
# directly:
358
#
359
360
a = torch.ones(2, 3, requires_grad=True)
361
print(a)
362
363
b1 = 2 * a
364
print(b1)
365
366
a.requires_grad = False
367
b2 = 2 * a
368
print(b2)
369
370
371
##########################################################################
372
# In the cell above, we see that ``b1`` has a ``grad_fn`` (i.e., a traced
373
# computation history), which is what we expect, since it was derived from
374
# a tensor, ``a``, that had autograd turned on. When we turn off autograd
375
# explicitly with ``a.requires_grad = False``, computation history is no
376
# longer tracked, as we see when we compute ``b2``.
377
#
378
# If you only need autograd turned off temporarily, a better way is to use
379
# the ``torch.no_grad()``:
380
#
381
382
a = torch.ones(2, 3, requires_grad=True) * 2
383
b = torch.ones(2, 3, requires_grad=True) * 3
384
385
c1 = a + b
386
print(c1)
387
388
with torch.no_grad():
389
c2 = a + b
390
391
print(c2)
392
393
c3 = a * b
394
print(c3)
395
396
397
##########################################################################
398
# ``torch.no_grad()`` can also be used as a function or method decorator:
399
#
400
401
def add_tensors1(x, y):
402
return x + y
403
404
@torch.no_grad()
405
def add_tensors2(x, y):
406
return x + y
407
408
409
a = torch.ones(2, 3, requires_grad=True) * 2
410
b = torch.ones(2, 3, requires_grad=True) * 3
411
412
c1 = add_tensors1(a, b)
413
print(c1)
414
415
c2 = add_tensors2(a, b)
416
print(c2)
417
418
419
##########################################################################
420
# There’s a corresponding context manager, ``torch.enable_grad()``, for
421
# turning autograd on when it isn’t already. It may also be used as a
422
# decorator.
423
#
424
# Finally, you may have a tensor that requires gradient tracking, but you
425
# want a copy that does not. For this we have the ``Tensor`` object’s
426
# ``detach()`` method - it creates a copy of the tensor that is *detached*
427
# from the computation history:
428
#
429
430
x = torch.rand(5, requires_grad=True)
431
y = x.detach()
432
433
print(x)
434
print(y)
435
436
437
#########################################################################
438
# We did this above when we wanted to graph some of our tensors. This is
439
# because ``matplotlib`` expects a NumPy array as input, and the implicit
440
# conversion from a PyTorch tensor to a NumPy array is not enabled for
441
# tensors with requires_grad=True. Making a detached copy lets us move
442
# forward.
443
#
444
# Autograd and In-place Operations
445
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
446
#
447
# In every example in this notebook so far, we’ve used variables to
448
# capture the intermediate values of a computation. Autograd needs these
449
# intermediate values to perform gradient computations. *For this reason,
450
# you must be careful about using in-place operations when using
451
# autograd.* Doing so can destroy information you need to compute
452
# derivatives in the ``backward()`` call. PyTorch will even stop you if
453
# you attempt an in-place operation on leaf variable that requires
454
# autograd, as shown below.
455
#
456
# .. note::
457
# The following code cell throws a runtime error. This is expected.
458
#
459
# .. code-block:: python
460
#
461
# a = torch.linspace(0., 2. * math.pi, steps=25, requires_grad=True)
462
# torch.sin_(a)
463
#
464
465
#########################################################################
466
# Autograd Profiler
467
# -----------------
468
#
469
# Autograd tracks every step of your computation in detail. Such a
470
# computation history, combined with timing information, would make a
471
# handy profiler - and autograd has that feature baked in. Here’s a quick
472
# example usage:
473
#
474
475
device = torch.device('cpu')
476
run_on_gpu = False
477
if torch.cuda.is_available():
478
device = torch.device('cuda')
479
run_on_gpu = True
480
481
x = torch.randn(2, 3, requires_grad=True)
482
y = torch.rand(2, 3, requires_grad=True)
483
z = torch.ones(2, 3, requires_grad=True)
484
485
with torch.autograd.profiler.profile(use_cuda=run_on_gpu) as prf:
486
for _ in range(1000):
487
z = (z / x) * y
488
489
print(prf.key_averages().table(sort_by='self_cpu_time_total'))
490
491
492
##########################################################################
493
# The profiler can also label individual sub-blocks of code, break out the
494
# data by input tensor shape, and export data as a Chrome tracing tools
495
# file. For full details of the API, see the
496
# `documentation <https://pytorch.org/docs/stable/autograd.html#profiler>`__.
497
#
498
# Advanced Topic: More Autograd Detail and the High-Level API
499
# -----------------------------------------------------------
500
#
501
# If you have a function with an n-dimensional input and m-dimensional
502
# output, :math:`\vec{y}=f(\vec{x})`, the complete gradient is a matrix of
503
# the derivative of every output with respect to every input, called the
504
# *Jacobian:*
505
#
506
# .. math::
507
#
508
# J
509
# =
510
# \left(\begin{array}{ccc}
511
# \frac{\partial y_{1}}{\partial x_{1}} & \cdots & \frac{\partial y_{1}}{\partial x_{n}}\\
512
# \vdots & \ddots & \vdots\\
513
# \frac{\partial y_{m}}{\partial x_{1}} & \cdots & \frac{\partial y_{m}}{\partial x_{n}}
514
# \end{array}\right)
515
#
516
# If you have a second function, :math:`l=g\left(\vec{y}\right)` that
517
# takes m-dimensional input (that is, the same dimensionality as the
518
# output above), and returns a scalar output, you can express its
519
# gradients with respect to :math:`\vec{y}` as a column vector,
520
# :math:`v=\left(\begin{array}{ccc}\frac{\partial l}{\partial y_{1}} & \cdots & \frac{\partial l}{\partial y_{m}}\end{array}\right)^{T}`
521
# - which is really just a one-column Jacobian.
522
#
523
# More concretely, imagine the first function as your PyTorch model (with
524
# potentially many inputs and many outputs) and the second function as a
525
# loss function (with the model’s output as input, and the loss value as
526
# the scalar output).
527
#
528
# If we multiply the first function’s Jacobian by the gradient of the
529
# second function, and apply the chain rule, we get:
530
#
531
# .. math::
532
#
533
# J^{T}\cdot v=\left(\begin{array}{ccc}
534
# \frac{\partial y_{1}}{\partial x_{1}} & \cdots & \frac{\partial y_{m}}{\partial x_{1}}\\
535
# \vdots & \ddots & \vdots\\
536
# \frac{\partial y_{1}}{\partial x_{n}} & \cdots & \frac{\partial y_{m}}{\partial x_{n}}
537
# \end{array}\right)\left(\begin{array}{c}
538
# \frac{\partial l}{\partial y_{1}}\\
539
# \vdots\\
540
# \frac{\partial l}{\partial y_{m}}
541
# \end{array}\right)=\left(\begin{array}{c}
542
# \frac{\partial l}{\partial x_{1}}\\
543
# \vdots\\
544
# \frac{\partial l}{\partial x_{n}}
545
# \end{array}\right)
546
#
547
# Note: You could also use the equivalent operation :math:`v^{T}\cdot J`,
548
# and get back a row vector.
549
#
550
# The resulting column vector is the *gradient of the second function with
551
# respect to the inputs of the first* - or in the case of our model and
552
# loss function, the gradient of the loss with respect to the model
553
# inputs.
554
#
555
# **``torch.autograd`` is an engine for computing these products.** This
556
# is how we accumulate the gradients over the learning weights during the
557
# backward pass.
558
#
559
# For this reason, the ``backward()`` call can *also* take an optional
560
# vector input. This vector represents a set of gradients over the tensor,
561
# which are multiplied by the Jacobian of the autograd-traced tensor that
562
# precedes it. Let’s try a specific example with a small vector:
563
#
564
565
x = torch.randn(3, requires_grad=True)
566
567
y = x * 2
568
while y.data.norm() < 1000:
569
y = y * 2
570
571
print(y)
572
573
574
##########################################################################
575
# If we tried to call ``y.backward()`` now, we’d get a runtime error and a
576
# message that gradients can only be *implicitly* computed for scalar
577
# outputs. For a multi-dimensional output, autograd expects us to provide
578
# gradients for those three outputs that it can multiply into the
579
# Jacobian:
580
#
581
582
v = torch.tensor([0.1, 1.0, 0.0001], dtype=torch.float) # stand-in for gradients
583
y.backward(v)
584
585
print(x.grad)
586
587
588
##########################################################################
589
# (Note that the output gradients are all related to powers of two - which
590
# we’d expect from a repeated doubling operation.)
591
#
592
# The High-Level API
593
# ~~~~~~~~~~~~~~~~~~
594
#
595
# There is an API on autograd that gives you direct access to important
596
# differential matrix and vector operations. In particular, it allows you
597
# to calculate the Jacobian and the *Hessian* matrices of a particular
598
# function for particular inputs. (The Hessian is like the Jacobian, but
599
# expresses all partial *second* derivatives.) It also provides methods
600
# for taking vector products with these matrices.
601
#
602
# Let’s take the Jacobian of a simple function, evaluated for a 2
603
# single-element inputs:
604
#
605
606
def exp_adder(x, y):
607
return 2 * x.exp() + 3 * y
608
609
inputs = (torch.rand(1), torch.rand(1)) # arguments for the function
610
print(inputs)
611
torch.autograd.functional.jacobian(exp_adder, inputs)
612
613
614
########################################################################
615
# If you look closely, the first output should equal :math:`2e^x` (since
616
# the derivative of :math:`e^x` is :math:`e^x`), and the second value
617
# should be 3.
618
#
619
# You can, of course, do this with higher-order tensors:
620
#
621
622
inputs = (torch.rand(3), torch.rand(3)) # arguments for the function
623
print(inputs)
624
torch.autograd.functional.jacobian(exp_adder, inputs)
625
626
627
#########################################################################
628
# The ``torch.autograd.functional.hessian()`` method works identically
629
# (assuming your function is twice differentiable), but returns a matrix
630
# of all second derivatives.
631
#
632
# There is also a function to directly compute the vector-Jacobian
633
# product, if you provide the vector:
634
#
635
636
def do_some_doubling(x):
637
y = x * 2
638
while y.data.norm() < 1000:
639
y = y * 2
640
return y
641
642
inputs = torch.randn(3)
643
my_gradients = torch.tensor([0.1, 1.0, 0.0001])
644
torch.autograd.functional.vjp(do_some_doubling, inputs, v=my_gradients)
645
646
647
##############################################################################
648
# The ``torch.autograd.functional.jvp()`` method performs the same matrix
649
# multiplication as ``vjp()`` with the operands reversed. The ``vhp()``
650
# and ``hvp()`` methods do the same for a vector-Hessian product.
651
#
652
# For more information, including performance notes on the `docs for the
653
# functional
654
# API <https://pytorch.org/docs/stable/autograd.html#functional-higher-level-api>`__
655
#
656
657