CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!

GitHub Repository: pytorch/tutorials
Path: blob/main/intermediate_source/inductor_debug_cpu.py
Views: 494
1
# -*- coding: utf-8 -*-
2
3
"""
4
Inductor CPU backend debugging and profiling
5
============================================
6
7
**Authors**: `Xuan Liao <https://github.com/Valentine233>`_, `Haozhe Zhu <https://github.com/zhuhaozhe>`_, `Jiong Gong <https://github.com/jgong5>`_, `Weihan Wang <https://github.com/EikanWang>`_
8
"""
9
10
#########################################################################
11
# Overview
12
# --------
13
#
14
# PyTorch 2.0 introduced the compilation API called ``torch.compile``.
15
# This new feature offers a significant speedup over eager mode execution through graph-level optimization powered by the default Inductor backend.
16
#
17
# This tutorial is intended to provide an in-depth introduction on the debugging
18
# and performance profiling on Inductor CPU backend by delving into the intricacies of ``torch.compile``.
19
#
20
# Meanwhile, you may also find related tutorials about ``torch.compile``
21
# around `basic usage <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`_,
22
# comprehensive `troubleshooting <https://pytorch.org/docs/stable/dynamo/troubleshooting.html>`_
23
# and GPU-specific knowledge like `GPU performance profiling <https://github.com/pytorch/pytorch/blob/main/docs/source/compile/profiling_torch_compile.rst>`_.
24
#
25
# We will start debugging with a motivating example that triggers compilation issues and accuracy problems
26
# by demonstrating the process of debugging to pinpoint the problems.
27
#
28
# By enabling logging and exploring the underlying generated code,
29
# you can learn how to narrow down the failure step by step and finally figure out the route cause.
30
#
31
# Following that, we will proceed to discuss how to profile the compiled code and,
32
# through a performance comparison with eager mode,
33
# elaborate on the reasons why ``torch.compile`` can provide an additional performance boost compared to its eager counterpart.
34
35
36
######################################################################
37
# Debugging
38
# ---------
39
#
40
# Here is a simple example to run the ``torch.compile`` using Inductor and compare its result with eager mode:
41
42
import torch
43
44
def foo1(x1, x2):
45
a = torch.neg(x1)
46
b = torch.maximum(x2, a)
47
y = torch.cat([b], dim=0)
48
return y
49
50
x1 = torch.randint(256, (1, 8), dtype=torch.uint8)
51
x2 = torch.randint(256, (8390, 8), dtype=torch.uint8)
52
53
compiled_foo1 = torch.compile(foo1)
54
result = compiled_foo1(x1, x2)
55
56
######################################################################
57
# The correct implementation of ``neg`` in the ``cpp`` codegen is as follows:
58
59
def neg1(x):
60
return f"decltype({x})(-{x})"
61
62
######################################################################
63
# In order to demonstrate the debugging, we will modify the function to a wrong one later.
64
#
65
#
66
# Get more logging information
67
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
68
#
69
# No debugging information would be provided if you run this simple example by default. In order to get more useful debugging and logging information, we usually add a ``TORCH_COMPILE_DEBUG`` environment variable like below:
70
#
71
# .. code-block:: shell
72
#
73
# TORCH_COMPILE_DEBUG=1 python xx.py
74
#
75
# This would print more debug information in the output logs and also dump the intermediate IRs generated during the codegen process. You can find the dumped file paths in the log like below:
76
#
77
# .. code-block:: shell
78
#
79
# torch._inductor.debug: [WARNING] model___20 debug trace: /tmp/torchinductor_root/rx/crxfi2ybd7yp5sbj2pnhw33wfhtdw7wumvrobyp5sjvdui5ktjc2.debug
80
#
81
# In this directory, the following files are saved for debugging purposes:
82
#
83
# +-----------------------------+----------------------------------------------------------------+
84
# | File | Description |
85
# +=============================+================================================================+
86
# | ``fx_graph_runnable.py`` | Executable FX graph, after decomposition, before pattern match |
87
# +-----------------------------+----------------------------------------------------------------+
88
# | ``fx_graph_transformed.py`` | Transformed FX graph, after pattern match |
89
# +-----------------------------+----------------------------------------------------------------+
90
# | ``ir_pre_fusion.txt`` | Inductor IR before fusion |
91
# +-----------------------------+----------------------------------------------------------------+
92
# | ``ir_post_fusion.txt`` | Inductor IR after fusion |
93
# +-----------------------------+----------------------------------------------------------------+
94
# | ``output_code.py`` | Generated Python code for graph, with C++/Triton kernels |
95
# +-----------------------------+----------------------------------------------------------------+
96
#
97
# Note that ``fx_graph_runnable.py`` and ``output_code.py`` are both runnable and editable in order to make debugging easier.
98
# Here are the main parts of code extracted from the files and we correlate the C++ generated line with the FX code line.
99
#
100
# ``fx_graph_runnable``:
101
#
102
103
def forward1(self, arg0_1, arg1_1):
104
neg = torch.ops.aten.neg.default(arg0_1); arg0_1 = None
105
maximum = torch.ops.aten.maximum.default(arg1_1, neg); arg1_1 = neg = None
106
clone = torch.ops.aten.clone.default(maximum); maximum = None
107
return (clone,)
108
109
######################################################################
110
# C++ kernel in ``output_code``:
111
#
112
113
import torch
114
from torch._inductor.async_compile import AsyncCompile
115
async_compile = AsyncCompile()
116
117
cpp_fused_cat_maximum_neg_0 = async_compile.cpp('''
118
#include "/tmp/torchinductor_root/gv/cgv6n5aotqjo5w4vknjibhengeycuattfto532hkxpozszcgxr3x.h"
119
extern "C" void kernel(const unsigned char* in_ptr0,
120
const unsigned char* in_ptr1,
121
unsigned char* out_ptr0)
122
{
123
{
124
#pragma GCC ivdep
125
for(long i0=static_cast<long>(0L); i0<static_cast<long>(8390L); i0+=static_cast<long>(1L))
126
{
127
#pragma GCC ivdep
128
for(long i1=static_cast<long>(0L); i1<static_cast<long>(8L); i1+=static_cast<long>(1L))
129
{
130
auto tmp0 = in_ptr0[static_cast<long>(i1 + (8L*i0))];
131
auto tmp1 = in_ptr1[static_cast<long>(i1)];
132
// Corresponding FX code line: neg = torch.ops.aten.neg.default(arg0_1); arg0_1 = None
133
auto tmp2 = decltype(tmp1)(-tmp1);
134
// Corresponding FX code line: maximum = torch.ops.aten.maximum.default(arg1_1, neg); arg1_1 = neg = None
135
auto tmp3 = max_propagate_nan(tmp0, tmp2);
136
// Corresponding FX code line: clone = torch.ops.aten.clone.default(maximum); maximum = None
137
out_ptr0[static_cast<long>(i1 + (8L*i0))] = tmp3;
138
}
139
}
140
}
141
}''')
142
143
144
######################################################################
145
# Determine component of error
146
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
147
#
148
# When encountering errors or accuracy problems, a straightforward solution to find the bug is to narrow down the problem. The first thing to do is to determine the component where the error occurs. Luckily, it can be simply achieved by changing the backend of ``torch.compile``.
149
#
150
# +--------------------------------------------+-----------------------------------------+
151
# | Code | Description |
152
# +============================================+=========================================+
153
# | ``torch.compile(fn, backend="eager")`` | Enable Dynamo |
154
# +--------------------------------------------+-----------------------------------------+
155
# | ``torch.compile(fn, backend="aot_eager")`` | Enable Dynamo + AOT Autograd |
156
# +--------------------------------------------+-----------------------------------------+
157
# | ``torch.compile(fn, backend="inductor")`` | Enable Dynamo + AOT Autograd + Inductor |
158
# +--------------------------------------------+-----------------------------------------+
159
#
160
# If the model can successfully run when the backend is set to ``eager`` or ``aot_eager`` while it fails with ``inductor``, we can narrow down the failure to Inductor.
161
#
162
#
163
# Compilation error
164
# ^^^^^^^^^^^^^^^^^
165
#
166
# As we know, the evolved chain of graph-level optimization is like:
167
#
168
# .. code-block:: sh
169
#
170
# torch.neg (Python) -> torch.ops.aten.neg.default (within FX graph) -> ops.neg (within IR node) -> tmp2 = -tmp1 (within C++ kernel)
171
#
172
# If you encounter a compilation error, there is something wrong when compiling C++ kernels in the output code.
173
# This type of error indicates that bugs are introduced when lowering IR nodes to output code.
174
# The root cause of compilation error is usually shown in the traceback log.
175
#
176
# For example, the ``neg`` function is modified like this:
177
178
def neg2(x):
179
return f"-{x}"
180
181
######################################################################
182
# The logging gives the following compile error with a rather clear reason.
183
#
184
# .. code-block::
185
#
186
# torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
187
# CppCompileError: C++ compile error
188
# /tmp/torchinductor_root/xg/cxga5tk3b4lkwoxyigrtocjp5s7vc5cg2ikuscf6bk6pjqip2bhx.cpp: In function ‘void kernel(const unsigned char*, const unsigned char*, unsigned char*)’:
189
# /tmp/torchinductor_root/xg/cxga5tk3b4lkwoxyigrtocjp5s7vc5cg2ikuscf6bk6pjqip2bhx.cpp:17:57: error: no matching function for call to ‘max_propagate_nan(unsigned char&, int&)’
190
# 17 | auto tmp3 = max_propagate_nan(tmp0, tmp2);
191
# | ^
192
# In file included from /tmp/torchinductor_root/xg/cxga5tk3b4lkwoxyigrtocjp5s7vc5cg2ikuscf6bk6pjqip2bhx.cpp:2:
193
# /tmp/torchinductor_root/gv/cgv6n5aotqjo5w4vknjibhengeycuattfto532hkxpozszcgxr3x.h:27:17: note: candidate: ‘template<class scalar_t> scalar_t max_propagate_nan(scalar_t, scalar_t)’
194
# 27 | inline scalar_t max_propagate_nan(scalar_t a, scalar_t b) {
195
# | ^~~~~~~~~~~~~~~~~
196
# /tmp/torchinductor_root/gv/cgv6n5aotqjo5w4vknjibhengeycuattfto532hkxpozszcgxr3x.h:27:17: note: template argument deduction/substitution failed:
197
# /tmp/torchinductor_root/xg/cxga5tk3b4lkwoxyigrtocjp5s7vc5cg2ikuscf6bk6pjqip2bhx.cpp:17:57: note: deduced conflicting types for parameter ‘scalar_t’ (‘unsigned char’ and ‘int’)
198
# 17 | auto tmp3 = max_propagate_nan(tmp0, tmp2);
199
# | ^
200
#
201
#
202
# Let us also see the corresponding C++ kernel in output code and IR node.
203
#
204
# C++ kernel:
205
#
206
# .. code:: c
207
#
208
# include "/tmp/torchinductor_root/gv/cgv6n5aotqjo5w4vknjibhengeycuattfto532hkxpozszcgxr3x.h"
209
# extern "C" void kernel(const unsigned char* in_ptr0,
210
# const unsigned char* in_ptr1,
211
# unsigned char* out_ptr0)
212
# {
213
# {
214
# #pragma GCC ivdep
215
# for(long i0=static_cast<long>(0L); i0<static_cast<long>(8390L); i0+=static_cast<long>(1L))
216
# {
217
# #pragma GCC ivdep
218
# for(long i1=static_cast<long>(0L); i1<static_cast<long>(8L); i1+=static_cast<long>(1L))
219
# {
220
# auto tmp0 = in_ptr0[static_cast<long>(i1 + (8L*i0))];
221
# auto tmp1 = in_ptr1[static_cast<long>(i1)];
222
# auto tmp2 = -tmp1;
223
# auto tmp3 = max_propagate_nan(tmp0, tmp2);
224
# out_ptr0[static_cast<long>(i1 + (8L*i0))] = tmp3;
225
# }
226
# }
227
# }
228
# }
229
#
230
231
######################################################################
232
# IR node:
233
#
234
# .. code-block:: sh
235
#
236
# buf0: SchedulerNode(ComputedBuffer)
237
# buf0.writes = [MemoryDep('buf0', c0, {c0: 67120})]
238
# buf0.unmet_dependencies = []
239
# buf0.met_dependencies =
240
# [ MemoryDep('arg0_1', c1, {c0: 8390, c1: 8}),
241
# MemoryDep('arg1_1', c0, {c0: 67120})]
242
# buf0.users = [NodeUser(node=OUTPUT, can_inplace=False)]
243
# buf0.group.device = cpu
244
# buf0.group.iteration = ((8390, 8), ())
245
# buf0.sizes = ([8390, 8], [])
246
# class buf0_loop_body:
247
# var_ranges = {z0: 8390, z1: 8}
248
# index0 = 8*z0 + z1
249
# index1 = z1
250
# def body(self, ops):
251
# get_index = self.get_index('index0')
252
# load = ops.load('arg1_1', get_index)
253
# get_index_1 = self.get_index('index1')
254
# load_1 = ops.load('arg0_1', get_index_1)
255
# neg = ops.neg(load_1)
256
# maximum = ops.maximum(load, neg)
257
# get_index_2 = self.get_index('index0')
258
# store = ops.store('buf0', get_index_2, maximum, None)
259
# return store
260
#
261
262
######################################################################
263
# According to the traceback logging, the compilation error is caused by the data type inconsistency of ``max_propagate_nan``'s inputs.
264
# By checking the C++ kernel, we know that ``tmp2`` is no longer ``long`` after doing ``-`` as ``tmp0`` is ``long``.
265
# We can easily match ``-`` and ``max_propagate_nan`` in C++ kernel with ``ops.neg`` and ``ops.maximum`` in IR node respectively.
266
#
267
# Now we successfully find that the root cause is the implementation of ``ops.neg`` in ``cpp`` codegen, which silently changes the data type when doing ``neg``.
268
#
269
#
270
# Accuracy debugging
271
# ^^^^^^^^^^^^^^^^^^^
272
#
273
# Otherwise, if the model runs with other errors or accuracy problem, you can use the PyTorch debugging tool called `Minifier <https://pytorch.org/functorch/stable/notebooks/minifier.html>`_.
274
#
275
# The core idea of ``Minifier`` is to keep removing the nodes and inputs of graph until finding the minimal graph with problem.
276
# It helps to automatically generate a minified problematic graph through 4 strategies: truncating suffix, delta debugging, eliminating dead code and removing unused inputs.
277
#
278
#
279
# We will now show the debugging process for the accuracy problem with the help of ``Minifer``.
280
# The accuracy problem refers to the case where the outputs of backends eager and inductor are different.
281
#
282
# For instance, we modify the example like this:
283
284
from torch._dynamo.utils import same
285
286
def foo2(x1, x2):
287
a = torch.neg(x1)
288
b = torch.maximum(x2, a)
289
y = torch.cat([b], dim=0)
290
return y
291
292
x1 = torch.randn((1, 8), dtype=torch.float32)
293
x2 = torch.randn((8390, 8), dtype=torch.float32)
294
295
expected_result = foo2(x1, x2)
296
297
compiled_foo2 = torch.compile(foo2)
298
actual_result = compiled_foo2(x1, x2)
299
300
assert same(expected_result, actual_result) == True
301
302
######################################################################
303
# And also modify the ``neg`` function:
304
305
def neg3(x):
306
return f"decltype({x})(2 * {x})"
307
308
######################################################################
309
# An accuracy problem would be raised as follows:
310
#
311
# .. code-block:: sh
312
#
313
# torch._dynamo.utils: [ERROR] Accuracy failed: allclose not within tol=0.0001
314
# Traceback (most recent call last):
315
# File "test_script.py", line 18, in <module>
316
# assert same(expected_result, actual_result) == True
317
# AssertionError
318
#
319
# To debug an accuracy problem with Minifier, two environment variables are needed:
320
#
321
# .. code-block:: sh
322
#
323
# TORCHDYNAMO_REPRO_AFTER="aot" TORCHDYNAMO_REPRO_LEVEL=4 python xx.py
324
#
325
# Which gives us logging information that demonstrates the steps of minifying:
326
#
327
# .. code-block:: sh
328
#
329
# Started off with 6 nodes
330
#
331
# Trying granularity 2
332
# Strategy: Truncate suffix (G: 2) (6 nodes, 2 inputs)
333
# SUCCESS: Went from 6 to 4 nodes
334
#
335
# Trying granularity 4
336
# Strategy: Remove unused inputs (G: 4) (4 nodes, 2 inputs)
337
# SUCCESS: Went from 4 to 3 nodes
338
#
339
# After running, we get the final minified graph with the target node ``neg``:
340
341
def forward2(self, arg0_1):
342
neg = torch.ops.aten.neg.default(arg0_1); arg0_1 = None
343
return (neg,)
344
345
######################################################################
346
# For more usage details about Minifier, please refer to `Troubleshooting <https://pytorch.org/docs/stable/dynamo/troubleshooting.html>`_.
347
348
349
######################################################################
350
# Performance profiling
351
# ---------------------
352
#
353
# Within this section, we will demonstrate the process of conducting performance analysis for a model that has been compiled using the Inductor CPU backend.
354
# In the example below, we benchmark a Hugging Face Transformer model ``MobileBertForQuestionAnswering`` with both the eager mode and the Inductor graph mode.
355
# The execution time and the speedup ratio of Inductor are printed after the benchmark.
356
# We use Intel(R) Xeon(R) Platinum 8358 CPU @ 2.60GHz and run benchmark on the first socket to demonstrate the optimization within this section.
357
# We set following environment variable as a best practice to benchmark on Intel(R) CPU.
358
359
#########################################################
360
# .. code-block:: shell
361
#
362
# export KMP_BLOCKTIME=1
363
# export KMP_SETTINGS=1
364
# export KMP_AFFINITY=granularity=fine,compact,1,0
365
# export LD_PRELOAD=${CONDA_PREFIX:-"$(dirname $(which conda))/../"}/lib/libiomp5.so:${CONDA_PREFIX:-"$(dirname $(which conda))/../"}/lib/libjemalloc.so
366
# export MALLOC_CONF="oversize_threshold:1,background_thread:true,metadata_thp:auto,dirty_decay_ms:-1,muzzy_decay_ms:-1"
367
# numactl -C 0-31 -m 0 python bench.py
368
#
369
370
# bench.py
371
from transformers import MobileBertForQuestionAnswering
372
# Initialize an eager model
373
model = MobileBertForQuestionAnswering.from_pretrained("csarron/mobilebert-uncased-squad-v2")
374
seq_length = 128
375
bs = 128
376
vocab_size = model.config.vocab_size
377
input = torch.randint(0, vocab_size, (bs, seq_length), dtype=torch.int64)
378
input_dict = {"input_ids": input}
379
380
# Initialize the inductor model
381
compiled_model = torch.compile(model)
382
with torch.no_grad():
383
compiled_model(**input_dict)
384
385
NUM_ITERS=50
386
import timeit
387
with torch.no_grad():
388
# warmup
389
for _ in range(10):
390
model(**input_dict)
391
eager_t = timeit.timeit("model(**input_dict)", number=NUM_ITERS, globals=globals())
392
393
with torch.no_grad():
394
# warmup
395
for _ in range(10):
396
compiled_model(**input_dict)
397
inductor_t = timeit.timeit("compiled_model(**input_dict)", number=NUM_ITERS, globals=globals())
398
# print(f"eager use: {eager_t * 1000 / NUM_ITERS} ms/iter")
399
# print(f"inductor use: {inductor_t * 1000 / NUM_ITERS} ms/iter")
400
# print(f"speed up ratio: {eager_t / inductor_t}")
401
402
403
######################################################################
404
# Output:
405
#
406
# .. code-block:: shell
407
#
408
# eager use: 802.1023553796113 ms/iter
409
# inductor use: 339.95180135127157 ms/iter
410
# speed up ratio: 2.359459053287382
411
#
412
# In our own testing, we find the Inductor CPU backend speed up the model by around 2.355x.
413
#
414
#
415
# Next, let's dive deep into the performance at the operation level to understand where the speed-up comes from.
416
# `Pytorch Profiler <https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html>`_ is a good tool to help us.
417
# Inductor CPU backend has the support to report the time of the fusion kernels to the profiler with the ``enable_kernel_profile`` configuration option:
418
419
from torch._inductor import config
420
config.cpp.enable_kernel_profile = True
421
422
######################################################################
423
# Following the steps in `Pytorch Profiler <https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html>`_
424
# We are able to get the profiling table and trace files.
425
426
# bench.py
427
from torch.profiler import profile, schedule, ProfilerActivity
428
RESULT_DIR = "./prof_trace"
429
my_schedule = schedule(
430
skip_first=10,
431
wait=5,
432
warmup=5,
433
active=1,
434
repeat=5)
435
436
def trace_handler(p):
437
output = p.key_averages().table(sort_by="self_cpu_time_total", row_limit=20)
438
# print(output)
439
p.export_chrome_trace(f"{RESULT_DIR}/{p.step_num}.json")
440
441
for _ in range(10):
442
model(**input_dict) # compiled_model(**input_dict) to get inductor model profiling
443
444
total = 0
445
with profile(
446
activities=[ProfilerActivity.CPU],
447
schedule=my_schedule,
448
on_trace_ready=trace_handler
449
) as p:
450
for _ in range(50):
451
model(**input_dict) # compiled_model(**input_dict) to get inductor model profiling
452
p.step()
453
454
######################################################################
455
# We get the following performance profiling table for the eager-mode model (omitting some columns):
456
#
457
# .. code-block:: shell
458
#
459
# ------------------------- ------------ ------------ ------------
460
# Name CPU total % CPU total # of Calls
461
# ------------------------- ------------ ------------ ------------
462
# aten::addmm 45.73% 370.814ms 362
463
# aten::add 19.89% 161.276ms 363
464
# aten::copy_ 14.97% 121.416ms 488
465
# aten::mul 9.02% 73.154ms 194
466
# aten::clamp_min 8.81% 71.444ms 96
467
# aten::bmm 5.46% 44.258ms 48
468
# ProfilerStep* 100.00% 810.920ms 1
469
# aten::div 2.89% 23.447ms 24
470
# aten::_softmax 1.00% 8.087ms 24
471
# aten::linear 46.48% 376.888ms 362
472
# aten::clone 2.77% 22.430ms 98
473
# aten::t 0.31% 2.502ms 362
474
# aten::view 0.14% 1.161ms 850
475
# aten::transpose 0.17% 1.377ms 386
476
# aten::index_select 0.12% 952.000us 3
477
# aten::expand 0.12% 986.000us 458
478
# aten::matmul 8.31% 67.420ms 48
479
# aten::cat 0.09% 703.000us 1
480
# aten::as_strided 0.08% 656.000us 963
481
# aten::relu 8.86% 71.864ms 96
482
# ------------------------- ------------ ------------ ------------
483
# Self CPU time total: 810.920ms
484
#
485
486
######################################################################
487
#
488
# Similarly, we also get the table for the compiled model with Inductor (omitting some columns):
489
#
490
# .. code-block:: shell
491
#
492
# ----------------------------------------------- ------------ ------------ ------------
493
# Name CPU total % CPU total # of Calls
494
# ----------------------------------------------- ------------ ------------ ------------
495
# mkl::_mkl_linear 68.79% 231.573ms 362
496
# aten::bmm 8.02% 26.992ms 48
497
# ProfilerStep* 100.00% 336.642ms 1
498
# graph_0_cpp_fused_constant_pad_nd_embedding_0 0.27% 915.000us 1
499
# aten::empty 0.27% 911.000us 362
500
# graph_0_cpp_fused__mkl_linear_add_mul_relu_151 0.27% 901.000us 1
501
# graph_0_cpp_fused__mkl_linear_add_mul_relu_226 0.27% 899.000us 1
502
# graph_0_cpp_fused__mkl_linear_add_mul_relu_361 0.27% 898.000us 1
503
# graph_0_cpp_fused__mkl_linear_add_mul_relu_121 0.27% 895.000us 1
504
# graph_0_cpp_fused__mkl_linear_add_mul_relu_31 0.27% 893.000us 1
505
# graph_0_cpp_fused__mkl_linear_add_mul_relu_76 0.26% 892.000us 1
506
# graph_0_cpp_fused__mkl_linear_add_mul_relu_256 0.26% 892.000us 1
507
# graph_0_cpp_fused__mkl_linear_add_mul_relu_346 0.26% 892.000us 1
508
# graph_0_cpp_fused__mkl_linear_add_mul_relu_241 0.26% 891.000us 1
509
# graph_0_cpp_fused__mkl_linear_add_mul_relu_316 0.26% 891.000us 1
510
# graph_0_cpp_fused__mkl_linear_add_mul_relu_91 0.26% 890.000us 1
511
# graph_0_cpp_fused__mkl_linear_add_mul_relu_106 0.26% 890.000us 1
512
# graph_0_cpp_fused__mkl_linear_add_mul_relu_211 0.26% 890.000us 1
513
# graph_0_cpp_fused__mkl_linear_add_mul_relu_61 0.26% 889.000us 1
514
# graph_0_cpp_fused__mkl_linear_add_mul_relu_286 0.26% 889.000us 1
515
# ----------------------------------------------- ------------ ------------ ------------
516
# Self CPU time total: 336.642ms
517
#
518
# From the profiling table of the eager model, we can see the most time consumption ops are [``aten::addmm``, ``aten::add``, ``aten::copy_``, ``aten::mul``, ``aten::clamp_min``, ``aten::bmm``].
519
# Comparing with the inductor model profiling table, we notice an ``mkl::_mkl_linear`` entry and multiple fused kernels in the form ``graph_0_cpp_fused_*``. They are the major
520
# optimizations that the inductor model is doing. Let us discuss them separately.
521
#
522
# (1) Regarding ``mkl::_mkl_linear``: You may notice the number of calls to this kernel is 362, which is exactly the same as ``aten::linear`` in the eager model profiling table.
523
# The CPU total of ``aten::linear`` is 376.888ms, while it is 231.573ms for ``mkl::_mkl_linear``. This suggests a ~1.63x for the "linear" part.
524
# The speedup mainly comes from `packing the weight tensor to block memory format <https://www.intel.com/content/www/us/en/docs/onemkl/developer-reference-c/2023-1/cblas-gemm-pack-002.html>`_
525
# and invoking `cblas_sgemm_compute <https://www.intel.com/content/www/us/en/docs/onemkl/developer-reference-c/2023-1/cblas-gemm-compute-002.html>`_ within the Inductor CPU backend
526
# to have a better cache behavior during GEMM computation.
527
#
528
# (2) Regarding other memory-intensive ops: The end-to-end latency for the eager/inductor model is 802/339ms in our testing. So we can roughly infer that the speed up for the other memory-intensive ops is around 3.94x.
529
# Let's read the generated code to understand how the inductor achieves this impressive optimization. You can find the generated code by
530
# searching ``cpp_fused__mkl_linear_add_mul_relu_151`` in ``output_code.py``
531
#
532
533
534
cpp_fused__mkl_linear_add_mul_relu_151 = async_compile.cpp('''
535
#include <ATen/record_function.h>
536
#include "/tmp/torchinductor_root/lr/clrlgu27q4ggd472umdzwsu6qcpqxcuusjxqvx2hwitjbujiiz7z.h"
537
extern "C" void kernel(float* in_out_ptr0,
538
const float* in_ptr0,
539
const float* in_ptr1,
540
const float* in_ptr2,
541
const float* in_ptr3)
542
{
543
RECORD_FUNCTION("graph_0_cpp_fused__mkl_linear_add_mul_relu_151", c10::ArrayRef<c10::IValue>({}));
544
#pragma omp parallel num_threads(32)
545
{
546
{
547
#pragma omp for
548
for(long i0=static_cast<long>(0L); i0<static_cast<long>(16384L); i0+=static_cast<long>(1L))
549
{
550
for(long i1=static_cast<long>(0L); i1<static_cast<long>(512L); i1+=static_cast<long>(8L))
551
{
552
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<long>(i1 + (512L*i0)));
553
auto tmp1 = at::vec::Vectorized<float>::loadu(in_ptr1 + static_cast<long>(i1));
554
auto tmp3 = at::vec::Vectorized<float>::loadu(in_out_ptr0 + static_cast<long>(i1 + (512L*i0)));
555
auto tmp5 = at::vec::Vectorized<float>::loadu(in_ptr2 + static_cast<long>(i1));
556
auto tmp7 = at::vec::Vectorized<float>::loadu(in_ptr3 + static_cast<long>(i1));
557
auto tmp2 = tmp0 + tmp1;
558
auto tmp4 = tmp2 + tmp3;
559
auto tmp6 = tmp4 * tmp5;
560
auto tmp8 = tmp6 + tmp7;
561
tmp8.store(in_out_ptr0 + static_cast<long>(i1 + (512L*i0)));
562
}
563
}
564
}
565
}
566
}''')
567
568
######################################################################
569
# From the generated code above, we can see this kernel has done a typical `Loop Fusion <https://en.wikipedia.org/wiki/Loop_fission_and_fusion>`_ on ``[add, add, mul, add]``.
570
# This is a memory-bound bottle neck preventing good performance. To get a more intuitive feeling about this optimization,
571
# we can infer the sizes and stride of the inputs and further benchmark this ``[add, add, mul, add]`` pattern.
572
573
# bench.py
574
def func(arg_0, arg_1, arg_2, arg_3, arg_4):
575
add_0 = arg_0 + arg_1
576
add_1 = add_0 + arg_2
577
mul_1 = add_1 * arg_3
578
add_2 = mul_1 + arg_4
579
arg_2 = add_2
580
return arg_2
581
582
arg_0 = torch.rand(16384, 512)
583
arg_1 = torch.rand(1, 512)
584
arg_2 = torch.zeros(16384, 512)
585
arg_3 = torch.rand(1, 512)
586
arg_4 = torch.rand(1, 512)
587
588
input = (arg_0, arg_1, arg_2, arg_3, arg_4)
589
inductor_func = torch.compile(func)
590
with torch.no_grad():
591
inductor_func(*input)
592
593
import timeit
594
NUM_ITERS=100
595
with torch.no_grad():
596
# warmup
597
for _ in range(10):
598
func(*input)
599
eager_t = timeit.timeit("func(*input)", number=NUM_ITERS, globals=globals())
600
601
with torch.no_grad():
602
# warmup
603
for _ in range(10):
604
inductor_func(*input)
605
inductor_t = timeit.timeit("inductor_func(*input)", number=NUM_ITERS, globals=globals())
606
# print(f"eager use: {eager_t * 1000 / NUM_ITERS} ms/iter")
607
# print(f"inductor use: {inductor_t * 1000 / NUM_ITERS} ms/iter")
608
# print(f"speed up ratio: {eager_t / inductor_t}")
609
610
######################################################################
611
# Output:
612
#
613
# .. code-block:: shell
614
#
615
# eager use: 5.780875144992024 ms/iter
616
# inductor use: 0.9588955780491233 ms/iter
617
# speed up ratio: 6.0286805751604735
618
#
619
#
620
# This is just an example. The profiling table shows all element-wise op are fused within the inductor automatically in this model. You can read more kernels in
621
# `output_code.py`
622
623
624
#########################################################################
625
# Conclusion
626
# ----------
627
#
628
# The document gives an in-depth tutorial for the Inductor CPU backend.
629
#
630
# With motivating examples, we walk through the process of debugging and profiling.
631
# The main idea is to narrow down the problem.
632
#
633
# We demonstrate step by step the way to delve deeper the issue and find the root cause of failures, with the help of debugging logging and the tool Minifier.
634
# Firstly determine which component the failure occurs in and then try to generate the smallest snippet of code that can reproduce the failure.
635
#
636
# When the performance with Inductor is better than that of eager mode, we provide a solid analytical method for performance profiling.
637
# We show how to find the time-consuming hotspot with PyTorch Profiler and figure out the operator-level or kernel-level reason to explain the phenomenon.
638
639