Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Path: blob/main/intermediate_source/parametrizations.py
Views: 712
# -*- coding: utf-8 -*-1"""2Parametrizations Tutorial3=========================4**Author**: `Mario Lezcano <https://github.com/lezcano>`_56Regularizing deep-learning models is a surprisingly challenging task.7Classical techniques such as penalty methods often fall short when applied8on deep models due to the complexity of the function being optimized.9This is particularly problematic when working with ill-conditioned models.10Examples of these are RNNs trained on long sequences and GANs. A number11of techniques have been proposed in recent years to regularize these12models and improve their convergence. On recurrent models, it has been13proposed to control the singular values of the recurrent kernel for the14RNN to be well-conditioned. This can be achieved, for example, by making15the recurrent kernel `orthogonal <https://en.wikipedia.org/wiki/Orthogonal_matrix>`_.16Another way to regularize recurrent models is via17"`weight normalization <https://pytorch.org/docs/stable/generated/torch.nn.utils.weight_norm.html>`_".18This approach proposes to decouple the learning of the parameters from the19learning of their norms. To do so, the parameter is divided by its20`Frobenius norm <https://en.wikipedia.org/wiki/Matrix_norm#Frobenius_norm>`_21and a separate parameter encoding its norm is learned.22A similar regularization was proposed for GANs under the name of23"`spectral normalization <https://pytorch.org/docs/stable/generated/torch.nn.utils.spectral_norm.html>`_". This method24controls the Lipschitz constant of the network by dividing its parameters by25their `spectral norm <https://en.wikipedia.org/wiki/Matrix_norm#Special_cases>`_,26rather than their Frobenius norm.2728All these methods have a common pattern: they all transform a parameter29in an appropriate way before using it. In the first case, they make it orthogonal by30using a function that maps matrices to orthogonal matrices. In the case of weight31and spectral normalization, they divide the original parameter by its norm.3233More generally, all these examples use a function to put extra structure on the parameters.34In other words, they use a function to constrain the parameters.3536In this tutorial, you will learn how to implement and use this pattern to put37constraints on your model. Doing so is as easy as writing your own ``nn.Module``.3839Requirements: ``torch>=1.9.0``4041Implementing parametrizations by hand42-------------------------------------4344Assume that we want to have a square linear layer with symmetric weights, that is,45with weights ``X`` such that ``X = Xᵀ``. One way to do so is46to copy the upper-triangular part of the matrix into its lower-triangular part47"""4849import torch50import torch.nn as nn51import torch.nn.utils.parametrize as parametrize5253def symmetric(X):54return X.triu() + X.triu(1).transpose(-1, -2)5556X = torch.rand(3, 3)57A = symmetric(X)58assert torch.allclose(A, A.T) # A is symmetric59print(A) # Quick visual check6061###############################################################################62# We can then use this idea to implement a linear layer with symmetric weights63class LinearSymmetric(nn.Module):64def __init__(self, n_features):65super().__init__()66self.weight = nn.Parameter(torch.rand(n_features, n_features))6768def forward(self, x):69A = symmetric(self.weight)70return x @ A7172###############################################################################73# The layer can be then used as a regular linear layer74layer = LinearSymmetric(3)75out = layer(torch.rand(8, 3))7677###############################################################################78# This implementation, although correct and self-contained, presents a number of problems:79#80# 1) It reimplements the layer. We had to implement the linear layer as ``x @ A``. This is81# not very problematic for a linear layer, but imagine having to reimplement a CNN or a82# Transformer...83# 2) It does not separate the layer and the parametrization. If the parametrization were84# more difficult, we would have to rewrite its code for each layer that we want to use it85# in.86# 3) It recomputes the parametrization every time we use the layer. If we use the layer87# several times during the forward pass, (imagine the recurrent kernel of an RNN), it88# would compute the same ``A`` every time that the layer is called.89#90# Introduction to parametrizations91# --------------------------------92#93# Parametrizations can solve all these problems as well as others.94#95# Let's start by reimplementing the code above using ``torch.nn.utils.parametrize``.96# The only thing that we have to do is to write the parametrization as a regular ``nn.Module``97class Symmetric(nn.Module):98def forward(self, X):99return X.triu() + X.triu(1).transpose(-1, -2)100101###############################################################################102# This is all we need to do. Once we have this, we can transform any regular layer into a103# symmetric layer by doing104layer = nn.Linear(3, 3)105parametrize.register_parametrization(layer, "weight", Symmetric())106107###############################################################################108# Now, the matrix of the linear layer is symmetric109A = layer.weight110assert torch.allclose(A, A.T) # A is symmetric111print(A) # Quick visual check112113###############################################################################114# We can do the same thing with any other layer. For example, we can create a CNN with115# `skew-symmetric <https://en.wikipedia.org/wiki/Skew-symmetric_matrix>`_ kernels.116# We use a similar parametrization, copying the upper-triangular part with signs117# reversed into the lower-triangular part118class Skew(nn.Module):119def forward(self, X):120A = X.triu(1)121return A - A.transpose(-1, -2)122123124cnn = nn.Conv2d(in_channels=5, out_channels=8, kernel_size=3)125parametrize.register_parametrization(cnn, "weight", Skew())126# Print a few kernels127print(cnn.weight[0, 1])128print(cnn.weight[2, 2])129130###############################################################################131# Inspecting a parametrized module132# --------------------------------133#134# When a module is parametrized, we find that the module has changed in three ways:135#136# 1) ``model.weight`` is now a property137#138# 2) It has a new ``module.parametrizations`` attribute139#140# 3) The unparametrized weight has been moved to ``module.parametrizations.weight.original``141#142# |143# After parametrizing ``weight``, ``layer.weight`` is turned into a144# `Python property <https://docs.python.org/3/library/functions.html#property>`_.145# This property computes ``parametrization(weight)`` every time we request ``layer.weight``146# just as we did in our implementation of ``LinearSymmetric`` above.147#148# Registered parametrizations are stored under a ``parametrizations`` attribute within the module.149layer = nn.Linear(3, 3)150print(f"Unparametrized:\n{layer}")151parametrize.register_parametrization(layer, "weight", Symmetric())152print(f"\nParametrized:\n{layer}")153154###############################################################################155# This ``parametrizations`` attribute is an ``nn.ModuleDict``, and it can be accessed as such156print(layer.parametrizations)157print(layer.parametrizations.weight)158159###############################################################################160# Each element of this ``nn.ModuleDict`` is a ``ParametrizationList``, which behaves like an161# ``nn.Sequential``. This list will allow us to concatenate parametrizations on one weight.162# Since this is a list, we can access the parametrizations indexing it. Here's163# where our ``Symmetric`` parametrization sits164print(layer.parametrizations.weight[0])165166###############################################################################167# The other thing that we notice is that, if we print the parameters, we see that the168# parameter ``weight`` has been moved169print(dict(layer.named_parameters()))170171###############################################################################172# It now sits under ``layer.parametrizations.weight.original``173print(layer.parametrizations.weight.original)174175###############################################################################176# Besides these three small differences, the parametrization is doing exactly the same177# as our manual implementation178symmetric = Symmetric()179weight_orig = layer.parametrizations.weight.original180print(torch.dist(layer.weight, symmetric(weight_orig)))181182###############################################################################183# Parametrizations are first-class citizens184# -----------------------------------------185#186# Since ``layer.parametrizations`` is an ``nn.ModuleList``, it means that the parametrizations187# are properly registered as submodules of the original module. As such, the same rules188# for registering parameters in a module apply to register a parametrization.189# For example, if a parametrization has parameters, these will be moved from CPU190# to CUDA when calling ``model = model.cuda()``.191#192# Caching the value of a parametrization193# --------------------------------------194#195# Parametrizations come with an inbuilt caching system via the context manager196# ``parametrize.cached()``197class NoisyParametrization(nn.Module):198def forward(self, X):199print("Computing the Parametrization")200return X201202layer = nn.Linear(4, 4)203parametrize.register_parametrization(layer, "weight", NoisyParametrization())204print("Here, layer.weight is recomputed every time we call it")205foo = layer.weight + layer.weight.T206bar = layer.weight.sum()207with parametrize.cached():208print("Here, it is computed just the first time layer.weight is called")209foo = layer.weight + layer.weight.T210bar = layer.weight.sum()211212###############################################################################213# Concatenating parametrizations214# ------------------------------215#216# Concatenating two parametrizations is as easy as registering them on the same tensor.217# We may use this to create more complex parametrizations from simpler ones. For example, the218# `Cayley map <https://en.wikipedia.org/wiki/Cayley_transform#Matrix_map>`_219# maps the skew-symmetric matrices to the orthogonal matrices of positive determinant. We can220# concatenate ``Skew`` and a parametrization that implements the Cayley map to get a layer with221# orthogonal weights222class CayleyMap(nn.Module):223def __init__(self, n):224super().__init__()225self.register_buffer("Id", torch.eye(n))226227def forward(self, X):228# (I + X)(I - X)^{-1}229return torch.linalg.solve(self.Id - X, self.Id + X)230231layer = nn.Linear(3, 3)232parametrize.register_parametrization(layer, "weight", Skew())233parametrize.register_parametrization(layer, "weight", CayleyMap(3))234X = layer.weight235print(torch.dist(X.T @ X, torch.eye(3))) # X is orthogonal236237###############################################################################238# This may also be used to prune a parametrized module, or to reuse parametrizations. For example,239# the matrix exponential maps the symmetric matrices to the Symmetric Positive Definite (SPD) matrices240# But the matrix exponential also maps the skew-symmetric matrices to the orthogonal matrices.241# Using these two facts, we may reuse the parametrizations before to our advantage242class MatrixExponential(nn.Module):243def forward(self, X):244return torch.matrix_exp(X)245246layer_orthogonal = nn.Linear(3, 3)247parametrize.register_parametrization(layer_orthogonal, "weight", Skew())248parametrize.register_parametrization(layer_orthogonal, "weight", MatrixExponential())249X = layer_orthogonal.weight250print(torch.dist(X.T @ X, torch.eye(3))) # X is orthogonal251252layer_spd = nn.Linear(3, 3)253parametrize.register_parametrization(layer_spd, "weight", Symmetric())254parametrize.register_parametrization(layer_spd, "weight", MatrixExponential())255X = layer_spd.weight256print(torch.dist(X, X.T)) # X is symmetric257print((torch.linalg.eigvalsh(X) > 0.).all()) # X is positive definite258259###############################################################################260# Initializing parametrizations261# -----------------------------262#263# Parametrizations come with a mechanism to initialize them. If we implement a method264# ``right_inverse`` with signature265#266# .. code-block:: python267#268# def right_inverse(self, X: Tensor) -> Tensor269#270# it will be used when assigning to the parametrized tensor.271#272# Let's upgrade our implementation of the ``Skew`` class to support this273class Skew(nn.Module):274def forward(self, X):275A = X.triu(1)276return A - A.transpose(-1, -2)277278def right_inverse(self, A):279# We assume that A is skew-symmetric280# We take the upper-triangular elements, as these are those used in the forward281return A.triu(1)282283###############################################################################284# We may now initialize a layer that is parametrized with ``Skew``285layer = nn.Linear(3, 3)286parametrize.register_parametrization(layer, "weight", Skew())287X = torch.rand(3, 3)288X = X - X.T # X is now skew-symmetric289layer.weight = X # Initialize layer.weight to be X290print(torch.dist(layer.weight, X)) # layer.weight == X291292###############################################################################293# This ``right_inverse`` works as expected when we concatenate parametrizations.294# To see this, let's upgrade the Cayley parametrization to also support being initialized295class CayleyMap(nn.Module):296def __init__(self, n):297super().__init__()298self.register_buffer("Id", torch.eye(n))299300def forward(self, X):301# Assume X skew-symmetric302# (I + X)(I - X)^{-1}303return torch.linalg.solve(self.Id - X, self.Id + X)304305def right_inverse(self, A):306# Assume A orthogonal307# See https://en.wikipedia.org/wiki/Cayley_transform#Matrix_map308# (A - I)(A + I)^{-1}309return torch.linalg.solve(A + self.Id, self.Id - A)310311layer_orthogonal = nn.Linear(3, 3)312parametrize.register_parametrization(layer_orthogonal, "weight", Skew())313parametrize.register_parametrization(layer_orthogonal, "weight", CayleyMap(3))314# Sample an orthogonal matrix with positive determinant315X = torch.empty(3, 3)316nn.init.orthogonal_(X)317if X.det() < 0.:318X[0].neg_()319layer_orthogonal.weight = X320print(torch.dist(layer_orthogonal.weight, X)) # layer_orthogonal.weight == X321322###############################################################################323# This initialization step can be written more succinctly as324layer_orthogonal.weight = nn.init.orthogonal_(layer_orthogonal.weight)325326###############################################################################327# The name of this method comes from the fact that we would often expect328# that ``forward(right_inverse(X)) == X``. This is a direct way of rewriting that329# the forward after the initialization with value ``X`` should return the value ``X``.330# This constraint is not strongly enforced in practice. In fact, at times, it might be of331# interest to relax this relation. For example, consider the following implementation332# of a randomized pruning method:333class PruningParametrization(nn.Module):334def __init__(self, X, p_drop=0.2):335super().__init__()336# sample zeros with probability p_drop337mask = torch.full_like(X, 1.0 - p_drop)338self.mask = torch.bernoulli(mask)339340def forward(self, X):341return X * self.mask342343def right_inverse(self, A):344return A345346###############################################################################347# In this case, it is not true that for every matrix A ``forward(right_inverse(A)) == A``.348# This is only true when the matrix ``A`` has zeros in the same positions as the mask.349# Even then, if we assign a tensor to a pruned parameter, it will comes as no surprise350# that tensor will be, in fact, pruned351layer = nn.Linear(3, 4)352X = torch.rand_like(layer.weight)353print(f"Initialization matrix:\n{X}")354parametrize.register_parametrization(layer, "weight", PruningParametrization(layer.weight))355layer.weight = X356print(f"\nInitialized weight:\n{layer.weight}")357358###############################################################################359# Removing parametrizations360# -------------------------361#362# We may remove all the parametrizations from a parameter or a buffer in a module363# by using ``parametrize.remove_parametrizations()``364layer = nn.Linear(3, 3)365print("Before:")366print(layer)367print(layer.weight)368parametrize.register_parametrization(layer, "weight", Skew())369print("\nParametrized:")370print(layer)371print(layer.weight)372parametrize.remove_parametrizations(layer, "weight")373print("\nAfter. Weight has skew-symmetric values but it is unconstrained:")374print(layer)375print(layer.weight)376377###############################################################################378# When removing a parametrization, we may choose to leave the original parameter (i.e. that in379# ``layer.parametriations.weight.original``) rather than its parametrized version by setting380# the flag ``leave_parametrized=False``381layer = nn.Linear(3, 3)382print("Before:")383print(layer)384print(layer.weight)385parametrize.register_parametrization(layer, "weight", Skew())386print("\nParametrized:")387print(layer)388print(layer.weight)389parametrize.remove_parametrizations(layer, "weight", leave_parametrized=False)390print("\nAfter. Same as Before:")391print(layer)392print(layer.weight)393394395