CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!
CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!
Path: blob/main/intermediate_source/autograd_saved_tensors_hooks_tutorial.py
Views: 494
"""1Hooks for autograd saved tensors2================================34"""567######################################################################8# PyTorch typically computes gradients using backpropagation. However,9# certain operations require intermediary results to be saved in order to10# perform backpropagation. This tutorial walks through how these tensors11# are saved/retrieved and how you can define hooks to control the12# packing/unpacking process.13#14# This tutorial assumes you are familiar with how backpropagation works in15# theory. If not, read `this <https://colab.research.google.com/drive/1aWNdmYt7RcHMbUk-Xz2Cv5-cGFSWPXe0#scrollTo=AHcEJ6nXUb7W>`_ first.16#171819######################################################################20# Saved tensors21# -------------22#232425######################################################################26# Training a model usually consumes more memory than running it for27# inference. Broadly speaking, one can say that it is because “PyTorch28# needs to save the computation graph, which is needed to call29# ``backward``”, hence the additional memory usage. One goal of this30# tutorial is to finetune this understanding.31#32# In fact, the graph in itself sometimes does not consume much more memory33# as it never copies any tensors. However, the graph can keep *references*34# to tensors that would otherwise have gone out of scope: those are35# referred to as **saved tensors**.36#373839######################################################################40# Why does training a model (typically) requires more memory than evaluating it?41# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~42#434445######################################################################46# We start with a simple example: :math:`y = a \cdot b` , for which47# we know the gradients of :math:`y` with respect to :math:`a` and48# :math:`b`:49#50# .. math:: \frac{\partial y}{\partial a} = b51#52# .. math:: \frac{\partial y}{\partial b} = a53#5455import torch5657a = torch.randn(5, requires_grad=True)58b = torch.ones(5, requires_grad=True)59y = a * b6061#################################################################62# Using a torchviz, we can visualize the computation graph63#64# .. figure:: https://user-images.githubusercontent.com/8019486/130124513-72e016a3-c36f-42b9-88e2-53baf3e016c5.png65# :width: 30066# :align: center676869######################################################################70# In this example, PyTorch saves intermediary values :math:`a` and71# :math:`b` in order to compute the gradient during the backward.72#73# .. figure:: https://user-images.githubusercontent.com/8019486/130124538-3da50977-6f0b-46d0-8909-5456ade9b598.png74# :width: 30075# :align: center767778######################################################################79# Those intermediary values (in orange above) can be accessed (for80# debugging purposes) by looking for attributes of the ``grad_fn`` of81# ``y`` which start with the prefix ``_saved``:82#8384print(y.grad_fn._saved_self)85print(y.grad_fn._saved_other)868788######################################################################89# As the computation graph grows in depth, it will store more *saved90# tensors*. Meanwhile, those tensors would have gone out of scope if not91# for the graph.92#9394def f(x):95return x * x9697x = torch.randn(5, requires_grad=True)98y = f(f(f(x)))99100######################################################################101# .. figure:: https://user-images.githubusercontent.com/8019486/130124570-f1074098-1bb3-459e-bf5a-03bf6f65b403.png102# :width: 500103# :align: center104105106######################################################################107# In the example above, executing without grad would only have kept ``x``108# and ``y`` in the scope, But the graph additionally stores ``f(x)`` and109# ``f(f(x))``. Hence, running a forward pass during training will be more110# costly in memory usage than during evaluation (more precisely, when111# autograd is not required).112#113114115######################################################################116# The concept of packing / unpacking117# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~118#119120121######################################################################122# Going back to the first example: ``y.grad_fn._saved_self`` and123# ``y.grad_fn._saved_other`` point to the original tensor object,124# respectively ``a`` and ``b``.125#126127a = torch.randn(5, requires_grad=True)128b = torch.ones(5, requires_grad=True)129y = a * b130131print(y.grad_fn._saved_self is a) # True132print(y.grad_fn._saved_other is b) # True133134135######################################################################136# However, that may not always be the case.137#138139a = torch.randn(5, requires_grad=True)140y = torch.exp(a)141print(y.grad_fn._saved_result.equal(y)) # True142print(y.grad_fn._saved_result is y) # False143144145######################################################################146# Under the hood, PyTorch has **packed** and **unpacked** the tensor147# ``y`` to prevent reference cycles.148#149# As a rule of thumb, you should *not* rely on the fact that accessing150# the tensor saved for backward will yield the same tensor object as the151# original tensor. They will however share the same *storage*.152#153154155######################################################################156# Saved tensors hooks157# -------------------158#159160161######################################################################162# PyTorch provides an API to control how saved tensors should be packed /163# unpacked.164#165166def pack_hook(x):167print("Packing", x)168return x169170def unpack_hook(x):171print("Unpacking", x)172return x173a = torch.ones(5, requires_grad=True)174b = torch.ones(5, requires_grad=True) * 2175176with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):177y = a * b178179y.sum().backward()180181182######################################################################183# The ``pack_hook`` function will be called every time an operation saves184# a tensor for backward.185# The output of ``pack_hook`` is then stored in the computation graph186# instead of the original tensor.187# The ``unpack_hook`` uses that return value to compute a new tensor,188# which is the one actually used during the backward pass.189# In general, you want ``unpack_hook(pack_hook(t))`` to be equal to190# ``t``.191#192193x = torch.randn(5, requires_grad=True)194with torch.autograd.graph.saved_tensors_hooks(lambda x: x * 4, lambda x: x / 4):195y = torch.pow(x, 2)196y.sum().backward()197assert(x.grad.equal(2 * x))198199200######################################################################201# One thing to note is that the output of ``pack_hook`` can be *any Python202# object*, as long as ``unpack_hook`` can derive a tensor with the correct203# value from it.204#205206207######################################################################208# Some unconventional examples209# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~210#211212213######################################################################214# First, some silly examples to illustrate what is possible but you215# probably don’t ever want to do it.216#217218######################################################################219# Returning an ``int``220# ^^^^^^^^^^^^^^^^^^^^221#222# Returning the index of a Python list223# Relatively harmless but with debatable usefulness224225storage = []226227def pack(x):228storage.append(x)229return len(storage) - 1230231def unpack(x):232return storage[x]233234x = torch.randn(5, requires_grad=True)235with torch.autograd.graph.saved_tensors_hooks(pack, unpack):236y = x * x237y.sum().backward()238239assert(x.grad.equal(2 * x))240241######################################################################242# Returning a tuple243# ^^^^^^^^^^^^^^^^^244#245# Returning some tensor and a function how to unpack it246# Quite unlikely to be useful in its current form247248def pack(x):249delta = torch.randn(*x.size())250return x - delta, lambda x: x + delta251252def unpack(packed):253x, f = packed254return f(x)255256257x = torch.randn(5, requires_grad=True)258with torch.autograd.graph.saved_tensors_hooks(pack, unpack):259y = x * x260y.sum().backward()261262assert(torch.allclose(x.grad, 2 * x))263264######################################################################265# Returning a ``str``266# ^^^^^^^^^^^^^^^^^^^267#268# Returning the ``__repr__ of`` the tensor269# Probably never do this270271x = torch.randn(5, requires_grad=True)272with torch.autograd.graph.saved_tensors_hooks(lambda x: repr(x), lambda x: eval("torch." + x)):273y = x * x274y.sum().backward()275assert(torch.all(x.grad - 2 * x <= 1e-4))276277278######################################################################279# Although those examples will not be useful in practice, they280# illustrate that the output of ``pack_hook`` can really be any Python281# object as long as it contains enough information to retrieve the282# content of the original tensor.283# In the next sections, we focus on more useful applications.284#285286287######################################################################288# Saving tensors to CPU289# ~~~~~~~~~~~~~~~~~~~~~290#291292293######################################################################294# Very often, the tensors involved in the computation graph live on GPU.295# Keeping a reference to those tensors in the graph is what causes most296# models to run out of GPU memory during training while they would have297# done fine during evaluation.298#299# Hooks provide a very simple way to implement that.300#301302def pack_hook(x):303return (x.device, x.cpu())304305def unpack_hook(packed):306device, tensor = packed307return tensor.to(device)308309x = torch.randn(5, requires_grad=True)310with torch.autograd.graph.saved_tensors_hooks(pack, unpack):311y = x * x312y.sum().backward()313314torch.allclose(x.grad, (2 * x))315316317######################################################################318# In fact, PyTorch provides an API to conveniently use those hooks (as319# well as the ability to use pinned memory).320#321322import torch.nn as nn323324class Model(nn.Module):325def __init__(self):326super().__init__()327self.w = nn.Parameter(torch.randn(5))328329def forward(self, x):330with torch.autograd.graph.save_on_cpu(pin_memory=True):331# some computation332return self.w * x333334x = torch.randn(5)335model = Model()336loss = model(x).sum()337loss.backward()338339340######################################################################341# In practice, on a A100 GPU, for a ResNet-152 with batch size 256, this342# corresponds to a GPU memory usage reduction from 48GB to 5GB, at the343# cost of a 6x slowdown.344#345# Of course, you can modulate the tradeoff by only saving to CPU certain346# parts of the network.347#348# For instance, you could define a special ``nn.Module`` that wraps any349# module and saves its tensors to CPU.350#351352class SaveToCpu(nn.Module):353def __init__(self, module):354super().__init__()355self.module = module356357def forward(self, *args, **kwargs):358with torch.autograd.graph.save_on_cpu(pin_memory=True):359return self.module(*args, **kwargs)360361model = nn.Sequential(362nn.Linear(10, 100),363SaveToCpu(nn.Linear(100, 100)),364nn.Linear(100, 10),365)366367x = torch.randn(10)368loss = model(x).sum()369loss.backward()370371372######################################################################373# Saving tensors to disk374# ~~~~~~~~~~~~~~~~~~~~~~375#376377378######################################################################379# Similarly, you may want to save those tensors to disk. Again, this is380# achievable with those hooks.381#382383384######################################################################385# A naive version would look like this.386#387388# Naive version - HINT: Don't do this389390import uuid391tmp_dir = "temp"392393def pack_hook(tensor):394name = os.path.join(tmp_dir, str(uuid.uuid4()))395torch.save(tensor, name)396return name397398def unpack_hook(name):399return torch.load(name, weights_only=True)400401402######################################################################403# The reason the above code is bad is that we are leaking files on the404# disk and they are never cleared. Fixing this is not as trivial as it405# seems.406#407408# Incorrect version - HINT: Don't do this409410import uuid411import os412import tempfile413tmp_dir_obj = tempfile.TemporaryDirectory()414tmp_dir = tmp_dir_obj.name415416def pack_hook(tensor):417name = os.path.join(tmp_dir, str(uuid.uuid4()))418torch.save(tensor, name)419return name420421def unpack_hook(name):422tensor = torch.load(name, weights_only=True)423os.remove(name)424return tensor425426427######################################################################428# The reason the above code doesn’t work is that ``unpack_hook`` can be429# called multiple times. If we delete the file during unpacking the first430# time, it will not be available when the saved tensor is accessed a431# second time, which will raise an error.432#433434x = torch.ones(5, requires_grad=True)435with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):436y = x.pow(2)437print(y.grad_fn._saved_self)438try:439print(y.grad_fn._saved_self)440print("Double access succeeded!")441except:442print("Double access failed!")443444445######################################################################446# To fix this, we can write a version of those hooks that takes advantage447# of the fact that PyTorch automatically releases (deletes) the saved data448# when it is no longer needed.449#450451class SelfDeletingTempFile():452def __init__(self):453self.name = os.path.join(tmp_dir, str(uuid.uuid4()))454455def __del__(self):456os.remove(self.name)457458def pack_hook(tensor):459temp_file = SelfDeletingTempFile()460torch.save(tensor, temp_file.name)461return temp_file462463def unpack_hook(temp_file):464return torch.load(temp_file.name, weights_only=True)465466467######################################################################468# When we call ``backward``, the output of ``pack_hook`` will be deleted,469# which causes the file to be removed, so we’re no longer leaking the470# files.471#472# This can then be used in your model, in the following way:473#474475# Only save on disk tensors that have size >= 1000476SAVE_ON_DISK_THRESHOLD = 1000477478def pack_hook(x):479if x.numel() < SAVE_ON_DISK_THRESHOLD:480return x481temp_file = SelfDeletingTempFile()482torch.save(tensor, temp_file.name)483return temp_file484485def unpack_hook(tensor_or_sctf):486if isinstance(tensor_or_sctf, torch.Tensor):487return tensor_or_sctf488return torch.load(tensor_or_sctf.name)489490class SaveToDisk(nn.Module):491def __init__(self, module):492super().__init__()493self.module = module494495def forward(self, *args, **kwargs):496with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):497return self.module(*args, **kwargs)498499net = nn.DataParallel(SaveToDisk(Model()))500501502######################################################################503# In this last example, we also demonstrate how to filter which tensors504# should be saved (here, those whose number of elements is greater than505# 1000) and how to combine this feature with ``nn.DataParallel``.506#507508509######################################################################510# If you’ve made it this far, congratulations! You now know how to use511# saved tensor hooks and how they can be useful in a few scenarios to512# tradeoff memory for compute.513#514515516