CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.

GitHub Repository: pytorch/tutorials
Path: blob/main/recipes_source/recipes/benchmark.py
Views: 713
1
"""
2
PyTorch Benchmark
3
====================================
4
This recipe provides a quick-start guide to using PyTorch
5
``benchmark`` module to measure and compare code performance.
6
7
Introduction
8
------------
9
Benchmarking is an important step in writing code. It helps
10
us validate that our code meets performance expectations,
11
compare different approaches to solving the same problem and
12
prevent performance regressions.
13
14
There are many options when it comes to benchmarking PyTorch code
15
including the Python builtin ``timeit`` module. However, benchmarking
16
PyTorch code has many caveats that can be easily overlooked such as
17
managing the number of threads and synchronizing CUDA devices. Moreover,
18
generating Tensor inputs for benchmarking can be quite tedious.
19
20
This recipe demonstrates how to use PyTorch ``benchmark`` module to avoid
21
common mistakes while making it easier to compare performance of
22
different code, generate input for benchmarking and more.
23
24
Setup
25
-----
26
Before we begin, install ``torch`` if it isn’t already available.
27
28
::
29
30
pip install torch
31
32
"""
33
34
35
######################################################################
36
# Steps
37
# -----
38
#
39
# 1. Defining functions to benchmark
40
# 2. Benchmarking with ``timeit.Timer``
41
# 3. Benchmarking with ``torch.utils.benchmark.Timer``
42
# 4. Benchmarking with ``Blocked Autorange``
43
# 5. Comparing benchmark results
44
# 6. Saving/Loading benchmark results
45
# 7. Generating inputs with ``Fuzzed Parameters``
46
# 8. Collecting instruction counts with ``Callgrind``
47
#
48
# 1. Defining functions to benchmark
49
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
50
#
51
# As of the time of this writing, `torch.dot <https://pytorch.org/docs/stable/generated/torch.dot.html?highlight=dot#torch.dot>`__
52
# does not support batched mode, so we will compare two approaches to
53
# implementing it using existing ``torch`` operators: one approach uses a
54
# combination of ``mul`` and ``sum`` while the other reduces the problem to ``bmm``.
55
#
56
57
import torch
58
59
60
def batched_dot_mul_sum(a, b):
61
'''Computes batched dot by multiplying and summing'''
62
return a.mul(b).sum(-1)
63
64
65
def batched_dot_bmm(a, b):
66
'''Computes batched dot by reducing to ``bmm``'''
67
a = a.reshape(-1, 1, a.shape[-1])
68
b = b.reshape(-1, b.shape[-1], 1)
69
return torch.bmm(a, b).flatten(-3)
70
71
72
# Input for benchmarking
73
x = torch.randn(10000, 64)
74
75
# Ensure that both functions compute the same output
76
assert batched_dot_mul_sum(x, x).allclose(batched_dot_bmm(x, x))
77
78
79
######################################################################
80
# 2. Benchmarking with ``timeit.Timer``
81
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
82
#
83
# First, let's benchmark the code using Python's builtin ``timeit`` module.
84
# We keep the benchmark code simple here so we can compare the defaults
85
# of ``timeit`` and ``torch.utils.benchmark``.
86
#
87
88
import timeit
89
90
t0 = timeit.Timer(
91
stmt='batched_dot_mul_sum(x, x)',
92
setup='from __main__ import batched_dot_mul_sum',
93
globals={'x': x})
94
95
t1 = timeit.Timer(
96
stmt='batched_dot_bmm(x, x)',
97
setup='from __main__ import batched_dot_bmm',
98
globals={'x': x})
99
100
print(f'mul_sum(x, x): {t0.timeit(100) / 100 * 1e6:>5.1f} us')
101
print(f'bmm(x, x): {t1.timeit(100) / 100 * 1e6:>5.1f} us')
102
103
######################################################################
104
# .. code-block:: none
105
# :caption: Output
106
#
107
# mul_sum(x, x): 111.6 us
108
# bmm(x, x): 70.0 us
109
#
110
111
112
######################################################################
113
# 3. Benchmarking with ``torch.utils.benchmark.Timer``
114
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
115
#
116
# PyTorch ``benchmark`` module was designed to be familiar to those who
117
# have used the ``timeit`` module before. However, its defaults make it
118
# easier and safer to use for benchmarking PyTorch code. Let's first
119
# compare the same basic API as above.
120
#
121
122
import torch.utils.benchmark as benchmark
123
124
t0 = benchmark.Timer(
125
stmt='batched_dot_mul_sum(x, x)',
126
setup='from __main__ import batched_dot_mul_sum',
127
globals={'x': x})
128
129
t1 = benchmark.Timer(
130
stmt='batched_dot_bmm(x, x)',
131
setup='from __main__ import batched_dot_bmm',
132
globals={'x': x})
133
134
print(t0.timeit(100))
135
print(t1.timeit(100))
136
137
######################################################################
138
# .. code-block:: none
139
# :caption: Output
140
#
141
# <torch.utils.benchmark.utils.common.Measurement object at 0x7fb10400d0f0>
142
# batched_dot_mul_sum(x, x)
143
# setup: from __main__ import batched_dot_mul_sum
144
# 379.29 us
145
# 1 measurement, 100 runs , 1 thread
146
# <torch.utils.benchmark.utils.common.Measurement object at 0x7fb103d67048>
147
# batched_dot_bmm(x, x)
148
# setup: from __main__ import batched_dot_bmm
149
# 716.42 us
150
# 1 measurement, 100 runs , 1 thread
151
#
152
153
######################################################################
154
# Even though the APIs are the same for the basic functionality, there
155
# are some important differences. ``benchmark.Timer.timeit()`` returns the
156
# time per run as opposed to the total runtime like ``timeit.Timer.timeit()``
157
# does. PyTorch ``benchmark`` module also provides formatted string
158
# representations for printing the results.
159
#
160
# Another important difference, and the reason why the results diverge
161
# is that PyTorch benchmark module runs in a single thread by default.
162
# We can change the number of threads with the ``num_threads`` argument.
163
#
164
# ``torch.utils.benchmark.Timer`` takes several additional arguments
165
# including: ``label``, ``sub_label``, ``description`` and ``env`` which change
166
# the __repr__ of the measurement object returned and are used for
167
# grouping the results (more on this later).
168
#
169
170
num_threads = torch.get_num_threads()
171
print(f'Benchmarking on {num_threads} threads')
172
173
t0 = benchmark.Timer(
174
stmt='batched_dot_mul_sum(x, x)',
175
setup='from __main__ import batched_dot_mul_sum',
176
globals={'x': x},
177
num_threads=num_threads,
178
label='Multithreaded batch dot',
179
sub_label='Implemented using mul and sum')
180
181
t1 = benchmark.Timer(
182
stmt='batched_dot_bmm(x, x)',
183
setup='from __main__ import batched_dot_bmm',
184
globals={'x': x},
185
num_threads=num_threads,
186
label='Multithreaded batch dot',
187
sub_label='Implemented using bmm')
188
189
print(t0.timeit(100))
190
print(t1.timeit(100))
191
192
######################################################################
193
# .. code-block:: none
194
# :caption: Output
195
#
196
# Benchmarking on 40 threads
197
# <torch.utils.benchmark.utils.common.Measurement object at 0x7fb103d54080>
198
# Multithreaded batch dot: Implemented using mul and sum
199
# setup: from __main__ import batched_dot_mul_sum
200
# 118.47 us
201
# 1 measurement, 100 runs , 40 threads
202
# <torch.utils.benchmark.utils.common.Measurement object at 0x7fb16935d2e8>
203
# Multithreaded batch dot: Implemented using bmm
204
# setup: from __main__ import batched_dot_bmm
205
# 68.21 us
206
# 1 measurement, 100 runs , 40 threads
207
208
######################################################################
209
# Running ``benchmark`` with all threads available gives similar results
210
# as the ``timeit`` module. More importantly, which version is faster
211
# depends on how many threads we run the code with. This is why it's
212
# important to benchmark the code with thread settings that are
213
# representative of real use cases. Another important thing to remember
214
# is to synchronize CPU and CUDA when benchmarking on the GPU. Let's run
215
# the above benchmarks again on a CUDA tensor and see what happens.
216
#
217
218
x = torch.randn(10000, 1024, device='cuda')
219
220
t0 = timeit.Timer(
221
stmt='batched_dot_mul_sum(x, x)',
222
setup='from __main__ import batched_dot_mul_sum',
223
globals={'x': x})
224
225
t1 = timeit.Timer(
226
stmt='batched_dot_bmm(x, x)',
227
setup='from __main__ import batched_dot_bmm',
228
globals={'x': x})
229
230
# Ran each twice to show difference before/after warm-up
231
print(f'mul_sum(x, x): {t0.timeit(100) / 100 * 1e6:>5.1f} us')
232
print(f'mul_sum(x, x): {t0.timeit(100) / 100 * 1e6:>5.1f} us')
233
print(f'bmm(x, x): {t1.timeit(100) / 100 * 1e6:>5.1f} us')
234
print(f'bmm(x, x): {t1.timeit(100) / 100 * 1e6:>5.1f} us')
235
236
######################################################################
237
# .. code-block:: none
238
# :caption: Output
239
#
240
# mul_sum(x, x): 27.6 us
241
# mul_sum(x, x): 25.3 us
242
# bmm(x, x): 2775.5 us
243
# bmm(x, x): 22.4 us
244
#
245
246
t0 = benchmark.Timer(
247
stmt='batched_dot_mul_sum(x, x)',
248
setup='from __main__ import batched_dot_mul_sum',
249
globals={'x': x})
250
251
t1 = benchmark.Timer(
252
stmt='batched_dot_bmm(x, x)',
253
setup='from __main__ import batched_dot_bmm',
254
globals={'x': x})
255
256
# Run only once since benchmark module does warm-up for us
257
print(t0.timeit(100))
258
print(t1.timeit(100))
259
260
######################################################################
261
# .. code-block:: none
262
# :caption: Output
263
#
264
# <torch.utils.benchmark.utils.common.Measurement object at 0x7fb10400d080>
265
# batched_dot_mul_sum(x, x)
266
# setup: from __main__ import batched_dot_mul_sum
267
# 232.93 us
268
# 1 measurement, 100 runs , 1 thread
269
# <torch.utils.benchmark.utils.common.Measurement object at 0x7fb10400d0f0>
270
# batched_dot_bmm(x, x)
271
# setup: from __main__ import batched_dot_bmm
272
# 181.04 us
273
# 1 measurement, 100 runs , 1 thread
274
#
275
276
######################################################################
277
# The results reveal something interesting. The first run of the ``bmm``
278
# version using the ``timeit`` module takes much longer than the second
279
# run. This is because ``bmm`` calls into `cuBLAS` which needs to be
280
# loaded the first time it's called which takes some time. This is why
281
# it's important to do a warm-up run before benchmarking, luckily for
282
# us, PyTorch's ``benchmark`` module takes care of that.
283
#
284
# The difference in the results between ``timeit`` and ``benchmark`` modules
285
# is because the `timeit` module is not synchronizing CUDA and is thus only
286
# timing the time to launch the kernel. PyTorch's ``benchmark`` module does
287
# the synchronization for us.
288
289
290
######################################################################
291
# 4. Benchmarking with `Blocked Autorange`
292
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
293
#
294
# While ``timeit.Timer.autorange`` takes a single continuous measurement
295
# of at least 0.2 seconds, `torch.utils.benchmark.blocked_autorange`
296
# takes many measurements whose times total at least 0.2 seconds (which
297
# can be changed by the `min_run_time` parameter) subject to the constraint
298
# that timing overhead is a small fraction of the overall measurement.
299
# This is accomplished by first running with an increasing number of runs
300
# per loop until the runtime is much larger than measurement overhead
301
# (which also serves as a warm up), and then taking measurements until
302
# the target time is reached. This has the useful properties that it wastes
303
# less data and allows us to compute statistics to estimate the reliability
304
# of the measurements.
305
#
306
307
m0 = t0.blocked_autorange()
308
m1 = t1.blocked_autorange()
309
310
print(m0)
311
print(m1)
312
313
######################################################################
314
# .. code-block:: none
315
# :caption: Output
316
#
317
# <torch.utils.benchmark.utils.common.Measurement object at 0x7fb10400d0f0>
318
# batched_dot_mul_sum(x, x)
319
# setup: from __main__ import batched_dot_mul_sum
320
# 231.79 us
321
# 1 measurement, 1000 runs , 1 thread
322
# <torch.utils.benchmark.utils.common.Measurement object at 0x7fb10400d080>
323
# batched_dot_bmm(x, x)
324
# setup: from __main__ import batched_dot_bmm
325
# Median: 162.08 us
326
# 2 measurements, 1000 runs per measurement, 1 thread
327
#
328
329
######################################################################
330
# We can also inspect the individual statistics from the returned
331
# measurements object.
332
333
print(f"Mean: {m0.mean * 1e6:6.2f} us")
334
print(f"Median: {m0.median * 1e6:6.2f} us")
335
336
######################################################################
337
# .. code-block:: none
338
# :caption: Output
339
#
340
# Mean: 231.79 us
341
# Median: 231.79 us
342
#
343
344
######################################################################
345
# 5. Comparing benchmark results
346
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
347
#
348
# So far we've been comparing our two versions of batched dot against a
349
# single input. In practice, we want to try a combination of inputs as
350
# well as different number of threads. The ``Compare`` class helps display
351
# the results of many measurements in a formatted table. It uses the
352
# annotations described above (`label`, `sub_label`, `num_threads`, etc.) as
353
# well as `description` to group and organize the table. Let's use
354
# ``Compare`` to see how our functions perform for different input sizes
355
# and number of threads.
356
#
357
358
from itertools import product
359
360
# Compare takes a list of measurements which we'll save in results.
361
results = []
362
363
sizes = [1, 64, 1024, 10000]
364
for b, n in product(sizes, sizes):
365
# label and sub_label are the rows
366
# description is the column
367
label = 'Batched dot'
368
sub_label = f'[{b}, {n}]'
369
x = torch.ones((b, n))
370
for num_threads in [1, 4, 16, 32]:
371
results.append(benchmark.Timer(
372
stmt='batched_dot_mul_sum(x, x)',
373
setup='from __main__ import batched_dot_mul_sum',
374
globals={'x': x},
375
num_threads=num_threads,
376
label=label,
377
sub_label=sub_label,
378
description='mul/sum',
379
).blocked_autorange(min_run_time=1))
380
results.append(benchmark.Timer(
381
stmt='batched_dot_bmm(x, x)',
382
setup='from __main__ import batched_dot_bmm',
383
globals={'x': x},
384
num_threads=num_threads,
385
label=label,
386
sub_label=sub_label,
387
description='bmm',
388
).blocked_autorange(min_run_time=1))
389
390
compare = benchmark.Compare(results)
391
compare.print()
392
393
######################################################################
394
# .. code-block:: none
395
# :caption: Output
396
#
397
# [--------------- Batched dot ----------------]
398
# | mul/sum | bmm
399
# 1 threads: -----------------------------------
400
# [1, 1] | 5.9 | 11.2
401
# [1, 64] | 6.4 | 11.4
402
# [1, 1024] | 6.7 | 14.2
403
# [1, 10000] | 10.2 | 23.7
404
# [64, 1] | 6.3 | 11.5
405
# [64, 64] | 8.6 | 15.4
406
# [64, 1024] | 39.4 | 204.4
407
# [64, 10000] | 274.9 | 748.5
408
# [1024, 1] | 7.7 | 17.8
409
# [1024, 64] | 40.3 | 76.4
410
# [1024, 1024] | 432.4 | 2795.9
411
# [1024, 10000] | 22657.3 | 11899.5
412
# [10000, 1] | 16.9 | 74.8
413
# [10000, 64] | 300.3 | 609.4
414
# [10000, 1024] | 23098.6 | 27246.1
415
# [10000, 10000] | 267073.7 | 118823.7
416
# 4 threads: -----------------------------------
417
# [1, 1] | 6.0 | 11.5
418
# [1, 64] | 6.2 | 11.2
419
# [1, 1024] | 6.8 | 14.3
420
# [1, 10000] | 10.2 | 23.7
421
# [64, 1] | 6.3 | 16.2
422
# [64, 64] | 8.8 | 18.2
423
# [64, 1024] | 41.5 | 189.1
424
# [64, 10000] | 91.7 | 849.1
425
# [1024, 1] | 7.6 | 17.4
426
# [1024, 64] | 43.5 | 33.5
427
# [1024, 1024] | 135.4 | 2782.3
428
# [1024, 10000] | 7471.1 | 11874.0
429
# [10000, 1] | 16.8 | 33.9
430
# [10000, 64] | 118.7 | 173.2
431
# [10000, 1024] | 7264.6 | 27824.7
432
# [10000, 10000] | 100060.9 | 121499.0
433
# 16 threads: ----------------------------------
434
# [1, 1] | 6.0 | 11.3
435
# [1, 64] | 6.2 | 11.2
436
# [1, 1024] | 6.9 | 14.2
437
# [1, 10000] | 10.3 | 23.8
438
# [64, 1] | 6.4 | 24.1
439
# [64, 64] | 9.0 | 23.8
440
# [64, 1024] | 54.1 | 188.5
441
# [64, 10000] | 49.9 | 748.0
442
# [1024, 1] | 7.6 | 23.4
443
# [1024, 64] | 55.5 | 28.2
444
# [1024, 1024] | 66.9 | 2773.9
445
# [1024, 10000] | 6111.5 | 12833.7
446
# [10000, 1] | 16.9 | 27.5
447
# [10000, 64] | 59.5 | 73.7
448
# [10000, 1024] | 6295.9 | 27062.0
449
# [10000, 10000] | 71804.5 | 120365.8
450
# 32 threads: ----------------------------------
451
# [1, 1] | 5.9 | 11.3
452
# [1, 64] | 6.2 | 11.3
453
# [1, 1024] | 6.7 | 14.2
454
# [1, 10000] | 10.5 | 23.8
455
# [64, 1] | 6.3 | 31.7
456
# [64, 64] | 9.1 | 30.4
457
# [64, 1024] | 72.0 | 190.4
458
# [64, 10000] | 103.1 | 746.9
459
# [1024, 1] | 7.6 | 28.4
460
# [1024, 64] | 70.5 | 31.9
461
# [1024, 1024] | 65.6 | 2804.6
462
# [1024, 10000] | 6764.0 | 11871.4
463
# [10000, 1] | 17.8 | 31.8
464
# [10000, 64] | 110.3 | 56.0
465
# [10000, 1024] | 6640.2 | 27592.2
466
# [10000, 10000] | 73003.4 | 120083.2
467
#
468
# Times are in microseconds (us).
469
#
470
471
######################################################################
472
# The results above indicate that the version which reduces to ``bmm``
473
# is better for larger tensors running on multiple threads, while for
474
# smaller and/or single thread code, the other version is better.
475
#
476
# ``Compare`` also provides functions for changing the table format
477
#
478
479
compare.trim_significant_figures()
480
compare.colorize()
481
compare.print()
482
483
484
######################################################################
485
# 6. Saving/Loading benchmark results
486
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
487
#
488
# `Measurements` (and ``CallgrindStats`` which are described in section 8)
489
# can be serialized by the ``pickle`` module. This makes A/B testing easy, as you can collect
490
# measurements from two separate environments, pickle them, and then
491
# load both in a single environment. Timer even takes an `env`
492
# constructor argument so that such A/B testing works seamlessly.
493
#
494
# Let's imagine that rather than two Python functions, the add/sum
495
# and ``bmm`` approaches were in two different builds of PyTorch.
496
# The example below demonstrates how one might A/B test them. For
497
# simplicity, we only use a subset of shapes, and simply round trip
498
# results through pickle rather than actually using multiple environments
499
# and writing results to disk.
500
#
501
502
import pickle
503
504
ab_test_results = []
505
for env in ('environment A: mul/sum', 'environment B: bmm'):
506
for b, n in ((1, 1), (1024, 10000), (10000, 1)):
507
x = torch.ones((b, n))
508
dot_fn = (batched_dot_mul_sum if env == 'environment A: mul/sum' else batched_dot_bmm)
509
m = benchmark.Timer(
510
stmt='batched_dot(x, x)',
511
globals={'x': x, 'batched_dot': dot_fn},
512
num_threads=1,
513
label='Batched dot',
514
description=f'[{b}, {n}]',
515
env=env,
516
).blocked_autorange(min_run_time=1)
517
ab_test_results.append(pickle.dumps(m))
518
519
ab_results = [pickle.loads(i) for i in ab_test_results]
520
compare = benchmark.Compare(ab_results)
521
compare.trim_significant_figures()
522
compare.colorize()
523
compare.print()
524
525
######################################################################
526
# .. code-block:: none
527
# :caption: Output
528
#
529
# [------------------------------------- Batched dot -------------------------------------]
530
# | [1, 1] | [1024, 10000] | [10000, 1]
531
# 1 threads: ------------------------------------------------------------------------------
532
# (environment A: mul/sum) batched_dot(x, x) | 7 | 36000 | 21
533
# (environment B: bmm) batched_dot(x, x) | 14 | 40000 | 85
534
#
535
# Times are in microseconds (us).
536
#
537
538
# And just to show that we can round trip all of the results from earlier:
539
round_tripped_results = pickle.loads(pickle.dumps(results))
540
assert(str(benchmark.Compare(results)) == str(benchmark.Compare(round_tripped_results)))
541
542
543
######################################################################
544
# 7. Generating inputs with `Fuzzed Parameters`
545
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
546
#
547
# As we've seen in the previous section, there can be some stark
548
# performance differences depending on the input tensors. Hence, it
549
# is a good idea to run benchmarks on a number of different inputs.
550
# However, creating all these input tensors can be tedious which is
551
# where ``torch.utils.benchmark.Fuzzer`` and related classes come in.
552
# Let's take a look at how we can use the ``Fuzzer`` to create some test
553
# cases for the benchmark.
554
#
555
556
from torch.utils.benchmark import Fuzzer, FuzzedParameter, FuzzedTensor, ParameterAlias
557
558
# Generates random tensors with 128 to 10000000 elements and sizes k0 and k1 chosen from a
559
# ``loguniform`` distribution in [1, 10000], 40% of which will be discontiguous on average.
560
example_fuzzer = Fuzzer(
561
parameters = [
562
FuzzedParameter('k0', minval=1, maxval=10000, distribution='loguniform'),
563
FuzzedParameter('k1', minval=1, maxval=10000, distribution='loguniform'),
564
],
565
tensors = [
566
FuzzedTensor('x', size=('k0', 'k1'), min_elements=128, max_elements=10000000, probability_contiguous=0.6)
567
],
568
seed=0,
569
)
570
571
results = []
572
for tensors, tensor_params, params in example_fuzzer.take(10):
573
# description is the column label
574
sub_label=f"{params['k0']:<6} x {params['k1']:<4} {'' if tensor_params['x']['is_contiguous'] else '(discontiguous)'}"
575
results.append(benchmark.Timer(
576
stmt='batched_dot_mul_sum(x, x)',
577
setup='from __main__ import batched_dot_mul_sum',
578
globals=tensors,
579
label='Batched dot',
580
sub_label=sub_label,
581
description='mul/sum',
582
).blocked_autorange(min_run_time=1))
583
results.append(benchmark.Timer(
584
stmt='batched_dot_bmm(x, x)',
585
setup='from __main__ import batched_dot_bmm',
586
globals=tensors,
587
label='Batched dot',
588
sub_label=sub_label,
589
description='bmm',
590
).blocked_autorange(min_run_time=1))
591
592
compare = benchmark.Compare(results)
593
compare.trim_significant_figures()
594
compare.print()
595
596
######################################################################
597
# .. code-block:: none
598
# :caption: Output
599
#
600
# [--------------------- Batched dot ---------------------]
601
# | mul/sum | bmm
602
# 1 threads: ----------------------------------------------
603
# 725 x 257 | 87 | 180
604
# 49 x 383 | 15 | 30
605
# 34 x 1468 | 30 | 118
606
# 187 x 5039 | 400 | 1200
607
# 2140 x 1296 (discontiguous) | 2000 | 41000
608
# 78 x 1598 | 74 | 310
609
# 519 x 763 | 190 | 1500
610
# 141 x 1082 | 87 | 500
611
# 78 x 5 (discontiguous) | 9 | 20
612
# 187 x 1 | 12 | 10
613
#
614
# Times are in microseconds (us).
615
#
616
617
######################################################################
618
# There is a lot of flexibility for defining your own ``fuzzers`` which
619
# is great for creating a powerful set of inputs to benchmark. But to
620
# make things even simpler, PyTorch benchmark module comes with some
621
# built-in ``fuzzers`` for common benchmarking needs. Let's take a look at
622
# how we can use one of these built-in ``fuzzers``.
623
#
624
625
from torch.utils.benchmark.op_fuzzers import binary
626
627
results = []
628
for tensors, tensor_params, params in binary.BinaryOpFuzzer(seed=0).take(10):
629
sub_label=f"{params['k0']:<6} x {params['k1']:<4} {'' if tensor_params['x']['is_contiguous'] else '(discontiguous)'}"
630
results.append(benchmark.Timer(
631
stmt='batched_dot_mul_sum(x, x)',
632
setup='from __main__ import batched_dot_mul_sum',
633
globals=tensors,
634
label='Batched dot',
635
sub_label=sub_label,
636
description='mul/sum',
637
).blocked_autorange(min_run_time=1))
638
results.append(benchmark.Timer(
639
stmt='batched_dot_bmm(x, x)',
640
setup='from __main__ import batched_dot_bmm',
641
globals=tensors,
642
label='Batched dot',
643
sub_label=sub_label,
644
description='bmm',
645
).blocked_autorange(min_run_time=1))
646
647
compare = benchmark.Compare(results)
648
compare.trim_significant_figures()
649
compare.colorize(rowwise=True)
650
compare.print()
651
652
######################################################################
653
# .. code-block:: none
654
# :caption: Output
655
#
656
# [----------------------- Batched dot ------------------------]
657
# | mul/sum | bmm
658
# 1 threads: ---------------------------------------------------
659
# 64 x 473 (discontiguous) | 10000 | 40000
660
# 16384 x 12642115 (discontiguous) | 31 | 78
661
# 8192 x 892 | 4800 | 20400
662
# 512 x 64 (discontiguous) | 110000 | 400000
663
# 493 x 27 (discontiguous) | 1100 | 2440
664
# 118 x 32 (discontiguous) | 870 | 2030
665
# 16 x 495 (discontiguous) | 23600 | 24000
666
# 488 x 62374 | 90000 | 100000
667
# 240372 x 69 | 40000 | 16000
668
# 40156 x 32 (discontiguous) | 2670 | 5000
669
#
670
# Times are in microseconds (us).
671
#
672
673
######################################################################
674
# 8. Collecting instruction counts with ``Callgrind``
675
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
676
#
677
# One of the challenges of optimizing code is the variation and opacity of
678
# wall time. There are many sources of non-determinism, from adaptive clock
679
# speeds to resource contention with other processes. Furthermore, end-to-end
680
# time gives no insight into where time is being spent, which is really what
681
# we're interested in when optimizing code.
682
#
683
# A complementary approach is to also collect instruction counts. These counts
684
# are a proxy metric and do not capture all aspects of performance
685
# (e.g. memory or I/O bound tasks), however they do have several useful
686
# properties. Instruction counts are reproducible, insensitive to environmental
687
# variation, and offer fine grained insight into where a program is spending
688
# cycles.
689
#
690
# To see the utility of instruction counts, let us look at how we might
691
# reduce the overhead of `batched_dot_mul_sum`. The obvious solution is to
692
# move it to C++, so we avoid going between Python and C++ multiple times.
693
#
694
# Fortunately, the source is nearly identical. One question that we have to ask
695
# in C++ is whether we should take arguments by value or reference.
696
#
697
698
batched_dot_src = """\
699
/* ---- Python ---- */
700
// def batched_dot_mul_sum(a, b):
701
// return a.mul(b).sum(-1)
702
703
torch::Tensor batched_dot_mul_sum_v0(
704
const torch::Tensor a,
705
const torch::Tensor b) {
706
return a.mul(b).sum(-1);
707
}
708
709
torch::Tensor batched_dot_mul_sum_v1(
710
const torch::Tensor& a,
711
const torch::Tensor& b) {
712
return a.mul(b).sum(-1);
713
}
714
"""
715
716
717
# PyTorch makes it easy to test our C++ implementations by providing a utility
718
# to JIT compile C++ source into Python extensions:
719
import os
720
from torch.utils import cpp_extension
721
cpp_lib = cpp_extension.load_inline(
722
name='cpp_lib',
723
cpp_sources=batched_dot_src,
724
extra_cflags=['-O3'],
725
extra_include_paths=[
726
# `load_inline` needs to know where to find ``pybind11`` headers.
727
os.path.join(os.getenv('CONDA_PREFIX'), 'include')
728
],
729
functions=['batched_dot_mul_sum_v0', 'batched_dot_mul_sum_v1']
730
)
731
732
# `load_inline` will create a shared object that is loaded into Python. When we collect
733
# instruction counts Timer will create a subprocess, so we need to re-import it. The
734
# import process is slightly more complicated for C extensions, but that's all we're
735
# doing here.
736
module_import_str = f"""\
737
# https://stackoverflow.com/questions/67631/how-to-import-a-module-given-the-full-path
738
import importlib.util
739
spec = importlib.util.spec_from_file_location("cpp_lib", {repr(cpp_lib.__file__)})
740
cpp_lib = importlib.util.module_from_spec(spec)
741
spec.loader.exec_module(cpp_lib)"""
742
743
import textwrap
744
def pretty_print(result):
745
"""Import machinery for ``cpp_lib.so`` can get repetitive to look at."""
746
print(repr(result).replace(textwrap.indent(module_import_str, " "), " import cpp_lib"))
747
748
749
t_baseline = benchmark.Timer(
750
stmt='batched_dot_mul_sum(x, x)',
751
setup='''\
752
from __main__ import batched_dot_mul_sum
753
x = torch.randn(2, 2)''')
754
755
t0 = benchmark.Timer(
756
stmt='cpp_lib.batched_dot_mul_sum_v0(x, x)',
757
setup=f'''\
758
{module_import_str}
759
x = torch.randn(2, 2)''')
760
761
t1 = benchmark.Timer(
762
stmt='cpp_lib.batched_dot_mul_sum_v1(x, x)',
763
setup=f'''\
764
{module_import_str}
765
x = torch.randn(2, 2)''')
766
767
# Moving to C++ did indeed reduce overhead, but it's hard to tell which
768
# calling convention is more efficient. v1 (call with references) seems to
769
# be a bit faster, but it's within measurement error.
770
pretty_print(t_baseline.blocked_autorange())
771
pretty_print(t0.blocked_autorange())
772
pretty_print(t1.blocked_autorange())
773
774
######################################################################
775
# .. code-block:: none
776
# :caption: Output
777
#
778
# <torch.utils.benchmark.utils.common.Measurement object at 0x7fb16935d2e8>
779
# batched_dot_mul_sum(x, x)
780
# setup:
781
# from __main__ import batched_dot_mul_sum
782
# x = torch.randn(2, 2)
783
#
784
# 6.92 us
785
# 1 measurement, 100000 runs , 1 thread
786
# <torch.utils.benchmark.utils.common.Measurement object at 0x7fb16935d2e8>
787
# cpp_lib.batched_dot_mul_sum_v0(x, x)
788
# setup:
789
# import cpp_lib
790
# x = torch.randn(2, 2)
791
#
792
# 5.29 us
793
# 1 measurement, 100000 runs , 1 thread
794
# <torch.utils.benchmark.utils.common.Measurement object at 0x7fb16935d2e8>
795
# cpp_lib.batched_dot_mul_sum_v1(x, x)
796
# setup:
797
# import cpp_lib
798
# x = torch.randn(2, 2)
799
#
800
# 5.22 us
801
# 1 measurement, 100000 runs , 1 thread
802
#
803
804
# Let's use ``Callgrind`` to determine which is better.
805
stats_v0 = t0.collect_callgrind()
806
stats_v1 = t1.collect_callgrind()
807
808
pretty_print(stats_v0)
809
pretty_print(stats_v1)
810
811
# `.as_standardized` removes file names and some path prefixes, and makes
812
# it easier to read the function symbols.
813
stats_v0 = stats_v0.as_standardized()
814
stats_v1 = stats_v1.as_standardized()
815
816
# `.delta` diffs the instruction counts, and `.denoise` removes several
817
# functions in the Python interpreter that are known to have significant
818
# jitter.
819
delta = stats_v1.delta(stats_v0).denoise()
820
821
# `.transform` is a convenience API for transforming function names. It is
822
# useful for increasing cancelation when ``diff-ing`` instructions, as well as
823
# just generally improving readability.
824
replacements = (
825
("???:void pybind11", "pybind11"),
826
("batched_dot_mul_sum_v0", "batched_dot_mul_sum_v1"),
827
("at::Tensor, at::Tensor", "..."),
828
("at::Tensor const&, at::Tensor const&", "..."),
829
("auto torch::detail::wrap_pybind_function_impl_", "wrap_pybind_function_impl_"),
830
)
831
for before, after in replacements:
832
delta = delta.transform(lambda l: l.replace(before, after))
833
834
# We can use print options to control how much of the function to display.
835
torch.set_printoptions(linewidth=160)
836
837
# Once parsed, the instruction counts make clear that passing `a` and `b`
838
# by reference is more efficient as it skips some ``c10::TensorImpl`` bookkeeping
839
# for the intermediate Tensors, and is also works better with ``pybind11``. This
840
# is consistent with our noisy wall time observations.
841
print(delta)
842
843
######################################################################
844
# .. code-block::
845
#
846
# <torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.CallgrindStats object at 0x7fb0f06e7630>
847
# cpp_lib.batched_dot_mul_sum_v0(x, x)
848
# setup:
849
# import cpp_lib
850
# x = torch.randn(2, 2)
851
# All Noisy symbols removed
852
# Instructions: 2392671 2392671
853
# Baseline: 4367 4367
854
# 100 runs per measurement, 1 thread
855
# Warning: PyTorch was not built with debug symbols.
856
# Source information may be limited. Rebuild with
857
# REL_WITH_DEB_INFO=1 for more detailed results.
858
# <torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.CallgrindStats object at 0x7fb10400d208>
859
# cpp_lib.batched_dot_mul_sum_v1(x, x)
860
# setup:
861
# import cpp_lib
862
# x = torch.randn(2, 2)
863
# All Noisy symbols removed
864
# Instructions: 2378978 2378978
865
# Baseline: 4367 4367
866
# 100 runs per measurement, 1 thread
867
# Warning: PyTorch was not built with debug symbols.
868
# Source information may be limited. Rebuild with
869
# REL_WITH_DEB_INFO=1 for more detailed results.
870
# <torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.FunctionCounts object at 0x7fb1000ab358>
871
# 86 ???:0x000000000020d9e0
872
# 56 ???:0x000000000020db10
873
# -1100 pybind11::cpp_function::initialize<wrap_pybind_function_impl_<at::Tensor ... r (&)(...), std::integer_sequence<unsigned long, 0ul, 1ul>)::{lambda(...)
874
# -1600 ???:wrap_pybind_function_impl_<at::Tensor (&)(...), 0ul, 1ul>(at::Tensor (&)(...), std::integer_sequence<unsigned long, 0ul, 1ul>)::{lambda(...)
875
# -5200 ???:c10::intrusive_ptr<c10::TensorImpl, c10::UndefinedTensorImpl>::reset_()
876
# -5935 ???:0x000000000022c0e0
877
# Total: -13693
878
#
879
880
881
######################################################################
882
# Learn More
883
# ----------
884
#
885
# Take a look at these other recipes to continue your learning:
886
#
887
# - `PyTorch Profiler <https://pytorch.org/tutorials/recipes/recipes/profiler.html>`_
888
#
889
890