CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!

GitHub Repository: pytorch/tutorials
Path: blob/main/prototype_source/nestedtensor.py
Views: 494
1
"""
2
3
Getting Started with Nested Tensors
4
===============================================================
5
6
Nested tensors generalize the shape of regular dense tensors, allowing for representation
7
of ragged-sized data.
8
9
* for a regular tensor, each dimension is regular and has a size
10
11
* for a nested tensor, not all dimensions have regular sizes; some of them are ragged
12
13
Nested tensors are a natural solution for representing sequential data within various domains:
14
15
* in NLP, sentences can have variable lengths, so a batch of sentences forms a nested tensor
16
17
* in CV, images can have variable shapes, so a batch of images forms a nested tensor
18
19
In this tutorial, we will demonstrate basic usage of nested tensors and motivate their usefulness
20
for operating on sequential data of varying lengths with a real-world example. In particular,
21
they are invaluable for building transformers that can efficiently operate on ragged sequential
22
inputs. Below, we present an implementation of multi-head attention using nested tensors that,
23
combined usage of ``torch.compile``, out-performs operating naively on tensors with padding.
24
25
Nested tensors are currently a prototype feature and are subject to change.
26
"""
27
28
import numpy as np
29
import timeit
30
import torch
31
import torch.nn.functional as F
32
33
from torch import nn
34
35
torch.manual_seed(1)
36
np.random.seed(1)
37
38
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
39
40
######################################################################
41
# Nested tensor initialization
42
# ----------------------------
43
#
44
# From the Python frontend, a nested tensor can be created from a list of tensors.
45
# We denote nt[i] as the ith tensor component of a nestedtensor.
46
nt = torch.nested.nested_tensor([torch.arange(12).reshape(
47
2, 6), torch.arange(18).reshape(3, 6)], dtype=torch.float, device=device)
48
print(f"{nt=}")
49
50
######################################################################
51
# By padding every underlying tensor to the same shape,
52
# a nestedtensor can be converted to a regular tensor.
53
padded_out_tensor = torch.nested.to_padded_tensor(nt, padding=0.0)
54
print(f"{padded_out_tensor=}")
55
56
######################################################################
57
# All tensors posses an attribute for determining if they are nested;
58
print(f"nt is nested: {nt.is_nested}")
59
print(f"padded_out_tensor is nested: {padded_out_tensor.is_nested}")
60
61
######################################################################
62
# It is common to construct nestedtensors from batches of irregularly shaped tensors.
63
# i.e. dimension 0 is assumed to be the batch dimension.
64
# Indexing dimension 0 gives back the first underlying tensor component.
65
print("First underlying tensor component:", nt[0], sep='\n')
66
print("last column of 2nd underlying tensor component:", nt[1, :, -1], sep='\n')
67
68
# When indexing a nestedtensor's 0th dimension, the result is a regular tensor.
69
print(f"First underlying tensor component is nested: {nt[0].is_nested}")
70
71
######################################################################
72
# An important note is that slicing in dimension 0 has not been supported yet.
73
# Which means it not currently possible to construct a view that combines the underlying
74
# tensor components.
75
76
######################################################################
77
# Nested Tensor Operations
78
# ------------------------
79
#
80
# As each operation must be explicitly implemented for nestedtensors,
81
# operation coverage for nestedtensors is currently narrower than that of regular tensors.
82
# For now, only basic operations such as index, dropout, softmax, transpose, reshape, linear, bmm are covered.
83
# However, coverage is being expanded.
84
# If you need certain operations, please file an `issue <https://github.com/pytorch/pytorch>`__
85
# to help us prioritize coverage.
86
#
87
# **reshape**
88
#
89
# The reshape op is for changing the shape of a tensor.
90
# Its full semantics for regular tensors can be found
91
# `here <https://pytorch.org/docs/stable/generated/torch.reshape.html>`__.
92
# For regular tensors, when specifying the new shape,
93
# a single dimension may be -1, in which case it is inferred
94
# from the remaining dimensions and the number of elements.
95
#
96
# The semantics for nestedtensors are similar, except that -1 no longer infers.
97
# Instead, it inherits the old size (here 2 for ``nt[0]`` and 3 for ``nt[1]``).
98
# -1 is the only legal size to specify for a jagged dimension.
99
nt_reshaped = nt.reshape(2, -1, 2, 3)
100
print(f"{nt_reshaped=}")
101
102
######################################################################
103
# **transpose**
104
#
105
# The transpose op is for swapping two dimensions of a tensor.
106
# Its full semantics can be found
107
# `here <https://pytorch.org/docs/stable/generated/torch.transpose.html>`__.
108
# Note that for nestedtensors dimension 0 is special;
109
# it is assumed to be the batch dimension,
110
# so transposes involving nestedtensor dimension 0 are not supported.
111
nt_transposed = nt_reshaped.transpose(1, 2)
112
print(f"{nt_transposed=}")
113
114
######################################################################
115
# **others**
116
#
117
# Other operations have the same semantics as for regular tensors.
118
# Applying the operation on a nestedtensor is equivalent to
119
# applying the operation to the underlying tensor components,
120
# with the result being a nestedtensor as well.
121
nt_mm = torch.nested.nested_tensor([torch.randn((2, 3, 4)), torch.randn((2, 3, 5))], device=device)
122
nt3 = torch.matmul(nt_transposed, nt_mm)
123
print(f"Result of Matmul:\n {nt3}")
124
125
nt4 = F.dropout(nt3, 0.1)
126
print(f"Result of Dropout:\n {nt4}")
127
128
nt5 = F.softmax(nt4, -1)
129
print(f"Result of Softmax:\n {nt5}")
130
131
######################################################################
132
# Why Nested Tensor
133
# -----------------
134
#
135
136
######################################################################
137
# When data is sequential, it is often the case that each sample has a different length.
138
# For example, in a batch of sentences, each sentence has a different number of words.
139
# A common technique for handling varying sequences is to manually pad each data tensor
140
# to the same shape in order to form a batch.
141
# For example, we have 2 sentences with different lengths and a vocabulary
142
# In order to represent his as single tensor we pad with 0 to the max length in the batch.
143
sentences = [["goodbye", "padding"],
144
["embrace", "nested", "tensor"]]
145
vocabulary = {"goodbye": 1.0, "padding": 2.0,
146
"embrace": 3.0, "nested": 4.0, "tensor": 5.0}
147
padded_sentences = torch.tensor([[1.0, 2.0, 0.0],
148
[3.0, 4.0, 5.0]])
149
nested_sentences = torch.nested.nested_tensor([torch.tensor([1.0, 2.0]),
150
torch.tensor([3.0, 4.0, 5.0])])
151
print(f"{padded_sentences=}")
152
print(f"{nested_sentences=}")
153
154
######################################################################
155
# This technique of padding a batch of data to its max length is not optimal.
156
# The padded data is not needed for computation and wastes memory by allocating
157
# larger tensors than necessary.
158
# Further, not all operations have the same semnatics when applied to padded data.
159
# For matrix multiplications in order to ignore the padded entries, one needs to pad
160
# with 0 while for softmax one has to pad with -inf to ignore specific entries.
161
# The primary objective of nested tensor is to facilitate operations on ragged
162
# data using the standard PyTorch tensor UX, thereby eliminating the need
163
# for inefficient and complex padding and masking.
164
padded_sentences_for_softmax = torch.tensor([[1.0, 2.0, float("-inf")],
165
[3.0, 4.0, 5.0]])
166
print(F.softmax(padded_sentences_for_softmax, -1))
167
print(F.softmax(nested_sentences, -1))
168
169
######################################################################
170
# Let us take a look at a practical example: the multi-head attention component
171
# utilized in `Transformers <https://arxiv.org/pdf/1706.03762.pdf>`__.
172
# We can implement this in such a way that it can operate on either padded
173
# or nested tensors.
174
class MultiHeadAttention(nn.Module):
175
"""
176
Computes multi-head attention. Supports nested or padded tensors.
177
178
Args:
179
E_q (int): Size of embedding dim for query
180
E_k (int): Size of embedding dim for key
181
E_v (int): Size of embedding dim for value
182
E_total (int): Total embedding dim of combined heads post input projection. Each head
183
has dim E_total // nheads
184
nheads (int): Number of heads
185
dropout_p (float, optional): Dropout probability. Default: 0.0
186
"""
187
def __init__(self, E_q: int, E_k: int, E_v: int, E_total: int,
188
nheads: int, dropout_p: float = 0.0):
189
super().__init__()
190
self.nheads = nheads
191
self.dropout_p = dropout_p
192
self.query_proj = nn.Linear(E_q, E_total)
193
self.key_proj = nn.Linear(E_k, E_total)
194
self.value_proj = nn.Linear(E_v, E_total)
195
E_out = E_q
196
self.out_proj = nn.Linear(E_total, E_out)
197
assert E_total % nheads == 0, "Embedding dim is not divisible by nheads"
198
self.E_head = E_total // nheads
199
200
def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
201
"""
202
Forward pass; runs the following process:
203
1. Apply input projection
204
2. Split heads and prepare for SDPA
205
3. Run SDPA
206
4. Apply output projection
207
208
Args:
209
query (torch.Tensor): query of shape (N, L_t, E_q)
210
key (torch.Tensor): key of shape (N, L_s, E_k)
211
value (torch.Tensor): value of shape (N, L_s, E_v)
212
213
Returns:
214
attn_output (torch.Tensor): output of shape (N, L_t, E_q)
215
"""
216
# Step 1. Apply input projection
217
# TODO: demonstrate packed projection
218
query = self.query_proj(query)
219
key = self.key_proj(key)
220
value = self.value_proj(value)
221
222
# Step 2. Split heads and prepare for SDPA
223
# reshape query, key, value to separate by head
224
# (N, L_t, E_total) -> (N, L_t, nheads, E_head) -> (N, nheads, L_t, E_head)
225
query = query.unflatten(-1, [self.nheads, self.E_head]).transpose(1, 2)
226
# (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head)
227
key = key.unflatten(-1, [self.nheads, self.E_head]).transpose(1, 2)
228
# (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head)
229
value = value.unflatten(-1, [self.nheads, self.E_head]).transpose(1, 2)
230
231
# Step 3. Run SDPA
232
# (N, nheads, L_t, E_head)
233
attn_output = F.scaled_dot_product_attention(
234
query, key, value, dropout_p=dropout_p, is_causal=True)
235
# (N, nheads, L_t, E_head) -> (N, L_t, nheads, E_head) -> (N, L_t, E_total)
236
attn_output = attn_output.transpose(1, 2).flatten(-2)
237
238
# Step 4. Apply output projection
239
# (N, L_t, E_total) -> (N, L_t, E_out)
240
attn_output = self.out_proj(attn_output)
241
242
return attn_output
243
244
######################################################################
245
# set hyperparameters following `the Transformer paper <https://arxiv.org/pdf/1706.03762.pdf>`__
246
N = 512
247
E_q, E_k, E_v, E_total = 512, 512, 512, 512
248
E_out = E_q
249
nheads = 8
250
251
######################################################################
252
# except for dropout probability: set to 0 for correctness check
253
dropout_p = 0.0
254
255
######################################################################
256
# Let us generate some realistic fake data from Zipf's law.
257
def zipf_sentence_lengths(alpha: float, batch_size: int) -> torch.Tensor:
258
# generate fake corpus by unigram Zipf distribution
259
# from wikitext-2 corpus, we get rank "." = 3, "!" = 386, "?" = 858
260
sentence_lengths = np.empty(batch_size, dtype=int)
261
for ibatch in range(batch_size):
262
sentence_lengths[ibatch] = 1
263
word = np.random.zipf(alpha)
264
while word != 3 and word != 386 and word != 858:
265
sentence_lengths[ibatch] += 1
266
word = np.random.zipf(alpha)
267
return torch.tensor(sentence_lengths)
268
269
######################################################################
270
# Create nested tensor batch inputs
271
def gen_batch(N, E_q, E_k, E_v, device):
272
# generate semi-realistic data using Zipf distribution for sentence lengths
273
sentence_lengths = zipf_sentence_lengths(alpha=1.2, batch_size=N)
274
275
# Note: the torch.jagged layout is a nested tensor layout that supports a single ragged
276
# dimension and works with torch.compile. The batch items each have shape (B, S*, D)
277
# where B = batch size, S* = ragged sequence length, and D = embedding dimension.
278
query = torch.nested.nested_tensor([
279
torch.randn(l.item(), E_q, device=device)
280
for l in sentence_lengths
281
], layout=torch.jagged)
282
283
key = torch.nested.nested_tensor([
284
torch.randn(s.item(), E_k, device=device)
285
for s in sentence_lengths
286
], layout=torch.jagged)
287
288
value = torch.nested.nested_tensor([
289
torch.randn(s.item(), E_v, device=device)
290
for s in sentence_lengths
291
], layout=torch.jagged)
292
293
return query, key, value, sentence_lengths
294
295
query, key, value, sentence_lengths = gen_batch(N, E_q, E_k, E_v, device)
296
297
######################################################################
298
# Generate padded forms of query, key, value for comparison
299
def jagged_to_padded(jt, padding_val):
300
# TODO: do jagged -> padded directly when this is supported
301
return torch.nested.to_padded_tensor(
302
torch.nested.nested_tensor(list(jt.unbind())),
303
padding_val)
304
305
padded_query, padded_key, padded_value = (
306
jagged_to_padded(t, 0.0) for t in (query, key, value)
307
)
308
309
######################################################################
310
# Construct the model
311
mha = MultiHeadAttention(E_q, E_k, E_v, E_total, nheads, dropout_p).to(device=device)
312
313
######################################################################
314
# Check correctness and performance
315
def benchmark(func, *args, **kwargs):
316
torch.cuda.synchronize()
317
begin = timeit.default_timer()
318
output = func(*args, **kwargs)
319
torch.cuda.synchronize()
320
end = timeit.default_timer()
321
return output, (end - begin)
322
323
output_nested, time_nested = benchmark(mha, query, key, value)
324
output_padded, time_padded = benchmark(mha, padded_query, padded_key, padded_value)
325
326
# padding-specific step: remove output projection bias from padded entries for fair comparison
327
for i, entry_length in enumerate(sentence_lengths):
328
output_padded[i, entry_length:] = 0.0
329
330
print("=== without torch.compile ===")
331
print("nested and padded calculations differ by", (jagged_to_padded(output_nested, 0.0) - output_padded).abs().max().item())
332
print("nested tensor multi-head attention takes", time_nested, "seconds")
333
print("padded tensor multi-head attention takes", time_padded, "seconds")
334
335
# warm up compile first...
336
compiled_mha = torch.compile(mha)
337
compiled_mha(query, key, value)
338
# ...now benchmark
339
compiled_output_nested, compiled_time_nested = benchmark(
340
compiled_mha, query, key, value)
341
342
# warm up compile first...
343
compiled_mha(padded_query, padded_key, padded_value)
344
# ...now benchmark
345
compiled_output_padded, compiled_time_padded = benchmark(
346
compiled_mha, padded_query, padded_key, padded_value)
347
348
# padding-specific step: remove output projection bias from padded entries for fair comparison
349
for i, entry_length in enumerate(sentence_lengths):
350
compiled_output_padded[i, entry_length:] = 0.0
351
352
print("=== with torch.compile ===")
353
print("nested and padded calculations differ by", (jagged_to_padded(compiled_output_nested, 0.0) - compiled_output_padded).abs().max().item())
354
print("nested tensor multi-head attention takes", compiled_time_nested, "seconds")
355
print("padded tensor multi-head attention takes", compiled_time_padded, "seconds")
356
357
######################################################################
358
# Note that without ``torch.compile``, the overhead of the python subclass nested tensor
359
# can make it slower than the equivalent computation on padded tensors. However, once
360
# ``torch.compile`` is enabled, operating on nested tensors gives a multiple x speedup.
361
# Avoiding wasted computation on padding becomes only more valuable as the percentage
362
# of padding in the batch increases.
363
print(f"Nested speedup: {compiled_time_padded / compiled_time_nested:.3f}")
364
365
######################################################################
366
# Conclusion
367
# ----------
368
# In this tutorial, we have learned how to perform basic operations with nested tensors and
369
# how implement multi-head attention for transformers in a way that avoids computation on padding.
370
# For more information, check out the docs for the
371
# `torch.nested <https://pytorch.org/docs/stable/nested.html>`__ namespace.
372
373