Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Path: blob/main/intermediate_source/fx_profiling_tutorial.py
Views: 712
# -*- coding: utf-8 -*-1"""2(beta) Building a Simple CPU Performance Profiler with FX3*********************************************************4**Author**: `James Reed <https://github.com/jamesr66a>`_56In this tutorial, we are going to use FX to do the following:781) Capture PyTorch Python code in a way that we can inspect and gather9statistics about the structure and execution of the code102) Build out a small class that will serve as a simple performance "profiler",11collecting runtime statistics about each part of the model from actual12runs.1314"""1516######################################################################17# For this tutorial, we are going to use the torchvision ResNet18 model18# for demonstration purposes.1920import torch21import torch.fx22import torchvision.models as models2324rn18 = models.resnet18()25rn18.eval()2627######################################################################28# Now that we have our model, we want to inspect deeper into its29# performance. That is, for the following invocation, which parts30# of the model are taking the longest?31input = torch.randn(5, 3, 224, 224)32output = rn18(input)3334######################################################################35# A common way of answering that question is to go through the program36# source, add code that collects timestamps at various points in the37# program, and compare the difference between those timestamps to see38# how long the regions between the timestamps take.39#40# That technique is certainly applicable to PyTorch code, however it41# would be nicer if we didn't have to copy over model code and edit it,42# especially code we haven't written (like this torchvision model).43# Instead, we are going to use FX to automate this "instrumentation"44# process without needing to modify any source.4546######################################################################47# First, let's get some imports out of the way (we will be using all48# of these later in the code).4950import statistics, tabulate, time51from typing import Any, Dict, List52from torch.fx import Interpreter5354######################################################################55# .. note::56# ``tabulate`` is an external library that is not a dependency of PyTorch.57# We will be using it to more easily visualize performance data. Please58# make sure you've installed it from your favorite Python package source.5960######################################################################61# Capturing the Model with Symbolic Tracing62# -----------------------------------------63# Next, we are going to use FX's symbolic tracing mechanism to capture64# the definition of our model in a data structure we can manipulate65# and examine.6667traced_rn18 = torch.fx.symbolic_trace(rn18)68print(traced_rn18.graph)6970######################################################################71# This gives us a Graph representation of the ResNet18 model. A Graph72# consists of a series of Nodes connected to each other. Each Node73# represents a call-site in the Python code (whether to a function,74# a module, or a method) and the edges (represented as ``args`` and ``kwargs``75# on each node) represent the values passed between these call-sites. More76# information about the Graph representation and the rest of FX's APIs ca77# be found at the FX documentation https://pytorch.org/docs/master/fx.html.787980######################################################################81# Creating a Profiling Interpreter82# --------------------------------83# Next, we are going to create a class that inherits from ``torch.fx.Interpreter``.84# Though the ``GraphModule`` that ``symbolic_trace`` produces compiles Python code85# that is run when you call a ``GraphModule``, an alternative way to run a86# ``GraphModule`` is by executing each ``Node`` in the ``Graph`` one by one. That is87# the functionality that ``Interpreter`` provides: It interprets the graph node-88# by-node.89#90# By inheriting from ``Interpreter``, we can override various functionality and91# install the profiling behavior we want. The goal is to have an object to which92# we can pass a model, invoke the model 1 or more times, then get statistics about93# how long the model and each part of the model took during those runs.94#95# Let's define our ``ProfilingInterpreter`` class:9697class ProfilingInterpreter(Interpreter):98def __init__(self, mod : torch.nn.Module):99# Rather than have the user symbolically trace their model,100# we're going to do it in the constructor. As a result, the101# user can pass in any ``Module`` without having to worry about102# symbolic tracing APIs103gm = torch.fx.symbolic_trace(mod)104super().__init__(gm)105106# We are going to store away two things here:107#108# 1. A list of total runtimes for ``mod``. In other words, we are109# storing away the time ``mod(...)`` took each time this110# interpreter is called.111self.total_runtime_sec : List[float] = []112# 2. A map from ``Node`` to a list of times (in seconds) that113# node took to run. This can be seen as similar to (1) but114# for specific sub-parts of the model.115self.runtimes_sec : Dict[torch.fx.Node, List[float]] = {}116117######################################################################118# Next, let's override our first method: ``run()``. ``Interpreter``'s ``run``119# method is the top-level entry point for execution of the model. We will120# want to intercept this so that we can record the total runtime of the121# model.122123def run(self, *args) -> Any:124# Record the time we started running the model125t_start = time.time()126# Run the model by delegating back into Interpreter.run()127return_val = super().run(*args)128# Record the time we finished running the model129t_end = time.time()130# Store the total elapsed time this model execution took in the131# ``ProfilingInterpreter``132self.total_runtime_sec.append(t_end - t_start)133return return_val134135######################################################################136# Now, let's override ``run_node``. ``Interpreter`` calls ``run_node`` each137# time it executes a single node. We will intercept this so that we138# can measure and record the time taken for each individual call in139# the model.140141def run_node(self, n : torch.fx.Node) -> Any:142# Record the time we started running the op143t_start = time.time()144# Run the op by delegating back into Interpreter.run_node()145return_val = super().run_node(n)146# Record the time we finished running the op147t_end = time.time()148# If we don't have an entry for this node in our runtimes_sec149# data structure, add one with an empty list value.150self.runtimes_sec.setdefault(n, [])151# Record the total elapsed time for this single invocation152# in the runtimes_sec data structure153self.runtimes_sec[n].append(t_end - t_start)154return return_val155156######################################################################157# Finally, we are going to define a method (one which doesn't override158# any ``Interpreter`` method) that provides us a nice, organized view of159# the data we have collected.160161def summary(self, should_sort : bool = False) -> str:162# Build up a list of summary information for each node163node_summaries : List[List[Any]] = []164# Calculate the mean runtime for the whole network. Because the165# network may have been called multiple times during profiling,166# we need to summarize the runtimes. We choose to use the167# arithmetic mean for this.168mean_total_runtime = statistics.mean(self.total_runtime_sec)169170# For each node, record summary statistics171for node, runtimes in self.runtimes_sec.items():172# Similarly, compute the mean runtime for ``node``173mean_runtime = statistics.mean(runtimes)174# For easier understanding, we also compute the percentage175# time each node took with respect to the whole network.176pct_total = mean_runtime / mean_total_runtime * 100177# Record the node's type, name of the node, mean runtime, and178# percent runtime.179node_summaries.append(180[node.op, str(node), mean_runtime, pct_total])181182# One of the most important questions to answer when doing performance183# profiling is "Which op(s) took the longest?". We can make this easy184# to see by providing sorting functionality in our summary view185if should_sort:186node_summaries.sort(key=lambda s: s[2], reverse=True)187188# Use the ``tabulate`` library to create a well-formatted table189# presenting our summary information190headers : List[str] = [191'Op type', 'Op', 'Average runtime (s)', 'Pct total runtime'192]193return tabulate.tabulate(node_summaries, headers=headers)194195######################################################################196# .. note::197# We use Python's ``time.time`` function to pull wall clock198# timestamps and compare them. This is not the most accurate199# way to measure performance, and will only give us a first-200# order approximation. We use this simple technique only for the201# purpose of demonstration in this tutorial.202203######################################################################204# Investigating the Performance of ResNet18205# -----------------------------------------206# We can now use ``ProfilingInterpreter`` to inspect the performance207# characteristics of our ResNet18 model;208209interp = ProfilingInterpreter(rn18)210interp.run(input)211print(interp.summary(True))212213######################################################################214# There are two things we should call out here:215#216# * ``MaxPool2d`` takes up the most time. This is a known issue:217# https://github.com/pytorch/pytorch/issues/51393218# * BatchNorm2d also takes up significant time. We can continue this219# line of thinking and optimize this in the Conv-BN Fusion with FX220# `tutorial <https://pytorch.org/tutorials/intermediate/fx_conv_bn_fuser.html>`_.221#222#223# Conclusion224# ----------225# As we can see, using FX we can easily capture PyTorch programs (even226# ones we don't have the source code for!) in a machine-interpretable227# format and use that for analysis, such as the performance analysis228# we've done here. FX opens up an exciting world of possibilities for229# working with PyTorch programs.230#231# Finally, since FX is still in beta, we would be happy to hear any232# feedback you have about using it. Please feel free to use the233# PyTorch Forums (https://discuss.pytorch.org/) and the issue tracker234# (https://github.com/pytorch/pytorch/issues) to provide any feedback235# you might have.236237238