CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!

GitHub Repository: pytorch/tutorials
Path: blob/main/recipes_source/recipes/swap_tensors.py
Views: 494
1
"""
2
Extension points in ``nn.Module`` for ``load_state_dict`` and tensor subclasses
3
===============================================================================
4
**Author:** `Mikayla Gawarecki <https://github.com/mikaylagawarecki>`_
5
6
This recipe introduces a new utility function ``torch.utils.swap_tensors``
7
as well as two new extension points where it has been integrated in
8
``nn.Module``:
9
10
* ``nn.Module.to()`` and related methods
11
* ``nn.Module.load_state_dict()``
12
13
.. note::
14
This recipe requires PyTorch 2.3.0 or later.
15
"""
16
17
###############################################################################
18
# ``torch.utils.swap_tensors``
19
# ----------------------------
20
# ``torch.utils.swap_tensors`` (hereafter referred to as ``swap_tensors``) is a
21
# utility function that takes in two Python tensors and swaps them.
22
23
import torch
24
import torch.nn as nn
25
t1 = torch.arange(2)
26
t2 = torch.arange(3)
27
print(f"Before swapping, t1: {t1}, t2: {t2}")
28
torch.utils.swap_tensors(t1, t2)
29
print(f"After swapping, t1: {t1}, t2: {t2}")
30
31
################################################################################
32
# More specifically, ``swap_tensors`` swaps the Python ``__class__``, ``__dict__``
33
# and ``__slots__`` of the two tensors, as well as their associated ``at::Tensor``.
34
#
35
#
36
# Application to ``nn.Module``
37
# ----------------------------
38
# This utility is pertinent to ``nn.Module`` when a Python object outside
39
# of the module holds a reference to parameters of the module. If an ``nn.Module``
40
# modifies any of its parameters out of place, the object holding references to
41
# the parameters will not see the change. A classic example of this is the
42
# optimizer, which holds a reference to the parameters of the ``nn.Module``.
43
# This leads to a silent correctness issue where the ``optimizer.step()`` will
44
# run without error but the weights of the ``nn.Module`` will not be updated.
45
46
mod = torch.nn.Linear(1, 2, bias=False)
47
optimizer = torch.optim.SGD(mod.parameters())
48
print(f"weight in mod: {mod.weight}")
49
print(f"weight in optimizer: {optimizer.param_groups[0]['params']}")
50
mod.weight = torch.nn.Parameter(2 * mod.weight)
51
print(f"weight in mod: {mod.weight}")
52
print(f"weight in optimizer: {optimizer.param_groups[0]['params']}")
53
54
################################################################################
55
# ``nn.Module.to()`` and related methods
56
# --------------------------------------
57
# This includes methods that change the device of the module (such as ``nn.Module.cpu()``),
58
# methods that change the ``dtype`` of the module (such as ``nn.Module.float()``)
59
# as well as methods that allow the module to be materialized
60
# (such as ``nn.Module.to_empty()``).
61
#
62
# At first glance, it might be non-intuitive that these methods are able to
63
# modify the parameters of the module in-place. The existing approach has been
64
# to use a nasty hack dating back from the first days of PyTorch.
65
#
66
# Notably, the existing approach does not work in these cases:
67
#
68
# * when using ``__torch_dispatch__`` subclasses
69
# * when ``param`` and ``new_param`` do not have the same Python ``type()``
70
# * For tensors with special C++ representations (such as sparse tensors and ``XLA`` tensors)
71
#
72
# In the following part of this recipe, we will define a toy ``__torch_dispatch__``
73
# subclass ``MyQuantizedLinearWeight`` that represents quantized linear weights.
74
# This subclass will be used for illustration purposes throughout the rest of
75
# the tutorial. For brevity, we omit most of the ``__torch_dispatch__``
76
# implementation.
77
aten = torch.ops.aten
78
79
class MyQuantizedLinearWeight(torch.Tensor):
80
@staticmethod
81
def __new__(cls, elem, scale):
82
return torch.Tensor._make_wrapper_subclass(
83
cls,
84
elem.shape,
85
dtype=elem.dtype,
86
layout=elem.layout,
87
device=elem.device,
88
strides=elem.stride(),
89
storage_offset=elem.storage_offset())
90
91
def __init__(self, elem: torch.Tensor, scale: float):
92
self.elem = elem
93
self.scale = scale
94
95
def __repr__(self):
96
return f"MyQuantizedLinearWeight({self.elem}, scale={self.scale})"
97
98
@classmethod
99
def __torch_dispatch__(cls, func, types, args, kwargs):
100
if func in (aten.detach.default, aten._to_copy.default):
101
new_elem = func(args[0].elem, *args[1:], **kwargs)
102
return cls(new_elem, args[0].scale)
103
# Implementations for certain ops would be added to ``OP_TABLE``.
104
# We omit this for brevity.
105
OP_TABLE = dict()
106
if func in OP_TABLE:
107
return OP_TABLE[func](func, args, kwargs)
108
raise NotImplementedError(f"Unsupported function {func}")
109
110
#################################################################################
111
# Let us create an ``nn.Linear`` layer of ``dtype`` ``torch.float32`` where the weight is
112
# a ``MyQuantizedLinearWeight`` and try to convert it to ``torch.bfloat16``.
113
# Observe that the weight's ``dtype`` changes as expected. However, the ``dtype``
114
# of the subclass' payload (``elem``) does not change.
115
116
m = nn.Linear(3, 5, dtype=torch.float32)
117
m.weight = torch.nn.Parameter(MyQuantizedLinearWeight(m.weight, 0.5))
118
print(f"Before: id(m.weight)={id(m.weight)}, id(m.bias)={id(m.bias)}")
119
m.bfloat16()
120
print(f"After: id(m.weight)={id(m.weight)}, id(m.bias)={id(m.bias)}")
121
print(f"m.weight.dtype: {m.weight.dtype}")
122
print(f"m.weight.elem.dtype: {m.weight.elem.dtype}")
123
print(f"m.bias.dtype: {m.bias.dtype}")
124
125
################################################################################
126
# To this end, we introduce a global config
127
# ``torch.__future__.set_swap_module_params_on_conversion`` that will use
128
# ``swap_tensors`` to swap the parameters of the module while preserving
129
# references in place of ``.data`` setting. When this config is set,
130
# ``swap_tensors`` will be used during the conversion, which ensures that
131
# the ``dtype`` of the payload is properly converted.
132
133
torch.__future__.set_swap_module_params_on_conversion(True)
134
m = nn.Linear(3, 5, dtype=torch.float32)
135
m.weight = torch.nn.Parameter(MyQuantizedLinearWeight(m.weight, 0.5))
136
print(f"Before: id(m.weight)={id(m.weight)}, id(m.bias)={id(m.bias)}")
137
m.bfloat16()
138
print(f"After: id(m.weight)={id(m.weight)}, id(m.bias)={id(m.bias)}")
139
print(f"m.weight.dtype: {m.weight.dtype}")
140
print(f"m.weight.elem.dtype: {m.weight.elem.dtype}")
141
print(f"m.bias.dtype: {m.bias.dtype}")
142
torch.__future__.set_swap_module_params_on_conversion(False)
143
144
################################################################################
145
# ``nn.Module.load_state_dict()``
146
# --------------------------------
147
# Depending on the value of the ``assign`` keyword argument passed
148
# to ``load_state_dict()``, there are two ways to load the ``state_dict``:
149
#
150
# * ``assign=False``: preserves the properties of ``module.param`` and only takes the values
151
# from ``state_dict['param_name']``
152
# * ``assign=True``: preserves the properties and values of ``state_dict['param_name']``.
153
#
154
#
155
# Previously, these were implemented with in-place ``copy_`` and ``__setattr__`` respectively.
156
# With the existing implementation, each approach had its own limitations -- ``assign=False``
157
# imposes the constraint that the type of the parameter in the ``state_dict`` must
158
# be the same as the type of the parameter in the module while ``assign=True`` imposes
159
# the constraint that anything that holds references to the module's parameters must
160
# be initialized after ``nn.Module.load_state_dict()``.
161
#
162
# Now, we address both constraints by adding a ``swap_tensors`` path to ``load_state_dict()``
163
# and introducing a new extension point ``torch.Tensor.module_load(self, other, assign=False)``.
164
# When the ``swap_tensors`` path is enabled via the ``__future__`` mentioned above,
165
# we can use a ``__torch_function__`` handler for ``module_load`` to apply a
166
# custom transformation to the value in the ``state_dict``. The result of this
167
# transformation will be swapped with the parameter in the module.
168
#
169
# In the following example, we will use the ``MyQuantizedLinearWeight`` subclass
170
# defined above to illustrate how we can use these features to apply a
171
# custom quantization scheme to the weights of a linear layer when
172
# loading the ``state_dict``.
173
#
174
# Recall that the ``__torch_function__`` handler for ``module_load`` will be
175
# invoked if either ``self`` or ``other`` (in this case ``param`` or
176
# ``state_dict[param_key]``) are ``MyQuantizedLinearWeight`` subclasses.
177
#
178
# Assume that we expect the ``state_dict`` to contain plain tensors and the
179
# module to contain ``MyQuantizedLinearWeight`` parameters where we want the
180
# tensors in the ``state_dict`` to be transformed into the subclass. Then we
181
# can define a ``__torch_function__`` handler for ``torch.Tensor.module_load``
182
# as such:
183
184
@classmethod
185
def custom_torch_function(cls, func, types, args=(), kwargs=None):
186
kwargs = {} if kwargs is None else kwargs
187
188
if func is torch.Tensor.module_load:
189
dest, src = args[0], args[1]
190
assert type(dest) == cls and type(src) == torch.Tensor
191
return MyQuantizedLinearWeight(src, dest.scale)
192
else:
193
with torch._C.DisableTorchFunctionSubclass():
194
return func(*args, **kwargs)
195
196
MyQuantizedLinearWeight.__torch_function__ = custom_torch_function
197
198
#################################################################################
199
# First, let us create a skeleton of a model on the meta device to avoid
200
# materializing storages. We convert all weights in the modules to
201
# ``MyQuantizedLinearWeight`` subclasses while leaving biases intact.
202
203
def fn(m):
204
if isinstance(m, nn.Linear):
205
requires_grad = m.weight.requires_grad
206
m.weight = torch.nn.Parameter(
207
MyQuantizedLinearWeight(m.weight, 0.5), requires_grad=requires_grad
208
)
209
210
with torch.device("meta"):
211
m = nn.Linear(3, 5)
212
m.apply(fn)
213
214
#################################################################################
215
# We can then load the ``state_dict``. Observe that we use ``assign=True`` because
216
# for biases, we want to preserve the properties of the tensor in the ``state_dict``
217
# (for example, we do not want the bias to be on the ``meta`` device after loading).
218
219
torch.__future__.set_swap_module_params_on_conversion(True)
220
print(f"Before: id(weight)={id(m.weight)}, id(bias)={id(m.bias)}")
221
print(f"m.state_dict() before load_state_dict():\n {m.state_dict()}")
222
state_dict = nn.Linear(3, 5).state_dict()
223
print(f"state_dict:\n {state_dict}")
224
m.load_state_dict(state_dict, assign=True)
225
print(f"After: id(weight)={id(m.weight)}, id(bias)={id(m.bias)}")
226
print(f"m.state_dict() after load_state_dict():\n {m.state_dict()}")
227
228
#################################################################################
229
# The above is a toy example of how we can use the new extension point in
230
# ``nn.Module.load_state_dict()``. One can also imagine alternate scenarios such
231
# as when we have tensor subclasses in the ``state_dict`` and plain ``nn.Parameters``/
232
# tensors in the module or when both are tensor subclasses. Based on the use
233
# case, we can define the ``__torch_function__`` handler for ``module_load``
234
# to apply the transforms as needed.
235
#
236
# Conclusion
237
# ----------
238
# In this recipe, we learned about ``swap_tensors``, the importance
239
# of preserving references for parameters in ``nn.Module`` as well as how to
240
# use the two new extension points that are gated by
241
# ``torch.__future__.set_swap_module_params_on_conversion``.
242
243