From f7e015dac8d5159cbef261f2967bf270ee51fd35 Mon Sep 17 00:00:00 2001 From: Horace He Date: Thu, 11 Feb 2021 02:37:45 -0800 Subject: [PATCH 1/6] Added fuser tutorial --- intermediate_source/fx_conv_bn_fuser.py | 229 ++++++++++++++++++++++++ 1 file changed, 229 insertions(+) create mode 100644 intermediate_source/fx_conv_bn_fuser.py diff --git a/intermediate_source/fx_conv_bn_fuser.py b/intermediate_source/fx_conv_bn_fuser.py new file mode 100644 index 00000000000..33b38b83b5f --- /dev/null +++ b/intermediate_source/fx_conv_bn_fuser.py @@ -0,0 +1,229 @@ +# -*- coding: utf-8 -*- +""" +(beta) Building a Conv/BN fuser in FX +******************************************************* +**Author**: `Horace He `_ + +In this tutorial, we are going to use FX to do the following: + +1) Find patterns of conv/batch norm. +2) Fold the batch norm statistics into the convolution weights. + +Despite this being a fairly trivial graph rewrite, this has been surprisingly +difficult to do in PyTorch for quite some time. + +We will be building the fuser that exists here: +https://github.com/pytorch/pytorch/blob/master/torch/fx/experimental/fuser.py + +""" + + +###################################################################### +# First, let's get some imports out of the way (we will be using all +# of these later in the code). + +from typing import Type, Dict, Any, Tuple, Iterable +import copy +import torch.fx as fx +import torch +import torch.nn as nn + +###################################################################### +# For this tutorial, we are going to create a model consisting of convolutions +# and batch norms. Note that this model has some tricky components - some of +# the modules are hidden within sequentials and one of the modules is wrapped +# inside of another PyTorch module. + +class Wrapper(nn.Module): + def __init__(self): + super().__init__() + self.mod = nn.BatchNorm2d(1) + def forward(self, x): + return self.mod(x) + +class M(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(1, 1, 1) + self.bn1 = nn.BatchNorm2d(1) + self.conv2 = nn.Conv2d(1, 1, 1) + self.nested = nn.Sequential( + nn.BatchNorm2d(1), + nn.Conv2d(1, 1, 1), + ) + self.wrapped = Wrapper() + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.conv2(x) + x = self.nested(x) + x = self.wrapped(x) + return x + +model = M() + +model.eval() + +###################################################################### +# Fusing Convolution with Batch Norm +# ----------------------------------------- +# One of the primary challenges with trying to automatically fuse convolution +# and batch norm in PyTorch is that PyTorch does not provide an easy way of +# accessing the computational graph. FX resolves that problem. + +traced_model = torch.fx.symbolic_trace(model) +print(traced_model.graph) + +###################################################################### +# This gives us a graph representation of our model. Note that both the modules +# hidden within the sequential as well as the wrapped modue have been inlined +# into the graph. More information can be found at the FX documentation +# https://pytorch.org/docs/master/fx.html. + + +#################################### +# Fusing Convolution with Batch Norm +# ---------------------------------- +# Unlike some other fusions, fusion of convolution with batch norm does not +# require any additional kernels. Instead, as batch norm during inference +# consists of a pointwise add and multiply, these operations can be "baked" +# into the preceding convolution's weights. Read +# https://nenadmarkus.com/p/fusing-batchnorm-and-conv/ for further details. The +# code here is copied from +# https://github.com/pytorch/pytorch/blob/master/torch/nn/utils/fusion.py for +# clarity purposes. +def fuse_conv_bn_eval(conv, bn): + assert(not (conv.training or bn.training)), "Fusion only for eval!" + fused_conv = copy.deepcopy(conv) + + fused_conv.weight, fused_conv.bias = \ + fuse_conv_bn_weights(fused_conv.weight, fused_conv.bias, + bn.running_mean, bn.running_var, bn.eps, bn.weight, bn.bias) + + return fused_conv + +def fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b): + if conv_b is None: + conv_b = torch.zeros_like(bn_rm) + if bn_w is None: + bn_w = torch.ones_like(bn_rm) + if bn_b is None: + bn_b = torch.zeros_like(bn_rm) + bn_var_rsqrt = torch.rsqrt(bn_rv + bn_eps) + + conv_w = conv_w * (bn_w * bn_var_rsqrt).reshape([-1] + [1] * (len(conv_w.shape) - 1)) + conv_b = (conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b + + return torch.nn.Parameter(conv_w), torch.nn.Parameter(conv_b) + + +#################################### +# FX Fusion Pass +# ---------------------------------- +# Now that we have our computational graph as well as a method for fusing +# convolution and batch norm, all that remains is to iterate over the FX graph +# and apply the desired fusions. + + +def _parent_name(target : str) -> Tuple[str, str]: + """ + Splits a qualname into parent path and last atom. + For example, `foo.bar.baz` -> (`foo.bar`, `baz`) + """ + *parent, name = target.rsplit('.', 1) + return parent[0] if parent else '', name + +def replace_node_module(node: fx.Node, modules: Dict[str, Any], new_module: torch.nn.Module): + assert(isinstance(node.target, str)) + parent_name, name = _parent_name(node.target) + setattr(modules[parent_name], name, new_module) + + +def fuse(model: torch.nn.Module) -> torch.nn.Module: + model = copy.deepcopy(model) + fx_model = fx.symbolic_trace(model) # We symbolically trace our model here + modules = dict(fx_model.named_modules()) + + for node in fx_model.graph.nodes: + if node.op != 'call_module': # If our current node isn't calling a module then we can ignore it. + continue + if type(modules[node.target]) is nn.BatchNorm2d and type(modules[node.args[0].target]) is nn.Conv2d: + if len(node.args[0].users) > 1: # Output of conv is used by other nodes + continue + conv = modules[node.args[0].target] + bn = modules[node.target] + fused_conv = fuse_conv_bn_eval(conv, bn) + replace_node_module(node.args[0], modules, fused_conv) + # As we've folded the batch nor into the conv, we need to replace all uses + # of the batch norm with the conv. + node.replace_all_uses_with(node.args[0]) + # Now that all uses of the batch norm have been replaced, we can + # safely remove the batch norm. + fx_model.graph.erase_node(node) + fx_model.graph.lint() + fx_model.recompile() + return fx_model + + +###################################################################### +# .. note:: +# We make some simplifications here for demonstration purposes, such as only +# matching 2D convolutions. View +# https://github.com/pytorch/pytorch/blob/master/torch/fx/experimental/fuser.py +# for a more usable pass. + +###################################################################### +# Testing out our Fusion Pass +# ----------------------------------------- +# We can now run this fusion pass on our initial toy model and verify that our +# results are identical. + + +fused_model = fuse(model) +inp = torch.randn(5, 1, 1, 1) +assert(abs(fused_model(inp).sum() - model(inp).sum()) < 1e-5) + + +# Benchmarking our Fusion on ResNet18 +# ---------- +# We can test our fusion pass on a larger model like ResNet18 and see how much +# this pass improves inference performance. +import torchvision.models as models +import time + +rn18 = models.resnet18() +rn18.eval() + +inp = torch.randn(10, 3, 224, 224) +output = rn18(inp) + +def benchmark(model, iters=20): + for _ in range(10): + model(inp) + begin = time.time() + for _ in range(iters): + model(inp) + return str(time.time()-begin) + +fused_rn18 = fuse(rn18) +print("Unfused time: ", benchmark(rn18)) +print("Fused time: ", benchmark(fused_rn18)) +# As FX is a source to source transformation, our transformation can still +# compose with Torchscript with no issues. So we can still script our model to +# try and increase our performance more. +jit_rn18 = torch.jit.script(fused_rn18) +print("jit time: ", benchmark(jit_rn18)) + + + +# Conclusion +# ---------- +# As we can see, using FX we can easily write static graph transformations on +# PyTorch code. +# +# Since FX is still in beta, we would be happy to hear any +# feedback you have about using it. Please feel free to use the +# PyTorch Forums (https://discuss.pytorch.org/) and the issue tracker +# (https://github.com/pytorch/pytorch/issues) to provide any feedback +# you might have. \ No newline at end of file From e5319e3231a188c865d173ea0c8e7ed460629987 Mon Sep 17 00:00:00 2001 From: Horace He Date: Thu, 11 Feb 2021 10:01:00 -0800 Subject: [PATCH 2/6] updated index.rst --- index.rst | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/index.rst b/index.rst index e66a7e3e615..9ab94c07e12 100644 --- a/index.rst +++ b/index.rst @@ -215,6 +215,15 @@ Welcome to PyTorch Tutorials :link: advanced/super_resolution_with_onnxruntime.html :tags: Production +.. Code Transformations with FX +.. customcarditem:: + :header: Building a Convolution/Batch Norm fuser in FX + :card_description: Build a simple FX interpreter to record the runtime of op, module, and function calls and report statistics + :image: _static/img/thumbnails/cropped/Deploying-PyTorch-in-Python-via-a-REST-API-with-Flask.png + :link: intermediate/fx_conv_bn_fuser.html + :tags: FX + + .. Frontend APIs .. customcarditem:: @@ -575,4 +584,4 @@ Additional Resources beginner/deeplabv3_on_ios.html beginner/deeplabv3_on_android.html - + From a49a8dff5a9bbb04a16b4a5a77f7ed7a5bccf8fe Mon Sep 17 00:00:00 2001 From: Horace He Date: Thu, 11 Feb 2021 10:40:49 -0800 Subject: [PATCH 3/6] fixed conclusion --- intermediate_source/fx_conv_bn_fuser.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/intermediate_source/fx_conv_bn_fuser.py b/intermediate_source/fx_conv_bn_fuser.py index 33b38b83b5f..e2539f02cdc 100644 --- a/intermediate_source/fx_conv_bn_fuser.py +++ b/intermediate_source/fx_conv_bn_fuser.py @@ -216,7 +216,7 @@ def benchmark(model, iters=20): print("jit time: ", benchmark(jit_rn18)) - +############ # Conclusion # ---------- # As we can see, using FX we can easily write static graph transformations on From 439139cd7cb9c0ca5313c6c4d63d2b014a198498 Mon Sep 17 00:00:00 2001 From: Horace He Date: Sat, 13 Feb 2021 18:38:00 -0800 Subject: [PATCH 4/6] responded to some comments --- intermediate_source/fx_conv_bn_fuser.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/intermediate_source/fx_conv_bn_fuser.py b/intermediate_source/fx_conv_bn_fuser.py index e2539f02cdc..cd8fd8680ef 100644 --- a/intermediate_source/fx_conv_bn_fuser.py +++ b/intermediate_source/fx_conv_bn_fuser.py @@ -6,11 +6,10 @@ In this tutorial, we are going to use FX to do the following: -1) Find patterns of conv/batch norm. +1) Find patterns of conv/batch norm in the data dependencies. 2) Fold the batch norm statistics into the convolution weights. -Despite this being a fairly trivial graph rewrite, this has been surprisingly -difficult to do in PyTorch for quite some time. +Note that this optimization only works for models in inference mode. We will be building the fuser that exists here: https://github.com/pytorch/pytorch/blob/master/torch/fx/experimental/fuser.py @@ -31,10 +30,10 @@ ###################################################################### # For this tutorial, we are going to create a model consisting of convolutions # and batch norms. Note that this model has some tricky components - some of -# the modules are hidden within sequentials and one of the modules is wrapped -# inside of another PyTorch module. +# the conv/batch norm patterns are hidden within Sequentials and one of the +# BatchNorms is wrapped in another module. -class Wrapper(nn.Module): +class WrappedBatchnorm(nn.Module): def __init__(self): super().__init__() self.mod = nn.BatchNorm2d(1) @@ -51,7 +50,7 @@ def __init__(self): nn.BatchNorm2d(1), nn.Conv2d(1, 1, 1), ) - self.wrapped = Wrapper() + self.wrapped = WrappedBatchnorm() def forward(self, x): x = self.conv1(x) @@ -77,7 +76,7 @@ def forward(self, x): ###################################################################### # This gives us a graph representation of our model. Note that both the modules -# hidden within the sequential as well as the wrapped modue have been inlined +# hidden within the sequential as well as the wrapped module have been inlined # into the graph. More information can be found at the FX documentation # https://pytorch.org/docs/master/fx.html. @@ -86,7 +85,7 @@ def forward(self, x): # Fusing Convolution with Batch Norm # ---------------------------------- # Unlike some other fusions, fusion of convolution with batch norm does not -# require any additional kernels. Instead, as batch norm during inference +# require any new operators. Instead, as batch norm during inference # consists of a pointwise add and multiply, these operations can be "baked" # into the preceding convolution's weights. Read # https://nenadmarkus.com/p/fusing-batchnorm-and-conv/ for further details. The @@ -162,6 +161,8 @@ def fuse(model: torch.nn.Module) -> torch.nn.Module: # safely remove the batch norm. fx_model.graph.erase_node(node) fx_model.graph.lint() + # After we've modified our graph, we need to recompile our graph in order + # to keep the generated code in sync. fx_model.recompile() return fx_model @@ -182,9 +183,10 @@ def fuse(model: torch.nn.Module) -> torch.nn.Module: fused_model = fuse(model) inp = torch.randn(5, 1, 1, 1) -assert(abs(fused_model(inp).sum() - model(inp).sum()) < 1e-5) +torch.testing.assert_allclose(fused_model(inp), model(inp)) +###################################################################### # Benchmarking our Fusion on ResNet18 # ---------- # We can test our fusion pass on a larger model like ResNet18 and see how much From a9796b4d05a9dc89bc551e7ea499ef02fd0a46cf Mon Sep 17 00:00:00 2001 From: Horace He Date: Tue, 16 Feb 2021 02:38:45 -0800 Subject: [PATCH 5/6] responded to comments --- index.rst | 2 +- intermediate_source/fx_conv_bn_fuser.py | 68 ++++++++++++++++++------- 2 files changed, 50 insertions(+), 20 deletions(-) diff --git a/index.rst b/index.rst index 9ab94c07e12..c3946660454 100644 --- a/index.rst +++ b/index.rst @@ -218,7 +218,7 @@ Welcome to PyTorch Tutorials .. Code Transformations with FX .. customcarditem:: :header: Building a Convolution/Batch Norm fuser in FX - :card_description: Build a simple FX interpreter to record the runtime of op, module, and function calls and report statistics + :card_description: Build a simple FX pass that fuses batch norm into convolution to improve performance during inference. :image: _static/img/thumbnails/cropped/Deploying-PyTorch-in-Python-via-a-REST-API-with-Flask.png :link: intermediate/fx_conv_bn_fuser.html :tags: FX diff --git a/intermediate_source/fx_conv_bn_fuser.py b/intermediate_source/fx_conv_bn_fuser.py index cd8fd8680ef..44472e4c36b 100644 --- a/intermediate_source/fx_conv_bn_fuser.py +++ b/intermediate_source/fx_conv_bn_fuser.py @@ -1,18 +1,18 @@ # -*- coding: utf-8 -*- """ -(beta) Building a Conv/BN fuser in FX +(beta) Building a Convolution/Batch Norm fuser in FX ******************************************************* **Author**: `Horace He `_ -In this tutorial, we are going to use FX to do the following: +In this tutorial, we are going to use FX, a Py to do the following: 1) Find patterns of conv/batch norm in the data dependencies. -2) Fold the batch norm statistics into the convolution weights. +2) For the patterns found in 1), fold the batch norm statistics into the convolution weights. -Note that this optimization only works for models in inference mode. +Note that this optimization only works for models in inference mode (i.e. `mode.eval()`) We will be building the fuser that exists here: -https://github.com/pytorch/pytorch/blob/master/torch/fx/experimental/fuser.py +https://github.com/pytorch/pytorch/blob/orig/release/1.8/torch/fx/experimental/fuser.py """ @@ -31,9 +31,9 @@ # For this tutorial, we are going to create a model consisting of convolutions # and batch norms. Note that this model has some tricky components - some of # the conv/batch norm patterns are hidden within Sequentials and one of the -# BatchNorms is wrapped in another module. +# BatchNorms is wrapped in another Module. -class WrappedBatchnorm(nn.Module): +class WrappedBatchNorm(nn.Module): def __init__(self): super().__init__() self.mod = nn.BatchNorm2d(1) @@ -69,16 +69,20 @@ def forward(self, x): # ----------------------------------------- # One of the primary challenges with trying to automatically fuse convolution # and batch norm in PyTorch is that PyTorch does not provide an easy way of -# accessing the computational graph. FX resolves that problem. +# accessing the computational graph. FX resolves this problem by symbolically +# tracing the actual operations called, so that we can track the computations +# through the `forward` call, nested within Sequential modules, or wrapped in +# an user-defined module. traced_model = torch.fx.symbolic_trace(model) print(traced_model.graph) ###################################################################### # This gives us a graph representation of our model. Note that both the modules -# hidden within the sequential as well as the wrapped module have been inlined -# into the graph. More information can be found at the FX documentation -# https://pytorch.org/docs/master/fx.html. +# hidden within the sequential as well as the wrapped Module have been inlined +# into the graph. This is the default level of abstraction, but it can be +# configured by the pass writer. More information can be found at the FX +# overview https://pytorch.org/docs/master/fx.html#module-torch.fx #################################### @@ -87,12 +91,17 @@ def forward(self, x): # Unlike some other fusions, fusion of convolution with batch norm does not # require any new operators. Instead, as batch norm during inference # consists of a pointwise add and multiply, these operations can be "baked" -# into the preceding convolution's weights. Read +# into the preceding convolution's weights. This allows us to remove the batch +# norm entirely from our model! Read # https://nenadmarkus.com/p/fusing-batchnorm-and-conv/ for further details. The # code here is copied from -# https://github.com/pytorch/pytorch/blob/master/torch/nn/utils/fusion.py for +# https://github.com/pytorch/pytorch/blob/orig/release/1.8/torch/nn/utils/fusion.py # clarity purposes. def fuse_conv_bn_eval(conv, bn): + """ + Given a conv Module `A` and an batch_norm module `B`, returns a conv + module `C` such that C(x) == B(A(x)) in inference mode. + """ assert(not (conv.training or bn.training)), "Fusion only for eval!" fused_conv = copy.deepcopy(conv) @@ -141,12 +150,29 @@ def replace_node_module(node: fx.Node, modules: Dict[str, Any], new_module: torc def fuse(model: torch.nn.Module) -> torch.nn.Module: model = copy.deepcopy(model) - fx_model = fx.symbolic_trace(model) # We symbolically trace our model here + # The first step of most FX passes is to symbolically trace our model to + # obtain a `GraphModule`. This is a representation of our original model + # that is functionally identical to our original model, except that we now + # also have a graph representation of our forward pass. + fx_model: fx.GraphModule = fx.symbolic_trace(model) modules = dict(fx_model.named_modules()) + # The primary representation for working with FX are the `Graph` and the + # `Node`. Each `GraphModule` has a `Graph` associated with it - this + # `Graph` is also what generates `GraphModule.code`. + # The `Graph` itself is represented as a list of `Node` objects. Thus, to + # iterate through all of the operations in our graph, we iterate over each + # `Node` in our `Graph`. for node in fx_model.graph.nodes: - if node.op != 'call_module': # If our current node isn't calling a module then we can ignore it. + # The FX IR contains several types of nodes, which generally represent + # call sites to modules, functions, or methods. The type of node is + # determined by `Node.op`. + if node.op != 'call_module': # If our current node isn't calling a Module then we can ignore it. continue + # For call sites, `Node.target` represents the module/function/method + # that's being called. Here, we check `Node.target` to see if it's a + # batch norm module, and then check `Node.args[0].target` to see if the + # input `Node` is a convolution. if type(modules[node.target]) is nn.BatchNorm2d and type(modules[node.args[0].target]) is nn.Conv2d: if len(node.args[0].users) > 1: # Output of conv is used by other nodes continue @@ -178,10 +204,12 @@ def fuse(model: torch.nn.Module) -> torch.nn.Module: # Testing out our Fusion Pass # ----------------------------------------- # We can now run this fusion pass on our initial toy model and verify that our -# results are identical. +# results are identical. In addition, we can print out the code for our fused +# model and verify that there are no more batch norms. fused_model = fuse(model) +print(fused_model.code) inp = torch.randn(5, 1, 1, 1) torch.testing.assert_allclose(fused_model(inp), model(inp)) @@ -211,9 +239,11 @@ def benchmark(model, iters=20): fused_rn18 = fuse(rn18) print("Unfused time: ", benchmark(rn18)) print("Fused time: ", benchmark(fused_rn18)) -# As FX is a source to source transformation, our transformation can still -# compose with Torchscript with no issues. So we can still script our model to -# try and increase our performance more. +###################################################################### +# As we previously saw, the output of FX is (Torchscriptable) PyTorch code, we +# can easily `jit.script` the output to try and increase our performance even +# more. In this way, our FX model transformation composes with Torchscript with +# no issues. jit_rn18 = torch.jit.script(fused_rn18) print("jit time: ", benchmark(jit_rn18)) From d5cbcdcaa94720fa9d3ca8168e40fe972c6c0488 Mon Sep 17 00:00:00 2001 From: Horace He Date: Tue, 16 Feb 2021 17:30:26 -0800 Subject: [PATCH 6/6] respond --- intermediate_source/fx_conv_bn_fuser.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/intermediate_source/fx_conv_bn_fuser.py b/intermediate_source/fx_conv_bn_fuser.py index 44472e4c36b..93b89c08fec 100644 --- a/intermediate_source/fx_conv_bn_fuser.py +++ b/intermediate_source/fx_conv_bn_fuser.py @@ -4,7 +4,8 @@ ******************************************************* **Author**: `Horace He `_ -In this tutorial, we are going to use FX, a Py to do the following: +In this tutorial, we are going to use FX, a toolkit for composable function +transformations of PyTorch, to do the following: 1) Find patterns of conv/batch norm in the data dependencies. 2) For the patterns found in 1), fold the batch norm statistics into the convolution weights. @@ -240,10 +241,10 @@ def benchmark(model, iters=20): print("Unfused time: ", benchmark(rn18)) print("Fused time: ", benchmark(fused_rn18)) ###################################################################### -# As we previously saw, the output of FX is (Torchscriptable) PyTorch code, we -# can easily `jit.script` the output to try and increase our performance even -# more. In this way, our FX model transformation composes with Torchscript with -# no issues. +# As we previously saw, the output of our FX transformation is +# (Torchscriptable) PyTorch code, we can easily `jit.script` the output to try +# and increase our performance even more. In this way, our FX model +# transformation composes with Torchscript with no issues. jit_rn18 = torch.jit.script(fused_rn18) print("jit time: ", benchmark(jit_rn18))