Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Path: blob/main/intermediate_source/inductor_debug_cpu.py
Views: 712
# -*- coding: utf-8 -*-12"""3Inductor CPU backend debugging and profiling4============================================56**Authors**: `Xuan Liao <https://github.com/Valentine233>`_, `Haozhe Zhu <https://github.com/zhuhaozhe>`_, `Jiong Gong <https://github.com/jgong5>`_, `Weihan Wang <https://github.com/EikanWang>`_7"""89#########################################################################10# Overview11# --------12#13# PyTorch 2.0 introduced the compilation API called ``torch.compile``.14# This new feature offers a significant speedup over eager mode execution through graph-level optimization powered by the default Inductor backend.15#16# This tutorial is intended to provide an in-depth introduction on the debugging17# and performance profiling on Inductor CPU backend by delving into the intricacies of ``torch.compile``.18#19# Meanwhile, you may also find related tutorials about ``torch.compile``20# around `basic usage <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`_,21# comprehensive `troubleshooting <https://pytorch.org/docs/stable/dynamo/troubleshooting.html>`_22# and GPU-specific knowledge like `GPU performance profiling <https://github.com/pytorch/pytorch/blob/main/docs/source/compile/profiling_torch_compile.rst>`_.23#24# We will start debugging with a motivating example that triggers compilation issues and accuracy problems25# by demonstrating the process of debugging to pinpoint the problems.26#27# By enabling logging and exploring the underlying generated code,28# you can learn how to narrow down the failure step by step and finally figure out the route cause.29#30# Following that, we will proceed to discuss how to profile the compiled code and,31# through a performance comparison with eager mode,32# elaborate on the reasons why ``torch.compile`` can provide an additional performance boost compared to its eager counterpart.333435######################################################################36# Debugging37# ---------38#39# Here is a simple example to run the ``torch.compile`` using Inductor and compare its result with eager mode:4041import torch4243def foo1(x1, x2):44a = torch.neg(x1)45b = torch.maximum(x2, a)46y = torch.cat([b], dim=0)47return y4849x1 = torch.randint(256, (1, 8), dtype=torch.uint8)50x2 = torch.randint(256, (8390, 8), dtype=torch.uint8)5152compiled_foo1 = torch.compile(foo1)53result = compiled_foo1(x1, x2)5455######################################################################56# The correct implementation of ``neg`` in the ``cpp`` codegen is as follows:5758def neg1(x):59return f"decltype({x})(-{x})"6061######################################################################62# In order to demonstrate the debugging, we will modify the function to a wrong one later.63#64#65# Get more logging information66# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^67#68# No debugging information would be provided if you run this simple example by default. In order to get more useful debugging and logging information, we usually add a ``TORCH_COMPILE_DEBUG`` environment variable like below:69#70# .. code-block:: shell71#72# TORCH_COMPILE_DEBUG=1 python xx.py73#74# This would print more debug information in the output logs and also dump the intermediate IRs generated during the codegen process. You can find the dumped file paths in the log like below:75#76# .. code-block:: shell77#78# torch._inductor.debug: [WARNING] model___20 debug trace: /tmp/torchinductor_root/rx/crxfi2ybd7yp5sbj2pnhw33wfhtdw7wumvrobyp5sjvdui5ktjc2.debug79#80# In this directory, the following files are saved for debugging purposes:81#82# +-----------------------------+----------------------------------------------------------------+83# | File | Description |84# +=============================+================================================================+85# | ``fx_graph_runnable.py`` | Executable FX graph, after decomposition, before pattern match |86# +-----------------------------+----------------------------------------------------------------+87# | ``fx_graph_transformed.py`` | Transformed FX graph, after pattern match |88# +-----------------------------+----------------------------------------------------------------+89# | ``ir_pre_fusion.txt`` | Inductor IR before fusion |90# +-----------------------------+----------------------------------------------------------------+91# | ``ir_post_fusion.txt`` | Inductor IR after fusion |92# +-----------------------------+----------------------------------------------------------------+93# | ``output_code.py`` | Generated Python code for graph, with C++/Triton kernels |94# +-----------------------------+----------------------------------------------------------------+95#96# Note that ``fx_graph_runnable.py`` and ``output_code.py`` are both runnable and editable in order to make debugging easier.97# Here are the main parts of code extracted from the files and we correlate the C++ generated line with the FX code line.98#99# ``fx_graph_runnable``:100#101102def forward1(self, arg0_1, arg1_1):103neg = torch.ops.aten.neg.default(arg0_1); arg0_1 = None104maximum = torch.ops.aten.maximum.default(arg1_1, neg); arg1_1 = neg = None105clone = torch.ops.aten.clone.default(maximum); maximum = None106return (clone,)107108######################################################################109# C++ kernel in ``output_code``:110#111112import torch113from torch._inductor.async_compile import AsyncCompile114async_compile = AsyncCompile()115116cpp_fused_cat_maximum_neg_0 = async_compile.cpp('''117#include "/tmp/torchinductor_root/gv/cgv6n5aotqjo5w4vknjibhengeycuattfto532hkxpozszcgxr3x.h"118extern "C" void kernel(const unsigned char* in_ptr0,119const unsigned char* in_ptr1,120unsigned char* out_ptr0)121{122{123#pragma GCC ivdep124for(long i0=static_cast<long>(0L); i0<static_cast<long>(8390L); i0+=static_cast<long>(1L))125{126#pragma GCC ivdep127for(long i1=static_cast<long>(0L); i1<static_cast<long>(8L); i1+=static_cast<long>(1L))128{129auto tmp0 = in_ptr0[static_cast<long>(i1 + (8L*i0))];130auto tmp1 = in_ptr1[static_cast<long>(i1)];131// Corresponding FX code line: neg = torch.ops.aten.neg.default(arg0_1); arg0_1 = None132auto tmp2 = decltype(tmp1)(-tmp1);133// Corresponding FX code line: maximum = torch.ops.aten.maximum.default(arg1_1, neg); arg1_1 = neg = None134auto tmp3 = max_propagate_nan(tmp0, tmp2);135// Corresponding FX code line: clone = torch.ops.aten.clone.default(maximum); maximum = None136out_ptr0[static_cast<long>(i1 + (8L*i0))] = tmp3;137}138}139}140}''')141142143######################################################################144# Determine component of error145# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^146#147# When encountering errors or accuracy problems, a straightforward solution to find the bug is to narrow down the problem. The first thing to do is to determine the component where the error occurs. Luckily, it can be simply achieved by changing the backend of ``torch.compile``.148#149# +--------------------------------------------+-----------------------------------------+150# | Code | Description |151# +============================================+=========================================+152# | ``torch.compile(fn, backend="eager")`` | Enable Dynamo |153# +--------------------------------------------+-----------------------------------------+154# | ``torch.compile(fn, backend="aot_eager")`` | Enable Dynamo + AOT Autograd |155# +--------------------------------------------+-----------------------------------------+156# | ``torch.compile(fn, backend="inductor")`` | Enable Dynamo + AOT Autograd + Inductor |157# +--------------------------------------------+-----------------------------------------+158#159# If the model can successfully run when the backend is set to ``eager`` or ``aot_eager`` while it fails with ``inductor``, we can narrow down the failure to Inductor.160#161#162# Compilation error163# ^^^^^^^^^^^^^^^^^164#165# As we know, the evolved chain of graph-level optimization is like:166#167# .. code-block:: sh168#169# torch.neg (Python) -> torch.ops.aten.neg.default (within FX graph) -> ops.neg (within IR node) -> tmp2 = -tmp1 (within C++ kernel)170#171# If you encounter a compilation error, there is something wrong when compiling C++ kernels in the output code.172# This type of error indicates that bugs are introduced when lowering IR nodes to output code.173# The root cause of compilation error is usually shown in the traceback log.174#175# For example, the ``neg`` function is modified like this:176177def neg2(x):178return f"-{x}"179180######################################################################181# The logging gives the following compile error with a rather clear reason.182#183# .. code-block::184#185# torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:186# CppCompileError: C++ compile error187# /tmp/torchinductor_root/xg/cxga5tk3b4lkwoxyigrtocjp5s7vc5cg2ikuscf6bk6pjqip2bhx.cpp: In function ‘void kernel(const unsigned char*, const unsigned char*, unsigned char*)’:188# /tmp/torchinductor_root/xg/cxga5tk3b4lkwoxyigrtocjp5s7vc5cg2ikuscf6bk6pjqip2bhx.cpp:17:57: error: no matching function for call to ‘max_propagate_nan(unsigned char&, int&)’189# 17 | auto tmp3 = max_propagate_nan(tmp0, tmp2);190# | ^191# In file included from /tmp/torchinductor_root/xg/cxga5tk3b4lkwoxyigrtocjp5s7vc5cg2ikuscf6bk6pjqip2bhx.cpp:2:192# /tmp/torchinductor_root/gv/cgv6n5aotqjo5w4vknjibhengeycuattfto532hkxpozszcgxr3x.h:27:17: note: candidate: ‘template<class scalar_t> scalar_t max_propagate_nan(scalar_t, scalar_t)’193# 27 | inline scalar_t max_propagate_nan(scalar_t a, scalar_t b) {194# | ^~~~~~~~~~~~~~~~~195# /tmp/torchinductor_root/gv/cgv6n5aotqjo5w4vknjibhengeycuattfto532hkxpozszcgxr3x.h:27:17: note: template argument deduction/substitution failed:196# /tmp/torchinductor_root/xg/cxga5tk3b4lkwoxyigrtocjp5s7vc5cg2ikuscf6bk6pjqip2bhx.cpp:17:57: note: deduced conflicting types for parameter ‘scalar_t’ (‘unsigned char’ and ‘int’)197# 17 | auto tmp3 = max_propagate_nan(tmp0, tmp2);198# | ^199#200#201# Let us also see the corresponding C++ kernel in output code and IR node.202#203# C++ kernel:204#205# .. code:: c206#207# include "/tmp/torchinductor_root/gv/cgv6n5aotqjo5w4vknjibhengeycuattfto532hkxpozszcgxr3x.h"208# extern "C" void kernel(const unsigned char* in_ptr0,209# const unsigned char* in_ptr1,210# unsigned char* out_ptr0)211# {212# {213# #pragma GCC ivdep214# for(long i0=static_cast<long>(0L); i0<static_cast<long>(8390L); i0+=static_cast<long>(1L))215# {216# #pragma GCC ivdep217# for(long i1=static_cast<long>(0L); i1<static_cast<long>(8L); i1+=static_cast<long>(1L))218# {219# auto tmp0 = in_ptr0[static_cast<long>(i1 + (8L*i0))];220# auto tmp1 = in_ptr1[static_cast<long>(i1)];221# auto tmp2 = -tmp1;222# auto tmp3 = max_propagate_nan(tmp0, tmp2);223# out_ptr0[static_cast<long>(i1 + (8L*i0))] = tmp3;224# }225# }226# }227# }228#229230######################################################################231# IR node:232#233# .. code-block:: sh234#235# buf0: SchedulerNode(ComputedBuffer)236# buf0.writes = [MemoryDep('buf0', c0, {c0: 67120})]237# buf0.unmet_dependencies = []238# buf0.met_dependencies =239# [ MemoryDep('arg0_1', c1, {c0: 8390, c1: 8}),240# MemoryDep('arg1_1', c0, {c0: 67120})]241# buf0.users = [NodeUser(node=OUTPUT, can_inplace=False)]242# buf0.group.device = cpu243# buf0.group.iteration = ((8390, 8), ())244# buf0.sizes = ([8390, 8], [])245# class buf0_loop_body:246# var_ranges = {z0: 8390, z1: 8}247# index0 = 8*z0 + z1248# index1 = z1249# def body(self, ops):250# get_index = self.get_index('index0')251# load = ops.load('arg1_1', get_index)252# get_index_1 = self.get_index('index1')253# load_1 = ops.load('arg0_1', get_index_1)254# neg = ops.neg(load_1)255# maximum = ops.maximum(load, neg)256# get_index_2 = self.get_index('index0')257# store = ops.store('buf0', get_index_2, maximum, None)258# return store259#260261######################################################################262# According to the traceback logging, the compilation error is caused by the data type inconsistency of ``max_propagate_nan``'s inputs.263# By checking the C++ kernel, we know that ``tmp2`` is no longer ``long`` after doing ``-`` as ``tmp0`` is ``long``.264# We can easily match ``-`` and ``max_propagate_nan`` in C++ kernel with ``ops.neg`` and ``ops.maximum`` in IR node respectively.265#266# Now we successfully find that the root cause is the implementation of ``ops.neg`` in ``cpp`` codegen, which silently changes the data type when doing ``neg``.267#268#269# Accuracy debugging270# ^^^^^^^^^^^^^^^^^^^271#272# Otherwise, if the model runs with other errors or accuracy problem, you can use the PyTorch debugging tool called `Minifier <https://pytorch.org/functorch/stable/notebooks/minifier.html>`_.273#274# The core idea of ``Minifier`` is to keep removing the nodes and inputs of graph until finding the minimal graph with problem.275# It helps to automatically generate a minified problematic graph through 4 strategies: truncating suffix, delta debugging, eliminating dead code and removing unused inputs.276#277#278# We will now show the debugging process for the accuracy problem with the help of ``Minifer``.279# The accuracy problem refers to the case where the outputs of backends eager and inductor are different.280#281# For instance, we modify the example like this:282283from torch._dynamo.utils import same284285def foo2(x1, x2):286a = torch.neg(x1)287b = torch.maximum(x2, a)288y = torch.cat([b], dim=0)289return y290291x1 = torch.randn((1, 8), dtype=torch.float32)292x2 = torch.randn((8390, 8), dtype=torch.float32)293294expected_result = foo2(x1, x2)295296compiled_foo2 = torch.compile(foo2)297actual_result = compiled_foo2(x1, x2)298299assert same(expected_result, actual_result) == True300301######################################################################302# And also modify the ``neg`` function:303304def neg3(x):305return f"decltype({x})(2 * {x})"306307######################################################################308# An accuracy problem would be raised as follows:309#310# .. code-block:: sh311#312# torch._dynamo.utils: [ERROR] Accuracy failed: allclose not within tol=0.0001313# Traceback (most recent call last):314# File "test_script.py", line 18, in <module>315# assert same(expected_result, actual_result) == True316# AssertionError317#318# To debug an accuracy problem with Minifier, two environment variables are needed:319#320# .. code-block:: sh321#322# TORCHDYNAMO_REPRO_AFTER="aot" TORCHDYNAMO_REPRO_LEVEL=4 python xx.py323#324# Which gives us logging information that demonstrates the steps of minifying:325#326# .. code-block:: sh327#328# Started off with 6 nodes329#330# Trying granularity 2331# Strategy: Truncate suffix (G: 2) (6 nodes, 2 inputs)332# SUCCESS: Went from 6 to 4 nodes333#334# Trying granularity 4335# Strategy: Remove unused inputs (G: 4) (4 nodes, 2 inputs)336# SUCCESS: Went from 4 to 3 nodes337#338# After running, we get the final minified graph with the target node ``neg``:339340def forward2(self, arg0_1):341neg = torch.ops.aten.neg.default(arg0_1); arg0_1 = None342return (neg,)343344######################################################################345# For more usage details about Minifier, please refer to `Troubleshooting <https://pytorch.org/docs/stable/dynamo/troubleshooting.html>`_.346347348######################################################################349# Performance profiling350# ---------------------351#352# Within this section, we will demonstrate the process of conducting performance analysis for a model that has been compiled using the Inductor CPU backend.353# In the example below, we benchmark a Hugging Face Transformer model ``MobileBertForQuestionAnswering`` with both the eager mode and the Inductor graph mode.354# The execution time and the speedup ratio of Inductor are printed after the benchmark.355# We use Intel(R) Xeon(R) Platinum 8358 CPU @ 2.60GHz and run benchmark on the first socket to demonstrate the optimization within this section.356# We set following environment variable as a best practice to benchmark on Intel(R) CPU.357358#########################################################359# .. code-block:: shell360#361# export KMP_BLOCKTIME=1362# export KMP_SETTINGS=1363# export KMP_AFFINITY=granularity=fine,compact,1,0364# export LD_PRELOAD=${CONDA_PREFIX:-"$(dirname $(which conda))/../"}/lib/libiomp5.so:${CONDA_PREFIX:-"$(dirname $(which conda))/../"}/lib/libjemalloc.so365# export MALLOC_CONF="oversize_threshold:1,background_thread:true,metadata_thp:auto,dirty_decay_ms:-1,muzzy_decay_ms:-1"366# numactl -C 0-31 -m 0 python bench.py367#368369# bench.py370from transformers import MobileBertForQuestionAnswering371# Initialize an eager model372model = MobileBertForQuestionAnswering.from_pretrained("csarron/mobilebert-uncased-squad-v2")373seq_length = 128374bs = 128375vocab_size = model.config.vocab_size376input = torch.randint(0, vocab_size, (bs, seq_length), dtype=torch.int64)377input_dict = {"input_ids": input}378379# Initialize the inductor model380compiled_model = torch.compile(model)381with torch.no_grad():382compiled_model(**input_dict)383384NUM_ITERS=50385import timeit386with torch.no_grad():387# warmup388for _ in range(10):389model(**input_dict)390eager_t = timeit.timeit("model(**input_dict)", number=NUM_ITERS, globals=globals())391392with torch.no_grad():393# warmup394for _ in range(10):395compiled_model(**input_dict)396inductor_t = timeit.timeit("compiled_model(**input_dict)", number=NUM_ITERS, globals=globals())397# print(f"eager use: {eager_t * 1000 / NUM_ITERS} ms/iter")398# print(f"inductor use: {inductor_t * 1000 / NUM_ITERS} ms/iter")399# print(f"speed up ratio: {eager_t / inductor_t}")400401402######################################################################403# Output:404#405# .. code-block:: shell406#407# eager use: 802.1023553796113 ms/iter408# inductor use: 339.95180135127157 ms/iter409# speed up ratio: 2.359459053287382410#411# In our own testing, we find the Inductor CPU backend speed up the model by around 2.355x.412#413#414# Next, let's dive deep into the performance at the operation level to understand where the speed-up comes from.415# `Pytorch Profiler <https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html>`_ is a good tool to help us.416# Inductor CPU backend has the support to report the time of the fusion kernels to the profiler with the ``enable_kernel_profile`` configuration option:417418from torch._inductor import config419config.cpp.enable_kernel_profile = True420421######################################################################422# Following the steps in `Pytorch Profiler <https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html>`_423# We are able to get the profiling table and trace files.424425# bench.py426from torch.profiler import profile, schedule, ProfilerActivity427RESULT_DIR = "./prof_trace"428my_schedule = schedule(429skip_first=10,430wait=5,431warmup=5,432active=1,433repeat=5)434435def trace_handler(p):436output = p.key_averages().table(sort_by="self_cpu_time_total", row_limit=20)437# print(output)438p.export_chrome_trace(f"{RESULT_DIR}/{p.step_num}.json")439440for _ in range(10):441model(**input_dict) # compiled_model(**input_dict) to get inductor model profiling442443total = 0444with profile(445activities=[ProfilerActivity.CPU],446schedule=my_schedule,447on_trace_ready=trace_handler448) as p:449for _ in range(50):450model(**input_dict) # compiled_model(**input_dict) to get inductor model profiling451p.step()452453######################################################################454# We get the following performance profiling table for the eager-mode model (omitting some columns):455#456# .. code-block:: shell457#458# ------------------------- ------------ ------------ ------------459# Name CPU total % CPU total # of Calls460# ------------------------- ------------ ------------ ------------461# aten::addmm 45.73% 370.814ms 362462# aten::add 19.89% 161.276ms 363463# aten::copy_ 14.97% 121.416ms 488464# aten::mul 9.02% 73.154ms 194465# aten::clamp_min 8.81% 71.444ms 96466# aten::bmm 5.46% 44.258ms 48467# ProfilerStep* 100.00% 810.920ms 1468# aten::div 2.89% 23.447ms 24469# aten::_softmax 1.00% 8.087ms 24470# aten::linear 46.48% 376.888ms 362471# aten::clone 2.77% 22.430ms 98472# aten::t 0.31% 2.502ms 362473# aten::view 0.14% 1.161ms 850474# aten::transpose 0.17% 1.377ms 386475# aten::index_select 0.12% 952.000us 3476# aten::expand 0.12% 986.000us 458477# aten::matmul 8.31% 67.420ms 48478# aten::cat 0.09% 703.000us 1479# aten::as_strided 0.08% 656.000us 963480# aten::relu 8.86% 71.864ms 96481# ------------------------- ------------ ------------ ------------482# Self CPU time total: 810.920ms483#484485######################################################################486#487# Similarly, we also get the table for the compiled model with Inductor (omitting some columns):488#489# .. code-block:: shell490#491# ----------------------------------------------- ------------ ------------ ------------492# Name CPU total % CPU total # of Calls493# ----------------------------------------------- ------------ ------------ ------------494# mkl::_mkl_linear 68.79% 231.573ms 362495# aten::bmm 8.02% 26.992ms 48496# ProfilerStep* 100.00% 336.642ms 1497# graph_0_cpp_fused_constant_pad_nd_embedding_0 0.27% 915.000us 1498# aten::empty 0.27% 911.000us 362499# graph_0_cpp_fused__mkl_linear_add_mul_relu_151 0.27% 901.000us 1500# graph_0_cpp_fused__mkl_linear_add_mul_relu_226 0.27% 899.000us 1501# graph_0_cpp_fused__mkl_linear_add_mul_relu_361 0.27% 898.000us 1502# graph_0_cpp_fused__mkl_linear_add_mul_relu_121 0.27% 895.000us 1503# graph_0_cpp_fused__mkl_linear_add_mul_relu_31 0.27% 893.000us 1504# graph_0_cpp_fused__mkl_linear_add_mul_relu_76 0.26% 892.000us 1505# graph_0_cpp_fused__mkl_linear_add_mul_relu_256 0.26% 892.000us 1506# graph_0_cpp_fused__mkl_linear_add_mul_relu_346 0.26% 892.000us 1507# graph_0_cpp_fused__mkl_linear_add_mul_relu_241 0.26% 891.000us 1508# graph_0_cpp_fused__mkl_linear_add_mul_relu_316 0.26% 891.000us 1509# graph_0_cpp_fused__mkl_linear_add_mul_relu_91 0.26% 890.000us 1510# graph_0_cpp_fused__mkl_linear_add_mul_relu_106 0.26% 890.000us 1511# graph_0_cpp_fused__mkl_linear_add_mul_relu_211 0.26% 890.000us 1512# graph_0_cpp_fused__mkl_linear_add_mul_relu_61 0.26% 889.000us 1513# graph_0_cpp_fused__mkl_linear_add_mul_relu_286 0.26% 889.000us 1514# ----------------------------------------------- ------------ ------------ ------------515# Self CPU time total: 336.642ms516#517# From the profiling table of the eager model, we can see the most time consumption ops are [``aten::addmm``, ``aten::add``, ``aten::copy_``, ``aten::mul``, ``aten::clamp_min``, ``aten::bmm``].518# Comparing with the inductor model profiling table, we notice an ``mkl::_mkl_linear`` entry and multiple fused kernels in the form ``graph_0_cpp_fused_*``. They are the major519# optimizations that the inductor model is doing. Let us discuss them separately.520#521# (1) Regarding ``mkl::_mkl_linear``: You may notice the number of calls to this kernel is 362, which is exactly the same as ``aten::linear`` in the eager model profiling table.522# The CPU total of ``aten::linear`` is 376.888ms, while it is 231.573ms for ``mkl::_mkl_linear``. This suggests a ~1.63x for the "linear" part.523# The speedup mainly comes from `packing the weight tensor to block memory format <https://www.intel.com/content/www/us/en/docs/onemkl/developer-reference-c/2023-1/cblas-gemm-pack-002.html>`_524# and invoking `cblas_sgemm_compute <https://www.intel.com/content/www/us/en/docs/onemkl/developer-reference-c/2023-1/cblas-gemm-compute-002.html>`_ within the Inductor CPU backend525# to have a better cache behavior during GEMM computation.526#527# (2) Regarding other memory-intensive ops: The end-to-end latency for the eager/inductor model is 802/339ms in our testing. So we can roughly infer that the speed up for the other memory-intensive ops is around 3.94x.528# Let's read the generated code to understand how the inductor achieves this impressive optimization. You can find the generated code by529# searching ``cpp_fused__mkl_linear_add_mul_relu_151`` in ``output_code.py``530#531532533cpp_fused__mkl_linear_add_mul_relu_151 = async_compile.cpp('''534#include <ATen/record_function.h>535#include "/tmp/torchinductor_root/lr/clrlgu27q4ggd472umdzwsu6qcpqxcuusjxqvx2hwitjbujiiz7z.h"536extern "C" void kernel(float* in_out_ptr0,537const float* in_ptr0,538const float* in_ptr1,539const float* in_ptr2,540const float* in_ptr3)541{542RECORD_FUNCTION("graph_0_cpp_fused__mkl_linear_add_mul_relu_151", c10::ArrayRef<c10::IValue>({}));543#pragma omp parallel num_threads(32)544{545{546#pragma omp for547for(long i0=static_cast<long>(0L); i0<static_cast<long>(16384L); i0+=static_cast<long>(1L))548{549for(long i1=static_cast<long>(0L); i1<static_cast<long>(512L); i1+=static_cast<long>(8L))550{551auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<long>(i1 + (512L*i0)));552auto tmp1 = at::vec::Vectorized<float>::loadu(in_ptr1 + static_cast<long>(i1));553auto tmp3 = at::vec::Vectorized<float>::loadu(in_out_ptr0 + static_cast<long>(i1 + (512L*i0)));554auto tmp5 = at::vec::Vectorized<float>::loadu(in_ptr2 + static_cast<long>(i1));555auto tmp7 = at::vec::Vectorized<float>::loadu(in_ptr3 + static_cast<long>(i1));556auto tmp2 = tmp0 + tmp1;557auto tmp4 = tmp2 + tmp3;558auto tmp6 = tmp4 * tmp5;559auto tmp8 = tmp6 + tmp7;560tmp8.store(in_out_ptr0 + static_cast<long>(i1 + (512L*i0)));561}562}563}564}565}''')566567######################################################################568# From the generated code above, we can see this kernel has done a typical `Loop Fusion <https://en.wikipedia.org/wiki/Loop_fission_and_fusion>`_ on ``[add, add, mul, add]``.569# This is a memory-bound bottle neck preventing good performance. To get a more intuitive feeling about this optimization,570# we can infer the sizes and stride of the inputs and further benchmark this ``[add, add, mul, add]`` pattern.571572# bench.py573def func(arg_0, arg_1, arg_2, arg_3, arg_4):574add_0 = arg_0 + arg_1575add_1 = add_0 + arg_2576mul_1 = add_1 * arg_3577add_2 = mul_1 + arg_4578arg_2 = add_2579return arg_2580581arg_0 = torch.rand(16384, 512)582arg_1 = torch.rand(1, 512)583arg_2 = torch.zeros(16384, 512)584arg_3 = torch.rand(1, 512)585arg_4 = torch.rand(1, 512)586587input = (arg_0, arg_1, arg_2, arg_3, arg_4)588inductor_func = torch.compile(func)589with torch.no_grad():590inductor_func(*input)591592import timeit593NUM_ITERS=100594with torch.no_grad():595# warmup596for _ in range(10):597func(*input)598eager_t = timeit.timeit("func(*input)", number=NUM_ITERS, globals=globals())599600with torch.no_grad():601# warmup602for _ in range(10):603inductor_func(*input)604inductor_t = timeit.timeit("inductor_func(*input)", number=NUM_ITERS, globals=globals())605# print(f"eager use: {eager_t * 1000 / NUM_ITERS} ms/iter")606# print(f"inductor use: {inductor_t * 1000 / NUM_ITERS} ms/iter")607# print(f"speed up ratio: {eager_t / inductor_t}")608609######################################################################610# Output:611#612# .. code-block:: shell613#614# eager use: 5.780875144992024 ms/iter615# inductor use: 0.9588955780491233 ms/iter616# speed up ratio: 6.0286805751604735617#618#619# This is just an example. The profiling table shows all element-wise op are fused within the inductor automatically in this model. You can read more kernels in620# `output_code.py`621622623#########################################################################624# Conclusion625# ----------626#627# The document gives an in-depth tutorial for the Inductor CPU backend.628#629# With motivating examples, we walk through the process of debugging and profiling.630# The main idea is to narrow down the problem.631#632# We demonstrate step by step the way to delve deeper the issue and find the root cause of failures, with the help of debugging logging and the tool Minifier.633# Firstly determine which component the failure occurs in and then try to generate the smallest snippet of code that can reproduce the failure.634#635# When the performance with Inductor is better than that of eager mode, we provide a solid analytical method for performance profiling.636# We show how to find the time-consuming hotspot with PyTorch Profiler and figure out the operator-level or kernel-level reason to explain the phenomenon.637638639