Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Path: blob/main/recipes_source/recipes/profiler_recipe.py
Views: 713
"""1PyTorch Profiler2====================================3This recipe explains how to use PyTorch profiler and measure the time and4memory consumption of the model's operators.56Introduction7------------8PyTorch includes a simple profiler API that is useful when user needs9to determine the most expensive operators in the model.1011In this recipe, we will use a simple Resnet model to demonstrate how to12use profiler to analyze model performance.1314Setup15-----16To install ``torch`` and ``torchvision`` use the following command:1718.. code-block:: sh1920pip install torch torchvision212223"""242526######################################################################27# Steps28# -----29#30# 1. Import all necessary libraries31# 2. Instantiate a simple Resnet model32# 3. Using profiler to analyze execution time33# 4. Using profiler to analyze memory consumption34# 5. Using tracing functionality35# 6. Examining stack traces36# 7. Using profiler to analyze long-running jobs37#38# 1. Import all necessary libraries39# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~40#41# In this recipe we will use ``torch``, ``torchvision.models``42# and ``profiler`` modules:43#4445import torch46import torchvision.models as models47from torch.profiler import profile, record_function, ProfilerActivity484950######################################################################51# 2. Instantiate a simple Resnet model52# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~53#54# Let's create an instance of a Resnet model and prepare an input55# for it:56#5758model = models.resnet18()59inputs = torch.randn(5, 3, 224, 224)6061######################################################################62# 3. Using profiler to analyze execution time63# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~64#65# PyTorch profiler is enabled through the context manager and accepts66# a number of parameters, some of the most useful are:67#68# - ``activities`` - a list of activities to profile:69# - ``ProfilerActivity.CPU`` - PyTorch operators, TorchScript functions and70# user-defined code labels (see ``record_function`` below);71# - ``ProfilerActivity.CUDA`` - on-device CUDA kernels;72# - ``ProfilerActivity.XPU`` - on-device XPU kernels;73# - ``record_shapes`` - whether to record shapes of the operator inputs;74# - ``profile_memory`` - whether to report amount of memory consumed by75# model's Tensors;76#77# Note: when using CUDA, profiler also shows the runtime CUDA events78# occurring on the host.7980######################################################################81# Let's see how we can use profiler to analyze the execution time:8283with profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof:84with record_function("model_inference"):85model(inputs)8687######################################################################88# Note that we can use ``record_function`` context manager to label89# arbitrary code ranges with user provided names90# (``model_inference`` is used as a label in the example above).91#92# Profiler allows one to check which operators were called during the93# execution of a code range wrapped with a profiler context manager.94# If multiple profiler ranges are active at the same time (e.g. in95# parallel PyTorch threads), each profiling context manager tracks only96# the operators of its corresponding range.97# Profiler also automatically profiles the asynchronous tasks launched98# with ``torch.jit._fork`` and (in case of a backward pass)99# the backward pass operators launched with ``backward()`` call.100#101# Let's print out the stats for the execution above:102103print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))104105######################################################################106# The output will look like (omitting some columns):107108# --------------------------------- ------------ ------------ ------------ ------------109# Name Self CPU CPU total CPU time avg # of Calls110# --------------------------------- ------------ ------------ ------------ ------------111# model_inference 5.509ms 57.503ms 57.503ms 1112# aten::conv2d 231.000us 31.931ms 1.597ms 20113# aten::convolution 250.000us 31.700ms 1.585ms 20114# aten::_convolution 336.000us 31.450ms 1.573ms 20115# aten::mkldnn_convolution 30.838ms 31.114ms 1.556ms 20116# aten::batch_norm 211.000us 14.693ms 734.650us 20117# aten::_batch_norm_impl_index 319.000us 14.482ms 724.100us 20118# aten::native_batch_norm 9.229ms 14.109ms 705.450us 20119# aten::mean 332.000us 2.631ms 125.286us 21120# aten::select 1.668ms 2.292ms 8.988us 255121# --------------------------------- ------------ ------------ ------------ ------------122# Self CPU time total: 57.549m123#124125######################################################################126# Here we see that, as expected, most of the time is spent in convolution (and specifically in ``mkldnn_convolution``127# for PyTorch compiled with ``MKL-DNN`` support).128# Note the difference between self cpu time and cpu time - operators can call other operators, self cpu time excludes time129# spent in children operator calls, while total cpu time includes it. You can choose to sort by the self cpu time by passing130# ``sort_by="self_cpu_time_total"`` into the ``table`` call.131#132# To get a finer granularity of results and include operator input shapes, pass ``group_by_input_shape=True``133# (note: this requires running the profiler with ``record_shapes=True``):134135print(prof.key_averages(group_by_input_shape=True).table(sort_by="cpu_time_total", row_limit=10))136137########################################################################################138# The output might look like this (omitting some columns):139#140# .. code-block:: sh141#142# --------------------------------- ------------ -------------------------------------------143# Name CPU total Input Shapes144# --------------------------------- ------------ -------------------------------------------145# model_inference 57.503ms []146# aten::conv2d 8.008ms [5,64,56,56], [64,64,3,3], [], ..., []]147# aten::convolution 7.956ms [[5,64,56,56], [64,64,3,3], [], ..., []]148# aten::_convolution 7.909ms [[5,64,56,56], [64,64,3,3], [], ..., []]149# aten::mkldnn_convolution 7.834ms [[5,64,56,56], [64,64,3,3], [], ..., []]150# aten::conv2d 6.332ms [[5,512,7,7], [512,512,3,3], [], ..., []]151# aten::convolution 6.303ms [[5,512,7,7], [512,512,3,3], [], ..., []]152# aten::_convolution 6.273ms [[5,512,7,7], [512,512,3,3], [], ..., []]153# aten::mkldnn_convolution 6.233ms [[5,512,7,7], [512,512,3,3], [], ..., []]154# aten::conv2d 4.751ms [[5,256,14,14], [256,256,3,3], [], ..., []]155# --------------------------------- ------------ -------------------------------------------156# Self CPU time total: 57.549ms157#158159######################################################################160# Note the occurrence of ``aten::convolution`` twice with different input shapes.161162######################################################################163# Profiler can also be used to analyze performance of models executed on GPUs and XPUs:164# Users could switch between cpu, cuda and xpu165if torch.cuda.is_available():166device = 'cuda'167elif torch.xpu.is_available():168device = 'xpu'169else:170print('Neither CUDA nor XPU devices are available to demonstrate profiling on acceleration devices')171import sys172sys.exit(0)173174activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA, ProfilerActivity.XPU]175sort_by_keyword = device + "_time_total"176177model = models.resnet18().to(device)178inputs = torch.randn(5, 3, 224, 224).to(device)179180with profile(activities=activities, record_shapes=True) as prof:181with record_function("model_inference"):182model(inputs)183184print(prof.key_averages().table(sort_by=sort_by_keyword, row_limit=10))185186######################################################################187# (Note: the first use of CUDA profiling may bring an extra overhead.)188189######################################################################190# The resulting table output (omitting some columns):191#192# .. code-block:: sh193#194# ------------------------------------------------------- ------------ ------------195# Name Self CUDA CUDA total196# ------------------------------------------------------- ------------ ------------197# model_inference 0.000us 11.666ms198# aten::conv2d 0.000us 10.484ms199# aten::convolution 0.000us 10.484ms200# aten::_convolution 0.000us 10.484ms201# aten::_convolution_nogroup 0.000us 10.484ms202# aten::thnn_conv2d 0.000us 10.484ms203# aten::thnn_conv2d_forward 10.484ms 10.484ms204# void at::native::im2col_kernel<float>(long, float co... 3.844ms 3.844ms205# sgemm_32x32x32_NN 3.206ms 3.206ms206# sgemm_32x32x32_NN_vec 3.093ms 3.093ms207# ------------------------------------------------------- ------------ ------------208# Self CPU time total: 23.015ms209# Self CUDA time total: 11.666ms210#211######################################################################212213214######################################################################215# (Note: the first use of XPU profiling may bring an extra overhead.)216217######################################################################218# The resulting table output (omitting some columns):219#220# .. code-block:: sh221#222#------------------------------------------------------- ------------ ------------ ------------ ------------ ------------223# Name Self XPU Self XPU % XPU total XPU time avg # of Calls224# ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------225# model_inference 0.000us 0.00% 2.567ms 2.567ms 1226# aten::conv2d 0.000us 0.00% 1.871ms 93.560us 20227# aten::convolution 0.000us 0.00% 1.871ms 93.560us 20228# aten::_convolution 0.000us 0.00% 1.871ms 93.560us 20229# aten::convolution_overrideable 1.871ms 72.89% 1.871ms 93.560us 20230# gen_conv 1.484ms 57.82% 1.484ms 74.216us 20231# aten::batch_norm 0.000us 0.00% 432.640us 21.632us 20232# aten::_batch_norm_impl_index 0.000us 0.00% 432.640us 21.632us 20233# aten::native_batch_norm 432.640us 16.85% 432.640us 21.632us 20234# conv_reorder 386.880us 15.07% 386.880us 6.448us 60235# ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------236# Self CPU time total: 712.486ms237# Self XPU time total: 2.567ms238239#240241242######################################################################243# Note the occurrence of on-device kernels in the output (e.g. ``sgemm_32x32x32_NN``).244245######################################################################246# 4. Using profiler to analyze memory consumption247# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~248#249# PyTorch profiler can also show the amount of memory (used by the model's tensors)250# that was allocated (or released) during the execution of the model's operators.251# In the output below, 'self' memory corresponds to the memory allocated (released)252# by the operator, excluding the children calls to the other operators.253# To enable memory profiling functionality pass ``profile_memory=True``.254255model = models.resnet18()256inputs = torch.randn(5, 3, 224, 224)257258with profile(activities=[ProfilerActivity.CPU],259profile_memory=True, record_shapes=True) as prof:260model(inputs)261262print(prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=10))263264# (omitting some columns)265# --------------------------------- ------------ ------------ ------------266# Name CPU Mem Self CPU Mem # of Calls267# --------------------------------- ------------ ------------ ------------268# aten::empty 94.79 Mb 94.79 Mb 121269# aten::max_pool2d_with_indices 11.48 Mb 11.48 Mb 1270# aten::addmm 19.53 Kb 19.53 Kb 1271# aten::empty_strided 572 b 572 b 25272# aten::resize_ 240 b 240 b 6273# aten::abs 480 b 240 b 4274# aten::add 160 b 160 b 20275# aten::masked_select 120 b 112 b 1276# aten::ne 122 b 53 b 6277# aten::eq 60 b 30 b 2278# --------------------------------- ------------ ------------ ------------279# Self CPU time total: 53.064ms280281print(prof.key_averages().table(sort_by="cpu_memory_usage", row_limit=10))282283#############################################################################284# The output might look like this (omitting some columns):285#286# .. code-block:: sh287#288# --------------------------------- ------------ ------------ ------------289# Name CPU Mem Self CPU Mem # of Calls290# --------------------------------- ------------ ------------ ------------291# aten::empty 94.79 Mb 94.79 Mb 121292# aten::batch_norm 47.41 Mb 0 b 20293# aten::_batch_norm_impl_index 47.41 Mb 0 b 20294# aten::native_batch_norm 47.41 Mb 0 b 20295# aten::conv2d 47.37 Mb 0 b 20296# aten::convolution 47.37 Mb 0 b 20297# aten::_convolution 47.37 Mb 0 b 20298# aten::mkldnn_convolution 47.37 Mb 0 b 20299# aten::max_pool2d 11.48 Mb 0 b 1300# aten::max_pool2d_with_indices 11.48 Mb 11.48 Mb 1301# --------------------------------- ------------ ------------ ------------302# Self CPU time total: 53.064ms303#304305######################################################################306# 5. Using tracing functionality307# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~308#309# Profiling results can be outputted as a ``.json`` trace file:310# Tracing CUDA or XPU kernels311# Users could switch between cpu, cuda and xpu312device = 'cuda'313314activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA, ProfilerActivity.XPU]315316model = models.resnet18().to(device)317inputs = torch.randn(5, 3, 224, 224).to(device)318319with profile(activities=activities) as prof:320model(inputs)321322prof.export_chrome_trace("trace.json")323324######################################################################325# You can examine the sequence of profiled operators and CUDA/XPU kernels326# in Chrome trace viewer (``chrome://tracing``):327#328# .. image:: ../../_static/img/trace_img.png329# :scale: 25 %330331######################################################################332# 6. Examining stack traces333# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~334#335# Profiler can be used to analyze Python and TorchScript stack traces:336sort_by_keyword = "self_" + device + "_time_total"337338with profile(339activities=activities,340with_stack=True,341) as prof:342model(inputs)343344# Print aggregated stats345print(prof.key_averages(group_by_stack_n=5).table(sort_by=sort_by_keyword, row_limit=2))346347#################################################################################348# The output might look like this (omitting some columns):349#350# .. code-block:: sh351#352# ------------------------- -----------------------------------------------------------353# Name Source Location354# ------------------------- -----------------------------------------------------------355# aten::thnn_conv2d_forward .../torch/nn/modules/conv.py(439): _conv_forward356# .../torch/nn/modules/conv.py(443): forward357# .../torch/nn/modules/module.py(1051): _call_impl358# .../site-packages/torchvision/models/resnet.py(63): forward359# .../torch/nn/modules/module.py(1051): _call_impl360# aten::thnn_conv2d_forward .../torch/nn/modules/conv.py(439): _conv_forward361# .../torch/nn/modules/conv.py(443): forward362# .../torch/nn/modules/module.py(1051): _call_impl363# .../site-packages/torchvision/models/resnet.py(59): forward364# .../torch/nn/modules/module.py(1051): _call_impl365# ------------------------- -----------------------------------------------------------366# Self CPU time total: 34.016ms367# Self CUDA time total: 11.659ms368#369370######################################################################371# Note the two convolutions and the two call sites in ``torchvision/models/resnet.py`` script.372#373# (Warning: stack tracing adds an extra profiling overhead.)374375######################################################################376# 7. Using profiler to analyze long-running jobs377# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~378#379# PyTorch profiler offers an additional API to handle long-running jobs380# (such as training loops). Tracing all of the execution can be381# slow and result in very large trace files. To avoid this, use optional382# arguments:383#384# - ``schedule`` - specifies a function that takes an integer argument (step number)385# as an input and returns an action for the profiler, the best way to use this parameter386# is to use ``torch.profiler.schedule`` helper function that can generate a schedule for you;387# - ``on_trace_ready`` - specifies a function that takes a reference to the profiler as388# an input and is called by the profiler each time the new trace is ready.389#390# To illustrate how the API works, let's first consider the following example with391# ``torch.profiler.schedule`` helper function:392393from torch.profiler import schedule394395my_schedule = schedule(396skip_first=10,397wait=5,398warmup=1,399active=3,400repeat=2)401402######################################################################403# Profiler assumes that the long-running job is composed of steps, numbered404# starting from zero. The example above defines the following sequence of actions405# for the profiler:406#407# 1. Parameter ``skip_first`` tells profiler that it should ignore the first 10 steps408# (default value of ``skip_first`` is zero);409# 2. After the first ``skip_first`` steps, profiler starts executing profiler cycles;410# 3. Each cycle consists of three phases:411#412# - idling (``wait=5`` steps), during this phase profiler is not active;413# - warming up (``warmup=1`` steps), during this phase profiler starts tracing, but414# the results are discarded; this phase is used to discard the samples obtained by415# the profiler at the beginning of the trace since they are usually skewed by an extra416# overhead;417# - active tracing (``active=3`` steps), during this phase profiler traces and records data;418# 4. An optional ``repeat`` parameter specifies an upper bound on the number of cycles.419# By default (zero value), profiler will execute cycles as long as the job runs.420421######################################################################422# Thus, in the example above, profiler will skip the first 15 steps, spend the next step on the warm up,423# actively record the next 3 steps, skip another 5 steps, spend the next step on the warm up, actively424# record another 3 steps. Since the ``repeat=2`` parameter value is specified, the profiler will stop425# the recording after the first two cycles.426#427# At the end of each cycle profiler calls the specified ``on_trace_ready`` function and passes itself as428# an argument. This function is used to process the new trace - either by obtaining the table output or429# by saving the output on disk as a trace file.430#431# To send the signal to the profiler that the next step has started, call ``prof.step()`` function.432# The current profiler step is stored in ``prof.step_num``.433#434# The following example shows how to use all of the concepts above for CUDA and XPU Kernels:435436sort_by_keyword = "self_" + device + "_time_total"437438def trace_handler(p):439output = p.key_averages().table(sort_by=sort_by_keyword, row_limit=10)440print(output)441p.export_chrome_trace("/tmp/trace_" + str(p.step_num) + ".json")442443with profile(444activities=activities,445schedule=torch.profiler.schedule(446wait=1,447warmup=1,448active=2),449on_trace_ready=trace_handler450) as p:451for idx in range(8):452model(inputs)453p.step()454455######################################################################456# Learn More457# ----------458#459# Take a look at the following recipes/tutorials to continue your learning:460#461# - `PyTorch Benchmark <https://pytorch.org/tutorials/recipes/recipes/benchmark.html>`_462# - `Visualizing models, data, and training with TensorBoard <https://pytorch.org/tutorials/intermediate/tensorboard_tutorial.html>`_ tutorial463#464465466