CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!

GitHub Repository: pytorch/tutorials
Path: blob/main/intermediate_source/fx_conv_bn_fuser.py
Views: 494
1
# -*- coding: utf-8 -*-
2
"""
3
(beta) Building a Convolution/Batch Norm fuser in FX
4
*******************************************************
5
**Author**: `Horace He <https://github.com/chillee>`_
6
7
In this tutorial, we are going to use FX, a toolkit for composable function
8
transformations of PyTorch, to do the following:
9
10
1) Find patterns of conv/batch norm in the data dependencies.
11
2) For the patterns found in 1), fold the batch norm statistics into the convolution weights.
12
13
Note that this optimization only works for models in inference mode (i.e. `mode.eval()`)
14
15
We will be building the fuser that exists here:
16
https://github.com/pytorch/pytorch/blob/orig/release/1.8/torch/fx/experimental/fuser.py
17
18
"""
19
20
21
######################################################################
22
# First, let's get some imports out of the way (we will be using all
23
# of these later in the code).
24
25
from typing import Type, Dict, Any, Tuple, Iterable
26
import copy
27
import torch.fx as fx
28
import torch
29
import torch.nn as nn
30
31
######################################################################
32
# For this tutorial, we are going to create a model consisting of convolutions
33
# and batch norms. Note that this model has some tricky components - some of
34
# the conv/batch norm patterns are hidden within Sequentials and one of the
35
# ``BatchNorms`` is wrapped in another Module.
36
37
class WrappedBatchNorm(nn.Module):
38
def __init__(self):
39
super().__init__()
40
self.mod = nn.BatchNorm2d(1)
41
def forward(self, x):
42
return self.mod(x)
43
44
class M(nn.Module):
45
def __init__(self):
46
super().__init__()
47
self.conv1 = nn.Conv2d(1, 1, 1)
48
self.bn1 = nn.BatchNorm2d(1)
49
self.conv2 = nn.Conv2d(1, 1, 1)
50
self.nested = nn.Sequential(
51
nn.BatchNorm2d(1),
52
nn.Conv2d(1, 1, 1),
53
)
54
self.wrapped = WrappedBatchNorm()
55
56
def forward(self, x):
57
x = self.conv1(x)
58
x = self.bn1(x)
59
x = self.conv2(x)
60
x = self.nested(x)
61
x = self.wrapped(x)
62
return x
63
64
model = M()
65
66
model.eval()
67
68
######################################################################
69
# Fusing Convolution with Batch Norm
70
# -----------------------------------------
71
# One of the primary challenges with trying to automatically fuse convolution
72
# and batch norm in PyTorch is that PyTorch does not provide an easy way of
73
# accessing the computational graph. FX resolves this problem by symbolically
74
# tracing the actual operations called, so that we can track the computations
75
# through the `forward` call, nested within Sequential modules, or wrapped in
76
# an user-defined module.
77
78
traced_model = torch.fx.symbolic_trace(model)
79
print(traced_model.graph)
80
81
######################################################################
82
# This gives us a graph representation of our model. Note that both the modules
83
# hidden within the sequential as well as the wrapped Module have been inlined
84
# into the graph. This is the default level of abstraction, but it can be
85
# configured by the pass writer. More information can be found at the FX
86
# overview https://pytorch.org/docs/master/fx.html#module-torch.fx
87
88
89
####################################
90
# Fusing Convolution with Batch Norm
91
# ----------------------------------
92
# Unlike some other fusions, fusion of convolution with batch norm does not
93
# require any new operators. Instead, as batch norm during inference
94
# consists of a pointwise add and multiply, these operations can be "baked"
95
# into the preceding convolution's weights. This allows us to remove the batch
96
# norm entirely from our model! Read
97
# https://nenadmarkus.com/p/fusing-batchnorm-and-conv/ for further details. The
98
# code here is copied from
99
# https://github.com/pytorch/pytorch/blob/orig/release/1.8/torch/nn/utils/fusion.py
100
# clarity purposes.
101
def fuse_conv_bn_eval(conv, bn):
102
"""
103
Given a conv Module `A` and an batch_norm module `B`, returns a conv
104
module `C` such that C(x) == B(A(x)) in inference mode.
105
"""
106
assert(not (conv.training or bn.training)), "Fusion only for eval!"
107
fused_conv = copy.deepcopy(conv)
108
109
fused_conv.weight, fused_conv.bias = \
110
fuse_conv_bn_weights(fused_conv.weight, fused_conv.bias,
111
bn.running_mean, bn.running_var, bn.eps, bn.weight, bn.bias)
112
113
return fused_conv
114
115
def fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b):
116
if conv_b is None:
117
conv_b = torch.zeros_like(bn_rm)
118
if bn_w is None:
119
bn_w = torch.ones_like(bn_rm)
120
if bn_b is None:
121
bn_b = torch.zeros_like(bn_rm)
122
bn_var_rsqrt = torch.rsqrt(bn_rv + bn_eps)
123
124
conv_w = conv_w * (bn_w * bn_var_rsqrt).reshape([-1] + [1] * (len(conv_w.shape) - 1))
125
conv_b = (conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b
126
127
return torch.nn.Parameter(conv_w), torch.nn.Parameter(conv_b)
128
129
130
####################################
131
# FX Fusion Pass
132
# ----------------------------------
133
# Now that we have our computational graph as well as a method for fusing
134
# convolution and batch norm, all that remains is to iterate over the FX graph
135
# and apply the desired fusions.
136
137
138
def _parent_name(target : str) -> Tuple[str, str]:
139
"""
140
Splits a ``qualname`` into parent path and last atom.
141
For example, `foo.bar.baz` -> (`foo.bar`, `baz`)
142
"""
143
*parent, name = target.rsplit('.', 1)
144
return parent[0] if parent else '', name
145
146
def replace_node_module(node: fx.Node, modules: Dict[str, Any], new_module: torch.nn.Module):
147
assert(isinstance(node.target, str))
148
parent_name, name = _parent_name(node.target)
149
setattr(modules[parent_name], name, new_module)
150
151
152
def fuse(model: torch.nn.Module) -> torch.nn.Module:
153
model = copy.deepcopy(model)
154
# The first step of most FX passes is to symbolically trace our model to
155
# obtain a `GraphModule`. This is a representation of our original model
156
# that is functionally identical to our original model, except that we now
157
# also have a graph representation of our forward pass.
158
fx_model: fx.GraphModule = fx.symbolic_trace(model)
159
modules = dict(fx_model.named_modules())
160
161
# The primary representation for working with FX are the `Graph` and the
162
# `Node`. Each `GraphModule` has a `Graph` associated with it - this
163
# `Graph` is also what generates `GraphModule.code`.
164
# The `Graph` itself is represented as a list of `Node` objects. Thus, to
165
# iterate through all of the operations in our graph, we iterate over each
166
# `Node` in our `Graph`.
167
for node in fx_model.graph.nodes:
168
# The FX IR contains several types of nodes, which generally represent
169
# call sites to modules, functions, or methods. The type of node is
170
# determined by `Node.op`.
171
if node.op != 'call_module': # If our current node isn't calling a Module then we can ignore it.
172
continue
173
# For call sites, `Node.target` represents the module/function/method
174
# that's being called. Here, we check `Node.target` to see if it's a
175
# batch norm module, and then check `Node.args[0].target` to see if the
176
# input `Node` is a convolution.
177
if type(modules[node.target]) is nn.BatchNorm2d and type(modules[node.args[0].target]) is nn.Conv2d:
178
if len(node.args[0].users) > 1: # Output of conv is used by other nodes
179
continue
180
conv = modules[node.args[0].target]
181
bn = modules[node.target]
182
fused_conv = fuse_conv_bn_eval(conv, bn)
183
replace_node_module(node.args[0], modules, fused_conv)
184
# As we've folded the batch nor into the conv, we need to replace all uses
185
# of the batch norm with the conv.
186
node.replace_all_uses_with(node.args[0])
187
# Now that all uses of the batch norm have been replaced, we can
188
# safely remove the batch norm.
189
fx_model.graph.erase_node(node)
190
fx_model.graph.lint()
191
# After we've modified our graph, we need to recompile our graph in order
192
# to keep the generated code in sync.
193
fx_model.recompile()
194
return fx_model
195
196
197
######################################################################
198
# .. note::
199
# We make some simplifications here for demonstration purposes, such as only
200
# matching 2D convolutions. View
201
# https://github.com/pytorch/pytorch/blob/master/torch/fx/experimental/fuser.py
202
# for a more usable pass.
203
204
######################################################################
205
# Testing out our Fusion Pass
206
# -----------------------------------------
207
# We can now run this fusion pass on our initial toy model and verify that our
208
# results are identical. In addition, we can print out the code for our fused
209
# model and verify that there are no more batch norms.
210
211
212
fused_model = fuse(model)
213
print(fused_model.code)
214
inp = torch.randn(5, 1, 1, 1)
215
torch.testing.assert_allclose(fused_model(inp), model(inp))
216
217
218
######################################################################
219
# Benchmarking our Fusion on ResNet18
220
# -----------------------------------
221
# We can test our fusion pass on a larger model like ResNet18 and see how much
222
# this pass improves inference performance.
223
import torchvision.models as models
224
import time
225
226
rn18 = models.resnet18()
227
rn18.eval()
228
229
inp = torch.randn(10, 3, 224, 224)
230
output = rn18(inp)
231
232
def benchmark(model, iters=20):
233
for _ in range(10):
234
model(inp)
235
begin = time.time()
236
for _ in range(iters):
237
model(inp)
238
return str(time.time()-begin)
239
240
fused_rn18 = fuse(rn18)
241
print("Unfused time: ", benchmark(rn18))
242
print("Fused time: ", benchmark(fused_rn18))
243
######################################################################
244
# As we previously saw, the output of our FX transformation is
245
# ("torchscriptable") PyTorch code, we can easily ``jit.script`` the output to try
246
# and increase our performance even more. In this way, our FX model
247
# transformation composes with TorchScript with no issues.
248
jit_rn18 = torch.jit.script(fused_rn18)
249
print("jit time: ", benchmark(jit_rn18))
250
251
252
############
253
# Conclusion
254
# ----------
255
# As we can see, using FX we can easily write static graph transformations on
256
# PyTorch code.
257
#
258
# Since FX is still in beta, we would be happy to hear any
259
# feedback you have about using it. Please feel free to use the
260
# PyTorch Forums (https://discuss.pytorch.org/) and the issue tracker
261
# (https://github.com/pytorch/pytorch/issues) to provide any feedback
262
# you might have.
263
264