CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!

GitHub Repository: pytorch/tutorials
Path: blob/main/intermediate_source/scaled_dot_product_attention_tutorial.py
Views: 494
1
"""
2
(Beta) Implementing High-Performance Transformers with Scaled Dot Product Attention (SDPA)
3
==========================================================================================
4
5
6
**Author:** `Driss Guessous <https://github.com/drisspg>`_
7
"""
8
9
######################################################################
10
# Summary
11
# ~~~~~~~~
12
#
13
# In this tutorial, we want to highlight a new ``torch.nn.functional`` function
14
# that can be helpful for implementing transformer architectures. The
15
# function is named ``torch.nn.functional.scaled_dot_product_attention``.
16
# For detailed description of the function, see the `PyTorch documentation <https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention>`__.
17
# This function has already been incorporated into ``torch.nn.MultiheadAttention`` and ``torch.nn.TransformerEncoderLayer``.
18
#
19
# Overview
20
# ~~~~~~~~~
21
# At a high level, this PyTorch function calculates the
22
# scaled dot product attention (SDPA) between query, key, and value according to
23
# the definition found in the paper `Attention is all you
24
# need <https://arxiv.org/abs/1706.03762>`__. While this function can
25
# be written in PyTorch using existing functions, a fused implementation can provide
26
# large performance benefits over a naive implementation.
27
#
28
# Fused implementations
29
# ~~~~~~~~~~~~~~~~~~~~~~
30
#
31
# For CUDA tensor inputs, the function will dispatch into one of the following
32
# implementations:
33
#
34
# * `FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness <https://arxiv.org/abs/2205.14135>`__
35
# * `Memory-Efficient Attention <https://github.com/facebookresearch/xformers>`__
36
# * A PyTorch implementation defined in C++
37
#
38
# .. note::
39
#
40
# This tutorial requires PyTorch 2.0.0 or later.
41
#
42
43
import torch
44
import torch.nn as nn
45
import torch.nn.functional as F
46
device = "cuda" if torch.cuda.is_available() else "cpu"
47
48
# Example Usage:
49
query, key, value = torch.randn(2, 3, 8, device=device), torch.randn(2, 3, 8, device=device), torch.randn(2, 3, 8, device=device)
50
F.scaled_dot_product_attention(query, key, value)
51
52
53
######################################################################
54
# Explicit Dispatcher Control
55
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~
56
#
57
# While the function will implicitly dispatch to one of the three
58
# implementations, the user can also explicitly control the dispatch via
59
# the use of a context manager. This context manager allows users to
60
# explicitly disable certain implementations. If a user wants to ensure
61
# the function is indeed using the fastest implementation for their
62
# specific inputs, the context manager can be used to sweep through
63
# measuring performance.
64
#
65
66
# Lets define a helpful benchmarking function:
67
import torch.utils.benchmark as benchmark
68
def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
69
t0 = benchmark.Timer(
70
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
71
)
72
return t0.blocked_autorange().mean * 1e6
73
74
# Lets define the hyper-parameters of our input
75
batch_size = 32
76
max_sequence_len = 1024
77
num_heads = 32
78
embed_dimension = 32
79
80
dtype = torch.float16
81
82
query = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
83
key = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
84
value = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
85
86
print(f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
87
88
# Lets explore the speed of each of the 3 implementations
89
from torch.nn.attention import SDPBackend, sdpa_kernel
90
91
92
with sdpa_kernel(SDPBackend.MATH):
93
math_time=benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value)
94
print(f"The math implementation runs in {math_time:.3f} microseconds")
95
96
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
97
try:
98
flash_time=benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value)
99
print(f"The flash attention implementation runs in {flash_time:.3f} microseconds")
100
except RuntimeError:
101
print("FlashAttention is not supported. See warnings for reasons.")
102
103
with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
104
try:
105
efficient_time=benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value)
106
print(f"The memory efficient implementation runs in {efficient_time:.3f} microseconds")
107
except RuntimeError:
108
print("EfficientAttention is not supported. See warnings for reasons.")
109
110
111
######################################################################
112
# Hardware dependence
113
# ~~~~~~~~~~~~~~~~~~~
114
#
115
# Depending on what machine you ran the above cell on and what hardware is
116
# available, your results might be different.
117
# - If you don’t have a GPU and are running on CPU then with FP32 the context manager
118
# will have no effect and all three runs should return similar timings.
119
# - Depending on what compute capability your graphics card supports
120
# flash attention or memory efficient might have failed.
121
122
123
######################################################################
124
# Causal Self Attention
125
# ~~~~~~~~~~~~~~~~~~~~~
126
#
127
# Below is an example implementation of a multi-headed causal self
128
# attention block inspired by
129
# `Andrej Karpathy NanoGPT <https://github.com/karpathy/nanoGPT>`__ repository.
130
#
131
132
class CausalSelfAttention(nn.Module):
133
134
def __init__(self, num_heads: int, embed_dimension: int, bias: bool=False, is_causal: bool=False, dropout:float=0.0):
135
super().__init__()
136
assert embed_dimension % num_heads == 0
137
# key, query, value projections for all heads, but in a batch
138
self.c_attn = nn.Linear(embed_dimension, 3 * embed_dimension, bias=bias)
139
# output projection
140
self.c_proj = nn.Linear(embed_dimension, embed_dimension, bias=bias)
141
# regularization
142
self.dropout = dropout
143
self.resid_dropout = nn.Dropout(dropout)
144
self.num_heads = num_heads
145
self.embed_dimension = embed_dimension
146
# Perform causal masking
147
self.is_causal = is_causal
148
149
def forward(self, x):
150
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
151
query_projected = self.c_attn(x)
152
153
batch_size = query_projected.size(0)
154
embed_dim = query_projected.size(2)
155
head_dim = embed_dim // (self.num_heads * 3)
156
157
query, key, value = query_projected.chunk(3, -1)
158
query = query.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
159
key = key.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
160
value = value.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
161
162
if self.training:
163
dropout = self.dropout
164
is_causal = self.is_causal
165
else:
166
dropout = 0.0
167
is_causal = False
168
169
y = F.scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=dropout, is_causal=is_causal)
170
y = y.transpose(1, 2).view(batch_size, -1, self.num_heads * head_dim)
171
172
y = self.resid_dropout(self.c_proj(y))
173
return y
174
175
176
num_heads = 8
177
heads_per_dim = 64
178
embed_dimension = num_heads * heads_per_dim
179
dtype = torch.float16
180
model = CausalSelfAttention(num_heads=num_heads, embed_dimension=embed_dimension, bias=False, is_causal=True, dropout=0.1).to("cuda").to(dtype).eval()
181
print(model)
182
183
184
#####################################################################
185
# ``NestedTensor`` and Dense tensor support
186
# -----------------------------------------
187
#
188
# SDPA supports both ``NestedTensor`` and Dense tensor inputs. ``NestedTensors`` handle the case where the input is a batch of variable length sequences
189
# without needing to pad each sequence to the maximum length in the batch. For more information about ``NestedTensors`` see
190
# `torch.nested <https://pytorch.org/docs/stable/nested.html>`__ and `NestedTensors Tutorial <https://pytorch.org/tutorials/prototype/nestedtensor.html>`__.
191
#
192
193
import random
194
def generate_rand_batch(
195
batch_size,
196
max_sequence_len,
197
embed_dimension,
198
pad_percentage=None,
199
dtype=torch.float16,
200
device="cuda",
201
):
202
if not pad_percentage:
203
return (
204
torch.randn(
205
batch_size,
206
max_sequence_len,
207
embed_dimension,
208
dtype=dtype,
209
device=device,
210
),
211
None,
212
)
213
# Random sequence lengths
214
seq_len_list = [
215
int(max_sequence_len * (1 - random.gauss(pad_percentage, 0.01)))
216
for _ in range(batch_size)
217
]
218
# Make random entry in the batch have max sequence length
219
seq_len_list[random.randint(0, batch_size - 1)] = max_sequence_len
220
return (
221
torch.nested.nested_tensor(
222
[
223
torch.randn(seq_len, embed_dimension,
224
dtype=dtype, device=device)
225
for seq_len in seq_len_list
226
]
227
),
228
seq_len_list,
229
)
230
231
random_nt, _ = generate_rand_batch(32, 512, embed_dimension, pad_percentage=0.5, dtype=dtype, device=device)
232
random_dense, _ = generate_rand_batch(32, 512, embed_dimension, pad_percentage=None, dtype=dtype, device=device)
233
234
# Currently the fused implementations don't support ``NestedTensor`` for training
235
model.eval()
236
237
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
238
try:
239
print(f"Random NT runs in {benchmark_torch_function_in_microseconds(model, random_nt):.3f} microseconds")
240
print(f"Random Dense runs in {benchmark_torch_function_in_microseconds(model, random_dense):.3f} microseconds")
241
except RuntimeError:
242
print("FlashAttention is not supported. See warnings for reasons.")
243
244
245
######################################################################
246
# Using SDPA with ``torch.compile``
247
# =================================
248
#
249
# With the release of PyTorch 2.0, a new feature called
250
# ``torch.compile()`` has been introduced, which can provide
251
# significant performance improvements over eager mode.
252
# Scaled dot product attention is fully composable with ``torch.compile()``.
253
# To demonstrate this, let's compile the ``CausalSelfAttention`` module using
254
# ``torch.compile()`` and observe the resulting performance improvements.
255
#
256
257
batch_size = 32
258
max_sequence_len = 256
259
x = torch.rand(batch_size, max_sequence_len,
260
embed_dimension, device=device, dtype=dtype)
261
print(
262
f"The non compiled module runs in {benchmark_torch_function_in_microseconds(model, x):.3f} microseconds")
263
264
265
compiled_model = torch.compile(model)
266
# Let's compile it
267
compiled_model(x)
268
print(
269
f"The compiled module runs in {benchmark_torch_function_in_microseconds(compiled_model, x):.3f} microseconds")
270
271
272
######################################################################
273
#
274
# The exact execution time is dependent on machine, however the results for mine:
275
# The non compiled module runs in 166.616 microseconds
276
# The compiled module runs in 166.726 microseconds
277
# That is not what we were expecting. Let's dig a little deeper.
278
# PyTorch comes with an amazing built-in profiler that you can use to
279
# inspect the performance characteristics of your code.
280
#
281
282
from torch.profiler import profile, record_function, ProfilerActivity
283
activities = [ProfilerActivity.CPU]
284
if device == 'cuda':
285
activities.append(ProfilerActivity.CUDA)
286
287
with profile(activities=activities, record_shapes=False) as prof:
288
with record_function(" Non-Compilied Causal Attention"):
289
for _ in range(25):
290
model(x)
291
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
292
293
294
with profile(activities=activities, record_shapes=False) as prof:
295
with record_function("Compiled Causal Attention"):
296
for _ in range(25):
297
compiled_model(x)
298
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
299
300
# For even more insights, you can export the trace and use ``chrome://tracing`` to view the results
301
#
302
# .. code-block:: python
303
#
304
# prof.export_chrome_trace("compiled_causal_attention_trace.json").
305
306
307
308
309
######################################################################
310
# The previous code snippet generates a report of the top 10 PyTorch functions
311
# that consumed the most GPU execution time, for both the compiled and non-compiled module.
312
# The analysis reveals that the majority of time spent on the GPU is concentrated
313
# on the same set of functions for both modules.
314
# The reason for this here is that ``torch.compile`` is very good at removing the
315
# framework overhead associated with PyTorch. If your model is launching
316
# large, efficient CUDA kernels, which in this case ``CausalSelfAttention``
317
# is, then the overhead of PyTorch can be hidden.
318
#
319
# In reality, your module does not normally consist of a singular
320
# ``CausalSelfAttention`` block. When experimenting with `Andrej Karpathy NanoGPT <https://github.com/karpathy/nanoGPT>`__ repository, compiling
321
# the module took the time per train step from: ``6090.49ms`` to
322
# ``3273.17ms``! This was done on commit: ``ae3a8d5`` of NanoGPT training on
323
# the Shakespeare dataset.
324
#
325
326
######################################################################
327
# Using SDPA with attn_bias subclasses`
328
# ==========================================
329
#
330
# As of PyTorch 2.3, we have added a new submodule that contains tensor subclasses.
331
# Designed to be used with ``torch.nn.functional.scaled_dot_product_attention``.
332
# The module is named ``torch.nn.attention.bias`` and contains the following two
333
# utilities for generating causal attention variants:
334
#
335
# - ``torch.nn.attention.bias.causal_upper_left``
336
# - ``torch.nn.attention.bias.causal_lower_right``
337
#
338
# .. note::
339
# The current argument ``is_causal`` in ``torch.nn.functional.scaled_dot_product_attention``
340
# is the same as using ``torch.nn.attention.bias.causal_upper_left``.
341
#
342
343
from torch.nn.attention.bias import causal_lower_right, causal_upper_left
344
345
batch_size = 32
346
sequence_length_q = 2
347
sequence_length_kv = 10
348
num_heads = 16
349
embed_dimension = 32
350
351
dtype = torch.float16
352
353
query = torch.rand(batch_size, num_heads, sequence_length_q, embed_dimension, device=device, dtype=dtype)
354
key = torch.rand(batch_size, num_heads, sequence_length_kv, embed_dimension, device=device, dtype=dtype)
355
value = torch.rand(batch_size, num_heads, sequence_length_kv, embed_dimension, device=device, dtype=dtype)
356
357
upper_left_bias = causal_upper_left(sequence_length_q, sequence_length_kv)
358
lower_right_bias = causal_lower_right(sequence_length_q, sequence_length_kv)
359
360
print(type(upper_left_bias))
361
print(type(lower_right_bias))
362
363
assert type(upper_left_bias) == type(lower_right_bias)
364
assert issubclass(type(upper_left_bias), torch.Tensor)
365
366
# As you can see from the previous output, are the same type ``torch.nn.attention.bias.CausalBias``
367
# and subclass ``torch.Tensor``
368
369
# Lets see what these tensors look like
370
print(upper_left_bias)
371
print(lower_right_bias)
372
373
# Upper Left Bias aligns the causal attention mask to the upper left corner of the attention scores matrix.
374
# This only has an impact when the attention scores matrix is not square, which is common for decoding use cases.
375
# Another way of thinking about this concept is that when you use upper left bias,
376
# the 0th token in the query is aligned to the 0th token in the key, while for lower right bias,
377
# Assuming the attention score matrix is two dimensional, ``attn_score[0][0]`` is the attention score
378
# between the 0th token in the query and the 0th token in the key.
379
# For lower right bias, the sequence of q is aligned so that the last token in q is aligned to the last token in k
380
# (for example, ``attn_score[-1][-1])`` is all True since the last token in q is at the same position as the last token in k
381
# even if the sequence length of q and k are different.
382
383
# These objects are intended to be used with sdpa
384
out_upper_left = F.scaled_dot_product_attention(query, key, value, upper_left_bias)
385
out_lower_right = F.scaled_dot_product_attention(query, key, value, lower_right_bias)
386
out_is_causal = F.scaled_dot_product_attention(query, key, value, is_causal=True)
387
388
assert torch.allclose(out_upper_left, out_is_causal)
389
assert not torch.allclose(out_upper_left, out_lower_right)
390
391
# These attention biases should also be compatible with torch.compile
392
compiled_sdpa = torch.compile(F.scaled_dot_product_attention, fullgraph=True)
393
out_upper_left = compiled_sdpa(query, key, value, upper_left_bias)
394
395
######################################################################
396
# Conclusion
397
# ==========
398
#
399
# In this tutorial, we have demonstrated the basic usage of
400
# ``torch.nn.functional.scaled_dot_product_attention``. We have shown how
401
# the ``sdpa_kernel`` context manager can be used to assert a certain
402
# implementation is used on GPU. As well, we built a simple
403
# ``CausalSelfAttention`` module that works with ``NestedTensor`` and is torch
404
# compilable. In the process we have shown how to the profiling tools can
405
# be used to explore the performance characteristics of a user defined
406
# module.
407
#
408
409