CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.

GitHub Repository: pytorch/tutorials
Path: blob/main/intermediate_source/parametrizations.py
Views: 712
1
# -*- coding: utf-8 -*-
2
"""
3
Parametrizations Tutorial
4
=========================
5
**Author**: `Mario Lezcano <https://github.com/lezcano>`_
6
7
Regularizing deep-learning models is a surprisingly challenging task.
8
Classical techniques such as penalty methods often fall short when applied
9
on deep models due to the complexity of the function being optimized.
10
This is particularly problematic when working with ill-conditioned models.
11
Examples of these are RNNs trained on long sequences and GANs. A number
12
of techniques have been proposed in recent years to regularize these
13
models and improve their convergence. On recurrent models, it has been
14
proposed to control the singular values of the recurrent kernel for the
15
RNN to be well-conditioned. This can be achieved, for example, by making
16
the recurrent kernel `orthogonal <https://en.wikipedia.org/wiki/Orthogonal_matrix>`_.
17
Another way to regularize recurrent models is via
18
"`weight normalization <https://pytorch.org/docs/stable/generated/torch.nn.utils.weight_norm.html>`_".
19
This approach proposes to decouple the learning of the parameters from the
20
learning of their norms. To do so, the parameter is divided by its
21
`Frobenius norm <https://en.wikipedia.org/wiki/Matrix_norm#Frobenius_norm>`_
22
and a separate parameter encoding its norm is learned.
23
A similar regularization was proposed for GANs under the name of
24
"`spectral normalization <https://pytorch.org/docs/stable/generated/torch.nn.utils.spectral_norm.html>`_". This method
25
controls the Lipschitz constant of the network by dividing its parameters by
26
their `spectral norm <https://en.wikipedia.org/wiki/Matrix_norm#Special_cases>`_,
27
rather than their Frobenius norm.
28
29
All these methods have a common pattern: they all transform a parameter
30
in an appropriate way before using it. In the first case, they make it orthogonal by
31
using a function that maps matrices to orthogonal matrices. In the case of weight
32
and spectral normalization, they divide the original parameter by its norm.
33
34
More generally, all these examples use a function to put extra structure on the parameters.
35
In other words, they use a function to constrain the parameters.
36
37
In this tutorial, you will learn how to implement and use this pattern to put
38
constraints on your model. Doing so is as easy as writing your own ``nn.Module``.
39
40
Requirements: ``torch>=1.9.0``
41
42
Implementing parametrizations by hand
43
-------------------------------------
44
45
Assume that we want to have a square linear layer with symmetric weights, that is,
46
with weights ``X`` such that ``X = Xᵀ``. One way to do so is
47
to copy the upper-triangular part of the matrix into its lower-triangular part
48
"""
49
50
import torch
51
import torch.nn as nn
52
import torch.nn.utils.parametrize as parametrize
53
54
def symmetric(X):
55
return X.triu() + X.triu(1).transpose(-1, -2)
56
57
X = torch.rand(3, 3)
58
A = symmetric(X)
59
assert torch.allclose(A, A.T) # A is symmetric
60
print(A) # Quick visual check
61
62
###############################################################################
63
# We can then use this idea to implement a linear layer with symmetric weights
64
class LinearSymmetric(nn.Module):
65
def __init__(self, n_features):
66
super().__init__()
67
self.weight = nn.Parameter(torch.rand(n_features, n_features))
68
69
def forward(self, x):
70
A = symmetric(self.weight)
71
return x @ A
72
73
###############################################################################
74
# The layer can be then used as a regular linear layer
75
layer = LinearSymmetric(3)
76
out = layer(torch.rand(8, 3))
77
78
###############################################################################
79
# This implementation, although correct and self-contained, presents a number of problems:
80
#
81
# 1) It reimplements the layer. We had to implement the linear layer as ``x @ A``. This is
82
# not very problematic for a linear layer, but imagine having to reimplement a CNN or a
83
# Transformer...
84
# 2) It does not separate the layer and the parametrization. If the parametrization were
85
# more difficult, we would have to rewrite its code for each layer that we want to use it
86
# in.
87
# 3) It recomputes the parametrization every time we use the layer. If we use the layer
88
# several times during the forward pass, (imagine the recurrent kernel of an RNN), it
89
# would compute the same ``A`` every time that the layer is called.
90
#
91
# Introduction to parametrizations
92
# --------------------------------
93
#
94
# Parametrizations can solve all these problems as well as others.
95
#
96
# Let's start by reimplementing the code above using ``torch.nn.utils.parametrize``.
97
# The only thing that we have to do is to write the parametrization as a regular ``nn.Module``
98
class Symmetric(nn.Module):
99
def forward(self, X):
100
return X.triu() + X.triu(1).transpose(-1, -2)
101
102
###############################################################################
103
# This is all we need to do. Once we have this, we can transform any regular layer into a
104
# symmetric layer by doing
105
layer = nn.Linear(3, 3)
106
parametrize.register_parametrization(layer, "weight", Symmetric())
107
108
###############################################################################
109
# Now, the matrix of the linear layer is symmetric
110
A = layer.weight
111
assert torch.allclose(A, A.T) # A is symmetric
112
print(A) # Quick visual check
113
114
###############################################################################
115
# We can do the same thing with any other layer. For example, we can create a CNN with
116
# `skew-symmetric <https://en.wikipedia.org/wiki/Skew-symmetric_matrix>`_ kernels.
117
# We use a similar parametrization, copying the upper-triangular part with signs
118
# reversed into the lower-triangular part
119
class Skew(nn.Module):
120
def forward(self, X):
121
A = X.triu(1)
122
return A - A.transpose(-1, -2)
123
124
125
cnn = nn.Conv2d(in_channels=5, out_channels=8, kernel_size=3)
126
parametrize.register_parametrization(cnn, "weight", Skew())
127
# Print a few kernels
128
print(cnn.weight[0, 1])
129
print(cnn.weight[2, 2])
130
131
###############################################################################
132
# Inspecting a parametrized module
133
# --------------------------------
134
#
135
# When a module is parametrized, we find that the module has changed in three ways:
136
#
137
# 1) ``model.weight`` is now a property
138
#
139
# 2) It has a new ``module.parametrizations`` attribute
140
#
141
# 3) The unparametrized weight has been moved to ``module.parametrizations.weight.original``
142
#
143
# |
144
# After parametrizing ``weight``, ``layer.weight`` is turned into a
145
# `Python property <https://docs.python.org/3/library/functions.html#property>`_.
146
# This property computes ``parametrization(weight)`` every time we request ``layer.weight``
147
# just as we did in our implementation of ``LinearSymmetric`` above.
148
#
149
# Registered parametrizations are stored under a ``parametrizations`` attribute within the module.
150
layer = nn.Linear(3, 3)
151
print(f"Unparametrized:\n{layer}")
152
parametrize.register_parametrization(layer, "weight", Symmetric())
153
print(f"\nParametrized:\n{layer}")
154
155
###############################################################################
156
# This ``parametrizations`` attribute is an ``nn.ModuleDict``, and it can be accessed as such
157
print(layer.parametrizations)
158
print(layer.parametrizations.weight)
159
160
###############################################################################
161
# Each element of this ``nn.ModuleDict`` is a ``ParametrizationList``, which behaves like an
162
# ``nn.Sequential``. This list will allow us to concatenate parametrizations on one weight.
163
# Since this is a list, we can access the parametrizations indexing it. Here's
164
# where our ``Symmetric`` parametrization sits
165
print(layer.parametrizations.weight[0])
166
167
###############################################################################
168
# The other thing that we notice is that, if we print the parameters, we see that the
169
# parameter ``weight`` has been moved
170
print(dict(layer.named_parameters()))
171
172
###############################################################################
173
# It now sits under ``layer.parametrizations.weight.original``
174
print(layer.parametrizations.weight.original)
175
176
###############################################################################
177
# Besides these three small differences, the parametrization is doing exactly the same
178
# as our manual implementation
179
symmetric = Symmetric()
180
weight_orig = layer.parametrizations.weight.original
181
print(torch.dist(layer.weight, symmetric(weight_orig)))
182
183
###############################################################################
184
# Parametrizations are first-class citizens
185
# -----------------------------------------
186
#
187
# Since ``layer.parametrizations`` is an ``nn.ModuleList``, it means that the parametrizations
188
# are properly registered as submodules of the original module. As such, the same rules
189
# for registering parameters in a module apply to register a parametrization.
190
# For example, if a parametrization has parameters, these will be moved from CPU
191
# to CUDA when calling ``model = model.cuda()``.
192
#
193
# Caching the value of a parametrization
194
# --------------------------------------
195
#
196
# Parametrizations come with an inbuilt caching system via the context manager
197
# ``parametrize.cached()``
198
class NoisyParametrization(nn.Module):
199
def forward(self, X):
200
print("Computing the Parametrization")
201
return X
202
203
layer = nn.Linear(4, 4)
204
parametrize.register_parametrization(layer, "weight", NoisyParametrization())
205
print("Here, layer.weight is recomputed every time we call it")
206
foo = layer.weight + layer.weight.T
207
bar = layer.weight.sum()
208
with parametrize.cached():
209
print("Here, it is computed just the first time layer.weight is called")
210
foo = layer.weight + layer.weight.T
211
bar = layer.weight.sum()
212
213
###############################################################################
214
# Concatenating parametrizations
215
# ------------------------------
216
#
217
# Concatenating two parametrizations is as easy as registering them on the same tensor.
218
# We may use this to create more complex parametrizations from simpler ones. For example, the
219
# `Cayley map <https://en.wikipedia.org/wiki/Cayley_transform#Matrix_map>`_
220
# maps the skew-symmetric matrices to the orthogonal matrices of positive determinant. We can
221
# concatenate ``Skew`` and a parametrization that implements the Cayley map to get a layer with
222
# orthogonal weights
223
class CayleyMap(nn.Module):
224
def __init__(self, n):
225
super().__init__()
226
self.register_buffer("Id", torch.eye(n))
227
228
def forward(self, X):
229
# (I + X)(I - X)^{-1}
230
return torch.linalg.solve(self.Id - X, self.Id + X)
231
232
layer = nn.Linear(3, 3)
233
parametrize.register_parametrization(layer, "weight", Skew())
234
parametrize.register_parametrization(layer, "weight", CayleyMap(3))
235
X = layer.weight
236
print(torch.dist(X.T @ X, torch.eye(3))) # X is orthogonal
237
238
###############################################################################
239
# This may also be used to prune a parametrized module, or to reuse parametrizations. For example,
240
# the matrix exponential maps the symmetric matrices to the Symmetric Positive Definite (SPD) matrices
241
# But the matrix exponential also maps the skew-symmetric matrices to the orthogonal matrices.
242
# Using these two facts, we may reuse the parametrizations before to our advantage
243
class MatrixExponential(nn.Module):
244
def forward(self, X):
245
return torch.matrix_exp(X)
246
247
layer_orthogonal = nn.Linear(3, 3)
248
parametrize.register_parametrization(layer_orthogonal, "weight", Skew())
249
parametrize.register_parametrization(layer_orthogonal, "weight", MatrixExponential())
250
X = layer_orthogonal.weight
251
print(torch.dist(X.T @ X, torch.eye(3))) # X is orthogonal
252
253
layer_spd = nn.Linear(3, 3)
254
parametrize.register_parametrization(layer_spd, "weight", Symmetric())
255
parametrize.register_parametrization(layer_spd, "weight", MatrixExponential())
256
X = layer_spd.weight
257
print(torch.dist(X, X.T)) # X is symmetric
258
print((torch.linalg.eigvalsh(X) > 0.).all()) # X is positive definite
259
260
###############################################################################
261
# Initializing parametrizations
262
# -----------------------------
263
#
264
# Parametrizations come with a mechanism to initialize them. If we implement a method
265
# ``right_inverse`` with signature
266
#
267
# .. code-block:: python
268
#
269
# def right_inverse(self, X: Tensor) -> Tensor
270
#
271
# it will be used when assigning to the parametrized tensor.
272
#
273
# Let's upgrade our implementation of the ``Skew`` class to support this
274
class Skew(nn.Module):
275
def forward(self, X):
276
A = X.triu(1)
277
return A - A.transpose(-1, -2)
278
279
def right_inverse(self, A):
280
# We assume that A is skew-symmetric
281
# We take the upper-triangular elements, as these are those used in the forward
282
return A.triu(1)
283
284
###############################################################################
285
# We may now initialize a layer that is parametrized with ``Skew``
286
layer = nn.Linear(3, 3)
287
parametrize.register_parametrization(layer, "weight", Skew())
288
X = torch.rand(3, 3)
289
X = X - X.T # X is now skew-symmetric
290
layer.weight = X # Initialize layer.weight to be X
291
print(torch.dist(layer.weight, X)) # layer.weight == X
292
293
###############################################################################
294
# This ``right_inverse`` works as expected when we concatenate parametrizations.
295
# To see this, let's upgrade the Cayley parametrization to also support being initialized
296
class CayleyMap(nn.Module):
297
def __init__(self, n):
298
super().__init__()
299
self.register_buffer("Id", torch.eye(n))
300
301
def forward(self, X):
302
# Assume X skew-symmetric
303
# (I + X)(I - X)^{-1}
304
return torch.linalg.solve(self.Id - X, self.Id + X)
305
306
def right_inverse(self, A):
307
# Assume A orthogonal
308
# See https://en.wikipedia.org/wiki/Cayley_transform#Matrix_map
309
# (A - I)(A + I)^{-1}
310
return torch.linalg.solve(A + self.Id, self.Id - A)
311
312
layer_orthogonal = nn.Linear(3, 3)
313
parametrize.register_parametrization(layer_orthogonal, "weight", Skew())
314
parametrize.register_parametrization(layer_orthogonal, "weight", CayleyMap(3))
315
# Sample an orthogonal matrix with positive determinant
316
X = torch.empty(3, 3)
317
nn.init.orthogonal_(X)
318
if X.det() < 0.:
319
X[0].neg_()
320
layer_orthogonal.weight = X
321
print(torch.dist(layer_orthogonal.weight, X)) # layer_orthogonal.weight == X
322
323
###############################################################################
324
# This initialization step can be written more succinctly as
325
layer_orthogonal.weight = nn.init.orthogonal_(layer_orthogonal.weight)
326
327
###############################################################################
328
# The name of this method comes from the fact that we would often expect
329
# that ``forward(right_inverse(X)) == X``. This is a direct way of rewriting that
330
# the forward after the initialization with value ``X`` should return the value ``X``.
331
# This constraint is not strongly enforced in practice. In fact, at times, it might be of
332
# interest to relax this relation. For example, consider the following implementation
333
# of a randomized pruning method:
334
class PruningParametrization(nn.Module):
335
def __init__(self, X, p_drop=0.2):
336
super().__init__()
337
# sample zeros with probability p_drop
338
mask = torch.full_like(X, 1.0 - p_drop)
339
self.mask = torch.bernoulli(mask)
340
341
def forward(self, X):
342
return X * self.mask
343
344
def right_inverse(self, A):
345
return A
346
347
###############################################################################
348
# In this case, it is not true that for every matrix A ``forward(right_inverse(A)) == A``.
349
# This is only true when the matrix ``A`` has zeros in the same positions as the mask.
350
# Even then, if we assign a tensor to a pruned parameter, it will comes as no surprise
351
# that tensor will be, in fact, pruned
352
layer = nn.Linear(3, 4)
353
X = torch.rand_like(layer.weight)
354
print(f"Initialization matrix:\n{X}")
355
parametrize.register_parametrization(layer, "weight", PruningParametrization(layer.weight))
356
layer.weight = X
357
print(f"\nInitialized weight:\n{layer.weight}")
358
359
###############################################################################
360
# Removing parametrizations
361
# -------------------------
362
#
363
# We may remove all the parametrizations from a parameter or a buffer in a module
364
# by using ``parametrize.remove_parametrizations()``
365
layer = nn.Linear(3, 3)
366
print("Before:")
367
print(layer)
368
print(layer.weight)
369
parametrize.register_parametrization(layer, "weight", Skew())
370
print("\nParametrized:")
371
print(layer)
372
print(layer.weight)
373
parametrize.remove_parametrizations(layer, "weight")
374
print("\nAfter. Weight has skew-symmetric values but it is unconstrained:")
375
print(layer)
376
print(layer.weight)
377
378
###############################################################################
379
# When removing a parametrization, we may choose to leave the original parameter (i.e. that in
380
# ``layer.parametriations.weight.original``) rather than its parametrized version by setting
381
# the flag ``leave_parametrized=False``
382
layer = nn.Linear(3, 3)
383
print("Before:")
384
print(layer)
385
print(layer.weight)
386
parametrize.register_parametrization(layer, "weight", Skew())
387
print("\nParametrized:")
388
print(layer)
389
print(layer.weight)
390
parametrize.remove_parametrizations(layer, "weight", leave_parametrized=False)
391
print("\nAfter. Same as Before:")
392
print(layer)
393
print(layer.weight)
394
395