CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!

GitHub Repository: pytorch/tutorials
Path: blob/main/beginner_source/profiler.py
Views: 494
1
"""
2
Profiling your PyTorch Module
3
-----------------------------
4
5
**Author:** `Suraj Subramanian <https://github.com/suraj813>`_
6
7
PyTorch includes a profiler API that is useful to identify the time and
8
memory costs of various PyTorch operations in your code. Profiler can be
9
easily integrated in your code, and the results can be printed as a table
10
or returned in a JSON trace file.
11
12
.. note::
13
Profiler supports multithreaded models. Profiler runs in the
14
same thread as the operation but it will also profile child operators
15
that might run in another thread. Concurrently-running profilers will be
16
scoped to their own thread to prevent mixing of results.
17
18
.. note::
19
PyTorch 1.8 introduces the new API that will replace the older profiler API
20
in the future releases. Check the new API at `this page <https://pytorch.org/docs/master/profiler.html>`__.
21
22
Head on over to `this
23
recipe <https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html>`__
24
for a quicker walkthrough of Profiler API usage.
25
26
27
--------------
28
"""
29
30
import torch
31
import numpy as np
32
from torch import nn
33
import torch.autograd.profiler as profiler
34
35
36
######################################################################
37
# Performance debugging using Profiler
38
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
39
#
40
# Profiler can be useful to identify performance bottlenecks in your
41
# models. In this example, we build a custom module that performs two
42
# sub-tasks:
43
#
44
# - a linear transformation on the input, and
45
# - use the transformation result to get indices on a mask tensor.
46
#
47
# We wrap the code for each sub-task in separate labelled context managers using
48
# ``profiler.record_function("label")``. In the profiler output, the
49
# aggregate performance metrics of all operations in the sub-task will
50
# show up under its corresponding label.
51
#
52
#
53
# Note that using Profiler incurs some overhead, and is best used only for investigating
54
# code. Remember to remove it if you are benchmarking runtimes.
55
#
56
57
class MyModule(nn.Module):
58
def __init__(self, in_features: int, out_features: int, bias: bool = True):
59
super(MyModule, self).__init__()
60
self.linear = nn.Linear(in_features, out_features, bias)
61
62
def forward(self, input, mask):
63
with profiler.record_function("LINEAR PASS"):
64
out = self.linear(input)
65
66
with profiler.record_function("MASK INDICES"):
67
threshold = out.sum(axis=1).mean().item()
68
hi_idx = np.argwhere(mask.cpu().numpy() > threshold)
69
hi_idx = torch.from_numpy(hi_idx).cuda()
70
71
return out, hi_idx
72
73
74
######################################################################
75
# Profile the forward pass
76
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
77
#
78
# We initialize random input and mask tensors, and the model.
79
#
80
# Before we run the profiler, we warm-up CUDA to ensure accurate
81
# performance benchmarking. We wrap the forward pass of our module in the
82
# ``profiler.profile`` context manager. The ``with_stack=True`` parameter appends the
83
# file and line number of the operation in the trace.
84
#
85
# .. warning::
86
# ``with_stack=True`` incurs an additional overhead, and is better suited for investigating code.
87
# Remember to remove it if you are benchmarking performance.
88
#
89
90
model = MyModule(500, 10).cuda()
91
input = torch.rand(128, 500).cuda()
92
mask = torch.rand((500, 500, 500), dtype=torch.double).cuda()
93
94
# warm-up
95
model(input, mask)
96
97
with profiler.profile(with_stack=True, profile_memory=True) as prof:
98
out, idx = model(input, mask)
99
100
101
######################################################################
102
# Print profiler results
103
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
104
#
105
# Finally, we print the profiler results. ``profiler.key_averages``
106
# aggregates the results by operator name, and optionally by input
107
# shapes and/or stack trace events.
108
# Grouping by input shapes is useful to identify which tensor shapes
109
# are utilized by the model.
110
#
111
# Here, we use ``group_by_stack_n=5`` which aggregates runtimes by the
112
# operation and its traceback (truncated to the most recent 5 events), and
113
# display the events in the order they are registered. The table can also
114
# be sorted by passing a ``sort_by`` argument (refer to the
115
# `docs <https://pytorch.org/docs/stable/autograd.html#profiler>`__ for
116
# valid sorting keys).
117
#
118
# .. note::
119
# When running profiler in a notebook, you might see entries like ``<ipython-input-18-193a910735e8>(13): forward``
120
# instead of filenames in the stacktrace. These correspond to ``<notebook-cell>(line number): calling-function``.
121
122
print(prof.key_averages(group_by_stack_n=5).table(sort_by='self_cpu_time_total', row_limit=5))
123
124
"""
125
(Some columns are omitted)
126
127
------------- ------------ ------------ ------------ ---------------------------------
128
Name Self CPU % Self CPU Self CPU Mem Source Location
129
------------- ------------ ------------ ------------ ---------------------------------
130
MASK INDICES 87.88% 5.212s -953.67 Mb /mnt/xarfuse/.../torch/au
131
<ipython-input-...>(10): forward
132
/mnt/xarfuse/.../torch/nn
133
<ipython-input-...>(9): <module>
134
/mnt/xarfuse/.../IPython/
135
136
aten::copy_ 12.07% 715.848ms 0 b <ipython-input-...>(12): forward
137
/mnt/xarfuse/.../torch/nn
138
<ipython-input-...>(9): <module>
139
/mnt/xarfuse/.../IPython/
140
/mnt/xarfuse/.../IPython/
141
142
LINEAR PASS 0.01% 350.151us -20 b /mnt/xarfuse/.../torch/au
143
<ipython-input-...>(7): forward
144
/mnt/xarfuse/.../torch/nn
145
<ipython-input-...>(9): <module>
146
/mnt/xarfuse/.../IPython/
147
148
aten::addmm 0.00% 293.342us 0 b /mnt/xarfuse/.../torch/nn
149
/mnt/xarfuse/.../torch/nn
150
/mnt/xarfuse/.../torch/nn
151
<ipython-input-...>(8): forward
152
/mnt/xarfuse/.../torch/nn
153
154
aten::mean 0.00% 235.095us 0 b <ipython-input-...>(11): forward
155
/mnt/xarfuse/.../torch/nn
156
<ipython-input-...>(9): <module>
157
/mnt/xarfuse/.../IPython/
158
/mnt/xarfuse/.../IPython/
159
160
----------------------------- ------------ ---------- ----------------------------------
161
Self CPU time total: 5.931s
162
163
"""
164
165
######################################################################
166
# Improve memory performance
167
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
168
# Note that the most expensive operations - in terms of memory and time -
169
# are at ``forward (10)`` representing the operations within MASK INDICES. Let’s try to
170
# tackle the memory consumption first. We can see that the ``.to()``
171
# operation at line 12 consumes 953.67 Mb. This operation copies ``mask`` to the CPU.
172
# ``mask`` is initialized with a ``torch.double`` datatype. Can we reduce the memory footprint by casting
173
# it to ``torch.float`` instead?
174
#
175
176
model = MyModule(500, 10).cuda()
177
input = torch.rand(128, 500).cuda()
178
mask = torch.rand((500, 500, 500), dtype=torch.float).cuda()
179
180
# warm-up
181
model(input, mask)
182
183
with profiler.profile(with_stack=True, profile_memory=True) as prof:
184
out, idx = model(input, mask)
185
186
print(prof.key_averages(group_by_stack_n=5).table(sort_by='self_cpu_time_total', row_limit=5))
187
188
"""
189
(Some columns are omitted)
190
191
----------------- ------------ ------------ ------------ --------------------------------
192
Name Self CPU % Self CPU Self CPU Mem Source Location
193
----------------- ------------ ------------ ------------ --------------------------------
194
MASK INDICES 93.61% 5.006s -476.84 Mb /mnt/xarfuse/.../torch/au
195
<ipython-input-...>(10): forward
196
/mnt/xarfuse/ /torch/nn
197
<ipython-input-...>(9): <module>
198
/mnt/xarfuse/.../IPython/
199
200
aten::copy_ 6.34% 338.759ms 0 b <ipython-input-...>(12): forward
201
/mnt/xarfuse/.../torch/nn
202
<ipython-input-...>(9): <module>
203
/mnt/xarfuse/.../IPython/
204
/mnt/xarfuse/.../IPython/
205
206
aten::as_strided 0.01% 281.808us 0 b <ipython-input-...>(11): forward
207
/mnt/xarfuse/.../torch/nn
208
<ipython-input-...>(9): <module>
209
/mnt/xarfuse/.../IPython/
210
/mnt/xarfuse/.../IPython/
211
212
aten::addmm 0.01% 275.721us 0 b /mnt/xarfuse/.../torch/nn
213
/mnt/xarfuse/.../torch/nn
214
/mnt/xarfuse/.../torch/nn
215
<ipython-input-...>(8): forward
216
/mnt/xarfuse/.../torch/nn
217
218
aten::_local 0.01% 268.650us 0 b <ipython-input-...>(11): forward
219
_scalar_dense /mnt/xarfuse/.../torch/nn
220
<ipython-input-...>(9): <module>
221
/mnt/xarfuse/.../IPython/
222
/mnt/xarfuse/.../IPython/
223
224
----------------- ------------ ------------ ------------ --------------------------------
225
Self CPU time total: 5.347s
226
227
"""
228
229
######################################################################
230
#
231
# The CPU memory footprint for this operation has halved.
232
#
233
# Improve time performance
234
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
235
# While the time consumed has also reduced a bit, it’s still too high.
236
# Turns out copying a matrix from CUDA to CPU is pretty expensive!
237
# The ``aten::copy_`` operator in ``forward (12)`` copies ``mask`` to CPU
238
# so that it can use the NumPy ``argwhere`` function. ``aten::copy_`` at ``forward(13)``
239
# copies the array back to CUDA as a tensor. We could eliminate both of these if we use a
240
# ``torch`` function ``nonzero()`` here instead.
241
#
242
243
class MyModule(nn.Module):
244
def __init__(self, in_features: int, out_features: int, bias: bool = True):
245
super(MyModule, self).__init__()
246
self.linear = nn.Linear(in_features, out_features, bias)
247
248
def forward(self, input, mask):
249
with profiler.record_function("LINEAR PASS"):
250
out = self.linear(input)
251
252
with profiler.record_function("MASK INDICES"):
253
threshold = out.sum(axis=1).mean()
254
hi_idx = (mask > threshold).nonzero(as_tuple=True)
255
256
return out, hi_idx
257
258
259
model = MyModule(500, 10).cuda()
260
input = torch.rand(128, 500).cuda()
261
mask = torch.rand((500, 500, 500), dtype=torch.float).cuda()
262
263
# warm-up
264
model(input, mask)
265
266
with profiler.profile(with_stack=True, profile_memory=True) as prof:
267
out, idx = model(input, mask)
268
269
print(prof.key_averages(group_by_stack_n=5).table(sort_by='self_cpu_time_total', row_limit=5))
270
271
"""
272
(Some columns are omitted)
273
274
-------------- ------------ ------------ ------------ ---------------------------------
275
Name Self CPU % Self CPU Self CPU Mem Source Location
276
-------------- ------------ ------------ ------------ ---------------------------------
277
aten::gt 57.17% 129.089ms 0 b <ipython-input-...>(12): forward
278
/mnt/xarfuse/.../torch/nn
279
<ipython-input-...>(25): <module>
280
/mnt/xarfuse/.../IPython/
281
/mnt/xarfuse/.../IPython/
282
283
aten::nonzero 37.38% 84.402ms 0 b <ipython-input-...>(12): forward
284
/mnt/xarfuse/.../torch/nn
285
<ipython-input-...>(25): <module>
286
/mnt/xarfuse/.../IPython/
287
/mnt/xarfuse/.../IPython/
288
289
INDEX SCORE 3.32% 7.491ms -119.21 Mb /mnt/xarfuse/.../torch/au
290
<ipython-input-...>(10): forward
291
/mnt/xarfuse/.../torch/nn
292
<ipython-input-...>(25): <module>
293
/mnt/xarfuse/.../IPython/
294
295
aten::as_strided 0.20% 441.587us 0 b <ipython-input-...>(12): forward
296
/mnt/xarfuse/.../torch/nn
297
<ipython-input-...>(25): <module>
298
/mnt/xarfuse/.../IPython/
299
/mnt/xarfuse/.../IPython/
300
301
aten::nonzero
302
_numpy 0.18% 395.602us 0 b <ipython-input-...>(12): forward
303
/mnt/xarfuse/.../torch/nn
304
<ipython-input-...>(25): <module>
305
/mnt/xarfuse/.../IPython/
306
/mnt/xarfuse/.../IPython/
307
-------------- ------------ ------------ ------------ ---------------------------------
308
Self CPU time total: 225.801ms
309
310
"""
311
312
313
######################################################################
314
# Further Reading
315
# ~~~~~~~~~~~~~~~~~
316
# We have seen how Profiler can be used to investigate time and memory bottlenecks in PyTorch models.
317
# Read more about Profiler here:
318
#
319
# - `Profiler Usage Recipe <https://pytorch.org/tutorials/recipes/recipes/profiler.html>`__
320
# - `Profiling RPC-Based Workloads <https://pytorch.org/tutorials/recipes/distributed_rpc_profiling.html>`__
321
# - `Profiler API Docs <https://pytorch.org/docs/stable/autograd.html?highlight=profiler#profiler>`__
322
323