Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Path: blob/main/recipes_source/recipes/swap_tensors.py
Views: 713
"""1Extension points in ``nn.Module`` for ``load_state_dict`` and tensor subclasses2===============================================================================3**Author:** `Mikayla Gawarecki <https://github.com/mikaylagawarecki>`_45This recipe introduces a new utility function ``torch.utils.swap_tensors``6as well as two new extension points where it has been integrated in7``nn.Module``:89* ``nn.Module.to()`` and related methods10* ``nn.Module.load_state_dict()``1112.. note::13This recipe requires PyTorch 2.3.0 or later.14"""1516###############################################################################17# ``torch.utils.swap_tensors``18# ----------------------------19# ``torch.utils.swap_tensors`` (hereafter referred to as ``swap_tensors``) is a20# utility function that takes in two Python tensors and swaps them.2122import torch23import torch.nn as nn24t1 = torch.arange(2)25t2 = torch.arange(3)26print(f"Before swapping, t1: {t1}, t2: {t2}")27torch.utils.swap_tensors(t1, t2)28print(f"After swapping, t1: {t1}, t2: {t2}")2930################################################################################31# More specifically, ``swap_tensors`` swaps the Python ``__class__``, ``__dict__``32# and ``__slots__`` of the two tensors, as well as their associated ``at::Tensor``.33#34#35# Application to ``nn.Module``36# ----------------------------37# This utility is pertinent to ``nn.Module`` when a Python object outside38# of the module holds a reference to parameters of the module. If an ``nn.Module``39# modifies any of its parameters out of place, the object holding references to40# the parameters will not see the change. A classic example of this is the41# optimizer, which holds a reference to the parameters of the ``nn.Module``.42# This leads to a silent correctness issue where the ``optimizer.step()`` will43# run without error but the weights of the ``nn.Module`` will not be updated.4445mod = torch.nn.Linear(1, 2, bias=False)46optimizer = torch.optim.SGD(mod.parameters())47print(f"weight in mod: {mod.weight}")48print(f"weight in optimizer: {optimizer.param_groups[0]['params']}")49mod.weight = torch.nn.Parameter(2 * mod.weight)50print(f"weight in mod: {mod.weight}")51print(f"weight in optimizer: {optimizer.param_groups[0]['params']}")5253################################################################################54# ``nn.Module.to()`` and related methods55# --------------------------------------56# This includes methods that change the device of the module (such as ``nn.Module.cpu()``),57# methods that change the ``dtype`` of the module (such as ``nn.Module.float()``)58# as well as methods that allow the module to be materialized59# (such as ``nn.Module.to_empty()``).60#61# At first glance, it might be non-intuitive that these methods are able to62# modify the parameters of the module in-place. The existing approach has been63# to use a nasty hack dating back from the first days of PyTorch.64#65# Notably, the existing approach does not work in these cases:66#67# * when using ``__torch_dispatch__`` subclasses68# * when ``param`` and ``new_param`` do not have the same Python ``type()``69# * For tensors with special C++ representations (such as sparse tensors and ``XLA`` tensors)70#71# In the following part of this recipe, we will define a toy ``__torch_dispatch__``72# subclass ``MyQuantizedLinearWeight`` that represents quantized linear weights.73# This subclass will be used for illustration purposes throughout the rest of74# the tutorial. For brevity, we omit most of the ``__torch_dispatch__``75# implementation.76aten = torch.ops.aten7778class MyQuantizedLinearWeight(torch.Tensor):79@staticmethod80def __new__(cls, elem, scale):81return torch.Tensor._make_wrapper_subclass(82cls,83elem.shape,84dtype=elem.dtype,85layout=elem.layout,86device=elem.device,87strides=elem.stride(),88storage_offset=elem.storage_offset())8990def __init__(self, elem: torch.Tensor, scale: float):91self.elem = elem92self.scale = scale9394def __repr__(self):95return f"MyQuantizedLinearWeight({self.elem}, scale={self.scale})"9697@classmethod98def __torch_dispatch__(cls, func, types, args, kwargs):99if func in (aten.detach.default, aten._to_copy.default):100new_elem = func(args[0].elem, *args[1:], **kwargs)101return cls(new_elem, args[0].scale)102# Implementations for certain ops would be added to ``OP_TABLE``.103# We omit this for brevity.104OP_TABLE = dict()105if func in OP_TABLE:106return OP_TABLE[func](func, args, kwargs)107raise NotImplementedError(f"Unsupported function {func}")108109#################################################################################110# Let us create an ``nn.Linear`` layer of ``dtype`` ``torch.float32`` where the weight is111# a ``MyQuantizedLinearWeight`` and try to convert it to ``torch.bfloat16``.112# Observe that the weight's ``dtype`` changes as expected. However, the ``dtype``113# of the subclass' payload (``elem``) does not change.114115m = nn.Linear(3, 5, dtype=torch.float32)116m.weight = torch.nn.Parameter(MyQuantizedLinearWeight(m.weight, 0.5))117print(f"Before: id(m.weight)={id(m.weight)}, id(m.bias)={id(m.bias)}")118m.bfloat16()119print(f"After: id(m.weight)={id(m.weight)}, id(m.bias)={id(m.bias)}")120print(f"m.weight.dtype: {m.weight.dtype}")121print(f"m.weight.elem.dtype: {m.weight.elem.dtype}")122print(f"m.bias.dtype: {m.bias.dtype}")123124################################################################################125# To this end, we introduce a global config126# ``torch.__future__.set_swap_module_params_on_conversion`` that will use127# ``swap_tensors`` to swap the parameters of the module while preserving128# references in place of ``.data`` setting. When this config is set,129# ``swap_tensors`` will be used during the conversion, which ensures that130# the ``dtype`` of the payload is properly converted.131132torch.__future__.set_swap_module_params_on_conversion(True)133m = nn.Linear(3, 5, dtype=torch.float32)134m.weight = torch.nn.Parameter(MyQuantizedLinearWeight(m.weight, 0.5))135print(f"Before: id(m.weight)={id(m.weight)}, id(m.bias)={id(m.bias)}")136m.bfloat16()137print(f"After: id(m.weight)={id(m.weight)}, id(m.bias)={id(m.bias)}")138print(f"m.weight.dtype: {m.weight.dtype}")139print(f"m.weight.elem.dtype: {m.weight.elem.dtype}")140print(f"m.bias.dtype: {m.bias.dtype}")141torch.__future__.set_swap_module_params_on_conversion(False)142143################################################################################144# ``nn.Module.load_state_dict()``145# --------------------------------146# Depending on the value of the ``assign`` keyword argument passed147# to ``load_state_dict()``, there are two ways to load the ``state_dict``:148#149# * ``assign=False``: preserves the properties of ``module.param`` and only takes the values150# from ``state_dict['param_name']``151# * ``assign=True``: preserves the properties and values of ``state_dict['param_name']``.152#153#154# Previously, these were implemented with in-place ``copy_`` and ``__setattr__`` respectively.155# With the existing implementation, each approach had its own limitations -- ``assign=False``156# imposes the constraint that the type of the parameter in the ``state_dict`` must157# be the same as the type of the parameter in the module while ``assign=True`` imposes158# the constraint that anything that holds references to the module's parameters must159# be initialized after ``nn.Module.load_state_dict()``.160#161# Now, we address both constraints by adding a ``swap_tensors`` path to ``load_state_dict()``162# and introducing a new extension point ``torch.Tensor.module_load(self, other, assign=False)``.163# When the ``swap_tensors`` path is enabled via the ``__future__`` mentioned above,164# we can use a ``__torch_function__`` handler for ``module_load`` to apply a165# custom transformation to the value in the ``state_dict``. The result of this166# transformation will be swapped with the parameter in the module.167#168# In the following example, we will use the ``MyQuantizedLinearWeight`` subclass169# defined above to illustrate how we can use these features to apply a170# custom quantization scheme to the weights of a linear layer when171# loading the ``state_dict``.172#173# Recall that the ``__torch_function__`` handler for ``module_load`` will be174# invoked if either ``self`` or ``other`` (in this case ``param`` or175# ``state_dict[param_key]``) are ``MyQuantizedLinearWeight`` subclasses.176#177# Assume that we expect the ``state_dict`` to contain plain tensors and the178# module to contain ``MyQuantizedLinearWeight`` parameters where we want the179# tensors in the ``state_dict`` to be transformed into the subclass. Then we180# can define a ``__torch_function__`` handler for ``torch.Tensor.module_load``181# as such:182183@classmethod184def custom_torch_function(cls, func, types, args=(), kwargs=None):185kwargs = {} if kwargs is None else kwargs186187if func is torch.Tensor.module_load:188dest, src = args[0], args[1]189assert type(dest) == cls and type(src) == torch.Tensor190return MyQuantizedLinearWeight(src, dest.scale)191else:192with torch._C.DisableTorchFunctionSubclass():193return func(*args, **kwargs)194195MyQuantizedLinearWeight.__torch_function__ = custom_torch_function196197#################################################################################198# First, let us create a skeleton of a model on the meta device to avoid199# materializing storages. We convert all weights in the modules to200# ``MyQuantizedLinearWeight`` subclasses while leaving biases intact.201202def fn(m):203if isinstance(m, nn.Linear):204requires_grad = m.weight.requires_grad205m.weight = torch.nn.Parameter(206MyQuantizedLinearWeight(m.weight, 0.5), requires_grad=requires_grad207)208209with torch.device("meta"):210m = nn.Linear(3, 5)211m.apply(fn)212213#################################################################################214# We can then load the ``state_dict``. Observe that we use ``assign=True`` because215# for biases, we want to preserve the properties of the tensor in the ``state_dict``216# (for example, we do not want the bias to be on the ``meta`` device after loading).217218torch.__future__.set_swap_module_params_on_conversion(True)219print(f"Before: id(weight)={id(m.weight)}, id(bias)={id(m.bias)}")220print(f"m.state_dict() before load_state_dict():\n {m.state_dict()}")221state_dict = nn.Linear(3, 5).state_dict()222print(f"state_dict:\n {state_dict}")223m.load_state_dict(state_dict, assign=True)224print(f"After: id(weight)={id(m.weight)}, id(bias)={id(m.bias)}")225print(f"m.state_dict() after load_state_dict():\n {m.state_dict()}")226227#################################################################################228# The above is a toy example of how we can use the new extension point in229# ``nn.Module.load_state_dict()``. One can also imagine alternate scenarios such230# as when we have tensor subclasses in the ``state_dict`` and plain ``nn.Parameters``/231# tensors in the module or when both are tensor subclasses. Based on the use232# case, we can define the ``__torch_function__`` handler for ``module_load``233# to apply the transforms as needed.234#235# Conclusion236# ----------237# In this recipe, we learned about ``swap_tensors``, the importance238# of preserving references for parameters in ``nn.Module`` as well as how to239# use the two new extension points that are gated by240# ``torch.__future__.set_swap_module_params_on_conversion``.241242243