Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Path: blob/main/intermediate_source/pruning_tutorial.py
Views: 712
# -*- coding: utf-8 -*-1"""2Pruning Tutorial3=====================================4**Author**: `Michela Paganini <https://github.com/mickypaganini>`_56State-of-the-art deep learning techniques rely on over-parametrized models7that are hard to deploy. On the contrary, biological neural networks are8known to use efficient sparse connectivity. Identifying optimal9techniques to compress models by reducing the number of parameters in them is10important in order to reduce memory, battery, and hardware consumption without11sacrificing accuracy. This in turn allows you to deploy lightweight models on device, and guarantee12privacy with private on-device computation. On the research front, pruning is13used to investigate the differences in learning dynamics between14over-parametrized and under-parametrized networks, to study the role of lucky15sparse subnetworks and initializations16("`lottery tickets <https://arxiv.org/abs/1803.03635>`_") as a destructive17neural architecture search technique, and more.1819In this tutorial, you will learn how to use ``torch.nn.utils.prune`` to20sparsify your neural networks, and how to extend it to implement your21own custom pruning technique.2223Requirements24------------25``"torch>=1.4.0a0+8e8a5e0"``2627"""28import torch29from torch import nn30import torch.nn.utils.prune as prune31import torch.nn.functional as F3233######################################################################34# Create a model35# --------------36#37# In this tutorial, we use the `LeNet38# <http://yann.lecun.com/exdb/publis/pdf/lecun-98.pdf>`_ architecture from39# LeCun et al., 1998.4041device = torch.device("cuda" if torch.cuda.is_available() else "cpu")4243class LeNet(nn.Module):44def __init__(self):45super(LeNet, self).__init__()46# 1 input image channel, 6 output channels, 5x5 square conv kernel47self.conv1 = nn.Conv2d(1, 6, 5)48self.conv2 = nn.Conv2d(6, 16, 5)49self.fc1 = nn.Linear(16 * 5 * 5, 120) # 5x5 image dimension50self.fc2 = nn.Linear(120, 84)51self.fc3 = nn.Linear(84, 10)5253def forward(self, x):54x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))55x = F.max_pool2d(F.relu(self.conv2(x)), 2)56x = x.view(-1, int(x.nelement() / x.shape[0]))57x = F.relu(self.fc1(x))58x = F.relu(self.fc2(x))59x = self.fc3(x)60return x6162model = LeNet().to(device=device)636465######################################################################66# Inspect a Module67# ----------------68#69# Let's inspect the (unpruned) ``conv1`` layer in our LeNet model. It will contain two70# parameters ``weight`` and ``bias``, and no buffers, for now.71module = model.conv172print(list(module.named_parameters()))7374######################################################################75print(list(module.named_buffers()))7677######################################################################78# Pruning a Module79# ----------------80#81# To prune a module (in this example, the ``conv1`` layer of our LeNet82# architecture), first select a pruning technique among those available in83# ``torch.nn.utils.prune`` (or84# `implement <#extending-torch-nn-utils-pruning-with-custom-pruning-functions>`_85# your own by subclassing86# ``BasePruningMethod``). Then, specify the module and the name of the parameter to87# prune within that module. Finally, using the adequate keyword arguments88# required by the selected pruning technique, specify the pruning parameters.89#90# In this example, we will prune at random 30% of the connections in91# the parameter named ``weight`` in the ``conv1`` layer.92# The module is passed as the first argument to the function; ``name``93# identifies the parameter within that module using its string identifier; and94# ``amount`` indicates either the percentage of connections to prune (if it95# is a float between 0. and 1.), or the absolute number of connections to96# prune (if it is a non-negative integer).97prune.random_unstructured(module, name="weight", amount=0.3)9899######################################################################100# Pruning acts by removing ``weight`` from the parameters and replacing it with101# a new parameter called ``weight_orig`` (i.e. appending ``"_orig"`` to the102# initial parameter ``name``). ``weight_orig`` stores the unpruned version of103# the tensor. The ``bias`` was not pruned, so it will remain intact.104print(list(module.named_parameters()))105106######################################################################107# The pruning mask generated by the pruning technique selected above is saved108# as a module buffer named ``weight_mask`` (i.e. appending ``"_mask"`` to the109# initial parameter ``name``).110print(list(module.named_buffers()))111112######################################################################113# For the forward pass to work without modification, the ``weight`` attribute114# needs to exist. The pruning techniques implemented in115# ``torch.nn.utils.prune`` compute the pruned version of the weight (by116# combining the mask with the original parameter) and store them in the117# attribute ``weight``. Note, this is no longer a parameter of the ``module``,118# it is now simply an attribute.119print(module.weight)120121######################################################################122# Finally, pruning is applied prior to each forward pass using PyTorch's123# ``forward_pre_hooks``. Specifically, when the ``module`` is pruned, as we124# have done here, it will acquire a ``forward_pre_hook`` for each parameter125# associated with it that gets pruned. In this case, since we have so far126# only pruned the original parameter named ``weight``, only one hook will be127# present.128print(module._forward_pre_hooks)129130######################################################################131# For completeness, we can now prune the ``bias`` too, to see how the132# parameters, buffers, hooks, and attributes of the ``module`` change.133# Just for the sake of trying out another pruning technique, here we prune the134# 3 smallest entries in the bias by L1 norm, as implemented in the135# ``l1_unstructured`` pruning function.136prune.l1_unstructured(module, name="bias", amount=3)137138######################################################################139# We now expect the named parameters to include both ``weight_orig`` (from140# before) and ``bias_orig``. The buffers will include ``weight_mask`` and141# ``bias_mask``. The pruned versions of the two tensors will exist as142# module attributes, and the module will now have two ``forward_pre_hooks``.143print(list(module.named_parameters()))144145######################################################################146print(list(module.named_buffers()))147148######################################################################149print(module.bias)150151######################################################################152print(module._forward_pre_hooks)153154######################################################################155# Iterative Pruning156# -----------------157#158# The same parameter in a module can be pruned multiple times, with the159# effect of the various pruning calls being equal to the combination of the160# various masks applied in series.161# The combination of a new mask with the old mask is handled by the162# ``PruningContainer``'s ``compute_mask`` method.163#164# Say, for example, that we now want to further prune ``module.weight``, this165# time using structured pruning along the 0th axis of the tensor (the 0th axis166# corresponds to the output channels of the convolutional layer and has167# dimensionality 6 for ``conv1``), based on the channels' L2 norm. This can be168# achieved using the ``ln_structured`` function, with ``n=2`` and ``dim=0``.169prune.ln_structured(module, name="weight", amount=0.5, n=2, dim=0)170171# As we can verify, this will zero out all the connections corresponding to172# 50% (3 out of 6) of the channels, while preserving the action of the173# previous mask.174print(module.weight)175176############################################################################177# The corresponding hook will now be of type178# ``torch.nn.utils.prune.PruningContainer``, and will store the history of179# pruning applied to the ``weight`` parameter.180for hook in module._forward_pre_hooks.values():181if hook._tensor_name == "weight": # select out the correct hook182break183184print(list(hook)) # pruning history in the container185186######################################################################187# Serializing a pruned model188# --------------------------189# All relevant tensors, including the mask buffers and the original parameters190# used to compute the pruned tensors are stored in the model's ``state_dict``191# and can therefore be easily serialized and saved, if needed.192print(model.state_dict().keys())193194195######################################################################196# Remove pruning re-parametrization197# ---------------------------------198#199# To make the pruning permanent, remove the re-parametrization in terms200# of ``weight_orig`` and ``weight_mask``, and remove the ``forward_pre_hook``,201# we can use the ``remove`` functionality from ``torch.nn.utils.prune``.202# Note that this doesn't undo the pruning, as if it never happened. It simply203# makes it permanent, instead, by reassigning the parameter ``weight`` to the204# model parameters, in its pruned version.205206######################################################################207# Prior to removing the re-parametrization:208print(list(module.named_parameters()))209######################################################################210print(list(module.named_buffers()))211######################################################################212print(module.weight)213214######################################################################215# After removing the re-parametrization:216prune.remove(module, 'weight')217print(list(module.named_parameters()))218######################################################################219print(list(module.named_buffers()))220221######################################################################222# Pruning multiple parameters in a model223# --------------------------------------224#225# By specifying the desired pruning technique and parameters, we can easily226# prune multiple tensors in a network, perhaps according to their type, as we227# will see in this example.228229new_model = LeNet()230for name, module in new_model.named_modules():231# prune 20% of connections in all 2D-conv layers232if isinstance(module, torch.nn.Conv2d):233prune.l1_unstructured(module, name='weight', amount=0.2)234# prune 40% of connections in all linear layers235elif isinstance(module, torch.nn.Linear):236prune.l1_unstructured(module, name='weight', amount=0.4)237238print(dict(new_model.named_buffers()).keys()) # to verify that all masks exist239240######################################################################241# Global pruning242# --------------243#244# So far, we only looked at what is usually referred to as "local" pruning,245# i.e. the practice of pruning tensors in a model one by one, by246# comparing the statistics (weight magnitude, activation, gradient, etc.) of247# each entry exclusively to the other entries in that tensor. However, a248# common and perhaps more powerful technique is to prune the model all at249# once, by removing (for example) the lowest 20% of connections across the250# whole model, instead of removing the lowest 20% of connections in each251# layer. This is likely to result in different pruning percentages per layer.252# Let's see how to do that using ``global_unstructured`` from253# ``torch.nn.utils.prune``.254255model = LeNet()256257parameters_to_prune = (258(model.conv1, 'weight'),259(model.conv2, 'weight'),260(model.fc1, 'weight'),261(model.fc2, 'weight'),262(model.fc3, 'weight'),263)264265prune.global_unstructured(266parameters_to_prune,267pruning_method=prune.L1Unstructured,268amount=0.2,269)270271######################################################################272# Now we can check the sparsity induced in every pruned parameter, which will273# not be equal to 20% in each layer. However, the global sparsity will be274# (approximately) 20%.275print(276"Sparsity in conv1.weight: {:.2f}%".format(277100. * float(torch.sum(model.conv1.weight == 0))278/ float(model.conv1.weight.nelement())279)280)281print(282"Sparsity in conv2.weight: {:.2f}%".format(283100. * float(torch.sum(model.conv2.weight == 0))284/ float(model.conv2.weight.nelement())285)286)287print(288"Sparsity in fc1.weight: {:.2f}%".format(289100. * float(torch.sum(model.fc1.weight == 0))290/ float(model.fc1.weight.nelement())291)292)293print(294"Sparsity in fc2.weight: {:.2f}%".format(295100. * float(torch.sum(model.fc2.weight == 0))296/ float(model.fc2.weight.nelement())297)298)299print(300"Sparsity in fc3.weight: {:.2f}%".format(301100. * float(torch.sum(model.fc3.weight == 0))302/ float(model.fc3.weight.nelement())303)304)305print(306"Global sparsity: {:.2f}%".format(307100. * float(308torch.sum(model.conv1.weight == 0)309+ torch.sum(model.conv2.weight == 0)310+ torch.sum(model.fc1.weight == 0)311+ torch.sum(model.fc2.weight == 0)312+ torch.sum(model.fc3.weight == 0)313)314/ float(315model.conv1.weight.nelement()316+ model.conv2.weight.nelement()317+ model.fc1.weight.nelement()318+ model.fc2.weight.nelement()319+ model.fc3.weight.nelement()320)321)322)323324325######################################################################326# Extending ``torch.nn.utils.prune`` with custom pruning functions327# ------------------------------------------------------------------328# To implement your own pruning function, you can extend the329# ``nn.utils.prune`` module by subclassing the ``BasePruningMethod``330# base class, the same way all other pruning methods do. The base class331# implements the following methods for you: ``__call__``, ``apply_mask``,332# ``apply``, ``prune``, and ``remove``. Beyond some special cases, you shouldn't333# have to reimplement these methods for your new pruning technique.334# You will, however, have to implement ``__init__`` (the constructor),335# and ``compute_mask`` (the instructions on how to compute the mask336# for the given tensor according to the logic of your pruning337# technique). In addition, you will have to specify which type of338# pruning this technique implements (supported options are ``global``,339# ``structured``, and ``unstructured``). This is needed to determine340# how to combine masks in the case in which pruning is applied341# iteratively. In other words, when pruning a prepruned parameter,342# the current pruning technique is expected to act on the unpruned343# portion of the parameter. Specifying the ``PRUNING_TYPE`` will344# enable the ``PruningContainer`` (which handles the iterative345# application of pruning masks) to correctly identify the slice of the346# parameter to prune.347#348# Let's assume, for example, that you want to implement a pruning349# technique that prunes every other entry in a tensor (or -- if the350# tensor has previously been pruned -- in the remaining unpruned351# portion of the tensor). This will be of ``PRUNING_TYPE='unstructured'``352# because it acts on individual connections in a layer and not on entire353# units/channels (``'structured'``), or across different parameters354# (``'global'``).355356class FooBarPruningMethod(prune.BasePruningMethod):357"""Prune every other entry in a tensor358"""359PRUNING_TYPE = 'unstructured'360361def compute_mask(self, t, default_mask):362mask = default_mask.clone()363mask.view(-1)[::2] = 0364return mask365366######################################################################367# Now, to apply this to a parameter in an ``nn.Module``, you should368# also provide a simple function that instantiates the method and369# applies it.370def foobar_unstructured(module, name):371"""Prunes tensor corresponding to parameter called `name` in `module`372by removing every other entry in the tensors.373Modifies module in place (and also return the modified module)374by:3751) adding a named buffer called `name+'_mask'` corresponding to the376binary mask applied to the parameter `name` by the pruning method.377The parameter `name` is replaced by its pruned version, while the378original (unpruned) parameter is stored in a new parameter named379`name+'_orig'`.380381Args:382module (nn.Module): module containing the tensor to prune383name (string): parameter name within `module` on which pruning384will act.385386Returns:387module (nn.Module): modified (i.e. pruned) version of the input388module389390Examples:391>>> m = nn.Linear(3, 4)392>>> foobar_unstructured(m, name='bias')393"""394FooBarPruningMethod.apply(module, name)395return module396397######################################################################398# Let's try it out!399model = LeNet()400foobar_unstructured(model.fc3, name='bias')401402print(model.fc3.bias_mask)403404405