CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!
CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!
Path: blob/main/intermediate_source/scaled_dot_product_attention_tutorial.py
Views: 494
"""1(Beta) Implementing High-Performance Transformers with Scaled Dot Product Attention (SDPA)2==========================================================================================345**Author:** `Driss Guessous <https://github.com/drisspg>`_6"""78######################################################################9# Summary10# ~~~~~~~~11#12# In this tutorial, we want to highlight a new ``torch.nn.functional`` function13# that can be helpful for implementing transformer architectures. The14# function is named ``torch.nn.functional.scaled_dot_product_attention``.15# For detailed description of the function, see the `PyTorch documentation <https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention>`__.16# This function has already been incorporated into ``torch.nn.MultiheadAttention`` and ``torch.nn.TransformerEncoderLayer``.17#18# Overview19# ~~~~~~~~~20# At a high level, this PyTorch function calculates the21# scaled dot product attention (SDPA) between query, key, and value according to22# the definition found in the paper `Attention is all you23# need <https://arxiv.org/abs/1706.03762>`__. While this function can24# be written in PyTorch using existing functions, a fused implementation can provide25# large performance benefits over a naive implementation.26#27# Fused implementations28# ~~~~~~~~~~~~~~~~~~~~~~29#30# For CUDA tensor inputs, the function will dispatch into one of the following31# implementations:32#33# * `FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness <https://arxiv.org/abs/2205.14135>`__34# * `Memory-Efficient Attention <https://github.com/facebookresearch/xformers>`__35# * A PyTorch implementation defined in C++36#37# .. note::38#39# This tutorial requires PyTorch 2.0.0 or later.40#4142import torch43import torch.nn as nn44import torch.nn.functional as F45device = "cuda" if torch.cuda.is_available() else "cpu"4647# Example Usage:48query, key, value = torch.randn(2, 3, 8, device=device), torch.randn(2, 3, 8, device=device), torch.randn(2, 3, 8, device=device)49F.scaled_dot_product_attention(query, key, value)505152######################################################################53# Explicit Dispatcher Control54# ~~~~~~~~~~~~~~~~~~~~~~~~~~~55#56# While the function will implicitly dispatch to one of the three57# implementations, the user can also explicitly control the dispatch via58# the use of a context manager. This context manager allows users to59# explicitly disable certain implementations. If a user wants to ensure60# the function is indeed using the fastest implementation for their61# specific inputs, the context manager can be used to sweep through62# measuring performance.63#6465# Lets define a helpful benchmarking function:66import torch.utils.benchmark as benchmark67def benchmark_torch_function_in_microseconds(f, *args, **kwargs):68t0 = benchmark.Timer(69stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}70)71return t0.blocked_autorange().mean * 1e67273# Lets define the hyper-parameters of our input74batch_size = 3275max_sequence_len = 102476num_heads = 3277embed_dimension = 327879dtype = torch.float168081query = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)82key = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)83value = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)8485print(f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")8687# Lets explore the speed of each of the 3 implementations88from torch.nn.attention import SDPBackend, sdpa_kernel899091with sdpa_kernel(SDPBackend.MATH):92math_time=benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value)93print(f"The math implementation runs in {math_time:.3f} microseconds")9495with sdpa_kernel(SDPBackend.FLASH_ATTENTION):96try:97flash_time=benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value)98print(f"The flash attention implementation runs in {flash_time:.3f} microseconds")99except RuntimeError:100print("FlashAttention is not supported. See warnings for reasons.")101102with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):103try:104efficient_time=benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value)105print(f"The memory efficient implementation runs in {efficient_time:.3f} microseconds")106except RuntimeError:107print("EfficientAttention is not supported. See warnings for reasons.")108109110######################################################################111# Hardware dependence112# ~~~~~~~~~~~~~~~~~~~113#114# Depending on what machine you ran the above cell on and what hardware is115# available, your results might be different.116# - If you don’t have a GPU and are running on CPU then with FP32 the context manager117# will have no effect and all three runs should return similar timings.118# - Depending on what compute capability your graphics card supports119# flash attention or memory efficient might have failed.120121122######################################################################123# Causal Self Attention124# ~~~~~~~~~~~~~~~~~~~~~125#126# Below is an example implementation of a multi-headed causal self127# attention block inspired by128# `Andrej Karpathy NanoGPT <https://github.com/karpathy/nanoGPT>`__ repository.129#130131class CausalSelfAttention(nn.Module):132133def __init__(self, num_heads: int, embed_dimension: int, bias: bool=False, is_causal: bool=False, dropout:float=0.0):134super().__init__()135assert embed_dimension % num_heads == 0136# key, query, value projections for all heads, but in a batch137self.c_attn = nn.Linear(embed_dimension, 3 * embed_dimension, bias=bias)138# output projection139self.c_proj = nn.Linear(embed_dimension, embed_dimension, bias=bias)140# regularization141self.dropout = dropout142self.resid_dropout = nn.Dropout(dropout)143self.num_heads = num_heads144self.embed_dimension = embed_dimension145# Perform causal masking146self.is_causal = is_causal147148def forward(self, x):149# calculate query, key, values for all heads in batch and move head forward to be the batch dim150query_projected = self.c_attn(x)151152batch_size = query_projected.size(0)153embed_dim = query_projected.size(2)154head_dim = embed_dim // (self.num_heads * 3)155156query, key, value = query_projected.chunk(3, -1)157query = query.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)158key = key.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)159value = value.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)160161if self.training:162dropout = self.dropout163is_causal = self.is_causal164else:165dropout = 0.0166is_causal = False167168y = F.scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=dropout, is_causal=is_causal)169y = y.transpose(1, 2).view(batch_size, -1, self.num_heads * head_dim)170171y = self.resid_dropout(self.c_proj(y))172return y173174175num_heads = 8176heads_per_dim = 64177embed_dimension = num_heads * heads_per_dim178dtype = torch.float16179model = CausalSelfAttention(num_heads=num_heads, embed_dimension=embed_dimension, bias=False, is_causal=True, dropout=0.1).to("cuda").to(dtype).eval()180print(model)181182183#####################################################################184# ``NestedTensor`` and Dense tensor support185# -----------------------------------------186#187# SDPA supports both ``NestedTensor`` and Dense tensor inputs. ``NestedTensors`` handle the case where the input is a batch of variable length sequences188# without needing to pad each sequence to the maximum length in the batch. For more information about ``NestedTensors`` see189# `torch.nested <https://pytorch.org/docs/stable/nested.html>`__ and `NestedTensors Tutorial <https://pytorch.org/tutorials/prototype/nestedtensor.html>`__.190#191192import random193def generate_rand_batch(194batch_size,195max_sequence_len,196embed_dimension,197pad_percentage=None,198dtype=torch.float16,199device="cuda",200):201if not pad_percentage:202return (203torch.randn(204batch_size,205max_sequence_len,206embed_dimension,207dtype=dtype,208device=device,209),210None,211)212# Random sequence lengths213seq_len_list = [214int(max_sequence_len * (1 - random.gauss(pad_percentage, 0.01)))215for _ in range(batch_size)216]217# Make random entry in the batch have max sequence length218seq_len_list[random.randint(0, batch_size - 1)] = max_sequence_len219return (220torch.nested.nested_tensor(221[222torch.randn(seq_len, embed_dimension,223dtype=dtype, device=device)224for seq_len in seq_len_list225]226),227seq_len_list,228)229230random_nt, _ = generate_rand_batch(32, 512, embed_dimension, pad_percentage=0.5, dtype=dtype, device=device)231random_dense, _ = generate_rand_batch(32, 512, embed_dimension, pad_percentage=None, dtype=dtype, device=device)232233# Currently the fused implementations don't support ``NestedTensor`` for training234model.eval()235236with sdpa_kernel(SDPBackend.FLASH_ATTENTION):237try:238print(f"Random NT runs in {benchmark_torch_function_in_microseconds(model, random_nt):.3f} microseconds")239print(f"Random Dense runs in {benchmark_torch_function_in_microseconds(model, random_dense):.3f} microseconds")240except RuntimeError:241print("FlashAttention is not supported. See warnings for reasons.")242243244######################################################################245# Using SDPA with ``torch.compile``246# =================================247#248# With the release of PyTorch 2.0, a new feature called249# ``torch.compile()`` has been introduced, which can provide250# significant performance improvements over eager mode.251# Scaled dot product attention is fully composable with ``torch.compile()``.252# To demonstrate this, let's compile the ``CausalSelfAttention`` module using253# ``torch.compile()`` and observe the resulting performance improvements.254#255256batch_size = 32257max_sequence_len = 256258x = torch.rand(batch_size, max_sequence_len,259embed_dimension, device=device, dtype=dtype)260print(261f"The non compiled module runs in {benchmark_torch_function_in_microseconds(model, x):.3f} microseconds")262263264compiled_model = torch.compile(model)265# Let's compile it266compiled_model(x)267print(268f"The compiled module runs in {benchmark_torch_function_in_microseconds(compiled_model, x):.3f} microseconds")269270271######################################################################272#273# The exact execution time is dependent on machine, however the results for mine:274# The non compiled module runs in 166.616 microseconds275# The compiled module runs in 166.726 microseconds276# That is not what we were expecting. Let's dig a little deeper.277# PyTorch comes with an amazing built-in profiler that you can use to278# inspect the performance characteristics of your code.279#280281from torch.profiler import profile, record_function, ProfilerActivity282activities = [ProfilerActivity.CPU]283if device == 'cuda':284activities.append(ProfilerActivity.CUDA)285286with profile(activities=activities, record_shapes=False) as prof:287with record_function(" Non-Compilied Causal Attention"):288for _ in range(25):289model(x)290print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))291292293with profile(activities=activities, record_shapes=False) as prof:294with record_function("Compiled Causal Attention"):295for _ in range(25):296compiled_model(x)297print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))298299# For even more insights, you can export the trace and use ``chrome://tracing`` to view the results300#301# .. code-block:: python302#303# prof.export_chrome_trace("compiled_causal_attention_trace.json").304305306307308######################################################################309# The previous code snippet generates a report of the top 10 PyTorch functions310# that consumed the most GPU execution time, for both the compiled and non-compiled module.311# The analysis reveals that the majority of time spent on the GPU is concentrated312# on the same set of functions for both modules.313# The reason for this here is that ``torch.compile`` is very good at removing the314# framework overhead associated with PyTorch. If your model is launching315# large, efficient CUDA kernels, which in this case ``CausalSelfAttention``316# is, then the overhead of PyTorch can be hidden.317#318# In reality, your module does not normally consist of a singular319# ``CausalSelfAttention`` block. When experimenting with `Andrej Karpathy NanoGPT <https://github.com/karpathy/nanoGPT>`__ repository, compiling320# the module took the time per train step from: ``6090.49ms`` to321# ``3273.17ms``! This was done on commit: ``ae3a8d5`` of NanoGPT training on322# the Shakespeare dataset.323#324325######################################################################326# Using SDPA with attn_bias subclasses`327# ==========================================328#329# As of PyTorch 2.3, we have added a new submodule that contains tensor subclasses.330# Designed to be used with ``torch.nn.functional.scaled_dot_product_attention``.331# The module is named ``torch.nn.attention.bias`` and contains the following two332# utilities for generating causal attention variants:333#334# - ``torch.nn.attention.bias.causal_upper_left``335# - ``torch.nn.attention.bias.causal_lower_right``336#337# .. note::338# The current argument ``is_causal`` in ``torch.nn.functional.scaled_dot_product_attention``339# is the same as using ``torch.nn.attention.bias.causal_upper_left``.340#341342from torch.nn.attention.bias import causal_lower_right, causal_upper_left343344batch_size = 32345sequence_length_q = 2346sequence_length_kv = 10347num_heads = 16348embed_dimension = 32349350dtype = torch.float16351352query = torch.rand(batch_size, num_heads, sequence_length_q, embed_dimension, device=device, dtype=dtype)353key = torch.rand(batch_size, num_heads, sequence_length_kv, embed_dimension, device=device, dtype=dtype)354value = torch.rand(batch_size, num_heads, sequence_length_kv, embed_dimension, device=device, dtype=dtype)355356upper_left_bias = causal_upper_left(sequence_length_q, sequence_length_kv)357lower_right_bias = causal_lower_right(sequence_length_q, sequence_length_kv)358359print(type(upper_left_bias))360print(type(lower_right_bias))361362assert type(upper_left_bias) == type(lower_right_bias)363assert issubclass(type(upper_left_bias), torch.Tensor)364365# As you can see from the previous output, are the same type ``torch.nn.attention.bias.CausalBias``366# and subclass ``torch.Tensor``367368# Lets see what these tensors look like369print(upper_left_bias)370print(lower_right_bias)371372# Upper Left Bias aligns the causal attention mask to the upper left corner of the attention scores matrix.373# This only has an impact when the attention scores matrix is not square, which is common for decoding use cases.374# Another way of thinking about this concept is that when you use upper left bias,375# the 0th token in the query is aligned to the 0th token in the key, while for lower right bias,376# Assuming the attention score matrix is two dimensional, ``attn_score[0][0]`` is the attention score377# between the 0th token in the query and the 0th token in the key.378# For lower right bias, the sequence of q is aligned so that the last token in q is aligned to the last token in k379# (for example, ``attn_score[-1][-1])`` is all True since the last token in q is at the same position as the last token in k380# even if the sequence length of q and k are different.381382# These objects are intended to be used with sdpa383out_upper_left = F.scaled_dot_product_attention(query, key, value, upper_left_bias)384out_lower_right = F.scaled_dot_product_attention(query, key, value, lower_right_bias)385out_is_causal = F.scaled_dot_product_attention(query, key, value, is_causal=True)386387assert torch.allclose(out_upper_left, out_is_causal)388assert not torch.allclose(out_upper_left, out_lower_right)389390# These attention biases should also be compatible with torch.compile391compiled_sdpa = torch.compile(F.scaled_dot_product_attention, fullgraph=True)392out_upper_left = compiled_sdpa(query, key, value, upper_left_bias)393394######################################################################395# Conclusion396# ==========397#398# In this tutorial, we have demonstrated the basic usage of399# ``torch.nn.functional.scaled_dot_product_attention``. We have shown how400# the ``sdpa_kernel`` context manager can be used to assert a certain401# implementation is used on GPU. As well, we built a simple402# ``CausalSelfAttention`` module that works with ``NestedTensor`` and is torch403# compilable. In the process we have shown how to the profiling tools can404# be used to explore the performance characteristics of a user defined405# module.406#407408409