CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.

GitHub Repository: pytorch/tutorials
Path: blob/main/intermediate_source/fx_profiling_tutorial.py
Views: 712
1
# -*- coding: utf-8 -*-
2
"""
3
(beta) Building a Simple CPU Performance Profiler with FX
4
*********************************************************
5
**Author**: `James Reed <https://github.com/jamesr66a>`_
6
7
In this tutorial, we are going to use FX to do the following:
8
9
1) Capture PyTorch Python code in a way that we can inspect and gather
10
statistics about the structure and execution of the code
11
2) Build out a small class that will serve as a simple performance "profiler",
12
collecting runtime statistics about each part of the model from actual
13
runs.
14
15
"""
16
17
######################################################################
18
# For this tutorial, we are going to use the torchvision ResNet18 model
19
# for demonstration purposes.
20
21
import torch
22
import torch.fx
23
import torchvision.models as models
24
25
rn18 = models.resnet18()
26
rn18.eval()
27
28
######################################################################
29
# Now that we have our model, we want to inspect deeper into its
30
# performance. That is, for the following invocation, which parts
31
# of the model are taking the longest?
32
input = torch.randn(5, 3, 224, 224)
33
output = rn18(input)
34
35
######################################################################
36
# A common way of answering that question is to go through the program
37
# source, add code that collects timestamps at various points in the
38
# program, and compare the difference between those timestamps to see
39
# how long the regions between the timestamps take.
40
#
41
# That technique is certainly applicable to PyTorch code, however it
42
# would be nicer if we didn't have to copy over model code and edit it,
43
# especially code we haven't written (like this torchvision model).
44
# Instead, we are going to use FX to automate this "instrumentation"
45
# process without needing to modify any source.
46
47
######################################################################
48
# First, let's get some imports out of the way (we will be using all
49
# of these later in the code).
50
51
import statistics, tabulate, time
52
from typing import Any, Dict, List
53
from torch.fx import Interpreter
54
55
######################################################################
56
# .. note::
57
# ``tabulate`` is an external library that is not a dependency of PyTorch.
58
# We will be using it to more easily visualize performance data. Please
59
# make sure you've installed it from your favorite Python package source.
60
61
######################################################################
62
# Capturing the Model with Symbolic Tracing
63
# -----------------------------------------
64
# Next, we are going to use FX's symbolic tracing mechanism to capture
65
# the definition of our model in a data structure we can manipulate
66
# and examine.
67
68
traced_rn18 = torch.fx.symbolic_trace(rn18)
69
print(traced_rn18.graph)
70
71
######################################################################
72
# This gives us a Graph representation of the ResNet18 model. A Graph
73
# consists of a series of Nodes connected to each other. Each Node
74
# represents a call-site in the Python code (whether to a function,
75
# a module, or a method) and the edges (represented as ``args`` and ``kwargs``
76
# on each node) represent the values passed between these call-sites. More
77
# information about the Graph representation and the rest of FX's APIs ca
78
# be found at the FX documentation https://pytorch.org/docs/master/fx.html.
79
80
81
######################################################################
82
# Creating a Profiling Interpreter
83
# --------------------------------
84
# Next, we are going to create a class that inherits from ``torch.fx.Interpreter``.
85
# Though the ``GraphModule`` that ``symbolic_trace`` produces compiles Python code
86
# that is run when you call a ``GraphModule``, an alternative way to run a
87
# ``GraphModule`` is by executing each ``Node`` in the ``Graph`` one by one. That is
88
# the functionality that ``Interpreter`` provides: It interprets the graph node-
89
# by-node.
90
#
91
# By inheriting from ``Interpreter``, we can override various functionality and
92
# install the profiling behavior we want. The goal is to have an object to which
93
# we can pass a model, invoke the model 1 or more times, then get statistics about
94
# how long the model and each part of the model took during those runs.
95
#
96
# Let's define our ``ProfilingInterpreter`` class:
97
98
class ProfilingInterpreter(Interpreter):
99
def __init__(self, mod : torch.nn.Module):
100
# Rather than have the user symbolically trace their model,
101
# we're going to do it in the constructor. As a result, the
102
# user can pass in any ``Module`` without having to worry about
103
# symbolic tracing APIs
104
gm = torch.fx.symbolic_trace(mod)
105
super().__init__(gm)
106
107
# We are going to store away two things here:
108
#
109
# 1. A list of total runtimes for ``mod``. In other words, we are
110
# storing away the time ``mod(...)`` took each time this
111
# interpreter is called.
112
self.total_runtime_sec : List[float] = []
113
# 2. A map from ``Node`` to a list of times (in seconds) that
114
# node took to run. This can be seen as similar to (1) but
115
# for specific sub-parts of the model.
116
self.runtimes_sec : Dict[torch.fx.Node, List[float]] = {}
117
118
######################################################################
119
# Next, let's override our first method: ``run()``. ``Interpreter``'s ``run``
120
# method is the top-level entry point for execution of the model. We will
121
# want to intercept this so that we can record the total runtime of the
122
# model.
123
124
def run(self, *args) -> Any:
125
# Record the time we started running the model
126
t_start = time.time()
127
# Run the model by delegating back into Interpreter.run()
128
return_val = super().run(*args)
129
# Record the time we finished running the model
130
t_end = time.time()
131
# Store the total elapsed time this model execution took in the
132
# ``ProfilingInterpreter``
133
self.total_runtime_sec.append(t_end - t_start)
134
return return_val
135
136
######################################################################
137
# Now, let's override ``run_node``. ``Interpreter`` calls ``run_node`` each
138
# time it executes a single node. We will intercept this so that we
139
# can measure and record the time taken for each individual call in
140
# the model.
141
142
def run_node(self, n : torch.fx.Node) -> Any:
143
# Record the time we started running the op
144
t_start = time.time()
145
# Run the op by delegating back into Interpreter.run_node()
146
return_val = super().run_node(n)
147
# Record the time we finished running the op
148
t_end = time.time()
149
# If we don't have an entry for this node in our runtimes_sec
150
# data structure, add one with an empty list value.
151
self.runtimes_sec.setdefault(n, [])
152
# Record the total elapsed time for this single invocation
153
# in the runtimes_sec data structure
154
self.runtimes_sec[n].append(t_end - t_start)
155
return return_val
156
157
######################################################################
158
# Finally, we are going to define a method (one which doesn't override
159
# any ``Interpreter`` method) that provides us a nice, organized view of
160
# the data we have collected.
161
162
def summary(self, should_sort : bool = False) -> str:
163
# Build up a list of summary information for each node
164
node_summaries : List[List[Any]] = []
165
# Calculate the mean runtime for the whole network. Because the
166
# network may have been called multiple times during profiling,
167
# we need to summarize the runtimes. We choose to use the
168
# arithmetic mean for this.
169
mean_total_runtime = statistics.mean(self.total_runtime_sec)
170
171
# For each node, record summary statistics
172
for node, runtimes in self.runtimes_sec.items():
173
# Similarly, compute the mean runtime for ``node``
174
mean_runtime = statistics.mean(runtimes)
175
# For easier understanding, we also compute the percentage
176
# time each node took with respect to the whole network.
177
pct_total = mean_runtime / mean_total_runtime * 100
178
# Record the node's type, name of the node, mean runtime, and
179
# percent runtime.
180
node_summaries.append(
181
[node.op, str(node), mean_runtime, pct_total])
182
183
# One of the most important questions to answer when doing performance
184
# profiling is "Which op(s) took the longest?". We can make this easy
185
# to see by providing sorting functionality in our summary view
186
if should_sort:
187
node_summaries.sort(key=lambda s: s[2], reverse=True)
188
189
# Use the ``tabulate`` library to create a well-formatted table
190
# presenting our summary information
191
headers : List[str] = [
192
'Op type', 'Op', 'Average runtime (s)', 'Pct total runtime'
193
]
194
return tabulate.tabulate(node_summaries, headers=headers)
195
196
######################################################################
197
# .. note::
198
# We use Python's ``time.time`` function to pull wall clock
199
# timestamps and compare them. This is not the most accurate
200
# way to measure performance, and will only give us a first-
201
# order approximation. We use this simple technique only for the
202
# purpose of demonstration in this tutorial.
203
204
######################################################################
205
# Investigating the Performance of ResNet18
206
# -----------------------------------------
207
# We can now use ``ProfilingInterpreter`` to inspect the performance
208
# characteristics of our ResNet18 model;
209
210
interp = ProfilingInterpreter(rn18)
211
interp.run(input)
212
print(interp.summary(True))
213
214
######################################################################
215
# There are two things we should call out here:
216
#
217
# * ``MaxPool2d`` takes up the most time. This is a known issue:
218
# https://github.com/pytorch/pytorch/issues/51393
219
# * BatchNorm2d also takes up significant time. We can continue this
220
# line of thinking and optimize this in the Conv-BN Fusion with FX
221
# `tutorial <https://pytorch.org/tutorials/intermediate/fx_conv_bn_fuser.html>`_.
222
#
223
#
224
# Conclusion
225
# ----------
226
# As we can see, using FX we can easily capture PyTorch programs (even
227
# ones we don't have the source code for!) in a machine-interpretable
228
# format and use that for analysis, such as the performance analysis
229
# we've done here. FX opens up an exciting world of possibilities for
230
# working with PyTorch programs.
231
#
232
# Finally, since FX is still in beta, we would be happy to hear any
233
# feedback you have about using it. Please feel free to use the
234
# PyTorch Forums (https://discuss.pytorch.org/) and the issue tracker
235
# (https://github.com/pytorch/pytorch/issues) to provide any feedback
236
# you might have.
237
238