CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!

GitHub Repository: pytorch/tutorials
Path: blob/main/intermediate_source/autograd_saved_tensors_hooks_tutorial.py
Views: 494
1
"""
2
Hooks for autograd saved tensors
3
================================
4
5
"""
6
7
8
######################################################################
9
# PyTorch typically computes gradients using backpropagation. However,
10
# certain operations require intermediary results to be saved in order to
11
# perform backpropagation. This tutorial walks through how these tensors
12
# are saved/retrieved and how you can define hooks to control the
13
# packing/unpacking process.
14
#
15
# This tutorial assumes you are familiar with how backpropagation works in
16
# theory. If not, read `this <https://colab.research.google.com/drive/1aWNdmYt7RcHMbUk-Xz2Cv5-cGFSWPXe0#scrollTo=AHcEJ6nXUb7W>`_ first.
17
#
18
19
20
######################################################################
21
# Saved tensors
22
# -------------
23
#
24
25
26
######################################################################
27
# Training a model usually consumes more memory than running it for
28
# inference. Broadly speaking, one can say that it is because “PyTorch
29
# needs to save the computation graph, which is needed to call
30
# ``backward``”, hence the additional memory usage. One goal of this
31
# tutorial is to finetune this understanding.
32
#
33
# In fact, the graph in itself sometimes does not consume much more memory
34
# as it never copies any tensors. However, the graph can keep *references*
35
# to tensors that would otherwise have gone out of scope: those are
36
# referred to as **saved tensors**.
37
#
38
39
40
######################################################################
41
# Why does training a model (typically) requires more memory than evaluating it?
42
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
43
#
44
45
46
######################################################################
47
# We start with a simple example: :math:`y = a \cdot b` , for which
48
# we know the gradients of :math:`y` with respect to :math:`a` and
49
# :math:`b`:
50
#
51
# .. math:: \frac{\partial y}{\partial a} = b
52
#
53
# .. math:: \frac{\partial y}{\partial b} = a
54
#
55
56
import torch
57
58
a = torch.randn(5, requires_grad=True)
59
b = torch.ones(5, requires_grad=True)
60
y = a * b
61
62
#################################################################
63
# Using a torchviz, we can visualize the computation graph
64
#
65
# .. figure:: https://user-images.githubusercontent.com/8019486/130124513-72e016a3-c36f-42b9-88e2-53baf3e016c5.png
66
# :width: 300
67
# :align: center
68
69
70
######################################################################
71
# In this example, PyTorch saves intermediary values :math:`a` and
72
# :math:`b` in order to compute the gradient during the backward.
73
#
74
# .. figure:: https://user-images.githubusercontent.com/8019486/130124538-3da50977-6f0b-46d0-8909-5456ade9b598.png
75
# :width: 300
76
# :align: center
77
78
79
######################################################################
80
# Those intermediary values (in orange above) can be accessed (for
81
# debugging purposes) by looking for attributes of the ``grad_fn`` of
82
# ``y`` which start with the prefix ``_saved``:
83
#
84
85
print(y.grad_fn._saved_self)
86
print(y.grad_fn._saved_other)
87
88
89
######################################################################
90
# As the computation graph grows in depth, it will store more *saved
91
# tensors*. Meanwhile, those tensors would have gone out of scope if not
92
# for the graph.
93
#
94
95
def f(x):
96
return x * x
97
98
x = torch.randn(5, requires_grad=True)
99
y = f(f(f(x)))
100
101
######################################################################
102
# .. figure:: https://user-images.githubusercontent.com/8019486/130124570-f1074098-1bb3-459e-bf5a-03bf6f65b403.png
103
# :width: 500
104
# :align: center
105
106
107
######################################################################
108
# In the example above, executing without grad would only have kept ``x``
109
# and ``y`` in the scope, But the graph additionally stores ``f(x)`` and
110
# ``f(f(x))``. Hence, running a forward pass during training will be more
111
# costly in memory usage than during evaluation (more precisely, when
112
# autograd is not required).
113
#
114
115
116
######################################################################
117
# The concept of packing / unpacking
118
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
119
#
120
121
122
######################################################################
123
# Going back to the first example: ``y.grad_fn._saved_self`` and
124
# ``y.grad_fn._saved_other`` point to the original tensor object,
125
# respectively ``a`` and ``b``.
126
#
127
128
a = torch.randn(5, requires_grad=True)
129
b = torch.ones(5, requires_grad=True)
130
y = a * b
131
132
print(y.grad_fn._saved_self is a) # True
133
print(y.grad_fn._saved_other is b) # True
134
135
136
######################################################################
137
# However, that may not always be the case.
138
#
139
140
a = torch.randn(5, requires_grad=True)
141
y = torch.exp(a)
142
print(y.grad_fn._saved_result.equal(y)) # True
143
print(y.grad_fn._saved_result is y) # False
144
145
146
######################################################################
147
# Under the hood, PyTorch has **packed** and **unpacked** the tensor
148
# ``y`` to prevent reference cycles.
149
#
150
# As a rule of thumb, you should *not* rely on the fact that accessing
151
# the tensor saved for backward will yield the same tensor object as the
152
# original tensor. They will however share the same *storage*.
153
#
154
155
156
######################################################################
157
# Saved tensors hooks
158
# -------------------
159
#
160
161
162
######################################################################
163
# PyTorch provides an API to control how saved tensors should be packed /
164
# unpacked.
165
#
166
167
def pack_hook(x):
168
print("Packing", x)
169
return x
170
171
def unpack_hook(x):
172
print("Unpacking", x)
173
return x
174
a = torch.ones(5, requires_grad=True)
175
b = torch.ones(5, requires_grad=True) * 2
176
177
with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
178
y = a * b
179
180
y.sum().backward()
181
182
183
######################################################################
184
# The ``pack_hook`` function will be called every time an operation saves
185
# a tensor for backward.
186
# The output of ``pack_hook`` is then stored in the computation graph
187
# instead of the original tensor.
188
# The ``unpack_hook`` uses that return value to compute a new tensor,
189
# which is the one actually used during the backward pass.
190
# In general, you want ``unpack_hook(pack_hook(t))`` to be equal to
191
# ``t``.
192
#
193
194
x = torch.randn(5, requires_grad=True)
195
with torch.autograd.graph.saved_tensors_hooks(lambda x: x * 4, lambda x: x / 4):
196
y = torch.pow(x, 2)
197
y.sum().backward()
198
assert(x.grad.equal(2 * x))
199
200
201
######################################################################
202
# One thing to note is that the output of ``pack_hook`` can be *any Python
203
# object*, as long as ``unpack_hook`` can derive a tensor with the correct
204
# value from it.
205
#
206
207
208
######################################################################
209
# Some unconventional examples
210
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
211
#
212
213
214
######################################################################
215
# First, some silly examples to illustrate what is possible but you
216
# probably don’t ever want to do it.
217
#
218
219
######################################################################
220
# Returning an ``int``
221
# ^^^^^^^^^^^^^^^^^^^^
222
#
223
# Returning the index of a Python list
224
# Relatively harmless but with debatable usefulness
225
226
storage = []
227
228
def pack(x):
229
storage.append(x)
230
return len(storage) - 1
231
232
def unpack(x):
233
return storage[x]
234
235
x = torch.randn(5, requires_grad=True)
236
with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
237
y = x * x
238
y.sum().backward()
239
240
assert(x.grad.equal(2 * x))
241
242
######################################################################
243
# Returning a tuple
244
# ^^^^^^^^^^^^^^^^^
245
#
246
# Returning some tensor and a function how to unpack it
247
# Quite unlikely to be useful in its current form
248
249
def pack(x):
250
delta = torch.randn(*x.size())
251
return x - delta, lambda x: x + delta
252
253
def unpack(packed):
254
x, f = packed
255
return f(x)
256
257
258
x = torch.randn(5, requires_grad=True)
259
with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
260
y = x * x
261
y.sum().backward()
262
263
assert(torch.allclose(x.grad, 2 * x))
264
265
######################################################################
266
# Returning a ``str``
267
# ^^^^^^^^^^^^^^^^^^^
268
#
269
# Returning the ``__repr__ of`` the tensor
270
# Probably never do this
271
272
x = torch.randn(5, requires_grad=True)
273
with torch.autograd.graph.saved_tensors_hooks(lambda x: repr(x), lambda x: eval("torch." + x)):
274
y = x * x
275
y.sum().backward()
276
assert(torch.all(x.grad - 2 * x <= 1e-4))
277
278
279
######################################################################
280
# Although those examples will not be useful in practice, they
281
# illustrate that the output of ``pack_hook`` can really be any Python
282
# object as long as it contains enough information to retrieve the
283
# content of the original tensor.
284
# In the next sections, we focus on more useful applications.
285
#
286
287
288
######################################################################
289
# Saving tensors to CPU
290
# ~~~~~~~~~~~~~~~~~~~~~
291
#
292
293
294
######################################################################
295
# Very often, the tensors involved in the computation graph live on GPU.
296
# Keeping a reference to those tensors in the graph is what causes most
297
# models to run out of GPU memory during training while they would have
298
# done fine during evaluation.
299
#
300
# Hooks provide a very simple way to implement that.
301
#
302
303
def pack_hook(x):
304
return (x.device, x.cpu())
305
306
def unpack_hook(packed):
307
device, tensor = packed
308
return tensor.to(device)
309
310
x = torch.randn(5, requires_grad=True)
311
with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
312
y = x * x
313
y.sum().backward()
314
315
torch.allclose(x.grad, (2 * x))
316
317
318
######################################################################
319
# In fact, PyTorch provides an API to conveniently use those hooks (as
320
# well as the ability to use pinned memory).
321
#
322
323
import torch.nn as nn
324
325
class Model(nn.Module):
326
def __init__(self):
327
super().__init__()
328
self.w = nn.Parameter(torch.randn(5))
329
330
def forward(self, x):
331
with torch.autograd.graph.save_on_cpu(pin_memory=True):
332
# some computation
333
return self.w * x
334
335
x = torch.randn(5)
336
model = Model()
337
loss = model(x).sum()
338
loss.backward()
339
340
341
######################################################################
342
# In practice, on a A100 GPU, for a ResNet-152 with batch size 256, this
343
# corresponds to a GPU memory usage reduction from 48GB to 5GB, at the
344
# cost of a 6x slowdown.
345
#
346
# Of course, you can modulate the tradeoff by only saving to CPU certain
347
# parts of the network.
348
#
349
# For instance, you could define a special ``nn.Module`` that wraps any
350
# module and saves its tensors to CPU.
351
#
352
353
class SaveToCpu(nn.Module):
354
def __init__(self, module):
355
super().__init__()
356
self.module = module
357
358
def forward(self, *args, **kwargs):
359
with torch.autograd.graph.save_on_cpu(pin_memory=True):
360
return self.module(*args, **kwargs)
361
362
model = nn.Sequential(
363
nn.Linear(10, 100),
364
SaveToCpu(nn.Linear(100, 100)),
365
nn.Linear(100, 10),
366
)
367
368
x = torch.randn(10)
369
loss = model(x).sum()
370
loss.backward()
371
372
373
######################################################################
374
# Saving tensors to disk
375
# ~~~~~~~~~~~~~~~~~~~~~~
376
#
377
378
379
######################################################################
380
# Similarly, you may want to save those tensors to disk. Again, this is
381
# achievable with those hooks.
382
#
383
384
385
######################################################################
386
# A naive version would look like this.
387
#
388
389
# Naive version - HINT: Don't do this
390
391
import uuid
392
tmp_dir = "temp"
393
394
def pack_hook(tensor):
395
name = os.path.join(tmp_dir, str(uuid.uuid4()))
396
torch.save(tensor, name)
397
return name
398
399
def unpack_hook(name):
400
return torch.load(name, weights_only=True)
401
402
403
######################################################################
404
# The reason the above code is bad is that we are leaking files on the
405
# disk and they are never cleared. Fixing this is not as trivial as it
406
# seems.
407
#
408
409
# Incorrect version - HINT: Don't do this
410
411
import uuid
412
import os
413
import tempfile
414
tmp_dir_obj = tempfile.TemporaryDirectory()
415
tmp_dir = tmp_dir_obj.name
416
417
def pack_hook(tensor):
418
name = os.path.join(tmp_dir, str(uuid.uuid4()))
419
torch.save(tensor, name)
420
return name
421
422
def unpack_hook(name):
423
tensor = torch.load(name, weights_only=True)
424
os.remove(name)
425
return tensor
426
427
428
######################################################################
429
# The reason the above code doesn’t work is that ``unpack_hook`` can be
430
# called multiple times. If we delete the file during unpacking the first
431
# time, it will not be available when the saved tensor is accessed a
432
# second time, which will raise an error.
433
#
434
435
x = torch.ones(5, requires_grad=True)
436
with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
437
y = x.pow(2)
438
print(y.grad_fn._saved_self)
439
try:
440
print(y.grad_fn._saved_self)
441
print("Double access succeeded!")
442
except:
443
print("Double access failed!")
444
445
446
######################################################################
447
# To fix this, we can write a version of those hooks that takes advantage
448
# of the fact that PyTorch automatically releases (deletes) the saved data
449
# when it is no longer needed.
450
#
451
452
class SelfDeletingTempFile():
453
def __init__(self):
454
self.name = os.path.join(tmp_dir, str(uuid.uuid4()))
455
456
def __del__(self):
457
os.remove(self.name)
458
459
def pack_hook(tensor):
460
temp_file = SelfDeletingTempFile()
461
torch.save(tensor, temp_file.name)
462
return temp_file
463
464
def unpack_hook(temp_file):
465
return torch.load(temp_file.name, weights_only=True)
466
467
468
######################################################################
469
# When we call ``backward``, the output of ``pack_hook`` will be deleted,
470
# which causes the file to be removed, so we’re no longer leaking the
471
# files.
472
#
473
# This can then be used in your model, in the following way:
474
#
475
476
# Only save on disk tensors that have size >= 1000
477
SAVE_ON_DISK_THRESHOLD = 1000
478
479
def pack_hook(x):
480
if x.numel() < SAVE_ON_DISK_THRESHOLD:
481
return x
482
temp_file = SelfDeletingTempFile()
483
torch.save(tensor, temp_file.name)
484
return temp_file
485
486
def unpack_hook(tensor_or_sctf):
487
if isinstance(tensor_or_sctf, torch.Tensor):
488
return tensor_or_sctf
489
return torch.load(tensor_or_sctf.name)
490
491
class SaveToDisk(nn.Module):
492
def __init__(self, module):
493
super().__init__()
494
self.module = module
495
496
def forward(self, *args, **kwargs):
497
with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
498
return self.module(*args, **kwargs)
499
500
net = nn.DataParallel(SaveToDisk(Model()))
501
502
503
######################################################################
504
# In this last example, we also demonstrate how to filter which tensors
505
# should be saved (here, those whose number of elements is greater than
506
# 1000) and how to combine this feature with ``nn.DataParallel``.
507
#
508
509
510
######################################################################
511
# If you’ve made it this far, congratulations! You now know how to use
512
# saved tensor hooks and how they can be useful in a few scenarios to
513
# tradeoff memory for compute.
514
#
515
516