CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!

GitHub Repository: pytorch/tutorials
Path: blob/main/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py
Views: 494
1
# -*- coding: utf-8 -*-
2
3
"""
4
Using User-Defined Triton Kernels with ``torch.compile``
5
=========================================================
6
**Author:** `Oguz Ulgen <https://github.com/oulgen>`_
7
"""
8
9
######################################################################
10
# User-defined Triton kernels can be used to optimize specific parts of your
11
# model's computation. These kernels are written in Triton's language, which is designed
12
# to make it easier to achieve peak hardware performance. By using user-defined Triton
13
# kernels with ``torch.compile``, you can integrate these optimized computations into
14
# your PyTorch model, potentially achieving significant performance improvements.
15
#
16
# This recipes demonstrates how you can use user-defined Triton kernels with ``torch.compile``.
17
#
18
# Prerequisites
19
# -------------------
20
#
21
# Before starting this recipe, make sure that you have the following:
22
#
23
# * Basic understanding of ``torch.compile`` and Triton. See:
24
#
25
# * `torch.compiler API documentation <https://pytorch.org/docs/stable/torch.compiler.html#torch-compiler>`__
26
# * `Introduction to torch.compile <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`__
27
# * `Triton language documentation <https://triton-lang.org/main/index.html>`__
28
#
29
# * PyTorch 2.3 or later
30
# * A GPU that supports Triton
31
#
32
33
import torch
34
from torch.utils._triton import has_triton
35
36
######################################################################
37
# Basic Usage
38
# --------------------
39
#
40
# In this example, we will use a simple vector addition kernel from the Triton documentation
41
# with ``torch.compile``.
42
# For reference, see `Triton documentation <https://triton-lang.org/main/getting-started/tutorials/01-vector-add.html>`__.
43
#
44
45
if not has_triton():
46
print("Skipping because triton is not supported on this device.")
47
else:
48
import triton
49
from triton import language as tl
50
51
@triton.jit
52
def add_kernel(
53
in_ptr0,
54
in_ptr1,
55
out_ptr,
56
n_elements,
57
BLOCK_SIZE: "tl.constexpr",
58
):
59
pid = tl.program_id(axis=0)
60
block_start = pid * BLOCK_SIZE
61
offsets = block_start + tl.arange(0, BLOCK_SIZE)
62
mask = offsets < n_elements
63
x = tl.load(in_ptr0 + offsets, mask=mask)
64
y = tl.load(in_ptr1 + offsets, mask=mask)
65
output = x + y
66
tl.store(out_ptr + offsets, output, mask=mask)
67
68
@torch.compile(fullgraph=True)
69
def add_fn(x, y):
70
output = torch.zeros_like(x)
71
n_elements = output.numel()
72
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
73
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=4)
74
return output
75
76
x = torch.randn(4, device="cuda")
77
y = torch.randn(4, device="cuda")
78
out = add_fn(x, y)
79
print(f"Vector addition of\nX:\t{x}\nY:\t{y}\nis equal to\n{out}")
80
81
######################################################################
82
# Advanced Usage
83
# -------------------------------------------------------------------
84
#
85
# Triton's autotune feature is a powerful tool that automatically optimizes the configuration
86
# parameters of your Triton kernels. It explores a range of possible configurations and
87
# selects the one that delivers the best performance for your specific use case.
88
#
89
# When used with ``torch.compile``, ``triton.autotune`` can help ensure that your PyTorch
90
# model is running as efficiently as possible. Here is an example of using ``torch.compile``
91
# and ``triton.autotune``.
92
#
93
# .. note::
94
#
95
# ``torch.compile`` only supports configs and key arguments to ``triton.autotune``.
96
97
if not has_triton():
98
print("Skipping because triton is not supported on this device.")
99
else:
100
import triton
101
from triton import language as tl
102
103
@triton.autotune(
104
configs=[
105
triton.Config({"BLOCK_SIZE": 4}, num_stages=3, num_warps=8),
106
triton.Config({"BLOCK_SIZE": 4}, num_stages=4, num_warps=4),
107
triton.Config({"BLOCK_SIZE": 2}, num_stages=3, num_warps=8),
108
triton.Config({"BLOCK_SIZE": 2}, num_stages=4, num_warps=4),
109
],
110
key=[],
111
)
112
@triton.jit
113
def add_kernel_autotuned(
114
in_ptr0,
115
in_ptr1,
116
out_ptr,
117
n_elements,
118
BLOCK_SIZE: "tl.constexpr",
119
):
120
pid = tl.program_id(axis=0)
121
block_start = pid * BLOCK_SIZE
122
offsets = block_start + tl.arange(0, BLOCK_SIZE)
123
mask = offsets < n_elements
124
x = tl.load(in_ptr0 + offsets, mask=mask)
125
y = tl.load(in_ptr1 + offsets, mask=mask)
126
output = x + y
127
tl.store(out_ptr + offsets, output, mask=mask)
128
129
@torch.compile(fullgraph=True)
130
def add_fn(x, y):
131
output = torch.zeros_like(x)
132
n_elements = output.numel()
133
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
134
add_kernel_autotuned[grid](x, y, output, n_elements)
135
return output
136
137
x = torch.randn(4, device="cuda")
138
y = torch.randn(4, device="cuda")
139
out = add_fn(x, y)
140
print(f"Vector addition of\nX:\t{x}\nY:\t{y}\nis equal to\n{out}")
141
142
######################################################################
143
# Composibility and Limitations
144
# --------------------------------------------------------------------
145
#
146
# As of PyTorch 2.3, the support for user-defined Triton kernels in ``torch.compile``
147
# includes dynamic shapes, ``torch.autograd.Function``, JIT inductor, and AOT inductor.
148
# You can use these features together to build complex, high-performance models.
149
#
150
# However, there are certain limitations to be aware of:
151
#
152
# * **Tensor Subclasses:** Currently, there is no support for
153
# tensor subclasses and other advanced features.
154
# * **Triton Features:** While ``triton.heuristics`` can be used either standalone or
155
# before ``triton.autotune``, it cannot be used after ```triton.autotune``. This
156
# implies that if ``triton.heuristics`` and ``triton.autotune`` are to be used
157
# together, ``triton.heuristics`` must be used first.
158
#
159
# Conclusion
160
# -----------
161
# In this recipe, we explored how to utilize user-defined Triton kernels
162
# with ``torch.compile``. We delved into the basic usage of a simple
163
# vector addition kernel and advanced usage involving Triton's autotune
164
# feature. We also discussed the composability of user-defined Triton
165
# kernels with other PyTorch features and highlighted some current limitations.
166
#
167
# See Also
168
# ---------
169
#
170
# * `Compiling the Optimizers <https://pytorch.org/tutorials/recipes/compiling_optimizer.html>`__
171
# * `Implementing High-Performance Transformers with Scaled Dot Product Attention <https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html>`__
172
173