Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Path: blob/main/intermediate_source/fx_conv_bn_fuser.py
Views: 712
# -*- coding: utf-8 -*-1"""2(beta) Building a Convolution/Batch Norm fuser in FX3*******************************************************4**Author**: `Horace He <https://github.com/chillee>`_56In this tutorial, we are going to use FX, a toolkit for composable function7transformations of PyTorch, to do the following:891) Find patterns of conv/batch norm in the data dependencies.102) For the patterns found in 1), fold the batch norm statistics into the convolution weights.1112Note that this optimization only works for models in inference mode (i.e. `mode.eval()`)1314We will be building the fuser that exists here:15https://github.com/pytorch/pytorch/blob/orig/release/1.8/torch/fx/experimental/fuser.py1617"""181920######################################################################21# First, let's get some imports out of the way (we will be using all22# of these later in the code).2324from typing import Type, Dict, Any, Tuple, Iterable25import copy26import torch.fx as fx27import torch28import torch.nn as nn2930######################################################################31# For this tutorial, we are going to create a model consisting of convolutions32# and batch norms. Note that this model has some tricky components - some of33# the conv/batch norm patterns are hidden within Sequentials and one of the34# ``BatchNorms`` is wrapped in another Module.3536class WrappedBatchNorm(nn.Module):37def __init__(self):38super().__init__()39self.mod = nn.BatchNorm2d(1)40def forward(self, x):41return self.mod(x)4243class M(nn.Module):44def __init__(self):45super().__init__()46self.conv1 = nn.Conv2d(1, 1, 1)47self.bn1 = nn.BatchNorm2d(1)48self.conv2 = nn.Conv2d(1, 1, 1)49self.nested = nn.Sequential(50nn.BatchNorm2d(1),51nn.Conv2d(1, 1, 1),52)53self.wrapped = WrappedBatchNorm()5455def forward(self, x):56x = self.conv1(x)57x = self.bn1(x)58x = self.conv2(x)59x = self.nested(x)60x = self.wrapped(x)61return x6263model = M()6465model.eval()6667######################################################################68# Fusing Convolution with Batch Norm69# -----------------------------------------70# One of the primary challenges with trying to automatically fuse convolution71# and batch norm in PyTorch is that PyTorch does not provide an easy way of72# accessing the computational graph. FX resolves this problem by symbolically73# tracing the actual operations called, so that we can track the computations74# through the `forward` call, nested within Sequential modules, or wrapped in75# an user-defined module.7677traced_model = torch.fx.symbolic_trace(model)78print(traced_model.graph)7980######################################################################81# This gives us a graph representation of our model. Note that both the modules82# hidden within the sequential as well as the wrapped Module have been inlined83# into the graph. This is the default level of abstraction, but it can be84# configured by the pass writer. More information can be found at the FX85# overview https://pytorch.org/docs/master/fx.html#module-torch.fx868788####################################89# Fusing Convolution with Batch Norm90# ----------------------------------91# Unlike some other fusions, fusion of convolution with batch norm does not92# require any new operators. Instead, as batch norm during inference93# consists of a pointwise add and multiply, these operations can be "baked"94# into the preceding convolution's weights. This allows us to remove the batch95# norm entirely from our model! Read96# https://nenadmarkus.com/p/fusing-batchnorm-and-conv/ for further details. The97# code here is copied from98# https://github.com/pytorch/pytorch/blob/orig/release/1.8/torch/nn/utils/fusion.py99# clarity purposes.100def fuse_conv_bn_eval(conv, bn):101"""102Given a conv Module `A` and an batch_norm module `B`, returns a conv103module `C` such that C(x) == B(A(x)) in inference mode.104"""105assert(not (conv.training or bn.training)), "Fusion only for eval!"106fused_conv = copy.deepcopy(conv)107108fused_conv.weight, fused_conv.bias = \109fuse_conv_bn_weights(fused_conv.weight, fused_conv.bias,110bn.running_mean, bn.running_var, bn.eps, bn.weight, bn.bias)111112return fused_conv113114def fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b):115if conv_b is None:116conv_b = torch.zeros_like(bn_rm)117if bn_w is None:118bn_w = torch.ones_like(bn_rm)119if bn_b is None:120bn_b = torch.zeros_like(bn_rm)121bn_var_rsqrt = torch.rsqrt(bn_rv + bn_eps)122123conv_w = conv_w * (bn_w * bn_var_rsqrt).reshape([-1] + [1] * (len(conv_w.shape) - 1))124conv_b = (conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b125126return torch.nn.Parameter(conv_w), torch.nn.Parameter(conv_b)127128129####################################130# FX Fusion Pass131# ----------------------------------132# Now that we have our computational graph as well as a method for fusing133# convolution and batch norm, all that remains is to iterate over the FX graph134# and apply the desired fusions.135136137def _parent_name(target : str) -> Tuple[str, str]:138"""139Splits a ``qualname`` into parent path and last atom.140For example, `foo.bar.baz` -> (`foo.bar`, `baz`)141"""142*parent, name = target.rsplit('.', 1)143return parent[0] if parent else '', name144145def replace_node_module(node: fx.Node, modules: Dict[str, Any], new_module: torch.nn.Module):146assert(isinstance(node.target, str))147parent_name, name = _parent_name(node.target)148setattr(modules[parent_name], name, new_module)149150151def fuse(model: torch.nn.Module) -> torch.nn.Module:152model = copy.deepcopy(model)153# The first step of most FX passes is to symbolically trace our model to154# obtain a `GraphModule`. This is a representation of our original model155# that is functionally identical to our original model, except that we now156# also have a graph representation of our forward pass.157fx_model: fx.GraphModule = fx.symbolic_trace(model)158modules = dict(fx_model.named_modules())159160# The primary representation for working with FX are the `Graph` and the161# `Node`. Each `GraphModule` has a `Graph` associated with it - this162# `Graph` is also what generates `GraphModule.code`.163# The `Graph` itself is represented as a list of `Node` objects. Thus, to164# iterate through all of the operations in our graph, we iterate over each165# `Node` in our `Graph`.166for node in fx_model.graph.nodes:167# The FX IR contains several types of nodes, which generally represent168# call sites to modules, functions, or methods. The type of node is169# determined by `Node.op`.170if node.op != 'call_module': # If our current node isn't calling a Module then we can ignore it.171continue172# For call sites, `Node.target` represents the module/function/method173# that's being called. Here, we check `Node.target` to see if it's a174# batch norm module, and then check `Node.args[0].target` to see if the175# input `Node` is a convolution.176if type(modules[node.target]) is nn.BatchNorm2d and type(modules[node.args[0].target]) is nn.Conv2d:177if len(node.args[0].users) > 1: # Output of conv is used by other nodes178continue179conv = modules[node.args[0].target]180bn = modules[node.target]181fused_conv = fuse_conv_bn_eval(conv, bn)182replace_node_module(node.args[0], modules, fused_conv)183# As we've folded the batch nor into the conv, we need to replace all uses184# of the batch norm with the conv.185node.replace_all_uses_with(node.args[0])186# Now that all uses of the batch norm have been replaced, we can187# safely remove the batch norm.188fx_model.graph.erase_node(node)189fx_model.graph.lint()190# After we've modified our graph, we need to recompile our graph in order191# to keep the generated code in sync.192fx_model.recompile()193return fx_model194195196######################################################################197# .. note::198# We make some simplifications here for demonstration purposes, such as only199# matching 2D convolutions. View200# https://github.com/pytorch/pytorch/blob/master/torch/fx/experimental/fuser.py201# for a more usable pass.202203######################################################################204# Testing out our Fusion Pass205# -----------------------------------------206# We can now run this fusion pass on our initial toy model and verify that our207# results are identical. In addition, we can print out the code for our fused208# model and verify that there are no more batch norms.209210211fused_model = fuse(model)212print(fused_model.code)213inp = torch.randn(5, 1, 1, 1)214torch.testing.assert_allclose(fused_model(inp), model(inp))215216217######################################################################218# Benchmarking our Fusion on ResNet18219# -----------------------------------220# We can test our fusion pass on a larger model like ResNet18 and see how much221# this pass improves inference performance.222import torchvision.models as models223import time224225rn18 = models.resnet18()226rn18.eval()227228inp = torch.randn(10, 3, 224, 224)229output = rn18(inp)230231def benchmark(model, iters=20):232for _ in range(10):233model(inp)234begin = time.time()235for _ in range(iters):236model(inp)237return str(time.time()-begin)238239fused_rn18 = fuse(rn18)240print("Unfused time: ", benchmark(rn18))241print("Fused time: ", benchmark(fused_rn18))242######################################################################243# As we previously saw, the output of our FX transformation is244# ("torchscriptable") PyTorch code, we can easily ``jit.script`` the output to try245# and increase our performance even more. In this way, our FX model246# transformation composes with TorchScript with no issues.247jit_rn18 = torch.jit.script(fused_rn18)248print("jit time: ", benchmark(jit_rn18))249250251############252# Conclusion253# ----------254# As we can see, using FX we can easily write static graph transformations on255# PyTorch code.256#257# Since FX is still in beta, we would be happy to hear any258# feedback you have about using it. Please feel free to use the259# PyTorch Forums (https://discuss.pytorch.org/) and the issue tracker260# (https://github.com/pytorch/pytorch/issues) to provide any feedback261# you might have.262263264