Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Path: blob/main/recipes_source/recipes/benchmark.py
Views: 713
"""1PyTorch Benchmark2====================================3This recipe provides a quick-start guide to using PyTorch4``benchmark`` module to measure and compare code performance.56Introduction7------------8Benchmarking is an important step in writing code. It helps9us validate that our code meets performance expectations,10compare different approaches to solving the same problem and11prevent performance regressions.1213There are many options when it comes to benchmarking PyTorch code14including the Python builtin ``timeit`` module. However, benchmarking15PyTorch code has many caveats that can be easily overlooked such as16managing the number of threads and synchronizing CUDA devices. Moreover,17generating Tensor inputs for benchmarking can be quite tedious.1819This recipe demonstrates how to use PyTorch ``benchmark`` module to avoid20common mistakes while making it easier to compare performance of21different code, generate input for benchmarking and more.2223Setup24-----25Before we begin, install ``torch`` if it isn’t already available.2627::2829pip install torch3031"""323334######################################################################35# Steps36# -----37#38# 1. Defining functions to benchmark39# 2. Benchmarking with ``timeit.Timer``40# 3. Benchmarking with ``torch.utils.benchmark.Timer``41# 4. Benchmarking with ``Blocked Autorange``42# 5. Comparing benchmark results43# 6. Saving/Loading benchmark results44# 7. Generating inputs with ``Fuzzed Parameters``45# 8. Collecting instruction counts with ``Callgrind``46#47# 1. Defining functions to benchmark48# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~49#50# As of the time of this writing, `torch.dot <https://pytorch.org/docs/stable/generated/torch.dot.html?highlight=dot#torch.dot>`__51# does not support batched mode, so we will compare two approaches to52# implementing it using existing ``torch`` operators: one approach uses a53# combination of ``mul`` and ``sum`` while the other reduces the problem to ``bmm``.54#5556import torch575859def batched_dot_mul_sum(a, b):60'''Computes batched dot by multiplying and summing'''61return a.mul(b).sum(-1)626364def batched_dot_bmm(a, b):65'''Computes batched dot by reducing to ``bmm``'''66a = a.reshape(-1, 1, a.shape[-1])67b = b.reshape(-1, b.shape[-1], 1)68return torch.bmm(a, b).flatten(-3)697071# Input for benchmarking72x = torch.randn(10000, 64)7374# Ensure that both functions compute the same output75assert batched_dot_mul_sum(x, x).allclose(batched_dot_bmm(x, x))767778######################################################################79# 2. Benchmarking with ``timeit.Timer``80# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~81#82# First, let's benchmark the code using Python's builtin ``timeit`` module.83# We keep the benchmark code simple here so we can compare the defaults84# of ``timeit`` and ``torch.utils.benchmark``.85#8687import timeit8889t0 = timeit.Timer(90stmt='batched_dot_mul_sum(x, x)',91setup='from __main__ import batched_dot_mul_sum',92globals={'x': x})9394t1 = timeit.Timer(95stmt='batched_dot_bmm(x, x)',96setup='from __main__ import batched_dot_bmm',97globals={'x': x})9899print(f'mul_sum(x, x): {t0.timeit(100) / 100 * 1e6:>5.1f} us')100print(f'bmm(x, x): {t1.timeit(100) / 100 * 1e6:>5.1f} us')101102######################################################################103# .. code-block:: none104# :caption: Output105#106# mul_sum(x, x): 111.6 us107# bmm(x, x): 70.0 us108#109110111######################################################################112# 3. Benchmarking with ``torch.utils.benchmark.Timer``113# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~114#115# PyTorch ``benchmark`` module was designed to be familiar to those who116# have used the ``timeit`` module before. However, its defaults make it117# easier and safer to use for benchmarking PyTorch code. Let's first118# compare the same basic API as above.119#120121import torch.utils.benchmark as benchmark122123t0 = benchmark.Timer(124stmt='batched_dot_mul_sum(x, x)',125setup='from __main__ import batched_dot_mul_sum',126globals={'x': x})127128t1 = benchmark.Timer(129stmt='batched_dot_bmm(x, x)',130setup='from __main__ import batched_dot_bmm',131globals={'x': x})132133print(t0.timeit(100))134print(t1.timeit(100))135136######################################################################137# .. code-block:: none138# :caption: Output139#140# <torch.utils.benchmark.utils.common.Measurement object at 0x7fb10400d0f0>141# batched_dot_mul_sum(x, x)142# setup: from __main__ import batched_dot_mul_sum143# 379.29 us144# 1 measurement, 100 runs , 1 thread145# <torch.utils.benchmark.utils.common.Measurement object at 0x7fb103d67048>146# batched_dot_bmm(x, x)147# setup: from __main__ import batched_dot_bmm148# 716.42 us149# 1 measurement, 100 runs , 1 thread150#151152######################################################################153# Even though the APIs are the same for the basic functionality, there154# are some important differences. ``benchmark.Timer.timeit()`` returns the155# time per run as opposed to the total runtime like ``timeit.Timer.timeit()``156# does. PyTorch ``benchmark`` module also provides formatted string157# representations for printing the results.158#159# Another important difference, and the reason why the results diverge160# is that PyTorch benchmark module runs in a single thread by default.161# We can change the number of threads with the ``num_threads`` argument.162#163# ``torch.utils.benchmark.Timer`` takes several additional arguments164# including: ``label``, ``sub_label``, ``description`` and ``env`` which change165# the __repr__ of the measurement object returned and are used for166# grouping the results (more on this later).167#168169num_threads = torch.get_num_threads()170print(f'Benchmarking on {num_threads} threads')171172t0 = benchmark.Timer(173stmt='batched_dot_mul_sum(x, x)',174setup='from __main__ import batched_dot_mul_sum',175globals={'x': x},176num_threads=num_threads,177label='Multithreaded batch dot',178sub_label='Implemented using mul and sum')179180t1 = benchmark.Timer(181stmt='batched_dot_bmm(x, x)',182setup='from __main__ import batched_dot_bmm',183globals={'x': x},184num_threads=num_threads,185label='Multithreaded batch dot',186sub_label='Implemented using bmm')187188print(t0.timeit(100))189print(t1.timeit(100))190191######################################################################192# .. code-block:: none193# :caption: Output194#195# Benchmarking on 40 threads196# <torch.utils.benchmark.utils.common.Measurement object at 0x7fb103d54080>197# Multithreaded batch dot: Implemented using mul and sum198# setup: from __main__ import batched_dot_mul_sum199# 118.47 us200# 1 measurement, 100 runs , 40 threads201# <torch.utils.benchmark.utils.common.Measurement object at 0x7fb16935d2e8>202# Multithreaded batch dot: Implemented using bmm203# setup: from __main__ import batched_dot_bmm204# 68.21 us205# 1 measurement, 100 runs , 40 threads206207######################################################################208# Running ``benchmark`` with all threads available gives similar results209# as the ``timeit`` module. More importantly, which version is faster210# depends on how many threads we run the code with. This is why it's211# important to benchmark the code with thread settings that are212# representative of real use cases. Another important thing to remember213# is to synchronize CPU and CUDA when benchmarking on the GPU. Let's run214# the above benchmarks again on a CUDA tensor and see what happens.215#216217x = torch.randn(10000, 1024, device='cuda')218219t0 = timeit.Timer(220stmt='batched_dot_mul_sum(x, x)',221setup='from __main__ import batched_dot_mul_sum',222globals={'x': x})223224t1 = timeit.Timer(225stmt='batched_dot_bmm(x, x)',226setup='from __main__ import batched_dot_bmm',227globals={'x': x})228229# Ran each twice to show difference before/after warm-up230print(f'mul_sum(x, x): {t0.timeit(100) / 100 * 1e6:>5.1f} us')231print(f'mul_sum(x, x): {t0.timeit(100) / 100 * 1e6:>5.1f} us')232print(f'bmm(x, x): {t1.timeit(100) / 100 * 1e6:>5.1f} us')233print(f'bmm(x, x): {t1.timeit(100) / 100 * 1e6:>5.1f} us')234235######################################################################236# .. code-block:: none237# :caption: Output238#239# mul_sum(x, x): 27.6 us240# mul_sum(x, x): 25.3 us241# bmm(x, x): 2775.5 us242# bmm(x, x): 22.4 us243#244245t0 = benchmark.Timer(246stmt='batched_dot_mul_sum(x, x)',247setup='from __main__ import batched_dot_mul_sum',248globals={'x': x})249250t1 = benchmark.Timer(251stmt='batched_dot_bmm(x, x)',252setup='from __main__ import batched_dot_bmm',253globals={'x': x})254255# Run only once since benchmark module does warm-up for us256print(t0.timeit(100))257print(t1.timeit(100))258259######################################################################260# .. code-block:: none261# :caption: Output262#263# <torch.utils.benchmark.utils.common.Measurement object at 0x7fb10400d080>264# batched_dot_mul_sum(x, x)265# setup: from __main__ import batched_dot_mul_sum266# 232.93 us267# 1 measurement, 100 runs , 1 thread268# <torch.utils.benchmark.utils.common.Measurement object at 0x7fb10400d0f0>269# batched_dot_bmm(x, x)270# setup: from __main__ import batched_dot_bmm271# 181.04 us272# 1 measurement, 100 runs , 1 thread273#274275######################################################################276# The results reveal something interesting. The first run of the ``bmm``277# version using the ``timeit`` module takes much longer than the second278# run. This is because ``bmm`` calls into `cuBLAS` which needs to be279# loaded the first time it's called which takes some time. This is why280# it's important to do a warm-up run before benchmarking, luckily for281# us, PyTorch's ``benchmark`` module takes care of that.282#283# The difference in the results between ``timeit`` and ``benchmark`` modules284# is because the `timeit` module is not synchronizing CUDA and is thus only285# timing the time to launch the kernel. PyTorch's ``benchmark`` module does286# the synchronization for us.287288289######################################################################290# 4. Benchmarking with `Blocked Autorange`291# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~292#293# While ``timeit.Timer.autorange`` takes a single continuous measurement294# of at least 0.2 seconds, `torch.utils.benchmark.blocked_autorange`295# takes many measurements whose times total at least 0.2 seconds (which296# can be changed by the `min_run_time` parameter) subject to the constraint297# that timing overhead is a small fraction of the overall measurement.298# This is accomplished by first running with an increasing number of runs299# per loop until the runtime is much larger than measurement overhead300# (which also serves as a warm up), and then taking measurements until301# the target time is reached. This has the useful properties that it wastes302# less data and allows us to compute statistics to estimate the reliability303# of the measurements.304#305306m0 = t0.blocked_autorange()307m1 = t1.blocked_autorange()308309print(m0)310print(m1)311312######################################################################313# .. code-block:: none314# :caption: Output315#316# <torch.utils.benchmark.utils.common.Measurement object at 0x7fb10400d0f0>317# batched_dot_mul_sum(x, x)318# setup: from __main__ import batched_dot_mul_sum319# 231.79 us320# 1 measurement, 1000 runs , 1 thread321# <torch.utils.benchmark.utils.common.Measurement object at 0x7fb10400d080>322# batched_dot_bmm(x, x)323# setup: from __main__ import batched_dot_bmm324# Median: 162.08 us325# 2 measurements, 1000 runs per measurement, 1 thread326#327328######################################################################329# We can also inspect the individual statistics from the returned330# measurements object.331332print(f"Mean: {m0.mean * 1e6:6.2f} us")333print(f"Median: {m0.median * 1e6:6.2f} us")334335######################################################################336# .. code-block:: none337# :caption: Output338#339# Mean: 231.79 us340# Median: 231.79 us341#342343######################################################################344# 5. Comparing benchmark results345# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~346#347# So far we've been comparing our two versions of batched dot against a348# single input. In practice, we want to try a combination of inputs as349# well as different number of threads. The ``Compare`` class helps display350# the results of many measurements in a formatted table. It uses the351# annotations described above (`label`, `sub_label`, `num_threads`, etc.) as352# well as `description` to group and organize the table. Let's use353# ``Compare`` to see how our functions perform for different input sizes354# and number of threads.355#356357from itertools import product358359# Compare takes a list of measurements which we'll save in results.360results = []361362sizes = [1, 64, 1024, 10000]363for b, n in product(sizes, sizes):364# label and sub_label are the rows365# description is the column366label = 'Batched dot'367sub_label = f'[{b}, {n}]'368x = torch.ones((b, n))369for num_threads in [1, 4, 16, 32]:370results.append(benchmark.Timer(371stmt='batched_dot_mul_sum(x, x)',372setup='from __main__ import batched_dot_mul_sum',373globals={'x': x},374num_threads=num_threads,375label=label,376sub_label=sub_label,377description='mul/sum',378).blocked_autorange(min_run_time=1))379results.append(benchmark.Timer(380stmt='batched_dot_bmm(x, x)',381setup='from __main__ import batched_dot_bmm',382globals={'x': x},383num_threads=num_threads,384label=label,385sub_label=sub_label,386description='bmm',387).blocked_autorange(min_run_time=1))388389compare = benchmark.Compare(results)390compare.print()391392######################################################################393# .. code-block:: none394# :caption: Output395#396# [--------------- Batched dot ----------------]397# | mul/sum | bmm398# 1 threads: -----------------------------------399# [1, 1] | 5.9 | 11.2400# [1, 64] | 6.4 | 11.4401# [1, 1024] | 6.7 | 14.2402# [1, 10000] | 10.2 | 23.7403# [64, 1] | 6.3 | 11.5404# [64, 64] | 8.6 | 15.4405# [64, 1024] | 39.4 | 204.4406# [64, 10000] | 274.9 | 748.5407# [1024, 1] | 7.7 | 17.8408# [1024, 64] | 40.3 | 76.4409# [1024, 1024] | 432.4 | 2795.9410# [1024, 10000] | 22657.3 | 11899.5411# [10000, 1] | 16.9 | 74.8412# [10000, 64] | 300.3 | 609.4413# [10000, 1024] | 23098.6 | 27246.1414# [10000, 10000] | 267073.7 | 118823.7415# 4 threads: -----------------------------------416# [1, 1] | 6.0 | 11.5417# [1, 64] | 6.2 | 11.2418# [1, 1024] | 6.8 | 14.3419# [1, 10000] | 10.2 | 23.7420# [64, 1] | 6.3 | 16.2421# [64, 64] | 8.8 | 18.2422# [64, 1024] | 41.5 | 189.1423# [64, 10000] | 91.7 | 849.1424# [1024, 1] | 7.6 | 17.4425# [1024, 64] | 43.5 | 33.5426# [1024, 1024] | 135.4 | 2782.3427# [1024, 10000] | 7471.1 | 11874.0428# [10000, 1] | 16.8 | 33.9429# [10000, 64] | 118.7 | 173.2430# [10000, 1024] | 7264.6 | 27824.7431# [10000, 10000] | 100060.9 | 121499.0432# 16 threads: ----------------------------------433# [1, 1] | 6.0 | 11.3434# [1, 64] | 6.2 | 11.2435# [1, 1024] | 6.9 | 14.2436# [1, 10000] | 10.3 | 23.8437# [64, 1] | 6.4 | 24.1438# [64, 64] | 9.0 | 23.8439# [64, 1024] | 54.1 | 188.5440# [64, 10000] | 49.9 | 748.0441# [1024, 1] | 7.6 | 23.4442# [1024, 64] | 55.5 | 28.2443# [1024, 1024] | 66.9 | 2773.9444# [1024, 10000] | 6111.5 | 12833.7445# [10000, 1] | 16.9 | 27.5446# [10000, 64] | 59.5 | 73.7447# [10000, 1024] | 6295.9 | 27062.0448# [10000, 10000] | 71804.5 | 120365.8449# 32 threads: ----------------------------------450# [1, 1] | 5.9 | 11.3451# [1, 64] | 6.2 | 11.3452# [1, 1024] | 6.7 | 14.2453# [1, 10000] | 10.5 | 23.8454# [64, 1] | 6.3 | 31.7455# [64, 64] | 9.1 | 30.4456# [64, 1024] | 72.0 | 190.4457# [64, 10000] | 103.1 | 746.9458# [1024, 1] | 7.6 | 28.4459# [1024, 64] | 70.5 | 31.9460# [1024, 1024] | 65.6 | 2804.6461# [1024, 10000] | 6764.0 | 11871.4462# [10000, 1] | 17.8 | 31.8463# [10000, 64] | 110.3 | 56.0464# [10000, 1024] | 6640.2 | 27592.2465# [10000, 10000] | 73003.4 | 120083.2466#467# Times are in microseconds (us).468#469470######################################################################471# The results above indicate that the version which reduces to ``bmm``472# is better for larger tensors running on multiple threads, while for473# smaller and/or single thread code, the other version is better.474#475# ``Compare`` also provides functions for changing the table format476#477478compare.trim_significant_figures()479compare.colorize()480compare.print()481482483######################################################################484# 6. Saving/Loading benchmark results485# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~486#487# `Measurements` (and ``CallgrindStats`` which are described in section 8)488# can be serialized by the ``pickle`` module. This makes A/B testing easy, as you can collect489# measurements from two separate environments, pickle them, and then490# load both in a single environment. Timer even takes an `env`491# constructor argument so that such A/B testing works seamlessly.492#493# Let's imagine that rather than two Python functions, the add/sum494# and ``bmm`` approaches were in two different builds of PyTorch.495# The example below demonstrates how one might A/B test them. For496# simplicity, we only use a subset of shapes, and simply round trip497# results through pickle rather than actually using multiple environments498# and writing results to disk.499#500501import pickle502503ab_test_results = []504for env in ('environment A: mul/sum', 'environment B: bmm'):505for b, n in ((1, 1), (1024, 10000), (10000, 1)):506x = torch.ones((b, n))507dot_fn = (batched_dot_mul_sum if env == 'environment A: mul/sum' else batched_dot_bmm)508m = benchmark.Timer(509stmt='batched_dot(x, x)',510globals={'x': x, 'batched_dot': dot_fn},511num_threads=1,512label='Batched dot',513description=f'[{b}, {n}]',514env=env,515).blocked_autorange(min_run_time=1)516ab_test_results.append(pickle.dumps(m))517518ab_results = [pickle.loads(i) for i in ab_test_results]519compare = benchmark.Compare(ab_results)520compare.trim_significant_figures()521compare.colorize()522compare.print()523524######################################################################525# .. code-block:: none526# :caption: Output527#528# [------------------------------------- Batched dot -------------------------------------]529# | [1, 1] | [1024, 10000] | [10000, 1]530# 1 threads: ------------------------------------------------------------------------------531# (environment A: mul/sum) batched_dot(x, x) | 7 | 36000 | 21532# (environment B: bmm) batched_dot(x, x) | 14 | 40000 | 85533#534# Times are in microseconds (us).535#536537# And just to show that we can round trip all of the results from earlier:538round_tripped_results = pickle.loads(pickle.dumps(results))539assert(str(benchmark.Compare(results)) == str(benchmark.Compare(round_tripped_results)))540541542######################################################################543# 7. Generating inputs with `Fuzzed Parameters`544# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~545#546# As we've seen in the previous section, there can be some stark547# performance differences depending on the input tensors. Hence, it548# is a good idea to run benchmarks on a number of different inputs.549# However, creating all these input tensors can be tedious which is550# where ``torch.utils.benchmark.Fuzzer`` and related classes come in.551# Let's take a look at how we can use the ``Fuzzer`` to create some test552# cases for the benchmark.553#554555from torch.utils.benchmark import Fuzzer, FuzzedParameter, FuzzedTensor, ParameterAlias556557# Generates random tensors with 128 to 10000000 elements and sizes k0 and k1 chosen from a558# ``loguniform`` distribution in [1, 10000], 40% of which will be discontiguous on average.559example_fuzzer = Fuzzer(560parameters = [561FuzzedParameter('k0', minval=1, maxval=10000, distribution='loguniform'),562FuzzedParameter('k1', minval=1, maxval=10000, distribution='loguniform'),563],564tensors = [565FuzzedTensor('x', size=('k0', 'k1'), min_elements=128, max_elements=10000000, probability_contiguous=0.6)566],567seed=0,568)569570results = []571for tensors, tensor_params, params in example_fuzzer.take(10):572# description is the column label573sub_label=f"{params['k0']:<6} x {params['k1']:<4} {'' if tensor_params['x']['is_contiguous'] else '(discontiguous)'}"574results.append(benchmark.Timer(575stmt='batched_dot_mul_sum(x, x)',576setup='from __main__ import batched_dot_mul_sum',577globals=tensors,578label='Batched dot',579sub_label=sub_label,580description='mul/sum',581).blocked_autorange(min_run_time=1))582results.append(benchmark.Timer(583stmt='batched_dot_bmm(x, x)',584setup='from __main__ import batched_dot_bmm',585globals=tensors,586label='Batched dot',587sub_label=sub_label,588description='bmm',589).blocked_autorange(min_run_time=1))590591compare = benchmark.Compare(results)592compare.trim_significant_figures()593compare.print()594595######################################################################596# .. code-block:: none597# :caption: Output598#599# [--------------------- Batched dot ---------------------]600# | mul/sum | bmm601# 1 threads: ----------------------------------------------602# 725 x 257 | 87 | 180603# 49 x 383 | 15 | 30604# 34 x 1468 | 30 | 118605# 187 x 5039 | 400 | 1200606# 2140 x 1296 (discontiguous) | 2000 | 41000607# 78 x 1598 | 74 | 310608# 519 x 763 | 190 | 1500609# 141 x 1082 | 87 | 500610# 78 x 5 (discontiguous) | 9 | 20611# 187 x 1 | 12 | 10612#613# Times are in microseconds (us).614#615616######################################################################617# There is a lot of flexibility for defining your own ``fuzzers`` which618# is great for creating a powerful set of inputs to benchmark. But to619# make things even simpler, PyTorch benchmark module comes with some620# built-in ``fuzzers`` for common benchmarking needs. Let's take a look at621# how we can use one of these built-in ``fuzzers``.622#623624from torch.utils.benchmark.op_fuzzers import binary625626results = []627for tensors, tensor_params, params in binary.BinaryOpFuzzer(seed=0).take(10):628sub_label=f"{params['k0']:<6} x {params['k1']:<4} {'' if tensor_params['x']['is_contiguous'] else '(discontiguous)'}"629results.append(benchmark.Timer(630stmt='batched_dot_mul_sum(x, x)',631setup='from __main__ import batched_dot_mul_sum',632globals=tensors,633label='Batched dot',634sub_label=sub_label,635description='mul/sum',636).blocked_autorange(min_run_time=1))637results.append(benchmark.Timer(638stmt='batched_dot_bmm(x, x)',639setup='from __main__ import batched_dot_bmm',640globals=tensors,641label='Batched dot',642sub_label=sub_label,643description='bmm',644).blocked_autorange(min_run_time=1))645646compare = benchmark.Compare(results)647compare.trim_significant_figures()648compare.colorize(rowwise=True)649compare.print()650651######################################################################652# .. code-block:: none653# :caption: Output654#655# [----------------------- Batched dot ------------------------]656# | mul/sum | bmm657# 1 threads: ---------------------------------------------------658# 64 x 473 (discontiguous) | 10000 | 40000659# 16384 x 12642115 (discontiguous) | 31 | 78660# 8192 x 892 | 4800 | 20400661# 512 x 64 (discontiguous) | 110000 | 400000662# 493 x 27 (discontiguous) | 1100 | 2440663# 118 x 32 (discontiguous) | 870 | 2030664# 16 x 495 (discontiguous) | 23600 | 24000665# 488 x 62374 | 90000 | 100000666# 240372 x 69 | 40000 | 16000667# 40156 x 32 (discontiguous) | 2670 | 5000668#669# Times are in microseconds (us).670#671672######################################################################673# 8. Collecting instruction counts with ``Callgrind``674# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~675#676# One of the challenges of optimizing code is the variation and opacity of677# wall time. There are many sources of non-determinism, from adaptive clock678# speeds to resource contention with other processes. Furthermore, end-to-end679# time gives no insight into where time is being spent, which is really what680# we're interested in when optimizing code.681#682# A complementary approach is to also collect instruction counts. These counts683# are a proxy metric and do not capture all aspects of performance684# (e.g. memory or I/O bound tasks), however they do have several useful685# properties. Instruction counts are reproducible, insensitive to environmental686# variation, and offer fine grained insight into where a program is spending687# cycles.688#689# To see the utility of instruction counts, let us look at how we might690# reduce the overhead of `batched_dot_mul_sum`. The obvious solution is to691# move it to C++, so we avoid going between Python and C++ multiple times.692#693# Fortunately, the source is nearly identical. One question that we have to ask694# in C++ is whether we should take arguments by value or reference.695#696697batched_dot_src = """\698/* ---- Python ---- */699// def batched_dot_mul_sum(a, b):700// return a.mul(b).sum(-1)701702torch::Tensor batched_dot_mul_sum_v0(703const torch::Tensor a,704const torch::Tensor b) {705return a.mul(b).sum(-1);706}707708torch::Tensor batched_dot_mul_sum_v1(709const torch::Tensor& a,710const torch::Tensor& b) {711return a.mul(b).sum(-1);712}713"""714715716# PyTorch makes it easy to test our C++ implementations by providing a utility717# to JIT compile C++ source into Python extensions:718import os719from torch.utils import cpp_extension720cpp_lib = cpp_extension.load_inline(721name='cpp_lib',722cpp_sources=batched_dot_src,723extra_cflags=['-O3'],724extra_include_paths=[725# `load_inline` needs to know where to find ``pybind11`` headers.726os.path.join(os.getenv('CONDA_PREFIX'), 'include')727],728functions=['batched_dot_mul_sum_v0', 'batched_dot_mul_sum_v1']729)730731# `load_inline` will create a shared object that is loaded into Python. When we collect732# instruction counts Timer will create a subprocess, so we need to re-import it. The733# import process is slightly more complicated for C extensions, but that's all we're734# doing here.735module_import_str = f"""\736# https://stackoverflow.com/questions/67631/how-to-import-a-module-given-the-full-path737import importlib.util738spec = importlib.util.spec_from_file_location("cpp_lib", {repr(cpp_lib.__file__)})739cpp_lib = importlib.util.module_from_spec(spec)740spec.loader.exec_module(cpp_lib)"""741742import textwrap743def pretty_print(result):744"""Import machinery for ``cpp_lib.so`` can get repetitive to look at."""745print(repr(result).replace(textwrap.indent(module_import_str, " "), " import cpp_lib"))746747748t_baseline = benchmark.Timer(749stmt='batched_dot_mul_sum(x, x)',750setup='''\751from __main__ import batched_dot_mul_sum752x = torch.randn(2, 2)''')753754t0 = benchmark.Timer(755stmt='cpp_lib.batched_dot_mul_sum_v0(x, x)',756setup=f'''\757{module_import_str}758x = torch.randn(2, 2)''')759760t1 = benchmark.Timer(761stmt='cpp_lib.batched_dot_mul_sum_v1(x, x)',762setup=f'''\763{module_import_str}764x = torch.randn(2, 2)''')765766# Moving to C++ did indeed reduce overhead, but it's hard to tell which767# calling convention is more efficient. v1 (call with references) seems to768# be a bit faster, but it's within measurement error.769pretty_print(t_baseline.blocked_autorange())770pretty_print(t0.blocked_autorange())771pretty_print(t1.blocked_autorange())772773######################################################################774# .. code-block:: none775# :caption: Output776#777# <torch.utils.benchmark.utils.common.Measurement object at 0x7fb16935d2e8>778# batched_dot_mul_sum(x, x)779# setup:780# from __main__ import batched_dot_mul_sum781# x = torch.randn(2, 2)782#783# 6.92 us784# 1 measurement, 100000 runs , 1 thread785# <torch.utils.benchmark.utils.common.Measurement object at 0x7fb16935d2e8>786# cpp_lib.batched_dot_mul_sum_v0(x, x)787# setup:788# import cpp_lib789# x = torch.randn(2, 2)790#791# 5.29 us792# 1 measurement, 100000 runs , 1 thread793# <torch.utils.benchmark.utils.common.Measurement object at 0x7fb16935d2e8>794# cpp_lib.batched_dot_mul_sum_v1(x, x)795# setup:796# import cpp_lib797# x = torch.randn(2, 2)798#799# 5.22 us800# 1 measurement, 100000 runs , 1 thread801#802803# Let's use ``Callgrind`` to determine which is better.804stats_v0 = t0.collect_callgrind()805stats_v1 = t1.collect_callgrind()806807pretty_print(stats_v0)808pretty_print(stats_v1)809810# `.as_standardized` removes file names and some path prefixes, and makes811# it easier to read the function symbols.812stats_v0 = stats_v0.as_standardized()813stats_v1 = stats_v1.as_standardized()814815# `.delta` diffs the instruction counts, and `.denoise` removes several816# functions in the Python interpreter that are known to have significant817# jitter.818delta = stats_v1.delta(stats_v0).denoise()819820# `.transform` is a convenience API for transforming function names. It is821# useful for increasing cancelation when ``diff-ing`` instructions, as well as822# just generally improving readability.823replacements = (824("???:void pybind11", "pybind11"),825("batched_dot_mul_sum_v0", "batched_dot_mul_sum_v1"),826("at::Tensor, at::Tensor", "..."),827("at::Tensor const&, at::Tensor const&", "..."),828("auto torch::detail::wrap_pybind_function_impl_", "wrap_pybind_function_impl_"),829)830for before, after in replacements:831delta = delta.transform(lambda l: l.replace(before, after))832833# We can use print options to control how much of the function to display.834torch.set_printoptions(linewidth=160)835836# Once parsed, the instruction counts make clear that passing `a` and `b`837# by reference is more efficient as it skips some ``c10::TensorImpl`` bookkeeping838# for the intermediate Tensors, and is also works better with ``pybind11``. This839# is consistent with our noisy wall time observations.840print(delta)841842######################################################################843# .. code-block::844#845# <torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.CallgrindStats object at 0x7fb0f06e7630>846# cpp_lib.batched_dot_mul_sum_v0(x, x)847# setup:848# import cpp_lib849# x = torch.randn(2, 2)850# All Noisy symbols removed851# Instructions: 2392671 2392671852# Baseline: 4367 4367853# 100 runs per measurement, 1 thread854# Warning: PyTorch was not built with debug symbols.855# Source information may be limited. Rebuild with856# REL_WITH_DEB_INFO=1 for more detailed results.857# <torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.CallgrindStats object at 0x7fb10400d208>858# cpp_lib.batched_dot_mul_sum_v1(x, x)859# setup:860# import cpp_lib861# x = torch.randn(2, 2)862# All Noisy symbols removed863# Instructions: 2378978 2378978864# Baseline: 4367 4367865# 100 runs per measurement, 1 thread866# Warning: PyTorch was not built with debug symbols.867# Source information may be limited. Rebuild with868# REL_WITH_DEB_INFO=1 for more detailed results.869# <torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.FunctionCounts object at 0x7fb1000ab358>870# 86 ???:0x000000000020d9e0871# 56 ???:0x000000000020db10872# -1100 pybind11::cpp_function::initialize<wrap_pybind_function_impl_<at::Tensor ... r (&)(...), std::integer_sequence<unsigned long, 0ul, 1ul>)::{lambda(...)873# -1600 ???:wrap_pybind_function_impl_<at::Tensor (&)(...), 0ul, 1ul>(at::Tensor (&)(...), std::integer_sequence<unsigned long, 0ul, 1ul>)::{lambda(...)874# -5200 ???:c10::intrusive_ptr<c10::TensorImpl, c10::UndefinedTensorImpl>::reset_()875# -5935 ???:0x000000000022c0e0876# Total: -13693877#878879880######################################################################881# Learn More882# ----------883#884# Take a look at these other recipes to continue your learning:885#886# - `PyTorch Profiler <https://pytorch.org/tutorials/recipes/recipes/profiler.html>`_887#888889890