CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.

GitHub Repository: pytorch/tutorials
Path: blob/main/intermediate_source/pruning_tutorial.py
Views: 712
1
# -*- coding: utf-8 -*-
2
"""
3
Pruning Tutorial
4
=====================================
5
**Author**: `Michela Paganini <https://github.com/mickypaganini>`_
6
7
State-of-the-art deep learning techniques rely on over-parametrized models
8
that are hard to deploy. On the contrary, biological neural networks are
9
known to use efficient sparse connectivity. Identifying optimal
10
techniques to compress models by reducing the number of parameters in them is
11
important in order to reduce memory, battery, and hardware consumption without
12
sacrificing accuracy. This in turn allows you to deploy lightweight models on device, and guarantee
13
privacy with private on-device computation. On the research front, pruning is
14
used to investigate the differences in learning dynamics between
15
over-parametrized and under-parametrized networks, to study the role of lucky
16
sparse subnetworks and initializations
17
("`lottery tickets <https://arxiv.org/abs/1803.03635>`_") as a destructive
18
neural architecture search technique, and more.
19
20
In this tutorial, you will learn how to use ``torch.nn.utils.prune`` to
21
sparsify your neural networks, and how to extend it to implement your
22
own custom pruning technique.
23
24
Requirements
25
------------
26
``"torch>=1.4.0a0+8e8a5e0"``
27
28
"""
29
import torch
30
from torch import nn
31
import torch.nn.utils.prune as prune
32
import torch.nn.functional as F
33
34
######################################################################
35
# Create a model
36
# --------------
37
#
38
# In this tutorial, we use the `LeNet
39
# <http://yann.lecun.com/exdb/publis/pdf/lecun-98.pdf>`_ architecture from
40
# LeCun et al., 1998.
41
42
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43
44
class LeNet(nn.Module):
45
def __init__(self):
46
super(LeNet, self).__init__()
47
# 1 input image channel, 6 output channels, 5x5 square conv kernel
48
self.conv1 = nn.Conv2d(1, 6, 5)
49
self.conv2 = nn.Conv2d(6, 16, 5)
50
self.fc1 = nn.Linear(16 * 5 * 5, 120) # 5x5 image dimension
51
self.fc2 = nn.Linear(120, 84)
52
self.fc3 = nn.Linear(84, 10)
53
54
def forward(self, x):
55
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
56
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
57
x = x.view(-1, int(x.nelement() / x.shape[0]))
58
x = F.relu(self.fc1(x))
59
x = F.relu(self.fc2(x))
60
x = self.fc3(x)
61
return x
62
63
model = LeNet().to(device=device)
64
65
66
######################################################################
67
# Inspect a Module
68
# ----------------
69
#
70
# Let's inspect the (unpruned) ``conv1`` layer in our LeNet model. It will contain two
71
# parameters ``weight`` and ``bias``, and no buffers, for now.
72
module = model.conv1
73
print(list(module.named_parameters()))
74
75
######################################################################
76
print(list(module.named_buffers()))
77
78
######################################################################
79
# Pruning a Module
80
# ----------------
81
#
82
# To prune a module (in this example, the ``conv1`` layer of our LeNet
83
# architecture), first select a pruning technique among those available in
84
# ``torch.nn.utils.prune`` (or
85
# `implement <#extending-torch-nn-utils-pruning-with-custom-pruning-functions>`_
86
# your own by subclassing
87
# ``BasePruningMethod``). Then, specify the module and the name of the parameter to
88
# prune within that module. Finally, using the adequate keyword arguments
89
# required by the selected pruning technique, specify the pruning parameters.
90
#
91
# In this example, we will prune at random 30% of the connections in
92
# the parameter named ``weight`` in the ``conv1`` layer.
93
# The module is passed as the first argument to the function; ``name``
94
# identifies the parameter within that module using its string identifier; and
95
# ``amount`` indicates either the percentage of connections to prune (if it
96
# is a float between 0. and 1.), or the absolute number of connections to
97
# prune (if it is a non-negative integer).
98
prune.random_unstructured(module, name="weight", amount=0.3)
99
100
######################################################################
101
# Pruning acts by removing ``weight`` from the parameters and replacing it with
102
# a new parameter called ``weight_orig`` (i.e. appending ``"_orig"`` to the
103
# initial parameter ``name``). ``weight_orig`` stores the unpruned version of
104
# the tensor. The ``bias`` was not pruned, so it will remain intact.
105
print(list(module.named_parameters()))
106
107
######################################################################
108
# The pruning mask generated by the pruning technique selected above is saved
109
# as a module buffer named ``weight_mask`` (i.e. appending ``"_mask"`` to the
110
# initial parameter ``name``).
111
print(list(module.named_buffers()))
112
113
######################################################################
114
# For the forward pass to work without modification, the ``weight`` attribute
115
# needs to exist. The pruning techniques implemented in
116
# ``torch.nn.utils.prune`` compute the pruned version of the weight (by
117
# combining the mask with the original parameter) and store them in the
118
# attribute ``weight``. Note, this is no longer a parameter of the ``module``,
119
# it is now simply an attribute.
120
print(module.weight)
121
122
######################################################################
123
# Finally, pruning is applied prior to each forward pass using PyTorch's
124
# ``forward_pre_hooks``. Specifically, when the ``module`` is pruned, as we
125
# have done here, it will acquire a ``forward_pre_hook`` for each parameter
126
# associated with it that gets pruned. In this case, since we have so far
127
# only pruned the original parameter named ``weight``, only one hook will be
128
# present.
129
print(module._forward_pre_hooks)
130
131
######################################################################
132
# For completeness, we can now prune the ``bias`` too, to see how the
133
# parameters, buffers, hooks, and attributes of the ``module`` change.
134
# Just for the sake of trying out another pruning technique, here we prune the
135
# 3 smallest entries in the bias by L1 norm, as implemented in the
136
# ``l1_unstructured`` pruning function.
137
prune.l1_unstructured(module, name="bias", amount=3)
138
139
######################################################################
140
# We now expect the named parameters to include both ``weight_orig`` (from
141
# before) and ``bias_orig``. The buffers will include ``weight_mask`` and
142
# ``bias_mask``. The pruned versions of the two tensors will exist as
143
# module attributes, and the module will now have two ``forward_pre_hooks``.
144
print(list(module.named_parameters()))
145
146
######################################################################
147
print(list(module.named_buffers()))
148
149
######################################################################
150
print(module.bias)
151
152
######################################################################
153
print(module._forward_pre_hooks)
154
155
######################################################################
156
# Iterative Pruning
157
# -----------------
158
#
159
# The same parameter in a module can be pruned multiple times, with the
160
# effect of the various pruning calls being equal to the combination of the
161
# various masks applied in series.
162
# The combination of a new mask with the old mask is handled by the
163
# ``PruningContainer``'s ``compute_mask`` method.
164
#
165
# Say, for example, that we now want to further prune ``module.weight``, this
166
# time using structured pruning along the 0th axis of the tensor (the 0th axis
167
# corresponds to the output channels of the convolutional layer and has
168
# dimensionality 6 for ``conv1``), based on the channels' L2 norm. This can be
169
# achieved using the ``ln_structured`` function, with ``n=2`` and ``dim=0``.
170
prune.ln_structured(module, name="weight", amount=0.5, n=2, dim=0)
171
172
# As we can verify, this will zero out all the connections corresponding to
173
# 50% (3 out of 6) of the channels, while preserving the action of the
174
# previous mask.
175
print(module.weight)
176
177
############################################################################
178
# The corresponding hook will now be of type
179
# ``torch.nn.utils.prune.PruningContainer``, and will store the history of
180
# pruning applied to the ``weight`` parameter.
181
for hook in module._forward_pre_hooks.values():
182
if hook._tensor_name == "weight": # select out the correct hook
183
break
184
185
print(list(hook)) # pruning history in the container
186
187
######################################################################
188
# Serializing a pruned model
189
# --------------------------
190
# All relevant tensors, including the mask buffers and the original parameters
191
# used to compute the pruned tensors are stored in the model's ``state_dict``
192
# and can therefore be easily serialized and saved, if needed.
193
print(model.state_dict().keys())
194
195
196
######################################################################
197
# Remove pruning re-parametrization
198
# ---------------------------------
199
#
200
# To make the pruning permanent, remove the re-parametrization in terms
201
# of ``weight_orig`` and ``weight_mask``, and remove the ``forward_pre_hook``,
202
# we can use the ``remove`` functionality from ``torch.nn.utils.prune``.
203
# Note that this doesn't undo the pruning, as if it never happened. It simply
204
# makes it permanent, instead, by reassigning the parameter ``weight`` to the
205
# model parameters, in its pruned version.
206
207
######################################################################
208
# Prior to removing the re-parametrization:
209
print(list(module.named_parameters()))
210
######################################################################
211
print(list(module.named_buffers()))
212
######################################################################
213
print(module.weight)
214
215
######################################################################
216
# After removing the re-parametrization:
217
prune.remove(module, 'weight')
218
print(list(module.named_parameters()))
219
######################################################################
220
print(list(module.named_buffers()))
221
222
######################################################################
223
# Pruning multiple parameters in a model
224
# --------------------------------------
225
#
226
# By specifying the desired pruning technique and parameters, we can easily
227
# prune multiple tensors in a network, perhaps according to their type, as we
228
# will see in this example.
229
230
new_model = LeNet()
231
for name, module in new_model.named_modules():
232
# prune 20% of connections in all 2D-conv layers
233
if isinstance(module, torch.nn.Conv2d):
234
prune.l1_unstructured(module, name='weight', amount=0.2)
235
# prune 40% of connections in all linear layers
236
elif isinstance(module, torch.nn.Linear):
237
prune.l1_unstructured(module, name='weight', amount=0.4)
238
239
print(dict(new_model.named_buffers()).keys()) # to verify that all masks exist
240
241
######################################################################
242
# Global pruning
243
# --------------
244
#
245
# So far, we only looked at what is usually referred to as "local" pruning,
246
# i.e. the practice of pruning tensors in a model one by one, by
247
# comparing the statistics (weight magnitude, activation, gradient, etc.) of
248
# each entry exclusively to the other entries in that tensor. However, a
249
# common and perhaps more powerful technique is to prune the model all at
250
# once, by removing (for example) the lowest 20% of connections across the
251
# whole model, instead of removing the lowest 20% of connections in each
252
# layer. This is likely to result in different pruning percentages per layer.
253
# Let's see how to do that using ``global_unstructured`` from
254
# ``torch.nn.utils.prune``.
255
256
model = LeNet()
257
258
parameters_to_prune = (
259
(model.conv1, 'weight'),
260
(model.conv2, 'weight'),
261
(model.fc1, 'weight'),
262
(model.fc2, 'weight'),
263
(model.fc3, 'weight'),
264
)
265
266
prune.global_unstructured(
267
parameters_to_prune,
268
pruning_method=prune.L1Unstructured,
269
amount=0.2,
270
)
271
272
######################################################################
273
# Now we can check the sparsity induced in every pruned parameter, which will
274
# not be equal to 20% in each layer. However, the global sparsity will be
275
# (approximately) 20%.
276
print(
277
"Sparsity in conv1.weight: {:.2f}%".format(
278
100. * float(torch.sum(model.conv1.weight == 0))
279
/ float(model.conv1.weight.nelement())
280
)
281
)
282
print(
283
"Sparsity in conv2.weight: {:.2f}%".format(
284
100. * float(torch.sum(model.conv2.weight == 0))
285
/ float(model.conv2.weight.nelement())
286
)
287
)
288
print(
289
"Sparsity in fc1.weight: {:.2f}%".format(
290
100. * float(torch.sum(model.fc1.weight == 0))
291
/ float(model.fc1.weight.nelement())
292
)
293
)
294
print(
295
"Sparsity in fc2.weight: {:.2f}%".format(
296
100. * float(torch.sum(model.fc2.weight == 0))
297
/ float(model.fc2.weight.nelement())
298
)
299
)
300
print(
301
"Sparsity in fc3.weight: {:.2f}%".format(
302
100. * float(torch.sum(model.fc3.weight == 0))
303
/ float(model.fc3.weight.nelement())
304
)
305
)
306
print(
307
"Global sparsity: {:.2f}%".format(
308
100. * float(
309
torch.sum(model.conv1.weight == 0)
310
+ torch.sum(model.conv2.weight == 0)
311
+ torch.sum(model.fc1.weight == 0)
312
+ torch.sum(model.fc2.weight == 0)
313
+ torch.sum(model.fc3.weight == 0)
314
)
315
/ float(
316
model.conv1.weight.nelement()
317
+ model.conv2.weight.nelement()
318
+ model.fc1.weight.nelement()
319
+ model.fc2.weight.nelement()
320
+ model.fc3.weight.nelement()
321
)
322
)
323
)
324
325
326
######################################################################
327
# Extending ``torch.nn.utils.prune`` with custom pruning functions
328
# ------------------------------------------------------------------
329
# To implement your own pruning function, you can extend the
330
# ``nn.utils.prune`` module by subclassing the ``BasePruningMethod``
331
# base class, the same way all other pruning methods do. The base class
332
# implements the following methods for you: ``__call__``, ``apply_mask``,
333
# ``apply``, ``prune``, and ``remove``. Beyond some special cases, you shouldn't
334
# have to reimplement these methods for your new pruning technique.
335
# You will, however, have to implement ``__init__`` (the constructor),
336
# and ``compute_mask`` (the instructions on how to compute the mask
337
# for the given tensor according to the logic of your pruning
338
# technique). In addition, you will have to specify which type of
339
# pruning this technique implements (supported options are ``global``,
340
# ``structured``, and ``unstructured``). This is needed to determine
341
# how to combine masks in the case in which pruning is applied
342
# iteratively. In other words, when pruning a prepruned parameter,
343
# the current pruning technique is expected to act on the unpruned
344
# portion of the parameter. Specifying the ``PRUNING_TYPE`` will
345
# enable the ``PruningContainer`` (which handles the iterative
346
# application of pruning masks) to correctly identify the slice of the
347
# parameter to prune.
348
#
349
# Let's assume, for example, that you want to implement a pruning
350
# technique that prunes every other entry in a tensor (or -- if the
351
# tensor has previously been pruned -- in the remaining unpruned
352
# portion of the tensor). This will be of ``PRUNING_TYPE='unstructured'``
353
# because it acts on individual connections in a layer and not on entire
354
# units/channels (``'structured'``), or across different parameters
355
# (``'global'``).
356
357
class FooBarPruningMethod(prune.BasePruningMethod):
358
"""Prune every other entry in a tensor
359
"""
360
PRUNING_TYPE = 'unstructured'
361
362
def compute_mask(self, t, default_mask):
363
mask = default_mask.clone()
364
mask.view(-1)[::2] = 0
365
return mask
366
367
######################################################################
368
# Now, to apply this to a parameter in an ``nn.Module``, you should
369
# also provide a simple function that instantiates the method and
370
# applies it.
371
def foobar_unstructured(module, name):
372
"""Prunes tensor corresponding to parameter called `name` in `module`
373
by removing every other entry in the tensors.
374
Modifies module in place (and also return the modified module)
375
by:
376
1) adding a named buffer called `name+'_mask'` corresponding to the
377
binary mask applied to the parameter `name` by the pruning method.
378
The parameter `name` is replaced by its pruned version, while the
379
original (unpruned) parameter is stored in a new parameter named
380
`name+'_orig'`.
381
382
Args:
383
module (nn.Module): module containing the tensor to prune
384
name (string): parameter name within `module` on which pruning
385
will act.
386
387
Returns:
388
module (nn.Module): modified (i.e. pruned) version of the input
389
module
390
391
Examples:
392
>>> m = nn.Linear(3, 4)
393
>>> foobar_unstructured(m, name='bias')
394
"""
395
FooBarPruningMethod.apply(module, name)
396
return module
397
398
######################################################################
399
# Let's try it out!
400
model = LeNet()
401
foobar_unstructured(model.fc3, name='bias')
402
403
print(model.fc3.bias_mask)
404
405