Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Path: blob/main/recipes_source/recipes/timer_quick_start.py
Views: 713
"""1Timer quick start2=================34In this tutorial, we're going to cover the primary APIs of5`torch.utils.benchmark.Timer`. The PyTorch Timer is based on the6`timeit.Timer <https://docs.python.org/3/library/timeit.html#timeit.Timer>`__7API, with several PyTorch specific modifications. Familiarity with the8builtin `Timer` class is not required for this tutorial, however we assume9that the reader is familiar with the fundamentals of performance work.1011For a more comprehensive performance tuning tutorial, see12`PyTorch Benchmark <https://pytorch.org/tutorials/recipes/recipes/benchmark.html>`__.131415**Contents:**161. `Defining a Timer <#defining-a-timer>`__172. `Wall time: Timer.blocked_autorange(...) <#wall-time-timer-blocked-autorange>`__183. `C++ snippets <#c-snippets>`__194. `Instruction counts: Timer.collect_callgrind(...) <#instruction-counts-timer-collect-callgrind>`__205. `Instruction counts: Delving deeper <#instruction-counts-delving-deeper>`__216. `A/B testing with Callgrind <#a-b-testing-with-callgrind>`__227. `Wrapping up <#wrapping-up>`__238. `Footnotes <#footnotes>`__24"""252627###############################################################################28# 1. Defining a Timer29# ~~~~~~~~~~~~~~~~~~~30#31# A `Timer` serves as a task definition.32#3334from torch.utils.benchmark import Timer3536timer = Timer(37# The computation which will be run in a loop and timed.38stmt="x * y",3940# `setup` will be run before calling the measurement loop, and is used to41# populate any state which is needed by `stmt`42setup="""43x = torch.ones((128,))44y = torch.ones((128,))45""",4647# Alternatively, ``globals`` can be used to pass variables from the outer scope.48#49# globals={50# "x": torch.ones((128,)),51# "y": torch.ones((128,)),52# },5354# Control the number of threads that PyTorch uses. (Default: 1)55num_threads=1,56)5758###############################################################################59# 2. Wall time: ``Timer.blocked_autorange(...)``60# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~61#62# This method will handle details such as picking a suitable number if repeats,63# fixing the number of threads, and providing a convenient representation of64# the results.65#6667# Measurement objects store the results of multiple repeats, and provide68# various utility features.69from torch.utils.benchmark import Measurement7071m: Measurement = timer.blocked_autorange(min_run_time=1)72print(m)7374###############################################################################75# .. code-block:: none76# :caption: **Snippet wall time.**77#78# <torch.utils.benchmark.utils.common.Measurement object at 0x7f1929a38ed0>79# x * y80# setup:81# x = torch.ones((128,))82# y = torch.ones((128,))83#84# Median: 2.34 us85# IQR: 0.07 us (2.31 to 2.38)86# 424 measurements, 1000 runs per measurement, 1 thread87#8889###############################################################################90# 3. C++ snippets91# ~~~~~~~~~~~~~~~92#9394from torch.utils.benchmark import Language9596cpp_timer = Timer(97"x * y;",98"""99auto x = torch::ones({128});100auto y = torch::ones({128});101""",102language=Language.CPP,103)104105print(cpp_timer.blocked_autorange(min_run_time=1))106107###############################################################################108# .. code-block:: none109# :caption: **C++ snippet wall time.**110#111# <torch.utils.benchmark.utils.common.Measurement object at 0x7f192b019ed0>112# x * y;113# setup:114# auto x = torch::ones({128});115# auto y = torch::ones({128});116#117# Median: 1.21 us118# IQR: 0.03 us (1.20 to 1.23)119# 83 measurements, 10000 runs per measurement, 1 thread120#121122###############################################################################123# Unsurprisingly, the C++ snippet is both faster and has lower variation.124#125126###############################################################################127# 4. Instruction counts: ``Timer.collect_callgrind(...)``128# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~129#130# For deep dive investigations, ``Timer.collect_callgrind`` wraps131# `Callgrind <https://valgrind.org/docs/manual/cl-manual.html>`__ in order to132# collect instruction counts. These are useful as they offer fine grained and133# deterministic (or very low noise in the case of Python) insights into how a134# snippet is run.135#136137from torch.utils.benchmark import CallgrindStats, FunctionCounts138139stats: CallgrindStats = cpp_timer.collect_callgrind()140print(stats)141142###############################################################################143# .. code-block:: none144# :caption: **C++ Callgrind stats (summary)**145#146# <torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.CallgrindStats object at 0x7f1929a35850>147# x * y;148# setup:149# auto x = torch::ones({128});150# auto y = torch::ones({128});151#152# All Noisy symbols removed153# Instructions: 563600 563600154# Baseline: 0 0155# 100 runs per measurement, 1 thread156#157158###############################################################################159# 5. Instruction counts: Delving deeper160# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~161#162# The string representation of ``CallgrindStats`` is similar to that of163# Measurement. `Noisy symbols` are a Python concept (removing calls in the164# CPython interpreter which are known to be noisy).165#166# For more detailed analysis, however, we will want to look at specific calls.167# ``CallgrindStats.stats()`` returns a ``FunctionCounts`` object to make this easier.168# Conceptually, ``FunctionCounts`` can be thought of as a tuple of pairs with some169# utility methods, where each pair is `(number of instructions, file path and170# function name)`.171#172# A note on paths:173# One generally doesn't care about absolute path. For instance, the full path174# and function name for a multiply call is something like:175#176# .. code-block:: sh177#178# /the/prefix/to/your/pytorch/install/dir/pytorch/build/aten/src/ATen/core/TensorMethods.cpp:at::Tensor::mul(at::Tensor const&) const [/the/path/to/your/conda/install/miniconda3/envs/ab_ref/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so]179#180# when in reality, all of the information that we're interested in can be181# represented in:182#183# .. code-block:: sh184#185# build/aten/src/ATen/core/TensorMethods.cpp:at::Tensor::mul(at::Tensor const&) const186#187# ``CallgrindStats.as_standardized()`` makes a best effort to strip low signal188# portions of the file path, as well as the shared object and is generally189# recommended.190#191192inclusive_stats = stats.as_standardized().stats(inclusive=False)193print(inclusive_stats[:10])194195###############################################################################196# .. code-block:: none197# :caption: **C++ Callgrind stats (detailed)**198#199# torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.FunctionCounts object at 0x7f192a6dfd90>200# 47264 ???:_int_free201# 25963 ???:_int_malloc202# 19900 build/../aten/src/ATen/TensorIter ... (at::TensorIteratorConfig const&)203# 18000 ???:__tls_get_addr204# 13500 ???:malloc205# 11300 build/../c10/util/SmallVector.h:a ... (at::TensorIteratorConfig const&)206# 10345 ???:_int_memalign207# 10000 build/../aten/src/ATen/TensorIter ... (at::TensorIteratorConfig const&)208# 9200 ???:free209# 8000 build/../c10/util/SmallVector.h:a ... IteratorBase::get_strides() const210#211# Total: 173472212#213214###############################################################################215# That's still quite a lot to digest. Let's use the `FunctionCounts.transform`216# method to trim some of the function path, and discard the function called.217# When we do, the counts of any collisions (e.g. `foo.h:a()` and `foo.h:b()`218# will both map to `foo.h`) will be added together.219#220221import os222import re223224def group_by_file(fn_name: str):225if fn_name.startswith("???"):226fn_dir, fn_file = fn_name.split(":")[:2]227else:228fn_dir, fn_file = os.path.split(fn_name.split(":")[0])229fn_dir = re.sub("^.*build/../", "", fn_dir)230fn_dir = re.sub("^.*torch/", "torch/", fn_dir)231232return f"{fn_dir:<15} {fn_file}"233234print(inclusive_stats.transform(group_by_file)[:10])235236###############################################################################237# .. code-block:: none238# :caption: **Callgrind stats (condensed)**239#240# <torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.FunctionCounts object at 0x7f192995d750>241# 118200 aten/src/ATen TensorIterator.cpp242# 65000 c10/util SmallVector.h243# 47264 ??? _int_free244# 25963 ??? _int_malloc245# 20900 c10/util intrusive_ptr.h246# 18000 ??? __tls_get_addr247# 15900 c10/core TensorImpl.h248# 15100 c10/core CPUAllocator.cpp249# 13500 ??? malloc250# 12500 c10/core TensorImpl.cpp251#252# Total: 352327253#254255###############################################################################256# 6. A/B testing with ``Callgrind``257# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~258#259# One of the most useful features of instruction counts is they allow fine260# grained comparison of computation, which is critical when analyzing261# performance.262#263# To see this in action, lets compare our multiplication of two size 128264# Tensors with a {128} x {1} multiplication, which will broadcast the second265# Tensor:266# result = {a0 * b0, a1 * b0, ..., a127 * b0}267#268269broadcasting_stats = Timer(270"x * y;",271"""272auto x = torch::ones({128});273auto y = torch::ones({1});274""",275language=Language.CPP,276).collect_callgrind().as_standardized().stats(inclusive=False)277278###############################################################################279# Often we want to A/B test two different environments. (e.g. testing a PR, or280# experimenting with compile flags.) This is quite simple, as ``CallgrindStats``,281# ``FunctionCounts``, and Measurement are all pickleable. Simply save measurements282# from each environment, and load them in a single process for analysis.283#284285import pickle286287# Let's round trip `broadcasting_stats` just to show that we can.288broadcasting_stats = pickle.loads(pickle.dumps(broadcasting_stats))289290291# And now to diff the two tasks:292delta = broadcasting_stats - inclusive_stats293294def extract_fn_name(fn: str):295"""Trim everything except the function name."""296fn = ":".join(fn.split(":")[1:])297return re.sub(r"\(.+\)", "(...)", fn)298299# We use `.transform` to make the diff readable:300print(delta.transform(extract_fn_name))301302303###############################################################################304# .. code-block:: none305# :caption: **Instruction count delta**306#307# <torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.FunctionCounts object at 0x7f192995d750>308# 17600 at::TensorIteratorBase::compute_strides(...)309# 12700 at::TensorIteratorBase::allocate_or_resize_outputs()310# 10200 c10::SmallVectorImpl<long>::operator=(...)311# 7400 at::infer_size(...)312# 6200 at::TensorIteratorBase::invert_perm(...) const313# 6064 _int_free314# 5100 at::TensorIteratorBase::reorder_dimensions()315# 4300 malloc316# 4300 at::TensorIteratorBase::compatible_stride(...) const317# ...318# -28 _int_memalign319# -100 c10::impl::check_tensor_options_and_extract_memory_format(...)320# -300 __memcmp_avx2_movbe321# -400 at::detail::empty_cpu(...)322# -1100 at::TensorIteratorBase::numel() const323# -1300 void at::native::(...)324# -2400 c10::TensorImpl::is_contiguous(...) const325# -6100 at::TensorIteratorBase::compute_fast_setup_type(...)326# -22600 at::TensorIteratorBase::fast_set_up(...)327#328# Total: 58091329#330331###############################################################################332# So the broadcasting version takes an extra 580 instructions per call (recall333# that we're collecting 100 runs per sample), or about 10%. There are quite a334# few ``TensorIterator`` calls, so lets drill down to those. ``FunctionCounts.filter``335# makes this easy.336#337338print(delta.transform(extract_fn_name).filter(lambda fn: "TensorIterator" in fn))339340###############################################################################341# .. code-block:: none342# :caption: **Instruction count delta (filter)**343#344# <torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.FunctionCounts object at 0x7f19299544d0>345# 17600 at::TensorIteratorBase::compute_strides(...)346# 12700 at::TensorIteratorBase::allocate_or_resize_outputs()347# 6200 at::TensorIteratorBase::invert_perm(...) const348# 5100 at::TensorIteratorBase::reorder_dimensions()349# 4300 at::TensorIteratorBase::compatible_stride(...) const350# 4000 at::TensorIteratorBase::compute_shape(...)351# 2300 at::TensorIteratorBase::coalesce_dimensions()352# 1600 at::TensorIteratorBase::build(...)353# -1100 at::TensorIteratorBase::numel() const354# -6100 at::TensorIteratorBase::compute_fast_setup_type(...)355# -22600 at::TensorIteratorBase::fast_set_up(...)356#357# Total: 24000358#359360###############################################################################361# This makes plain what is going on: there is a fast path in ``TensorIterator``362# setup, but in the {128} x {1} case we miss it and have to do a more general363# analysis which is more expensive. The most prominent call omitted by the364# filter is `c10::SmallVectorImpl<long>::operator=(...)`, which is also part365# of the more general setup.366#367368###############################################################################369# 7. Wrapping up370# ~~~~~~~~~~~~~~371#372# In summary, use `Timer.blocked_autorange` to collect wall times. If timing373# variation is too high, increase `min_run_time`, or move to C++ snippets if374# convenient.375#376# For fine grained analysis, use `Timer.collect_callgrind` to measure377# instruction counts and `FunctionCounts.(__add__ / __sub__ / transform / filter)`378# to slice-and-dice them.379#380381###############################################################################382# 8. Footnotes383# ~~~~~~~~~~~~384#385# - Implied ``import torch``386# If `globals` does not contain "torch", Timer will automatically387# populate it. This means that ``Timer("torch.empty(())")`` will work.388# (Though other imports should be placed in `setup`,389# e.g. ``Timer("np.zeros(())", "import numpy as np")``)390#391# - ``REL_WITH_DEB_INFO``392# In order to provide full information about the PyTorch internals which393# are executed, ``Callgrind`` needs access to C++ debug symbols. This is394# accomplished by setting ``REL_WITH_DEB_INFO=1`` when building PyTorch.395# Otherwise function calls will be opaque. (The resultant ``CallgrindStats``396# will warn if debug symbols are missing.)397398399