CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!
CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!
Path: blob/main/prototype_source/nestedtensor.py
Views: 494
"""12Getting Started with Nested Tensors3===============================================================45Nested tensors generalize the shape of regular dense tensors, allowing for representation6of ragged-sized data.78* for a regular tensor, each dimension is regular and has a size910* for a nested tensor, not all dimensions have regular sizes; some of them are ragged1112Nested tensors are a natural solution for representing sequential data within various domains:1314* in NLP, sentences can have variable lengths, so a batch of sentences forms a nested tensor1516* in CV, images can have variable shapes, so a batch of images forms a nested tensor1718In this tutorial, we will demonstrate basic usage of nested tensors and motivate their usefulness19for operating on sequential data of varying lengths with a real-world example. In particular,20they are invaluable for building transformers that can efficiently operate on ragged sequential21inputs. Below, we present an implementation of multi-head attention using nested tensors that,22combined usage of ``torch.compile``, out-performs operating naively on tensors with padding.2324Nested tensors are currently a prototype feature and are subject to change.25"""2627import numpy as np28import timeit29import torch30import torch.nn.functional as F3132from torch import nn3334torch.manual_seed(1)35np.random.seed(1)3637device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')3839######################################################################40# Nested tensor initialization41# ----------------------------42#43# From the Python frontend, a nested tensor can be created from a list of tensors.44# We denote nt[i] as the ith tensor component of a nestedtensor.45nt = torch.nested.nested_tensor([torch.arange(12).reshape(462, 6), torch.arange(18).reshape(3, 6)], dtype=torch.float, device=device)47print(f"{nt=}")4849######################################################################50# By padding every underlying tensor to the same shape,51# a nestedtensor can be converted to a regular tensor.52padded_out_tensor = torch.nested.to_padded_tensor(nt, padding=0.0)53print(f"{padded_out_tensor=}")5455######################################################################56# All tensors posses an attribute for determining if they are nested;57print(f"nt is nested: {nt.is_nested}")58print(f"padded_out_tensor is nested: {padded_out_tensor.is_nested}")5960######################################################################61# It is common to construct nestedtensors from batches of irregularly shaped tensors.62# i.e. dimension 0 is assumed to be the batch dimension.63# Indexing dimension 0 gives back the first underlying tensor component.64print("First underlying tensor component:", nt[0], sep='\n')65print("last column of 2nd underlying tensor component:", nt[1, :, -1], sep='\n')6667# When indexing a nestedtensor's 0th dimension, the result is a regular tensor.68print(f"First underlying tensor component is nested: {nt[0].is_nested}")6970######################################################################71# An important note is that slicing in dimension 0 has not been supported yet.72# Which means it not currently possible to construct a view that combines the underlying73# tensor components.7475######################################################################76# Nested Tensor Operations77# ------------------------78#79# As each operation must be explicitly implemented for nestedtensors,80# operation coverage for nestedtensors is currently narrower than that of regular tensors.81# For now, only basic operations such as index, dropout, softmax, transpose, reshape, linear, bmm are covered.82# However, coverage is being expanded.83# If you need certain operations, please file an `issue <https://github.com/pytorch/pytorch>`__84# to help us prioritize coverage.85#86# **reshape**87#88# The reshape op is for changing the shape of a tensor.89# Its full semantics for regular tensors can be found90# `here <https://pytorch.org/docs/stable/generated/torch.reshape.html>`__.91# For regular tensors, when specifying the new shape,92# a single dimension may be -1, in which case it is inferred93# from the remaining dimensions and the number of elements.94#95# The semantics for nestedtensors are similar, except that -1 no longer infers.96# Instead, it inherits the old size (here 2 for ``nt[0]`` and 3 for ``nt[1]``).97# -1 is the only legal size to specify for a jagged dimension.98nt_reshaped = nt.reshape(2, -1, 2, 3)99print(f"{nt_reshaped=}")100101######################################################################102# **transpose**103#104# The transpose op is for swapping two dimensions of a tensor.105# Its full semantics can be found106# `here <https://pytorch.org/docs/stable/generated/torch.transpose.html>`__.107# Note that for nestedtensors dimension 0 is special;108# it is assumed to be the batch dimension,109# so transposes involving nestedtensor dimension 0 are not supported.110nt_transposed = nt_reshaped.transpose(1, 2)111print(f"{nt_transposed=}")112113######################################################################114# **others**115#116# Other operations have the same semantics as for regular tensors.117# Applying the operation on a nestedtensor is equivalent to118# applying the operation to the underlying tensor components,119# with the result being a nestedtensor as well.120nt_mm = torch.nested.nested_tensor([torch.randn((2, 3, 4)), torch.randn((2, 3, 5))], device=device)121nt3 = torch.matmul(nt_transposed, nt_mm)122print(f"Result of Matmul:\n {nt3}")123124nt4 = F.dropout(nt3, 0.1)125print(f"Result of Dropout:\n {nt4}")126127nt5 = F.softmax(nt4, -1)128print(f"Result of Softmax:\n {nt5}")129130######################################################################131# Why Nested Tensor132# -----------------133#134135######################################################################136# When data is sequential, it is often the case that each sample has a different length.137# For example, in a batch of sentences, each sentence has a different number of words.138# A common technique for handling varying sequences is to manually pad each data tensor139# to the same shape in order to form a batch.140# For example, we have 2 sentences with different lengths and a vocabulary141# In order to represent his as single tensor we pad with 0 to the max length in the batch.142sentences = [["goodbye", "padding"],143["embrace", "nested", "tensor"]]144vocabulary = {"goodbye": 1.0, "padding": 2.0,145"embrace": 3.0, "nested": 4.0, "tensor": 5.0}146padded_sentences = torch.tensor([[1.0, 2.0, 0.0],147[3.0, 4.0, 5.0]])148nested_sentences = torch.nested.nested_tensor([torch.tensor([1.0, 2.0]),149torch.tensor([3.0, 4.0, 5.0])])150print(f"{padded_sentences=}")151print(f"{nested_sentences=}")152153######################################################################154# This technique of padding a batch of data to its max length is not optimal.155# The padded data is not needed for computation and wastes memory by allocating156# larger tensors than necessary.157# Further, not all operations have the same semnatics when applied to padded data.158# For matrix multiplications in order to ignore the padded entries, one needs to pad159# with 0 while for softmax one has to pad with -inf to ignore specific entries.160# The primary objective of nested tensor is to facilitate operations on ragged161# data using the standard PyTorch tensor UX, thereby eliminating the need162# for inefficient and complex padding and masking.163padded_sentences_for_softmax = torch.tensor([[1.0, 2.0, float("-inf")],164[3.0, 4.0, 5.0]])165print(F.softmax(padded_sentences_for_softmax, -1))166print(F.softmax(nested_sentences, -1))167168######################################################################169# Let us take a look at a practical example: the multi-head attention component170# utilized in `Transformers <https://arxiv.org/pdf/1706.03762.pdf>`__.171# We can implement this in such a way that it can operate on either padded172# or nested tensors.173class MultiHeadAttention(nn.Module):174"""175Computes multi-head attention. Supports nested or padded tensors.176177Args:178E_q (int): Size of embedding dim for query179E_k (int): Size of embedding dim for key180E_v (int): Size of embedding dim for value181E_total (int): Total embedding dim of combined heads post input projection. Each head182has dim E_total // nheads183nheads (int): Number of heads184dropout_p (float, optional): Dropout probability. Default: 0.0185"""186def __init__(self, E_q: int, E_k: int, E_v: int, E_total: int,187nheads: int, dropout_p: float = 0.0):188super().__init__()189self.nheads = nheads190self.dropout_p = dropout_p191self.query_proj = nn.Linear(E_q, E_total)192self.key_proj = nn.Linear(E_k, E_total)193self.value_proj = nn.Linear(E_v, E_total)194E_out = E_q195self.out_proj = nn.Linear(E_total, E_out)196assert E_total % nheads == 0, "Embedding dim is not divisible by nheads"197self.E_head = E_total // nheads198199def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:200"""201Forward pass; runs the following process:2021. Apply input projection2032. Split heads and prepare for SDPA2043. Run SDPA2054. Apply output projection206207Args:208query (torch.Tensor): query of shape (N, L_t, E_q)209key (torch.Tensor): key of shape (N, L_s, E_k)210value (torch.Tensor): value of shape (N, L_s, E_v)211212Returns:213attn_output (torch.Tensor): output of shape (N, L_t, E_q)214"""215# Step 1. Apply input projection216# TODO: demonstrate packed projection217query = self.query_proj(query)218key = self.key_proj(key)219value = self.value_proj(value)220221# Step 2. Split heads and prepare for SDPA222# reshape query, key, value to separate by head223# (N, L_t, E_total) -> (N, L_t, nheads, E_head) -> (N, nheads, L_t, E_head)224query = query.unflatten(-1, [self.nheads, self.E_head]).transpose(1, 2)225# (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head)226key = key.unflatten(-1, [self.nheads, self.E_head]).transpose(1, 2)227# (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head)228value = value.unflatten(-1, [self.nheads, self.E_head]).transpose(1, 2)229230# Step 3. Run SDPA231# (N, nheads, L_t, E_head)232attn_output = F.scaled_dot_product_attention(233query, key, value, dropout_p=dropout_p, is_causal=True)234# (N, nheads, L_t, E_head) -> (N, L_t, nheads, E_head) -> (N, L_t, E_total)235attn_output = attn_output.transpose(1, 2).flatten(-2)236237# Step 4. Apply output projection238# (N, L_t, E_total) -> (N, L_t, E_out)239attn_output = self.out_proj(attn_output)240241return attn_output242243######################################################################244# set hyperparameters following `the Transformer paper <https://arxiv.org/pdf/1706.03762.pdf>`__245N = 512246E_q, E_k, E_v, E_total = 512, 512, 512, 512247E_out = E_q248nheads = 8249250######################################################################251# except for dropout probability: set to 0 for correctness check252dropout_p = 0.0253254######################################################################255# Let us generate some realistic fake data from Zipf's law.256def zipf_sentence_lengths(alpha: float, batch_size: int) -> torch.Tensor:257# generate fake corpus by unigram Zipf distribution258# from wikitext-2 corpus, we get rank "." = 3, "!" = 386, "?" = 858259sentence_lengths = np.empty(batch_size, dtype=int)260for ibatch in range(batch_size):261sentence_lengths[ibatch] = 1262word = np.random.zipf(alpha)263while word != 3 and word != 386 and word != 858:264sentence_lengths[ibatch] += 1265word = np.random.zipf(alpha)266return torch.tensor(sentence_lengths)267268######################################################################269# Create nested tensor batch inputs270def gen_batch(N, E_q, E_k, E_v, device):271# generate semi-realistic data using Zipf distribution for sentence lengths272sentence_lengths = zipf_sentence_lengths(alpha=1.2, batch_size=N)273274# Note: the torch.jagged layout is a nested tensor layout that supports a single ragged275# dimension and works with torch.compile. The batch items each have shape (B, S*, D)276# where B = batch size, S* = ragged sequence length, and D = embedding dimension.277query = torch.nested.nested_tensor([278torch.randn(l.item(), E_q, device=device)279for l in sentence_lengths280], layout=torch.jagged)281282key = torch.nested.nested_tensor([283torch.randn(s.item(), E_k, device=device)284for s in sentence_lengths285], layout=torch.jagged)286287value = torch.nested.nested_tensor([288torch.randn(s.item(), E_v, device=device)289for s in sentence_lengths290], layout=torch.jagged)291292return query, key, value, sentence_lengths293294query, key, value, sentence_lengths = gen_batch(N, E_q, E_k, E_v, device)295296######################################################################297# Generate padded forms of query, key, value for comparison298def jagged_to_padded(jt, padding_val):299# TODO: do jagged -> padded directly when this is supported300return torch.nested.to_padded_tensor(301torch.nested.nested_tensor(list(jt.unbind())),302padding_val)303304padded_query, padded_key, padded_value = (305jagged_to_padded(t, 0.0) for t in (query, key, value)306)307308######################################################################309# Construct the model310mha = MultiHeadAttention(E_q, E_k, E_v, E_total, nheads, dropout_p).to(device=device)311312######################################################################313# Check correctness and performance314def benchmark(func, *args, **kwargs):315torch.cuda.synchronize()316begin = timeit.default_timer()317output = func(*args, **kwargs)318torch.cuda.synchronize()319end = timeit.default_timer()320return output, (end - begin)321322output_nested, time_nested = benchmark(mha, query, key, value)323output_padded, time_padded = benchmark(mha, padded_query, padded_key, padded_value)324325# padding-specific step: remove output projection bias from padded entries for fair comparison326for i, entry_length in enumerate(sentence_lengths):327output_padded[i, entry_length:] = 0.0328329print("=== without torch.compile ===")330print("nested and padded calculations differ by", (jagged_to_padded(output_nested, 0.0) - output_padded).abs().max().item())331print("nested tensor multi-head attention takes", time_nested, "seconds")332print("padded tensor multi-head attention takes", time_padded, "seconds")333334# warm up compile first...335compiled_mha = torch.compile(mha)336compiled_mha(query, key, value)337# ...now benchmark338compiled_output_nested, compiled_time_nested = benchmark(339compiled_mha, query, key, value)340341# warm up compile first...342compiled_mha(padded_query, padded_key, padded_value)343# ...now benchmark344compiled_output_padded, compiled_time_padded = benchmark(345compiled_mha, padded_query, padded_key, padded_value)346347# padding-specific step: remove output projection bias from padded entries for fair comparison348for i, entry_length in enumerate(sentence_lengths):349compiled_output_padded[i, entry_length:] = 0.0350351print("=== with torch.compile ===")352print("nested and padded calculations differ by", (jagged_to_padded(compiled_output_nested, 0.0) - compiled_output_padded).abs().max().item())353print("nested tensor multi-head attention takes", compiled_time_nested, "seconds")354print("padded tensor multi-head attention takes", compiled_time_padded, "seconds")355356######################################################################357# Note that without ``torch.compile``, the overhead of the python subclass nested tensor358# can make it slower than the equivalent computation on padded tensors. However, once359# ``torch.compile`` is enabled, operating on nested tensors gives a multiple x speedup.360# Avoiding wasted computation on padding becomes only more valuable as the percentage361# of padding in the batch increases.362print(f"Nested speedup: {compiled_time_padded / compiled_time_nested:.3f}")363364######################################################################365# Conclusion366# ----------367# In this tutorial, we have learned how to perform basic operations with nested tensors and368# how implement multi-head attention for transformers in a way that avoids computation on padding.369# For more information, check out the docs for the370# `torch.nested <https://pytorch.org/docs/stable/nested.html>`__ namespace.371372373