CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!
CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!
Path: blob/main/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py
Views: 494
# -*- coding: utf-8 -*-12"""3Using User-Defined Triton Kernels with ``torch.compile``4=========================================================5**Author:** `Oguz Ulgen <https://github.com/oulgen>`_6"""78######################################################################9# User-defined Triton kernels can be used to optimize specific parts of your10# model's computation. These kernels are written in Triton's language, which is designed11# to make it easier to achieve peak hardware performance. By using user-defined Triton12# kernels with ``torch.compile``, you can integrate these optimized computations into13# your PyTorch model, potentially achieving significant performance improvements.14#15# This recipes demonstrates how you can use user-defined Triton kernels with ``torch.compile``.16#17# Prerequisites18# -------------------19#20# Before starting this recipe, make sure that you have the following:21#22# * Basic understanding of ``torch.compile`` and Triton. See:23#24# * `torch.compiler API documentation <https://pytorch.org/docs/stable/torch.compiler.html#torch-compiler>`__25# * `Introduction to torch.compile <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`__26# * `Triton language documentation <https://triton-lang.org/main/index.html>`__27#28# * PyTorch 2.3 or later29# * A GPU that supports Triton30#3132import torch33from torch.utils._triton import has_triton3435######################################################################36# Basic Usage37# --------------------38#39# In this example, we will use a simple vector addition kernel from the Triton documentation40# with ``torch.compile``.41# For reference, see `Triton documentation <https://triton-lang.org/main/getting-started/tutorials/01-vector-add.html>`__.42#4344if not has_triton():45print("Skipping because triton is not supported on this device.")46else:47import triton48from triton import language as tl4950@triton.jit51def add_kernel(52in_ptr0,53in_ptr1,54out_ptr,55n_elements,56BLOCK_SIZE: "tl.constexpr",57):58pid = tl.program_id(axis=0)59block_start = pid * BLOCK_SIZE60offsets = block_start + tl.arange(0, BLOCK_SIZE)61mask = offsets < n_elements62x = tl.load(in_ptr0 + offsets, mask=mask)63y = tl.load(in_ptr1 + offsets, mask=mask)64output = x + y65tl.store(out_ptr + offsets, output, mask=mask)6667@torch.compile(fullgraph=True)68def add_fn(x, y):69output = torch.zeros_like(x)70n_elements = output.numel()71grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)72add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=4)73return output7475x = torch.randn(4, device="cuda")76y = torch.randn(4, device="cuda")77out = add_fn(x, y)78print(f"Vector addition of\nX:\t{x}\nY:\t{y}\nis equal to\n{out}")7980######################################################################81# Advanced Usage82# -------------------------------------------------------------------83#84# Triton's autotune feature is a powerful tool that automatically optimizes the configuration85# parameters of your Triton kernels. It explores a range of possible configurations and86# selects the one that delivers the best performance for your specific use case.87#88# When used with ``torch.compile``, ``triton.autotune`` can help ensure that your PyTorch89# model is running as efficiently as possible. Here is an example of using ``torch.compile``90# and ``triton.autotune``.91#92# .. note::93#94# ``torch.compile`` only supports configs and key arguments to ``triton.autotune``.9596if not has_triton():97print("Skipping because triton is not supported on this device.")98else:99import triton100from triton import language as tl101102@triton.autotune(103configs=[104triton.Config({"BLOCK_SIZE": 4}, num_stages=3, num_warps=8),105triton.Config({"BLOCK_SIZE": 4}, num_stages=4, num_warps=4),106triton.Config({"BLOCK_SIZE": 2}, num_stages=3, num_warps=8),107triton.Config({"BLOCK_SIZE": 2}, num_stages=4, num_warps=4),108],109key=[],110)111@triton.jit112def add_kernel_autotuned(113in_ptr0,114in_ptr1,115out_ptr,116n_elements,117BLOCK_SIZE: "tl.constexpr",118):119pid = tl.program_id(axis=0)120block_start = pid * BLOCK_SIZE121offsets = block_start + tl.arange(0, BLOCK_SIZE)122mask = offsets < n_elements123x = tl.load(in_ptr0 + offsets, mask=mask)124y = tl.load(in_ptr1 + offsets, mask=mask)125output = x + y126tl.store(out_ptr + offsets, output, mask=mask)127128@torch.compile(fullgraph=True)129def add_fn(x, y):130output = torch.zeros_like(x)131n_elements = output.numel()132grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)133add_kernel_autotuned[grid](x, y, output, n_elements)134return output135136x = torch.randn(4, device="cuda")137y = torch.randn(4, device="cuda")138out = add_fn(x, y)139print(f"Vector addition of\nX:\t{x}\nY:\t{y}\nis equal to\n{out}")140141######################################################################142# Composibility and Limitations143# --------------------------------------------------------------------144#145# As of PyTorch 2.3, the support for user-defined Triton kernels in ``torch.compile``146# includes dynamic shapes, ``torch.autograd.Function``, JIT inductor, and AOT inductor.147# You can use these features together to build complex, high-performance models.148#149# However, there are certain limitations to be aware of:150#151# * **Tensor Subclasses:** Currently, there is no support for152# tensor subclasses and other advanced features.153# * **Triton Features:** While ``triton.heuristics`` can be used either standalone or154# before ``triton.autotune``, it cannot be used after ```triton.autotune``. This155# implies that if ``triton.heuristics`` and ``triton.autotune`` are to be used156# together, ``triton.heuristics`` must be used first.157#158# Conclusion159# -----------160# In this recipe, we explored how to utilize user-defined Triton kernels161# with ``torch.compile``. We delved into the basic usage of a simple162# vector addition kernel and advanced usage involving Triton's autotune163# feature. We also discussed the composability of user-defined Triton164# kernels with other PyTorch features and highlighted some current limitations.165#166# See Also167# ---------168#169# * `Compiling the Optimizers <https://pytorch.org/tutorials/recipes/compiling_optimizer.html>`__170# * `Implementing High-Performance Transformers with Scaled Dot Product Attention <https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html>`__171172173