CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!

GitHub Repository: pytorch/tutorials
Path: blob/main/recipes_source/recipes/profiler_recipe.py
Views: 494
1
"""
2
PyTorch Profiler
3
====================================
4
This recipe explains how to use PyTorch profiler and measure the time and
5
memory consumption of the model's operators.
6
7
Introduction
8
------------
9
PyTorch includes a simple profiler API that is useful when user needs
10
to determine the most expensive operators in the model.
11
12
In this recipe, we will use a simple Resnet model to demonstrate how to
13
use profiler to analyze model performance.
14
15
Setup
16
-----
17
To install ``torch`` and ``torchvision`` use the following command:
18
19
.. code-block:: sh
20
21
pip install torch torchvision
22
23
24
"""
25
26
27
######################################################################
28
# Steps
29
# -----
30
#
31
# 1. Import all necessary libraries
32
# 2. Instantiate a simple Resnet model
33
# 3. Using profiler to analyze execution time
34
# 4. Using profiler to analyze memory consumption
35
# 5. Using tracing functionality
36
# 6. Examining stack traces
37
# 7. Using profiler to analyze long-running jobs
38
#
39
# 1. Import all necessary libraries
40
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
41
#
42
# In this recipe we will use ``torch``, ``torchvision.models``
43
# and ``profiler`` modules:
44
#
45
46
import torch
47
import torchvision.models as models
48
from torch.profiler import profile, record_function, ProfilerActivity
49
50
51
######################################################################
52
# 2. Instantiate a simple Resnet model
53
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
54
#
55
# Let's create an instance of a Resnet model and prepare an input
56
# for it:
57
#
58
59
model = models.resnet18()
60
inputs = torch.randn(5, 3, 224, 224)
61
62
######################################################################
63
# 3. Using profiler to analyze execution time
64
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
65
#
66
# PyTorch profiler is enabled through the context manager and accepts
67
# a number of parameters, some of the most useful are:
68
#
69
# - ``activities`` - a list of activities to profile:
70
# - ``ProfilerActivity.CPU`` - PyTorch operators, TorchScript functions and
71
# user-defined code labels (see ``record_function`` below);
72
# - ``ProfilerActivity.CUDA`` - on-device CUDA kernels;
73
# - ``record_shapes`` - whether to record shapes of the operator inputs;
74
# - ``profile_memory`` - whether to report amount of memory consumed by
75
# model's Tensors;
76
#
77
# Note: when using CUDA, profiler also shows the runtime CUDA events
78
# occurring on the host.
79
80
######################################################################
81
# Let's see how we can use profiler to analyze the execution time:
82
83
with profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof:
84
with record_function("model_inference"):
85
model(inputs)
86
87
######################################################################
88
# Note that we can use ``record_function`` context manager to label
89
# arbitrary code ranges with user provided names
90
# (``model_inference`` is used as a label in the example above).
91
#
92
# Profiler allows one to check which operators were called during the
93
# execution of a code range wrapped with a profiler context manager.
94
# If multiple profiler ranges are active at the same time (e.g. in
95
# parallel PyTorch threads), each profiling context manager tracks only
96
# the operators of its corresponding range.
97
# Profiler also automatically profiles the asynchronous tasks launched
98
# with ``torch.jit._fork`` and (in case of a backward pass)
99
# the backward pass operators launched with ``backward()`` call.
100
#
101
# Let's print out the stats for the execution above:
102
103
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
104
105
######################################################################
106
# The output will look like (omitting some columns):
107
108
# --------------------------------- ------------ ------------ ------------ ------------
109
# Name Self CPU CPU total CPU time avg # of Calls
110
# --------------------------------- ------------ ------------ ------------ ------------
111
# model_inference 5.509ms 57.503ms 57.503ms 1
112
# aten::conv2d 231.000us 31.931ms 1.597ms 20
113
# aten::convolution 250.000us 31.700ms 1.585ms 20
114
# aten::_convolution 336.000us 31.450ms 1.573ms 20
115
# aten::mkldnn_convolution 30.838ms 31.114ms 1.556ms 20
116
# aten::batch_norm 211.000us 14.693ms 734.650us 20
117
# aten::_batch_norm_impl_index 319.000us 14.482ms 724.100us 20
118
# aten::native_batch_norm 9.229ms 14.109ms 705.450us 20
119
# aten::mean 332.000us 2.631ms 125.286us 21
120
# aten::select 1.668ms 2.292ms 8.988us 255
121
# --------------------------------- ------------ ------------ ------------ ------------
122
# Self CPU time total: 57.549m
123
#
124
125
######################################################################
126
# Here we see that, as expected, most of the time is spent in convolution (and specifically in ``mkldnn_convolution``
127
# for PyTorch compiled with ``MKL-DNN`` support).
128
# Note the difference between self cpu time and cpu time - operators can call other operators, self cpu time excludes time
129
# spent in children operator calls, while total cpu time includes it. You can choose to sort by the self cpu time by passing
130
# ``sort_by="self_cpu_time_total"`` into the ``table`` call.
131
#
132
# To get a finer granularity of results and include operator input shapes, pass ``group_by_input_shape=True``
133
# (note: this requires running the profiler with ``record_shapes=True``):
134
135
print(prof.key_averages(group_by_input_shape=True).table(sort_by="cpu_time_total", row_limit=10))
136
137
########################################################################################
138
# The output might look like this (omitting some columns):
139
#
140
# .. code-block:: sh
141
#
142
# --------------------------------- ------------ -------------------------------------------
143
# Name CPU total Input Shapes
144
# --------------------------------- ------------ -------------------------------------------
145
# model_inference 57.503ms []
146
# aten::conv2d 8.008ms [5,64,56,56], [64,64,3,3], [], ..., []]
147
# aten::convolution 7.956ms [[5,64,56,56], [64,64,3,3], [], ..., []]
148
# aten::_convolution 7.909ms [[5,64,56,56], [64,64,3,3], [], ..., []]
149
# aten::mkldnn_convolution 7.834ms [[5,64,56,56], [64,64,3,3], [], ..., []]
150
# aten::conv2d 6.332ms [[5,512,7,7], [512,512,3,3], [], ..., []]
151
# aten::convolution 6.303ms [[5,512,7,7], [512,512,3,3], [], ..., []]
152
# aten::_convolution 6.273ms [[5,512,7,7], [512,512,3,3], [], ..., []]
153
# aten::mkldnn_convolution 6.233ms [[5,512,7,7], [512,512,3,3], [], ..., []]
154
# aten::conv2d 4.751ms [[5,256,14,14], [256,256,3,3], [], ..., []]
155
# --------------------------------- ------------ -------------------------------------------
156
# Self CPU time total: 57.549ms
157
#
158
159
######################################################################
160
# Note the occurrence of ``aten::convolution`` twice with different input shapes.
161
162
######################################################################
163
# Profiler can also be used to analyze performance of models executed on GPUs:
164
165
model = models.resnet18().cuda()
166
inputs = torch.randn(5, 3, 224, 224).cuda()
167
168
with profile(activities=[
169
ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
170
with record_function("model_inference"):
171
model(inputs)
172
173
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
174
175
######################################################################
176
# (Note: the first use of CUDA profiling may bring an extra overhead.)
177
178
######################################################################
179
# The resulting table output (omitting some columns):
180
#
181
# .. code-block:: sh
182
#
183
# ------------------------------------------------------- ------------ ------------
184
# Name Self CUDA CUDA total
185
# ------------------------------------------------------- ------------ ------------
186
# model_inference 0.000us 11.666ms
187
# aten::conv2d 0.000us 10.484ms
188
# aten::convolution 0.000us 10.484ms
189
# aten::_convolution 0.000us 10.484ms
190
# aten::_convolution_nogroup 0.000us 10.484ms
191
# aten::thnn_conv2d 0.000us 10.484ms
192
# aten::thnn_conv2d_forward 10.484ms 10.484ms
193
# void at::native::im2col_kernel<float>(long, float co... 3.844ms 3.844ms
194
# sgemm_32x32x32_NN 3.206ms 3.206ms
195
# sgemm_32x32x32_NN_vec 3.093ms 3.093ms
196
# ------------------------------------------------------- ------------ ------------
197
# Self CPU time total: 23.015ms
198
# Self CUDA time total: 11.666ms
199
#
200
201
######################################################################
202
# Note the occurrence of on-device kernels in the output (e.g. ``sgemm_32x32x32_NN``).
203
204
######################################################################
205
# 4. Using profiler to analyze memory consumption
206
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
207
#
208
# PyTorch profiler can also show the amount of memory (used by the model's tensors)
209
# that was allocated (or released) during the execution of the model's operators.
210
# In the output below, 'self' memory corresponds to the memory allocated (released)
211
# by the operator, excluding the children calls to the other operators.
212
# To enable memory profiling functionality pass ``profile_memory=True``.
213
214
model = models.resnet18()
215
inputs = torch.randn(5, 3, 224, 224)
216
217
with profile(activities=[ProfilerActivity.CPU],
218
profile_memory=True, record_shapes=True) as prof:
219
model(inputs)
220
221
print(prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=10))
222
223
# (omitting some columns)
224
# --------------------------------- ------------ ------------ ------------
225
# Name CPU Mem Self CPU Mem # of Calls
226
# --------------------------------- ------------ ------------ ------------
227
# aten::empty 94.79 Mb 94.79 Mb 121
228
# aten::max_pool2d_with_indices 11.48 Mb 11.48 Mb 1
229
# aten::addmm 19.53 Kb 19.53 Kb 1
230
# aten::empty_strided 572 b 572 b 25
231
# aten::resize_ 240 b 240 b 6
232
# aten::abs 480 b 240 b 4
233
# aten::add 160 b 160 b 20
234
# aten::masked_select 120 b 112 b 1
235
# aten::ne 122 b 53 b 6
236
# aten::eq 60 b 30 b 2
237
# --------------------------------- ------------ ------------ ------------
238
# Self CPU time total: 53.064ms
239
240
print(prof.key_averages().table(sort_by="cpu_memory_usage", row_limit=10))
241
242
#############################################################################
243
# The output might look like this (omitting some columns):
244
#
245
# .. code-block:: sh
246
#
247
# --------------------------------- ------------ ------------ ------------
248
# Name CPU Mem Self CPU Mem # of Calls
249
# --------------------------------- ------------ ------------ ------------
250
# aten::empty 94.79 Mb 94.79 Mb 121
251
# aten::batch_norm 47.41 Mb 0 b 20
252
# aten::_batch_norm_impl_index 47.41 Mb 0 b 20
253
# aten::native_batch_norm 47.41 Mb 0 b 20
254
# aten::conv2d 47.37 Mb 0 b 20
255
# aten::convolution 47.37 Mb 0 b 20
256
# aten::_convolution 47.37 Mb 0 b 20
257
# aten::mkldnn_convolution 47.37 Mb 0 b 20
258
# aten::max_pool2d 11.48 Mb 0 b 1
259
# aten::max_pool2d_with_indices 11.48 Mb 11.48 Mb 1
260
# --------------------------------- ------------ ------------ ------------
261
# Self CPU time total: 53.064ms
262
#
263
264
######################################################################
265
# 5. Using tracing functionality
266
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
267
#
268
# Profiling results can be outputted as a ``.json`` trace file:
269
270
model = models.resnet18().cuda()
271
inputs = torch.randn(5, 3, 224, 224).cuda()
272
273
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
274
model(inputs)
275
276
prof.export_chrome_trace("trace.json")
277
278
######################################################################
279
# You can examine the sequence of profiled operators and CUDA kernels
280
# in Chrome trace viewer (``chrome://tracing``):
281
#
282
# .. image:: ../../_static/img/trace_img.png
283
# :scale: 25 %
284
285
######################################################################
286
# 6. Examining stack traces
287
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
288
#
289
# Profiler can be used to analyze Python and TorchScript stack traces:
290
291
with profile(
292
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
293
with_stack=True,
294
) as prof:
295
model(inputs)
296
297
# Print aggregated stats
298
print(prof.key_averages(group_by_stack_n=5).table(sort_by="self_cuda_time_total", row_limit=2))
299
300
#################################################################################
301
# The output might look like this (omitting some columns):
302
#
303
# .. code-block:: sh
304
#
305
# ------------------------- -----------------------------------------------------------
306
# Name Source Location
307
# ------------------------- -----------------------------------------------------------
308
# aten::thnn_conv2d_forward .../torch/nn/modules/conv.py(439): _conv_forward
309
# .../torch/nn/modules/conv.py(443): forward
310
# .../torch/nn/modules/module.py(1051): _call_impl
311
# .../site-packages/torchvision/models/resnet.py(63): forward
312
# .../torch/nn/modules/module.py(1051): _call_impl
313
# aten::thnn_conv2d_forward .../torch/nn/modules/conv.py(439): _conv_forward
314
# .../torch/nn/modules/conv.py(443): forward
315
# .../torch/nn/modules/module.py(1051): _call_impl
316
# .../site-packages/torchvision/models/resnet.py(59): forward
317
# .../torch/nn/modules/module.py(1051): _call_impl
318
# ------------------------- -----------------------------------------------------------
319
# Self CPU time total: 34.016ms
320
# Self CUDA time total: 11.659ms
321
#
322
323
######################################################################
324
# Note the two convolutions and the two call sites in ``torchvision/models/resnet.py`` script.
325
#
326
# (Warning: stack tracing adds an extra profiling overhead.)
327
328
######################################################################
329
# 7. Using profiler to analyze long-running jobs
330
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
331
#
332
# PyTorch profiler offers an additional API to handle long-running jobs
333
# (such as training loops). Tracing all of the execution can be
334
# slow and result in very large trace files. To avoid this, use optional
335
# arguments:
336
#
337
# - ``schedule`` - specifies a function that takes an integer argument (step number)
338
# as an input and returns an action for the profiler, the best way to use this parameter
339
# is to use ``torch.profiler.schedule`` helper function that can generate a schedule for you;
340
# - ``on_trace_ready`` - specifies a function that takes a reference to the profiler as
341
# an input and is called by the profiler each time the new trace is ready.
342
#
343
# To illustrate how the API works, let's first consider the following example with
344
# ``torch.profiler.schedule`` helper function:
345
346
from torch.profiler import schedule
347
348
my_schedule = schedule(
349
skip_first=10,
350
wait=5,
351
warmup=1,
352
active=3,
353
repeat=2)
354
355
######################################################################
356
# Profiler assumes that the long-running job is composed of steps, numbered
357
# starting from zero. The example above defines the following sequence of actions
358
# for the profiler:
359
#
360
# 1. Parameter ``skip_first`` tells profiler that it should ignore the first 10 steps
361
# (default value of ``skip_first`` is zero);
362
# 2. After the first ``skip_first`` steps, profiler starts executing profiler cycles;
363
# 3. Each cycle consists of three phases:
364
#
365
# - idling (``wait=5`` steps), during this phase profiler is not active;
366
# - warming up (``warmup=1`` steps), during this phase profiler starts tracing, but
367
# the results are discarded; this phase is used to discard the samples obtained by
368
# the profiler at the beginning of the trace since they are usually skewed by an extra
369
# overhead;
370
# - active tracing (``active=3`` steps), during this phase profiler traces and records data;
371
# 4. An optional ``repeat`` parameter specifies an upper bound on the number of cycles.
372
# By default (zero value), profiler will execute cycles as long as the job runs.
373
374
######################################################################
375
# Thus, in the example above, profiler will skip the first 15 steps, spend the next step on the warm up,
376
# actively record the next 3 steps, skip another 5 steps, spend the next step on the warm up, actively
377
# record another 3 steps. Since the ``repeat=2`` parameter value is specified, the profiler will stop
378
# the recording after the first two cycles.
379
#
380
# At the end of each cycle profiler calls the specified ``on_trace_ready`` function and passes itself as
381
# an argument. This function is used to process the new trace - either by obtaining the table output or
382
# by saving the output on disk as a trace file.
383
#
384
# To send the signal to the profiler that the next step has started, call ``prof.step()`` function.
385
# The current profiler step is stored in ``prof.step_num``.
386
#
387
# The following example shows how to use all of the concepts above:
388
389
def trace_handler(p):
390
output = p.key_averages().table(sort_by="self_cuda_time_total", row_limit=10)
391
print(output)
392
p.export_chrome_trace("/tmp/trace_" + str(p.step_num) + ".json")
393
394
with profile(
395
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
396
schedule=torch.profiler.schedule(
397
wait=1,
398
warmup=1,
399
active=2),
400
on_trace_ready=trace_handler
401
) as p:
402
for idx in range(8):
403
model(inputs)
404
p.step()
405
406
407
######################################################################
408
# Learn More
409
# ----------
410
#
411
# Take a look at the following recipes/tutorials to continue your learning:
412
#
413
# - `PyTorch Benchmark <https://pytorch.org/tutorials/recipes/recipes/benchmark.html>`_
414
# - `PyTorch Profiler with TensorBoard <https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html>`_ tutorial
415
# - `Visualizing models, data, and training with TensorBoard <https://pytorch.org/tutorials/intermediate/tensorboard_tutorial.html>`_ tutorial
416
#
417
418