Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pytorch
GitHub Repository: pytorch/tutorials
Path: blob/main/advanced_source/cuda_graph_annotations_tutorial.py
6463 views
1
# -*- coding: utf-8 -*-
2
"""
3
.. _cuda-graph-annotations-tutorial:
4
5
CUDA Graph Kernel Annotations and Profiling
6
============================================
7
8
**Author**: `Shangdi Yu <https://github.com/yushangdi>`_
9
10
.. grid:: 2
11
12
.. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn
13
:class-card: card-prerequisites
14
15
* How to capture CUDA graphs with kernel annotations
16
* How to profile annotated graphs
17
* How to post-process traces with semantic kernel lanes
18
* How to visualize graph execution with custom stream assignments
19
* How to annotate communication collectives with the metadata
20
(collective type, message size, group, rank) that eager NCCL
21
traces expose but CUDA graphs drop
22
23
.. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites
24
:class-card: card-prerequisites
25
26
* PyTorch 2.12+
27
* CUDA-capable GPU
28
* Driver/CUDA-compat >= 13.1 for annotation support
29
* cuda-bindings >= 13.1.0
30
* perfetto (``pip install perfetto``)
31
32
CUDA graphs are a powerful optimization technique that can significantly reduce
33
kernel launch overhead by capturing and replaying sequences of CUDA operations.
34
However, when profiling CUDA graphs, all kernels appear on the same stream,
35
making it difficult to understand the logical structure of your computation.
36
37
This tutorial demonstrates how to use **kernel annotations** to add semantic
38
labels to kernels within CUDA graphs. These annotations can be merged back into
39
profiler traces to create custom visualization lanes, making it easier to
40
understand and debug complex graph executions.
41
42
Annotations are not limited to compute kernels. One of the most valuable uses
43
is annotating **communication collectives**. In eager mode, the profiler
44
attaches rich metadata to every NCCL kernel -- the collective type, message
45
size, process group, and ranks -- so you can see exactly what each comm is
46
doing. Under CUDA graphs that metadata is lost: the collective replays as an
47
opaque kernel. This tutorial shows how to re-attach that metadata with
48
annotations so graphed comms read just like eager ones.
49
"""
50
51
###############################################################################
52
# Overview
53
# --------
54
#
55
# CUDA graph kernel annotations allow you to add semantic labels to kernels
56
# during graph capture. These labels help you understand what each kernel does
57
# when profiling, making it easy to identify which parts of your model (e.g.,
58
# attention, MLP, normalization) are executing at any given time.
59
#
60
# Without annotations, profiler traces show all kernels on a single stream with
61
# auto-generated names, making it difficult to understand the logical structure
62
# of your computation. With annotations, you can:
63
#
64
# 1. **Label kernel groups** with meaningful names during capture
65
# 2. **Assign custom stream IDs** for visual organization
66
# 3. **Merge labels into profiler traces** for semantic visualization
67
#
68
# The result is a profiler trace where kernels are labeled and organized by
69
# their function, making it much easier to identify performance bottlenecks
70
# and understand execution flow.
71
#
72
# **Before annotations:** All kernels appear on a single stream with
73
# auto-generated names, making it difficult to understand which operations
74
# belong to which logical component of your model.
75
#
76
# .. image:: /_static/img/cuda_graph_trace_before.png
77
# :width: 80%
78
# :alt: CUDA graph trace before annotations showing all kernels on one stream
79
#
80
# **After annotations:** Kernels are organized into semantic lanes (streams 61
81
# and 62) with meaningful labels like "attention" and "mlp", making it easy to
82
# identify different components and understand the execution structure.
83
#
84
# .. image:: /_static/img/cuda_graph_trace_after.png
85
# :width: 80%
86
# :alt: CUDA graph trace after annotations showing kernels organized by function
87
#
88
# As another example, here is an AllReduce kernel with annotated metadata:
89
#
90
# .. image:: /_static/img/annotated_cudagraph.png
91
# :width: 80%
92
# :alt: AllReduce kernel with annotated metadata
93
#
94
# Requirements
95
# ------------
96
#
97
# For this tutorial, you'll need:
98
#
99
# - PyTorch 2.12+
100
# - A CUDA GPU
101
# - Driver/CUDA-compat >= 13.1 for annotation support
102
# - The ``cuda-bindings`` package >= 13.1.0 (``pip install cuda-python``)
103
# - The ``perfetto`` package for writing the trace (``pip install perfetto``)
104
#
105
# The cuda-bindings package provides the Python bindings for CUDA runtime APIs.
106
# Version 13.1.0+ is required for the ``cudaGraphNodeGetToolsId`` API that
107
# enables kernel annotations. If you have an older version, the tutorial will
108
# run but annotations will be disabled with a warning message explaining how
109
# to upgrade.
110
#
111
# On older drivers or cuda-bindings versions, the capture and profiling will
112
# still work, but ``mark_kernels`` will be a no-op and no semantic lanes will
113
# appear in the final trace.
114
115
import copy
116
import hashlib
117
import json
118
import math
119
import os
120
import pickle
121
import sys
122
from collections import Counter, defaultdict
123
from pathlib import Path
124
125
import torch
126
import torch.distributed as dist
127
import torch.multiprocessing
128
from torch.profiler import profile, ProfilerActivity
129
from torch.cuda._graph_annotations import (
130
get_kernel_annotations,
131
get_stream_for_pg,
132
mark_kernels,
133
_is_tools_id_unavailable,
134
)
135
from torch.cuda._annotate_cuda_graph_trace import (
136
annotate_trace,
137
load_trace,
138
)
139
140
###############################################################################
141
# Building a Model
142
# ----------------
143
#
144
# Let's create a simple transformer block as our example model. We'll annotate
145
# different parts of the computation (QKV projection, attention, output
146
# projection, MLP) to see them as separate lanes in the profiler.
147
148
def build_transformer_block():
149
"""Create a simple transformer block with parameters."""
150
device = "cuda"
151
torch.manual_seed(0)
152
153
# Model dimensions
154
batch_size, seq_len, dim, num_heads = 4, 256, 1024, 8
155
head_dim = dim // num_heads
156
157
# Initialize parameters
158
params = {
159
"x": torch.randn(batch_size, seq_len, dim, device=device),
160
"Wqkv": torch.randn(dim, 3 * dim, device=device) / math.sqrt(dim),
161
"Wo": torch.randn(dim, dim, device=device) / math.sqrt(dim),
162
"W1": torch.randn(dim, 4 * dim, device=device) / math.sqrt(dim),
163
"W2": torch.randn(4 * dim, dim, device=device) / math.sqrt(4 * dim),
164
}
165
166
def forward():
167
"""Forward pass with annotated regions."""
168
B, T, D, H = batch_size, seq_len, dim, num_heads
169
hd = head_dim
170
171
# Annotate QKV projection
172
with mark_kernels({"name": "qkv_proj"}):
173
qkv = params["x"] @ params["Wqkv"]
174
175
# Reshape for multi-head attention
176
q, k, v = qkv.split(D, dim=-1)
177
q = q.view(B, T, H, hd).transpose(1, 2)
178
k = k.view(B, T, H, hd).transpose(1, 2)
179
v = v.view(B, T, H, hd).transpose(1, 2)
180
181
# Annotate attention computation (optionally on a custom stream)
182
with mark_kernels({"name": "attention", "stream": 62}):
183
scores = (q @ k.transpose(-1, -2)) / math.sqrt(hd)
184
attn = torch.softmax(scores, dim=-1)
185
ctx = (attn @ v).transpose(1, 2).reshape(B, T, D)
186
187
# Annotate output projection
188
with mark_kernels({"name": "out_proj"}):
189
o = ctx @ params["Wo"]
190
191
# Annotate MLP (on another custom stream)
192
with mark_kernels({"name": "mlp", "stream": 61}):
193
return torch.nn.functional.gelu(o @ params["W1"]) @ params["W2"]
194
195
return forward
196
197
###############################################################################
198
# The ``mark_kernels`` Context Manager
199
# -------------------------------------
200
#
201
# The key API is ``mark_kernels()``, which takes a dictionary with:
202
#
203
# - ``name``: A string label for this kernel group (becomes the lane name)
204
# - ``stream`` (optional): A virtual stream ID for visualization
205
#
206
# Any CUDA kernels launched within the context will be tagged with these
207
# annotations. Later, when we post-process the profiler trace, these tags
208
# will be used to organize kernels into custom lanes.
209
210
###############################################################################
211
# Capturing a CUDA Graph with Annotations
212
# ----------------------------------------
213
#
214
# To capture a graph with annotations enabled, we pass
215
# ``enable_annotations=True`` to ``torch.cuda.graph()``. This automatically
216
# handles the annotation lifecycle: enabling, resolving, and remapping.
217
218
def capture_graph_with_annotations(model_fn):
219
"""Capture the model into a CUDA graph with annotations enabled."""
220
# Warm up on a side stream before capture
221
warmup_stream = torch.cuda.Stream()
222
warmup_stream.wait_stream(torch.cuda.current_stream())
223
224
with torch.cuda.stream(warmup_stream):
225
for _ in range(3):
226
model_fn()
227
228
torch.cuda.current_stream().wait_stream(warmup_stream)
229
230
# Capture with annotations enabled
231
graph = torch.cuda.CUDAGraph()
232
with torch.cuda.graph(graph, enable_annotations=True):
233
output = model_fn()
234
235
num_annotations = len(get_kernel_annotations())
236
print(f"Captured graph with {num_annotations} annotated nodes")
237
238
return graph, output
239
240
###############################################################################
241
# Profiling the Graph
242
# -------------------
243
#
244
# After capturing the graph, we replay it a few times to warm up, then profile
245
# subsequent replays. The profiler will record kernel execution times, which
246
# we'll later merge with our annotations.
247
248
def profile_graph(graph, output_dir):
249
"""Profile graph replays and save the trace."""
250
output_dir = Path(output_dir)
251
output_dir.mkdir(exist_ok=True, parents=True)
252
253
# Warm up replays
254
for _ in range(3):
255
graph.replay()
256
torch.cuda.synchronize()
257
258
# Profile several replays
259
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
260
for _ in range(5):
261
graph.replay()
262
torch.cuda.synchronize()
263
264
# Export the raw trace
265
trace_path = output_dir / "trace_raw.json.gz"
266
prof.export_chrome_trace(str(trace_path))
267
print(f"Saved raw trace to {trace_path}")
268
269
return trace_path
270
271
###############################################################################
272
# Saving Annotation Metadata
273
# ---------------------------
274
#
275
# We need to save the annotation metadata in a pickle file that the
276
# post-processing tool can discover. The file should be named
277
# ``kernel_annotations_rank0_fwd_bwd.pkl`` and placed where the trace tool
278
# can find it.
279
280
def save_annotations(output_dir):
281
"""Save kernel annotations to a pickle file."""
282
output_dir = Path(output_dir)
283
output_dir.mkdir(exist_ok=True, parents=True)
284
annotations_path = output_dir / "kernel_annotations_rank0_fwd_bwd.pkl"
285
286
annotations = dict(get_kernel_annotations())
287
with open(annotations_path, "wb") as f:
288
pickle.dump(annotations, f)
289
290
print(f"Saved {len(annotations)} annotations to {annotations_path}")
291
return annotations_path
292
293
###############################################################################
294
# Post-Processing: Merging Annotations into Traces
295
# -------------------------------------------------
296
#
297
# The final step is to merge the annotations back into the trace. This involves:
298
#
299
# 1. Loading the raw trace and annotations
300
# 2. Calling ``annotate_trace()`` to apply the annotations
301
# 3. Emitting a native Perfetto ``.pftrace`` that preserves overlapping kernels
302
# on their real stream
303
#
304
# The result is a trace where kernels are organized by your semantic labels.
305
#
306
# **Why a Perfetto protobuf trace (not Chrome JSON)?** A Chrome JSON trace --
307
# the format ``torch.profiler.export_chrome_trace`` produces -- has a
308
# fundamental limitation: a single track (a ``(pid, tid)`` row) can only show
309
# **properly nested** slices, never crossing/overlapping ones.
310
#
311
# Perfetto's native **protobuf** trace (``.pftrace``) solves this
312
# via the ``TrackDescriptor`` field ``sibling_merge_key``. We split
313
# overlapping slices across hidden *backing* tracks (so each protobuf
314
# begin/end stack stays validly nested), then give those backing tracks the
315
# **same** ``sibling_merge_key`` so the Perfetto UI merges them back into a
316
# single logical row. Nothing is relocated to a fake stream and no timestamp is
317
# clamped -- the overlap is shown faithfully on the kernel's real stream.
318
#
319
# This converter is adapted from Driss Guessous's `transformer_nuggets
320
# <https://github.com/drisspg/transformer_nuggets>`_
321
# (``transformer_nuggets/utils/track_event.py``); we inline a compact,
322
# self-contained version here. It needs the ``perfetto`` package
323
# (``pip install perfetto``).
324
325
def _stable_uuid(*parts):
326
"""A stable 60-bit track UUID derived from its identifying parts."""
327
digest = hashlib.sha1(":".join(str(p) for p in parts).encode()).hexdigest()
328
return int(digest[:15], 16)
329
330
331
def _assign_nesting_lanes(slices):
332
"""Split overlapping slices into backing lanes so each lane is nestable.
333
334
A lane only holds slices that are either disjoint or fully contained, so a
335
begin/end stack on that lane never has crossing slices. Returns
336
``(lane_of_index, lane_count)``. The lane is a *backing* track index, not a
337
user-visible stream -- lanes sharing a stream are merged back in the UI.
338
"""
339
order = sorted(
340
range(len(slices)),
341
key=lambda i: (slices[i]["ts"], -slices[i]["end"], slices[i]["index"]),
342
)
343
lane_of = {}
344
lane_end_stacks = []
345
for i in order:
346
s = slices[i]
347
assigned = None
348
for lane, stack in enumerate(lane_end_stacks):
349
while stack and stack[-1] <= s["ts"]:
350
stack.pop()
351
# Valid if the lane is free or this slice nests inside the open one.
352
if not stack or s["end"] <= stack[-1]:
353
stack.append(s["end"])
354
assigned = lane
355
break
356
if assigned is None:
357
lane_end_stacks.append([s["end"]])
358
assigned = len(lane_end_stacks) - 1
359
lane_of[i] = assigned
360
return lane_of, len(lane_end_stacks)
361
362
363
def _add_debug_annotation(track_event, name, value):
364
"""Carry a Chrome event arg over as a typed Perfetto debug annotation."""
365
ann = track_event.debug_annotations.add()
366
ann.name = str(name)
367
# bool must be checked before int (bool is a subclass of int in Python).
368
if isinstance(value, bool):
369
ann.bool_value = value
370
elif isinstance(value, int):
371
ann.int_value = value
372
elif isinstance(value, float):
373
ann.double_value = value
374
elif value is None:
375
ann.string_value = "null"
376
elif isinstance(value, str):
377
ann.string_value = value
378
else:
379
ann.legacy_json_value = json.dumps(value, default=str)
380
381
382
def write_perfetto_trace(trace, output_path):
383
"""Convert a Chrome JSON trace dict to a native Perfetto ``.pftrace``.
384
385
Each Chrome ``(pid, tid)`` row becomes a ``TrackDescriptor``; each ``ph='X'``
386
slice becomes a ``TYPE_SLICE_BEGIN`` / ``TYPE_SLICE_END`` pair. Overlapping
387
slices are split across backing lanes that share a ``sibling_merge_key`` so
388
the UI re-merges them onto their real stream.
389
"""
390
from perfetto.trace_builder.proto_builder import TraceProtoBuilder
391
from perfetto.protos.perfetto.trace.perfetto_trace_pb2 import (
392
TrackDescriptor,
393
TrackEvent,
394
)
395
396
events = trace["traceEvents"]
397
398
# Collect the process/thread names emitted as metadata ('M') events.
399
process_names, thread_names = {}, {}
400
for e in events:
401
if e.get("ph") == "M":
402
if e.get("name") == "process_name":
403
process_names[e.get("pid")] = e.get("args", {}).get("name", "")
404
elif e.get("name") == "thread_name":
405
key = (e.get("pid"), e.get("tid"))
406
thread_names[key] = e.get("args", {}).get("name", "")
407
408
# Group complete ('X') slices by their (pid, tid) track.
409
slices_by_track = defaultdict(list)
410
for i, e in enumerate(events):
411
if e.get("ph") == "X":
412
ts = float(e.get("ts", 0) or 0)
413
dur = float(e.get("dur", 0) or 0)
414
slices_by_track[(e.get("pid"), e.get("tid"))].append(
415
{"event": e, "index": i, "ts": ts, "end": ts + dur}
416
)
417
418
def ts_us_to_ns(value):
419
return int(round(value * 1000.0))
420
421
builder = TraceProtoBuilder()
422
SEQ = 1
423
424
# One descriptor per process.
425
for pid in {pid for (pid, _tid) in slices_by_track}:
426
pkt = builder.add_packet()
427
desc = pkt.track_descriptor
428
desc.uuid = _stable_uuid("process", pid)
429
desc.name = process_names.get(pid, f"process {pid}")
430
431
# One descriptor per backing lane; emit begin/end markers per slice.
432
markers = []
433
for (pid, tid), slices in slices_by_track.items():
434
lane_of, lane_count = _assign_nesting_lanes(slices)
435
name = thread_names.get((pid, tid), f"stream {tid}")
436
lane_uuids = []
437
for lane in range(lane_count):
438
uuid = _stable_uuid("track", pid, tid, lane)
439
lane_uuids.append(uuid)
440
pkt = builder.add_packet()
441
desc = pkt.track_descriptor
442
desc.uuid = uuid
443
desc.parent_uuid = _stable_uuid("process", pid)
444
desc.name = name
445
# Multiple lanes for one stream -> merge them into one UI row.
446
if lane_count > 1:
447
desc.sibling_merge_behavior = (
448
TrackDescriptor.SIBLING_MERGE_BEHAVIOR_BY_SIBLING_MERGE_KEY
449
)
450
desc.sibling_merge_key = f"{pid}:{tid}:{name}"
451
for i, s in enumerate(slices):
452
uuid = lane_uuids[lane_of[i]]
453
markers.append((ts_us_to_ns(s["ts"]), 1, uuid, "begin", s["event"]))
454
markers.append((ts_us_to_ns(s["end"]), 0, uuid, "end", s["event"]))
455
456
# Begin markers must be ordered before end markers at the same timestamp.
457
markers.sort(key=lambda m: (m[0], m[1]))
458
for ts_ns, _rank, uuid, kind, event in markers:
459
pkt = builder.add_packet()
460
pkt.timestamp = ts_ns
461
pkt.trusted_packet_sequence_id = SEQ
462
track_event = pkt.track_event
463
track_event.track_uuid = uuid
464
if kind == "begin":
465
track_event.type = TrackEvent.TYPE_SLICE_BEGIN
466
track_event.name = event.get("name", "slice")
467
if event.get("cat"):
468
track_event.categories.append(event["cat"])
469
for key, value in (event.get("args") or {}).items():
470
_add_debug_annotation(track_event, key, value)
471
else:
472
track_event.type = TrackEvent.TYPE_SLICE_END
473
474
Path(output_path).write_bytes(builder.serialize())
475
return output_path
476
477
478
def post_process_trace(raw_trace_path, annotations_path, output_dir):
479
"""Merge annotations into the trace and emit a Perfetto ``.pftrace``."""
480
output_dir = Path(output_dir)
481
482
# Load raw trace and annotations
483
raw_trace = load_trace(raw_trace_path)
484
with open(annotations_path, "rb") as f:
485
annotations = pickle.load(f)
486
487
# Make a copy for post-processing
488
annotated_trace = copy.deepcopy(raw_trace)
489
490
# Apply annotations
491
num_annotated = annotate_trace(annotated_trace, annotations)
492
print(f"Annotated {num_annotated} kernels in the trace")
493
494
# Emit a native Perfetto protobuf trace. Overlapping kernels are split onto
495
# backing lanes that re-merge in the UI -- no kernel is relocated to a fake
496
# stream and no timestamp is mutated.
497
annotated_path = output_dir / "trace_annotated.pftrace"
498
write_perfetto_trace(annotated_trace, annotated_path)
499
print(f"Saved annotated trace to {annotated_path}")
500
501
return annotated_path, raw_trace, annotated_trace
502
503
###############################################################################
504
# Comparing Before and After
505
# ---------------------------
506
#
507
# To see the impact of annotations, let's count how kernels are distributed
508
# across thread IDs (which represent visualization lanes in the trace).
509
510
def compare_traces(raw_trace, annotated_trace):
511
"""Compare kernel distribution before and after annotation."""
512
def count_lanes(trace):
513
"""Count kernels per lane (tid)."""
514
counter = Counter(
515
event["tid"]
516
for event in trace["traceEvents"]
517
if event.get("cat") == "kernel"
518
)
519
return dict(sorted(counter.items()))
520
521
raw_lanes = count_lanes(raw_trace)
522
annotated_lanes = count_lanes(annotated_trace)
523
524
print("\n" + "="*60)
525
print("BEFORE annotation - kernels per lane (tid -> count):")
526
for tid, count in raw_lanes.items():
527
print(f" Stream {tid}: {count} kernels")
528
529
print("\nAFTER annotation - kernels per lane (tid -> count):")
530
for tid, count in annotated_lanes.items():
531
print(f" Stream {tid}: {count} kernels")
532
print("="*60)
533
534
###############################################################################
535
# Putting It All Together
536
# ------------------------
537
#
538
# Now let's run the complete workflow: build a model, capture it with
539
# annotations, profile it, and post-process the trace.
540
541
def main():
542
"""End-to-end CUDA graph annotation and profiling demo."""
543
if not torch.cuda.is_available():
544
raise SystemExit("CUDA required for this tutorial")
545
546
# Check if annotation support is available
547
# PyTorch will log a warning if cuda-bindings version is too old
548
supported = not _is_tools_id_unavailable()
549
print(f"Annotation support available: {supported}")
550
if not supported:
551
print("NOTE: Annotation API not available.")
552
print("This could be due to:")
553
print(" - Driver/CUDA-compat < 13.1")
554
print(" - Outdated cuda-bindings (check PyTorch warnings above)")
555
print("Annotations will not be recorded, but the demo will still run.")
556
print("Kernels will be reassigned to the default lane, not semantic lanes.\n")
557
558
output_dir = Path("traces")
559
560
# Build the model
561
print("\n1. Building transformer block model...")
562
model_fn = build_transformer_block()
563
564
# Capture graph with annotations
565
print("\n2. Capturing CUDA graph with annotations...")
566
graph, output = capture_graph_with_annotations(model_fn)
567
568
# Save annotations
569
print("\n3. Saving annotation metadata...")
570
annotations_path = save_annotations(output_dir)
571
572
# Profile the graph
573
print("\n4. Profiling graph replays...")
574
raw_trace_path = profile_graph(graph, output_dir)
575
576
# Post-process the trace
577
print("\n5. Post-processing: merging annotations into trace...")
578
annotated_path, raw_trace, annotated_trace = post_process_trace(
579
raw_trace_path, annotations_path, output_dir
580
)
581
582
# Compare before and after
583
print("\n6. Comparing traces...")
584
compare_traces(raw_trace, annotated_trace)
585
586
# Summary
587
print("\n" + "="*60)
588
print("SUMMARY")
589
print("="*60)
590
print(f"Raw trace: {raw_trace_path}")
591
print(f"Annotated trace: {annotated_path}")
592
print(f"Annotations: {annotations_path}")
593
print("\nOpen the annotated trace in https://ui.perfetto.dev/ to visualize")
594
print("the semantic kernel lanes.")
595
print("="*60)
596
597
# Example output:
598
# if __name__ == "__main__":
599
# main()
600
#
601
# Annotation support available: True
602
#
603
# 1. Building transformer block model...
604
#
605
# 2. Capturing CUDA graph with annotations...
606
# Captured graph with 13 annotated nodes
607
#
608
# 3. Saving annotation metadata...
609
# Saved 13 annotations to traces/kernel_annotations_rank0_fwd_bwd.pkl
610
#
611
# 4. Profiling graph replays...
612
# Saved raw trace to traces/trace_raw.json.gz
613
#
614
# 5. Post-processing: merging annotations into trace...
615
# Annotated 65 kernels in the trace
616
# Saved annotated trace to traces/trace_annotated.pftrace
617
#
618
# 6. Comparing traces...
619
#
620
# ============================================================
621
# BEFORE annotation - kernels per lane (tid -> count):
622
# Stream 7: 65 kernels
623
#
624
# AFTER annotation - kernels per lane (tid -> count):
625
# Stream 7: 10 kernels
626
# Stream 61: 15 kernels
627
# Stream 62: 40 kernels
628
# ============================================================
629
#
630
# ============================================================
631
# SUMMARY
632
# ============================================================
633
# Raw trace: traces/trace_raw.json.gz
634
# Annotated trace: traces/trace_annotated.pftrace
635
# Annotations: traces/kernel_annotations_rank0_fwd_bwd.pkl
636
#
637
# Open the annotated trace in https://ui.perfetto.dev/ to visualize
638
# the semantic kernel lanes.
639
# ============================================================
640
641
###############################################################################
642
# Annotating Communication Collectives
643
# -------------------------------------
644
#
645
# In eager mode the profiler **automatically intercepts** NCCL collectives and
646
# records rich metadata: collective type, input/output message sizes, the process
647
# group, its size, and the participating ranks.
648
#
649
# Under CUDA graphs that automatic interception stops working. The collective is
650
# captured once and then replayed as an opaque kernel node. The profiler cannot
651
# intercept graph replay, so it has nothing to attach the NCCL metadata to. The
652
# kernels still show up in the trace (e.g., ``ncclDevKernel_AllReduce_Sum_f32_RING_LL``),
653
# but they are opaque: you cannot tell what collective type it is, how many bytes
654
# moved, or which process group it belongs to.
655
#
656
# Annotations close this gap. By wrapping the collective in ``mark_kernels``
657
# with the same fields the profiler auto-attaches in eager mode, we manually
658
# re-attach that metadata to the graphed kernel. After post-processing, a
659
# graphed collective reads just like an eager one. The helper below builds the
660
# metadata dict; using the field names the profiler uses in eager
661
# (``In msg nelems``, ``Group size``, ``Process Group Name``, ...) keeps the
662
# annotated trace consistent with non-graphed traces.
663
664
def annotate_collective(collective_name, input_tensor, output_tensor, group=None):
665
"""Annotate a collective with the metadata eager NCCL traces expose.
666
667
Returns a ``mark_kernels`` context manager. Any kernels launched inside
668
(i.e. the collective) are tagged with the collective type, message sizes,
669
dtype, and the process group's name/description/ranks, and placed on a
670
dedicated lane keyed by the process group so comms are visually separated
671
from compute.
672
673
The field names match the keys the profiler records for eager collectives
674
(``In msg nelems``, ``Group size``, ``Process Group Name``, ...), so an
675
annotated graphed collective reads exactly like a non-graphed one.
676
"""
677
pg = group if group is not None else (dist.group.WORLD if dist.is_initialized() else None)
678
ranks = dist.get_process_group_ranks(pg) if pg is not None else [0]
679
group_name = getattr(pg, "group_name", "default")
680
group_desc = getattr(pg, "group_desc", "default")
681
682
# NCCL always uses its own internal stream, so key the lane on the process
683
# group (name + description) and give it a stable id (>= 60).
684
pg_key = f"{group_name}_{group_desc}"
685
annotation = {
686
"name": collective_name,
687
"In msg nelems": input_tensor.numel(),
688
"Out msg nelems": output_tensor.numel(),
689
"Group size": len(ranks),
690
"dtype": str(input_tensor.dtype).replace("torch.", ""),
691
"Process Group Name": group_name,
692
"Process Group Description": group_desc,
693
"Process Group Ranks": ranks,
694
"stream": get_stream_for_pg(pg_key),
695
}
696
return mark_kernels(annotation)
697
698
###############################################################################
699
# A Block That Mixes Compute and Communication
700
# ----------------------------------------------
701
#
702
# A tensor- or data-parallel layer interleaves matmuls with collectives. Here
703
# the projection output is all-reduced across the group, mirroring the comm in
704
# a tensor-parallel linear. The collective is annotated with
705
# ``annotate_collective`` and lands on its own lane.
706
707
def build_comm_block(group=None):
708
"""Create a compute + collective block annotated for profiling."""
709
device = "cuda"
710
torch.manual_seed(0)
711
dim = 1024
712
params = {
713
"x": torch.randn(4, 256, dim, device=device),
714
"W": torch.randn(dim, dim, device=device) / math.sqrt(dim),
715
}
716
717
def forward():
718
with mark_kernels({"name": "proj", "stream": 61}):
719
h = params["x"] @ params["W"]
720
721
# All-reduce the projection output across the group (e.g. tensor
722
# parallel). all_reduce is in-place, so the input and output tensors
723
# are the same. The annotation re-attaches the NCCL metadata that a
724
# CUDA graph would otherwise drop.
725
if dist.is_available() and dist.is_initialized():
726
with annotate_collective("all_reduce", h, h, group):
727
dist.all_reduce(h)
728
return h
729
730
return forward
731
732
###############################################################################
733
# Running the Communication Demo
734
# -------------------------------
735
#
736
737
WORLD_SIZE = 2
738
739
def init_pg(rank, world_size):
740
"""Initialize a NCCL group for one rank of the spawned demo."""
741
os.environ["MASTER_ADDR"] = "127.0.0.1"
742
os.environ["MASTER_PORT"] = "29500"
743
os.environ["RANK"] = str(rank)
744
os.environ["WORLD_SIZE"] = str(world_size)
745
# Use loopback interface for single-node setup
746
os.environ["NCCL_SOCKET_IFNAME"] = "lo"
747
dist.init_process_group("nccl", rank=rank, world_size=world_size)
748
torch.cuda.set_device(rank)
749
750
def _comm_worker(rank, world_size):
751
"""Per-rank worker: build, capture, profile, and (on rank 0) post-process."""
752
init_pg(rank, world_size)
753
754
output_dir = Path("traces_comm")
755
756
if rank == 0:
757
print("\nBuilding compute + collective block...")
758
model_fn = build_comm_block()
759
760
if rank == 0:
761
print("Capturing CUDA graph with annotations...")
762
graph, _ = capture_graph_with_annotations(model_fn)
763
764
# Every rank participates in the collective during profiling, but only
765
# rank 0 saves and post-processes the trace.
766
if rank == 0:
767
annotations_path = save_annotations(output_dir)
768
raw_trace_path = profile_graph(graph, output_dir)
769
annotated_path, _, annotated_trace = post_process_trace(
770
raw_trace_path, annotations_path, output_dir
771
)
772
773
# Print the args of the annotated collective kernel(s) to show that the
774
# eager-style metadata is now attached to the graphed comm.
775
print("\nAnnotated collective kernels (metadata restored):")
776
for event in annotated_trace["traceEvents"]:
777
args = event.get("args", {})
778
if args.get("In msg nelems") is not None:
779
print(f" {event.get('name', '?')[:40]}")
780
for key in (
781
"In msg nelems",
782
"Out msg nelems",
783
"Group size",
784
"dtype",
785
"Process Group Name",
786
"Process Group Description",
787
"Process Group Ranks",
788
"stream",
789
):
790
if key in args:
791
print(f" {key}: {args[key]}")
792
print(f"\nAnnotated trace: {annotated_path}")
793
else:
794
# Match rank 0's warmup + profiled replays so the collective completes.
795
for _ in range(3):
796
graph.replay()
797
torch.cuda.synchronize()
798
for _ in range(5):
799
graph.replay()
800
torch.cuda.synchronize()
801
802
dist.destroy_process_group()
803
804
def comm_annotation_demo():
805
"""Spawn a ``world_size=2`` group and surface the comm metadata."""
806
if not (dist.is_available() and torch.cuda.is_available()):
807
print("Distributed/NCCL unavailable; skipping comm annotation demo.")
808
return
809
if torch.cuda.device_count() < WORLD_SIZE:
810
print(f"Need {WORLD_SIZE} GPUs for the comm demo; skipping.")
811
return
812
813
torch.multiprocessing.spawn(
814
_comm_worker, args=(WORLD_SIZE,), nprocs=WORLD_SIZE, join=True
815
)
816
817
# Example output (2 GPUs):
818
# if __name__ == "__main__":
819
# comm_annotation_demo()
820
#
821
# Building compute + collective block...
822
# Capturing CUDA graph with annotations...
823
# Captured graph with 2 annotated nodes
824
# Saved 2 annotations to traces_comm/kernel_annotations_rank0_fwd_bwd.pkl
825
# Saved raw trace to traces_comm/trace_raw.json.gz
826
# Annotated 5 kernels in the trace
827
# Saved annotated trace to traces_comm/trace_annotated.pftrace
828
#
829
# The all_reduce runs a real NCCL kernel
830
# (``ncclDevKernel_AllReduce_Sum_f32_RING_LL``) across the two ranks:
831
#
832
# Annotated collective kernels (metadata restored):
833
# ncclDevKernel_AllReduce_Sum_f32_RING_LL
834
# In msg nelems: 1048576
835
# Out msg nelems: 1048576
836
# Group size: 2
837
# dtype: float32
838
# Process Group Name: default
839
# Process Group Description: default
840
# Process Group Ranks: [0, 1]
841
# stream: 60
842
#
843
# In the trace viewer, the all-reduce sits on its own dedicated comm lane
844
# (stream 60), and selecting it shows the collective type, message sizes, group,
845
# and ranks -- the same fields you would see in an eager trace, now recovered
846
# for a CUDA-graphed collective. This metadata is LOST without annotations.
847
848
###############################################################################
849
# How Overlapping Kernels Are Handled
850
# ------------------------------------
851
#
852
# Graphed CUDA kernels often overlap slightly, and a single trace track can
853
# only render properly nested slices. The Perfetto converter handles this
854
# faithfully:
855
#
856
# 1. ``_assign_nesting_lanes()``: For each stream, overlapping slices are split
857
# across hidden *backing* lanes so that each lane's begin/end stack is validly
858
# nested. A lane is a backing track index, **not** a user-visible stream.
859
#
860
# 2. ``sibling_merge_key``: All backing lanes for one stream are given the same
861
# merge key, so the Perfetto UI merges them back into a single logical row.
862
#
863
# The result: overlaps render correctly on the kernel's **real** stream. No
864
# kernel is relocated to a fabricated stream, and no timestamp is mutated --
865
# unlike the legacy Chrome-JSON workaround, which had to do both.
866
867
###############################################################################
868
# Performance Considerations
869
# ---------------------------
870
#
871
# Kernel annotations add minimal overhead:
872
#
873
# - Annotation marking happens during graph capture (one-time cost)
874
# - Graph replay performance is identical to unannotated graphs
875
# - Post-processing is offline and doesn't affect runtime
876
#
877
# The main cost is the profiling itself, which you would do anyway when
878
# optimizing performance. Annotations simply make the profiler output more
879
# useful by adding semantic structure.
880
881
###############################################################################
882
# Troubleshooting
883
# ---------------
884
#
885
# **No annotations in the trace?**
886
#
887
# - Check that your driver/CUDA-compat >= 13.1
888
# - Verify that ``enable_annotations=True`` was passed to ``torch.cuda.graph()``
889
# - Ensure ``cuda-python`` is installed
890
#
891
# **Annotations not showing up in specific kernels?**
892
#
893
# - Some operations may not launch kernels (e.g., tensor views)
894
# - Only kernels launched within the ``mark_kernels`` context are annotated
895
# - Verify the operation actually produces CUDA kernels using ``torch.profiler``
896
897
###############################################################################
898
# Conclusion
899
# ----------
900
#
901
# CUDA graph kernel annotations provide a powerful way to add semantic
902
# structure to your profiling traces. By marking logical components of your
903
# model during graph capture and merging these annotations in post-processing,
904
# you can create visualizations that make it much easier to understand and
905
# optimize complex CUDA graph executions.
906
#
907
# Key takeaways:
908
#
909
# - Use ``mark_kernels()`` to label regions during graph capture
910
# - Enable annotations with ``enable_annotations=True``
911
# - Annotate communication collectives to recover the NCCL metadata
912
# (collective type, message size, group, rank) that CUDA graphs drop but
913
# eager traces expose
914
# - Post-process traces with ``annotate_trace()``
915
# - View results in https://ui.perfetto.dev/ for intuitive visualization
916
#
917
# This technique is especially valuable for large models with many components,
918
# distributed training setups, or any scenario where understanding the
919
# execution structure is critical for performance optimization.
920
921