Path: blob/main/advanced_source/cuda_graph_annotations_tutorial.py
6463 views
# -*- coding: utf-8 -*-1"""2.. _cuda-graph-annotations-tutorial:34CUDA Graph Kernel Annotations and Profiling5============================================67**Author**: `Shangdi Yu <https://github.com/yushangdi>`_89.. grid:: 21011.. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn12:class-card: card-prerequisites1314* How to capture CUDA graphs with kernel annotations15* How to profile annotated graphs16* How to post-process traces with semantic kernel lanes17* How to visualize graph execution with custom stream assignments18* How to annotate communication collectives with the metadata19(collective type, message size, group, rank) that eager NCCL20traces expose but CUDA graphs drop2122.. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites23:class-card: card-prerequisites2425* PyTorch 2.12+26* CUDA-capable GPU27* Driver/CUDA-compat >= 13.1 for annotation support28* cuda-bindings >= 13.1.029* perfetto (``pip install perfetto``)3031CUDA graphs are a powerful optimization technique that can significantly reduce32kernel launch overhead by capturing and replaying sequences of CUDA operations.33However, when profiling CUDA graphs, all kernels appear on the same stream,34making it difficult to understand the logical structure of your computation.3536This tutorial demonstrates how to use **kernel annotations** to add semantic37labels to kernels within CUDA graphs. These annotations can be merged back into38profiler traces to create custom visualization lanes, making it easier to39understand and debug complex graph executions.4041Annotations are not limited to compute kernels. One of the most valuable uses42is annotating **communication collectives**. In eager mode, the profiler43attaches rich metadata to every NCCL kernel -- the collective type, message44size, process group, and ranks -- so you can see exactly what each comm is45doing. Under CUDA graphs that metadata is lost: the collective replays as an46opaque kernel. This tutorial shows how to re-attach that metadata with47annotations so graphed comms read just like eager ones.48"""4950###############################################################################51# Overview52# --------53#54# CUDA graph kernel annotations allow you to add semantic labels to kernels55# during graph capture. These labels help you understand what each kernel does56# when profiling, making it easy to identify which parts of your model (e.g.,57# attention, MLP, normalization) are executing at any given time.58#59# Without annotations, profiler traces show all kernels on a single stream with60# auto-generated names, making it difficult to understand the logical structure61# of your computation. With annotations, you can:62#63# 1. **Label kernel groups** with meaningful names during capture64# 2. **Assign custom stream IDs** for visual organization65# 3. **Merge labels into profiler traces** for semantic visualization66#67# The result is a profiler trace where kernels are labeled and organized by68# their function, making it much easier to identify performance bottlenecks69# and understand execution flow.70#71# **Before annotations:** All kernels appear on a single stream with72# auto-generated names, making it difficult to understand which operations73# belong to which logical component of your model.74#75# .. image:: /_static/img/cuda_graph_trace_before.png76# :width: 80%77# :alt: CUDA graph trace before annotations showing all kernels on one stream78#79# **After annotations:** Kernels are organized into semantic lanes (streams 6180# and 62) with meaningful labels like "attention" and "mlp", making it easy to81# identify different components and understand the execution structure.82#83# .. image:: /_static/img/cuda_graph_trace_after.png84# :width: 80%85# :alt: CUDA graph trace after annotations showing kernels organized by function86#87# As another example, here is an AllReduce kernel with annotated metadata:88#89# .. image:: /_static/img/annotated_cudagraph.png90# :width: 80%91# :alt: AllReduce kernel with annotated metadata92#93# Requirements94# ------------95#96# For this tutorial, you'll need:97#98# - PyTorch 2.12+99# - A CUDA GPU100# - Driver/CUDA-compat >= 13.1 for annotation support101# - The ``cuda-bindings`` package >= 13.1.0 (``pip install cuda-python``)102# - The ``perfetto`` package for writing the trace (``pip install perfetto``)103#104# The cuda-bindings package provides the Python bindings for CUDA runtime APIs.105# Version 13.1.0+ is required for the ``cudaGraphNodeGetToolsId`` API that106# enables kernel annotations. If you have an older version, the tutorial will107# run but annotations will be disabled with a warning message explaining how108# to upgrade.109#110# On older drivers or cuda-bindings versions, the capture and profiling will111# still work, but ``mark_kernels`` will be a no-op and no semantic lanes will112# appear in the final trace.113114import copy115import hashlib116import json117import math118import os119import pickle120import sys121from collections import Counter, defaultdict122from pathlib import Path123124import torch125import torch.distributed as dist126import torch.multiprocessing127from torch.profiler import profile, ProfilerActivity128from torch.cuda._graph_annotations import (129get_kernel_annotations,130get_stream_for_pg,131mark_kernels,132_is_tools_id_unavailable,133)134from torch.cuda._annotate_cuda_graph_trace import (135annotate_trace,136load_trace,137)138139###############################################################################140# Building a Model141# ----------------142#143# Let's create a simple transformer block as our example model. We'll annotate144# different parts of the computation (QKV projection, attention, output145# projection, MLP) to see them as separate lanes in the profiler.146147def build_transformer_block():148"""Create a simple transformer block with parameters."""149device = "cuda"150torch.manual_seed(0)151152# Model dimensions153batch_size, seq_len, dim, num_heads = 4, 256, 1024, 8154head_dim = dim // num_heads155156# Initialize parameters157params = {158"x": torch.randn(batch_size, seq_len, dim, device=device),159"Wqkv": torch.randn(dim, 3 * dim, device=device) / math.sqrt(dim),160"Wo": torch.randn(dim, dim, device=device) / math.sqrt(dim),161"W1": torch.randn(dim, 4 * dim, device=device) / math.sqrt(dim),162"W2": torch.randn(4 * dim, dim, device=device) / math.sqrt(4 * dim),163}164165def forward():166"""Forward pass with annotated regions."""167B, T, D, H = batch_size, seq_len, dim, num_heads168hd = head_dim169170# Annotate QKV projection171with mark_kernels({"name": "qkv_proj"}):172qkv = params["x"] @ params["Wqkv"]173174# Reshape for multi-head attention175q, k, v = qkv.split(D, dim=-1)176q = q.view(B, T, H, hd).transpose(1, 2)177k = k.view(B, T, H, hd).transpose(1, 2)178v = v.view(B, T, H, hd).transpose(1, 2)179180# Annotate attention computation (optionally on a custom stream)181with mark_kernels({"name": "attention", "stream": 62}):182scores = (q @ k.transpose(-1, -2)) / math.sqrt(hd)183attn = torch.softmax(scores, dim=-1)184ctx = (attn @ v).transpose(1, 2).reshape(B, T, D)185186# Annotate output projection187with mark_kernels({"name": "out_proj"}):188o = ctx @ params["Wo"]189190# Annotate MLP (on another custom stream)191with mark_kernels({"name": "mlp", "stream": 61}):192return torch.nn.functional.gelu(o @ params["W1"]) @ params["W2"]193194return forward195196###############################################################################197# The ``mark_kernels`` Context Manager198# -------------------------------------199#200# The key API is ``mark_kernels()``, which takes a dictionary with:201#202# - ``name``: A string label for this kernel group (becomes the lane name)203# - ``stream`` (optional): A virtual stream ID for visualization204#205# Any CUDA kernels launched within the context will be tagged with these206# annotations. Later, when we post-process the profiler trace, these tags207# will be used to organize kernels into custom lanes.208209###############################################################################210# Capturing a CUDA Graph with Annotations211# ----------------------------------------212#213# To capture a graph with annotations enabled, we pass214# ``enable_annotations=True`` to ``torch.cuda.graph()``. This automatically215# handles the annotation lifecycle: enabling, resolving, and remapping.216217def capture_graph_with_annotations(model_fn):218"""Capture the model into a CUDA graph with annotations enabled."""219# Warm up on a side stream before capture220warmup_stream = torch.cuda.Stream()221warmup_stream.wait_stream(torch.cuda.current_stream())222223with torch.cuda.stream(warmup_stream):224for _ in range(3):225model_fn()226227torch.cuda.current_stream().wait_stream(warmup_stream)228229# Capture with annotations enabled230graph = torch.cuda.CUDAGraph()231with torch.cuda.graph(graph, enable_annotations=True):232output = model_fn()233234num_annotations = len(get_kernel_annotations())235print(f"Captured graph with {num_annotations} annotated nodes")236237return graph, output238239###############################################################################240# Profiling the Graph241# -------------------242#243# After capturing the graph, we replay it a few times to warm up, then profile244# subsequent replays. The profiler will record kernel execution times, which245# we'll later merge with our annotations.246247def profile_graph(graph, output_dir):248"""Profile graph replays and save the trace."""249output_dir = Path(output_dir)250output_dir.mkdir(exist_ok=True, parents=True)251252# Warm up replays253for _ in range(3):254graph.replay()255torch.cuda.synchronize()256257# Profile several replays258with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:259for _ in range(5):260graph.replay()261torch.cuda.synchronize()262263# Export the raw trace264trace_path = output_dir / "trace_raw.json.gz"265prof.export_chrome_trace(str(trace_path))266print(f"Saved raw trace to {trace_path}")267268return trace_path269270###############################################################################271# Saving Annotation Metadata272# ---------------------------273#274# We need to save the annotation metadata in a pickle file that the275# post-processing tool can discover. The file should be named276# ``kernel_annotations_rank0_fwd_bwd.pkl`` and placed where the trace tool277# can find it.278279def save_annotations(output_dir):280"""Save kernel annotations to a pickle file."""281output_dir = Path(output_dir)282output_dir.mkdir(exist_ok=True, parents=True)283annotations_path = output_dir / "kernel_annotations_rank0_fwd_bwd.pkl"284285annotations = dict(get_kernel_annotations())286with open(annotations_path, "wb") as f:287pickle.dump(annotations, f)288289print(f"Saved {len(annotations)} annotations to {annotations_path}")290return annotations_path291292###############################################################################293# Post-Processing: Merging Annotations into Traces294# -------------------------------------------------295#296# The final step is to merge the annotations back into the trace. This involves:297#298# 1. Loading the raw trace and annotations299# 2. Calling ``annotate_trace()`` to apply the annotations300# 3. Emitting a native Perfetto ``.pftrace`` that preserves overlapping kernels301# on their real stream302#303# The result is a trace where kernels are organized by your semantic labels.304#305# **Why a Perfetto protobuf trace (not Chrome JSON)?** A Chrome JSON trace --306# the format ``torch.profiler.export_chrome_trace`` produces -- has a307# fundamental limitation: a single track (a ``(pid, tid)`` row) can only show308# **properly nested** slices, never crossing/overlapping ones.309#310# Perfetto's native **protobuf** trace (``.pftrace``) solves this311# via the ``TrackDescriptor`` field ``sibling_merge_key``. We split312# overlapping slices across hidden *backing* tracks (so each protobuf313# begin/end stack stays validly nested), then give those backing tracks the314# **same** ``sibling_merge_key`` so the Perfetto UI merges them back into a315# single logical row. Nothing is relocated to a fake stream and no timestamp is316# clamped -- the overlap is shown faithfully on the kernel's real stream.317#318# This converter is adapted from Driss Guessous's `transformer_nuggets319# <https://github.com/drisspg/transformer_nuggets>`_320# (``transformer_nuggets/utils/track_event.py``); we inline a compact,321# self-contained version here. It needs the ``perfetto`` package322# (``pip install perfetto``).323324def _stable_uuid(*parts):325"""A stable 60-bit track UUID derived from its identifying parts."""326digest = hashlib.sha1(":".join(str(p) for p in parts).encode()).hexdigest()327return int(digest[:15], 16)328329330def _assign_nesting_lanes(slices):331"""Split overlapping slices into backing lanes so each lane is nestable.332333A lane only holds slices that are either disjoint or fully contained, so a334begin/end stack on that lane never has crossing slices. Returns335``(lane_of_index, lane_count)``. The lane is a *backing* track index, not a336user-visible stream -- lanes sharing a stream are merged back in the UI.337"""338order = sorted(339range(len(slices)),340key=lambda i: (slices[i]["ts"], -slices[i]["end"], slices[i]["index"]),341)342lane_of = {}343lane_end_stacks = []344for i in order:345s = slices[i]346assigned = None347for lane, stack in enumerate(lane_end_stacks):348while stack and stack[-1] <= s["ts"]:349stack.pop()350# Valid if the lane is free or this slice nests inside the open one.351if not stack or s["end"] <= stack[-1]:352stack.append(s["end"])353assigned = lane354break355if assigned is None:356lane_end_stacks.append([s["end"]])357assigned = len(lane_end_stacks) - 1358lane_of[i] = assigned359return lane_of, len(lane_end_stacks)360361362def _add_debug_annotation(track_event, name, value):363"""Carry a Chrome event arg over as a typed Perfetto debug annotation."""364ann = track_event.debug_annotations.add()365ann.name = str(name)366# bool must be checked before int (bool is a subclass of int in Python).367if isinstance(value, bool):368ann.bool_value = value369elif isinstance(value, int):370ann.int_value = value371elif isinstance(value, float):372ann.double_value = value373elif value is None:374ann.string_value = "null"375elif isinstance(value, str):376ann.string_value = value377else:378ann.legacy_json_value = json.dumps(value, default=str)379380381def write_perfetto_trace(trace, output_path):382"""Convert a Chrome JSON trace dict to a native Perfetto ``.pftrace``.383384Each Chrome ``(pid, tid)`` row becomes a ``TrackDescriptor``; each ``ph='X'``385slice becomes a ``TYPE_SLICE_BEGIN`` / ``TYPE_SLICE_END`` pair. Overlapping386slices are split across backing lanes that share a ``sibling_merge_key`` so387the UI re-merges them onto their real stream.388"""389from perfetto.trace_builder.proto_builder import TraceProtoBuilder390from perfetto.protos.perfetto.trace.perfetto_trace_pb2 import (391TrackDescriptor,392TrackEvent,393)394395events = trace["traceEvents"]396397# Collect the process/thread names emitted as metadata ('M') events.398process_names, thread_names = {}, {}399for e in events:400if e.get("ph") == "M":401if e.get("name") == "process_name":402process_names[e.get("pid")] = e.get("args", {}).get("name", "")403elif e.get("name") == "thread_name":404key = (e.get("pid"), e.get("tid"))405thread_names[key] = e.get("args", {}).get("name", "")406407# Group complete ('X') slices by their (pid, tid) track.408slices_by_track = defaultdict(list)409for i, e in enumerate(events):410if e.get("ph") == "X":411ts = float(e.get("ts", 0) or 0)412dur = float(e.get("dur", 0) or 0)413slices_by_track[(e.get("pid"), e.get("tid"))].append(414{"event": e, "index": i, "ts": ts, "end": ts + dur}415)416417def ts_us_to_ns(value):418return int(round(value * 1000.0))419420builder = TraceProtoBuilder()421SEQ = 1422423# One descriptor per process.424for pid in {pid for (pid, _tid) in slices_by_track}:425pkt = builder.add_packet()426desc = pkt.track_descriptor427desc.uuid = _stable_uuid("process", pid)428desc.name = process_names.get(pid, f"process {pid}")429430# One descriptor per backing lane; emit begin/end markers per slice.431markers = []432for (pid, tid), slices in slices_by_track.items():433lane_of, lane_count = _assign_nesting_lanes(slices)434name = thread_names.get((pid, tid), f"stream {tid}")435lane_uuids = []436for lane in range(lane_count):437uuid = _stable_uuid("track", pid, tid, lane)438lane_uuids.append(uuid)439pkt = builder.add_packet()440desc = pkt.track_descriptor441desc.uuid = uuid442desc.parent_uuid = _stable_uuid("process", pid)443desc.name = name444# Multiple lanes for one stream -> merge them into one UI row.445if lane_count > 1:446desc.sibling_merge_behavior = (447TrackDescriptor.SIBLING_MERGE_BEHAVIOR_BY_SIBLING_MERGE_KEY448)449desc.sibling_merge_key = f"{pid}:{tid}:{name}"450for i, s in enumerate(slices):451uuid = lane_uuids[lane_of[i]]452markers.append((ts_us_to_ns(s["ts"]), 1, uuid, "begin", s["event"]))453markers.append((ts_us_to_ns(s["end"]), 0, uuid, "end", s["event"]))454455# Begin markers must be ordered before end markers at the same timestamp.456markers.sort(key=lambda m: (m[0], m[1]))457for ts_ns, _rank, uuid, kind, event in markers:458pkt = builder.add_packet()459pkt.timestamp = ts_ns460pkt.trusted_packet_sequence_id = SEQ461track_event = pkt.track_event462track_event.track_uuid = uuid463if kind == "begin":464track_event.type = TrackEvent.TYPE_SLICE_BEGIN465track_event.name = event.get("name", "slice")466if event.get("cat"):467track_event.categories.append(event["cat"])468for key, value in (event.get("args") or {}).items():469_add_debug_annotation(track_event, key, value)470else:471track_event.type = TrackEvent.TYPE_SLICE_END472473Path(output_path).write_bytes(builder.serialize())474return output_path475476477def post_process_trace(raw_trace_path, annotations_path, output_dir):478"""Merge annotations into the trace and emit a Perfetto ``.pftrace``."""479output_dir = Path(output_dir)480481# Load raw trace and annotations482raw_trace = load_trace(raw_trace_path)483with open(annotations_path, "rb") as f:484annotations = pickle.load(f)485486# Make a copy for post-processing487annotated_trace = copy.deepcopy(raw_trace)488489# Apply annotations490num_annotated = annotate_trace(annotated_trace, annotations)491print(f"Annotated {num_annotated} kernels in the trace")492493# Emit a native Perfetto protobuf trace. Overlapping kernels are split onto494# backing lanes that re-merge in the UI -- no kernel is relocated to a fake495# stream and no timestamp is mutated.496annotated_path = output_dir / "trace_annotated.pftrace"497write_perfetto_trace(annotated_trace, annotated_path)498print(f"Saved annotated trace to {annotated_path}")499500return annotated_path, raw_trace, annotated_trace501502###############################################################################503# Comparing Before and After504# ---------------------------505#506# To see the impact of annotations, let's count how kernels are distributed507# across thread IDs (which represent visualization lanes in the trace).508509def compare_traces(raw_trace, annotated_trace):510"""Compare kernel distribution before and after annotation."""511def count_lanes(trace):512"""Count kernels per lane (tid)."""513counter = Counter(514event["tid"]515for event in trace["traceEvents"]516if event.get("cat") == "kernel"517)518return dict(sorted(counter.items()))519520raw_lanes = count_lanes(raw_trace)521annotated_lanes = count_lanes(annotated_trace)522523print("\n" + "="*60)524print("BEFORE annotation - kernels per lane (tid -> count):")525for tid, count in raw_lanes.items():526print(f" Stream {tid}: {count} kernels")527528print("\nAFTER annotation - kernels per lane (tid -> count):")529for tid, count in annotated_lanes.items():530print(f" Stream {tid}: {count} kernels")531print("="*60)532533###############################################################################534# Putting It All Together535# ------------------------536#537# Now let's run the complete workflow: build a model, capture it with538# annotations, profile it, and post-process the trace.539540def main():541"""End-to-end CUDA graph annotation and profiling demo."""542if not torch.cuda.is_available():543raise SystemExit("CUDA required for this tutorial")544545# Check if annotation support is available546# PyTorch will log a warning if cuda-bindings version is too old547supported = not _is_tools_id_unavailable()548print(f"Annotation support available: {supported}")549if not supported:550print("NOTE: Annotation API not available.")551print("This could be due to:")552print(" - Driver/CUDA-compat < 13.1")553print(" - Outdated cuda-bindings (check PyTorch warnings above)")554print("Annotations will not be recorded, but the demo will still run.")555print("Kernels will be reassigned to the default lane, not semantic lanes.\n")556557output_dir = Path("traces")558559# Build the model560print("\n1. Building transformer block model...")561model_fn = build_transformer_block()562563# Capture graph with annotations564print("\n2. Capturing CUDA graph with annotations...")565graph, output = capture_graph_with_annotations(model_fn)566567# Save annotations568print("\n3. Saving annotation metadata...")569annotations_path = save_annotations(output_dir)570571# Profile the graph572print("\n4. Profiling graph replays...")573raw_trace_path = profile_graph(graph, output_dir)574575# Post-process the trace576print("\n5. Post-processing: merging annotations into trace...")577annotated_path, raw_trace, annotated_trace = post_process_trace(578raw_trace_path, annotations_path, output_dir579)580581# Compare before and after582print("\n6. Comparing traces...")583compare_traces(raw_trace, annotated_trace)584585# Summary586print("\n" + "="*60)587print("SUMMARY")588print("="*60)589print(f"Raw trace: {raw_trace_path}")590print(f"Annotated trace: {annotated_path}")591print(f"Annotations: {annotations_path}")592print("\nOpen the annotated trace in https://ui.perfetto.dev/ to visualize")593print("the semantic kernel lanes.")594print("="*60)595596# Example output:597# if __name__ == "__main__":598# main()599#600# Annotation support available: True601#602# 1. Building transformer block model...603#604# 2. Capturing CUDA graph with annotations...605# Captured graph with 13 annotated nodes606#607# 3. Saving annotation metadata...608# Saved 13 annotations to traces/kernel_annotations_rank0_fwd_bwd.pkl609#610# 4. Profiling graph replays...611# Saved raw trace to traces/trace_raw.json.gz612#613# 5. Post-processing: merging annotations into trace...614# Annotated 65 kernels in the trace615# Saved annotated trace to traces/trace_annotated.pftrace616#617# 6. Comparing traces...618#619# ============================================================620# BEFORE annotation - kernels per lane (tid -> count):621# Stream 7: 65 kernels622#623# AFTER annotation - kernels per lane (tid -> count):624# Stream 7: 10 kernels625# Stream 61: 15 kernels626# Stream 62: 40 kernels627# ============================================================628#629# ============================================================630# SUMMARY631# ============================================================632# Raw trace: traces/trace_raw.json.gz633# Annotated trace: traces/trace_annotated.pftrace634# Annotations: traces/kernel_annotations_rank0_fwd_bwd.pkl635#636# Open the annotated trace in https://ui.perfetto.dev/ to visualize637# the semantic kernel lanes.638# ============================================================639640###############################################################################641# Annotating Communication Collectives642# -------------------------------------643#644# In eager mode the profiler **automatically intercepts** NCCL collectives and645# records rich metadata: collective type, input/output message sizes, the process646# group, its size, and the participating ranks.647#648# Under CUDA graphs that automatic interception stops working. The collective is649# captured once and then replayed as an opaque kernel node. The profiler cannot650# intercept graph replay, so it has nothing to attach the NCCL metadata to. The651# kernels still show up in the trace (e.g., ``ncclDevKernel_AllReduce_Sum_f32_RING_LL``),652# but they are opaque: you cannot tell what collective type it is, how many bytes653# moved, or which process group it belongs to.654#655# Annotations close this gap. By wrapping the collective in ``mark_kernels``656# with the same fields the profiler auto-attaches in eager mode, we manually657# re-attach that metadata to the graphed kernel. After post-processing, a658# graphed collective reads just like an eager one. The helper below builds the659# metadata dict; using the field names the profiler uses in eager660# (``In msg nelems``, ``Group size``, ``Process Group Name``, ...) keeps the661# annotated trace consistent with non-graphed traces.662663def annotate_collective(collective_name, input_tensor, output_tensor, group=None):664"""Annotate a collective with the metadata eager NCCL traces expose.665666Returns a ``mark_kernels`` context manager. Any kernels launched inside667(i.e. the collective) are tagged with the collective type, message sizes,668dtype, and the process group's name/description/ranks, and placed on a669dedicated lane keyed by the process group so comms are visually separated670from compute.671672The field names match the keys the profiler records for eager collectives673(``In msg nelems``, ``Group size``, ``Process Group Name``, ...), so an674annotated graphed collective reads exactly like a non-graphed one.675"""676pg = group if group is not None else (dist.group.WORLD if dist.is_initialized() else None)677ranks = dist.get_process_group_ranks(pg) if pg is not None else [0]678group_name = getattr(pg, "group_name", "default")679group_desc = getattr(pg, "group_desc", "default")680681# NCCL always uses its own internal stream, so key the lane on the process682# group (name + description) and give it a stable id (>= 60).683pg_key = f"{group_name}_{group_desc}"684annotation = {685"name": collective_name,686"In msg nelems": input_tensor.numel(),687"Out msg nelems": output_tensor.numel(),688"Group size": len(ranks),689"dtype": str(input_tensor.dtype).replace("torch.", ""),690"Process Group Name": group_name,691"Process Group Description": group_desc,692"Process Group Ranks": ranks,693"stream": get_stream_for_pg(pg_key),694}695return mark_kernels(annotation)696697###############################################################################698# A Block That Mixes Compute and Communication699# ----------------------------------------------700#701# A tensor- or data-parallel layer interleaves matmuls with collectives. Here702# the projection output is all-reduced across the group, mirroring the comm in703# a tensor-parallel linear. The collective is annotated with704# ``annotate_collective`` and lands on its own lane.705706def build_comm_block(group=None):707"""Create a compute + collective block annotated for profiling."""708device = "cuda"709torch.manual_seed(0)710dim = 1024711params = {712"x": torch.randn(4, 256, dim, device=device),713"W": torch.randn(dim, dim, device=device) / math.sqrt(dim),714}715716def forward():717with mark_kernels({"name": "proj", "stream": 61}):718h = params["x"] @ params["W"]719720# All-reduce the projection output across the group (e.g. tensor721# parallel). all_reduce is in-place, so the input and output tensors722# are the same. The annotation re-attaches the NCCL metadata that a723# CUDA graph would otherwise drop.724if dist.is_available() and dist.is_initialized():725with annotate_collective("all_reduce", h, h, group):726dist.all_reduce(h)727return h728729return forward730731###############################################################################732# Running the Communication Demo733# -------------------------------734#735736WORLD_SIZE = 2737738def init_pg(rank, world_size):739"""Initialize a NCCL group for one rank of the spawned demo."""740os.environ["MASTER_ADDR"] = "127.0.0.1"741os.environ["MASTER_PORT"] = "29500"742os.environ["RANK"] = str(rank)743os.environ["WORLD_SIZE"] = str(world_size)744# Use loopback interface for single-node setup745os.environ["NCCL_SOCKET_IFNAME"] = "lo"746dist.init_process_group("nccl", rank=rank, world_size=world_size)747torch.cuda.set_device(rank)748749def _comm_worker(rank, world_size):750"""Per-rank worker: build, capture, profile, and (on rank 0) post-process."""751init_pg(rank, world_size)752753output_dir = Path("traces_comm")754755if rank == 0:756print("\nBuilding compute + collective block...")757model_fn = build_comm_block()758759if rank == 0:760print("Capturing CUDA graph with annotations...")761graph, _ = capture_graph_with_annotations(model_fn)762763# Every rank participates in the collective during profiling, but only764# rank 0 saves and post-processes the trace.765if rank == 0:766annotations_path = save_annotations(output_dir)767raw_trace_path = profile_graph(graph, output_dir)768annotated_path, _, annotated_trace = post_process_trace(769raw_trace_path, annotations_path, output_dir770)771772# Print the args of the annotated collective kernel(s) to show that the773# eager-style metadata is now attached to the graphed comm.774print("\nAnnotated collective kernels (metadata restored):")775for event in annotated_trace["traceEvents"]:776args = event.get("args", {})777if args.get("In msg nelems") is not None:778print(f" {event.get('name', '?')[:40]}")779for key in (780"In msg nelems",781"Out msg nelems",782"Group size",783"dtype",784"Process Group Name",785"Process Group Description",786"Process Group Ranks",787"stream",788):789if key in args:790print(f" {key}: {args[key]}")791print(f"\nAnnotated trace: {annotated_path}")792else:793# Match rank 0's warmup + profiled replays so the collective completes.794for _ in range(3):795graph.replay()796torch.cuda.synchronize()797for _ in range(5):798graph.replay()799torch.cuda.synchronize()800801dist.destroy_process_group()802803def comm_annotation_demo():804"""Spawn a ``world_size=2`` group and surface the comm metadata."""805if not (dist.is_available() and torch.cuda.is_available()):806print("Distributed/NCCL unavailable; skipping comm annotation demo.")807return808if torch.cuda.device_count() < WORLD_SIZE:809print(f"Need {WORLD_SIZE} GPUs for the comm demo; skipping.")810return811812torch.multiprocessing.spawn(813_comm_worker, args=(WORLD_SIZE,), nprocs=WORLD_SIZE, join=True814)815816# Example output (2 GPUs):817# if __name__ == "__main__":818# comm_annotation_demo()819#820# Building compute + collective block...821# Capturing CUDA graph with annotations...822# Captured graph with 2 annotated nodes823# Saved 2 annotations to traces_comm/kernel_annotations_rank0_fwd_bwd.pkl824# Saved raw trace to traces_comm/trace_raw.json.gz825# Annotated 5 kernels in the trace826# Saved annotated trace to traces_comm/trace_annotated.pftrace827#828# The all_reduce runs a real NCCL kernel829# (``ncclDevKernel_AllReduce_Sum_f32_RING_LL``) across the two ranks:830#831# Annotated collective kernels (metadata restored):832# ncclDevKernel_AllReduce_Sum_f32_RING_LL833# In msg nelems: 1048576834# Out msg nelems: 1048576835# Group size: 2836# dtype: float32837# Process Group Name: default838# Process Group Description: default839# Process Group Ranks: [0, 1]840# stream: 60841#842# In the trace viewer, the all-reduce sits on its own dedicated comm lane843# (stream 60), and selecting it shows the collective type, message sizes, group,844# and ranks -- the same fields you would see in an eager trace, now recovered845# for a CUDA-graphed collective. This metadata is LOST without annotations.846847###############################################################################848# How Overlapping Kernels Are Handled849# ------------------------------------850#851# Graphed CUDA kernels often overlap slightly, and a single trace track can852# only render properly nested slices. The Perfetto converter handles this853# faithfully:854#855# 1. ``_assign_nesting_lanes()``: For each stream, overlapping slices are split856# across hidden *backing* lanes so that each lane's begin/end stack is validly857# nested. A lane is a backing track index, **not** a user-visible stream.858#859# 2. ``sibling_merge_key``: All backing lanes for one stream are given the same860# merge key, so the Perfetto UI merges them back into a single logical row.861#862# The result: overlaps render correctly on the kernel's **real** stream. No863# kernel is relocated to a fabricated stream, and no timestamp is mutated --864# unlike the legacy Chrome-JSON workaround, which had to do both.865866###############################################################################867# Performance Considerations868# ---------------------------869#870# Kernel annotations add minimal overhead:871#872# - Annotation marking happens during graph capture (one-time cost)873# - Graph replay performance is identical to unannotated graphs874# - Post-processing is offline and doesn't affect runtime875#876# The main cost is the profiling itself, which you would do anyway when877# optimizing performance. Annotations simply make the profiler output more878# useful by adding semantic structure.879880###############################################################################881# Troubleshooting882# ---------------883#884# **No annotations in the trace?**885#886# - Check that your driver/CUDA-compat >= 13.1887# - Verify that ``enable_annotations=True`` was passed to ``torch.cuda.graph()``888# - Ensure ``cuda-python`` is installed889#890# **Annotations not showing up in specific kernels?**891#892# - Some operations may not launch kernels (e.g., tensor views)893# - Only kernels launched within the ``mark_kernels`` context are annotated894# - Verify the operation actually produces CUDA kernels using ``torch.profiler``895896###############################################################################897# Conclusion898# ----------899#900# CUDA graph kernel annotations provide a powerful way to add semantic901# structure to your profiling traces. By marking logical components of your902# model during graph capture and merging these annotations in post-processing,903# you can create visualizations that make it much easier to understand and904# optimize complex CUDA graph executions.905#906# Key takeaways:907#908# - Use ``mark_kernels()`` to label regions during graph capture909# - Enable annotations with ``enable_annotations=True``910# - Annotate communication collectives to recover the NCCL metadata911# (collective type, message size, group, rank) that CUDA graphs drop but912# eager traces expose913# - Post-process traces with ``annotate_trace()``914# - View results in https://ui.perfetto.dev/ for intuitive visualization915#916# This technique is especially valuable for large models with many components,917# distributed training setups, or any scenario where understanding the918# execution structure is critical for performance optimization.919920921