Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/master/Natural Language Processing with Attention Models/Week 4 - Chatbot/C4_W4_Ungraded_Lab_Reformer_LSH.ipynb
Views: 13373
Reformer Efficient Attention: Ungraded Lab
The videos describe two 'reforms' made to the Transformer to make it more memory and compute efficient. The Reversible Layers reduce memory and Locality Sensitive Hashing(LSH) reduces the cost of the Dot Product attention for large input sizes. This ungraded lab will look more closely at LSH and how it is used in the Reformer model.
Specifically, the notebook has 3 goals
review dot-product self attention for reference
examine LSH based self attention
extend our understanding and familiarity with Trax infrastructure
Outline
Part 1.0 Trax Efficient Attention classes
Trax is similar to other popular NN development platforms such as Keras (now integrated into Tensorflow) and Pytorch in that it uses 'layers' as a useful level of abstraction. Layers are often represented as classes. We're going to improve our understanding of Trax by locally extending the classes used in the attention layers. We will extend only the 'forward' functions and utilize the existing attention layers as parent classes. The original code can be found at github:trax/layers/Research/Efficient_attention. This link references release 1.3.4 but note that this is under the 'research' directory as this is an area of active research. When accessing the code on Github for review on this assignment, be sure you select the 1.3.4 release tag, the master copy may have new changes.:
While Trax uses classes liberally, we have not built many classes in the course so far. Let's spend a few moments reviewing the classes we will be using.
Starting on the right in the diagram below you see EfficientAttentionBase. The parent to this class is the base.layer which has the routines used by all layers. EfficientAttentionBase leaves many routines to be overridden by child classes - but it has an important feature in the Forward routine. It supports a use_reference_code
capability that selects implementations that limit some of the complexities to provide a more easily understood version of the algorithms. In particular, it implements a nested loop that treats each 'example, head' independently. This simplifies our work as we need only worry about matrix operations on one 'example, head' at a time. This loop calls forward_unbatched, which is the child process that we will be overriding.
On the top left are the outlines of the two child classes we will be using. The SelfAttention layer is a 'traditional' implementation of the dot product attention. We will be implementing the forward_unbatched version of this to highlight the differences between this and the LSH implementation.
Below that is the LSHSelfAttention. This is the routine used in the Reformer architecture. We will override the forward_unbatched section of this and some of the utility functions it uses to explore its implementation in more detail.
The code we will be working with is from the Trax source, and as such has implementation details that will make it a bit harder to follow. However, it will allow use of the results along with the rest of the Trax infrastructure. I will try to briefly describe these as they arise. The Trax documentation can also be referenced.
Part 1.2 Trax Details
The goal in this notebook is to override a few routines in the Trax classes with our own versions. To maintain their functionality in a full Trax environment, many of the details we might ignore in example version of routines will be maintained in this code. Here are some of the considerations that may impact our code:
Trax operates with multiple back-end libraries, we will see special cases that will utilize unique features.
'Fancy' numpy indexing is not supported in all backend environments and must be emulated in other ways.
Some operations don't have gradients for backprop and must be ignored or include forced re-evaluation.
Here are some of the functions we may see:
Abstracted as
fastmath
, Trax supports multiple backend's such as Jax and Tensorflow2tie_in: Some non-numeric operations must be invoked during backpropagation. Normally, the gradient compute graph would determine invocation but these functions are not included. To force re-evaluation, they are 'tied' to other numeric operations using tie_in.
stop_gradient: Some operations are intentionally excluded from backprop gradient calculations by setting their gradients to zero.
Below we will execute
from trax.fastmath import numpy as np
, this uses accelerated forms of numpy functions. This is, however a subset of numpy
Part 2 Full Dot-Product Self Attention
Part 2.1 Description
The diagram above shows many of the familiar data structures and operations related to attention and describes the routines in which they are implemented. We will start by working on our_simple_attend or our simpler version of the original attend function. We will review the steps in performing dot-product attention with more focus on the details of the operations and their significance. This is useful when comparing to LSH attention. Note we will be discussing a single example/head unless otherwise specified.
The attend function receives Query and Key. As a reminder, they are produced by a matrix multiply of all the inputs with a single set of weights. We will describe the inputs as embeddings assuming an NLP application, however, this is not required. This matrix multiply very much like a convolutional network where a set of weights (a filter) slide across the input vectors leaving behind a map of the similarity of the input to the filter. In this case, the filters are the weight matrices and . The resulting maps are Q and K. Q and K have the dimensions of (n_seq, n_q) where n_seq is the number input embeddings and n_q or n_k is the selected size of the Q or K vectors. Note the shading of Q and K, this reflects the fact that each entry is associated with a particular input embedding. You will note later in the code that K is optional. Apparently, similar results can be achieved using Query alone saving the compute and storage associated with K. In that case, the dot-product in attend is matmul(q,q). Note the resulting dot-product (Dot) entries describe a complete (n_seq,n_seq) map of the similarity of all entries of q vs all entries of k. This is reflected in the notation in the dot-product boxes of , representing word_n, word_m. Note that each row of Dot describes the relationship of an input embedding, say , with every other input.
In some applications some values are masked. This can be used, for example to exclude results that occur later in time (causal) or to mask padding or other inputs.
The routine below mask_self_attention implements a flexible masking capability. The masking is controlled by the information in q_info and kv_info.
A SoftMax is applied per row of the Dot matrix to scale the values in the row between 0 and 1.
This code uses a separable form of the softmax calculation. Recall the softmax: This can be alternately implemented as: The work below will maintain a copy of the logsumexp allowing the softmax to be completed in sections. You will see how this is useful later in the LSHSelfAttention class. We'll create a routine to implement that here with the addition of a passthrough. The matrix operations we will be working on below are easier to follow if we can maintain integer values. So, for tests, we will skip the softmax in some cases.
Let's check our implementation.
The purpose of the dot-product is to 'focus attention' on some of the inputs. Dot now has entries appropriately scaled to enhance some values and reduce others. These are now applied to the entries.
is of size (n_seq,n_v). Note the shading in the diagram. This is to draw attention to the operation of the matrix multiplication. This is detailed below.
is formed by a matrix multiply of the input embedding with the weight matrix whose values were set by backpropagation. The row entries of are then related to the corresponding input embedding. The matrix multiply weights first column of V, representing a section of each of the input embeddings, with the first row of Dot, representing the similarity of and each word of the input embedding and deposits the value in
Part 2.2 our_simple_attend
In this section we'll work on an implementation of attend whose operations you can see in figure 3. It is a slightly simplified version of the routine in efficient_attention.py. We will fill in a few lines of code. The main goal is to become familiar with the routine. You have implemented similar functionality in a previous assignment.
Instructions Step 1: matrix multiply (np.matmul) q and the k 'transpose' kr. Step 2: use our_softmax() to perform a softmax on masked output of the dot product, dots. Step 3: matrix multiply (np.matmul) dots and v.
completed code for reference
This notebook is ungraded, so for reference, the completed code follows:Part 2.3 Class OurSelfAttention
Here we create our own self attention layer by creating a class OurSelfAttention
. The parent class will be the tl.SelfAttention layer in Trax. We will only override the forward_unbatched
routine. We're not asking you to modify anything in this routine. There are some comments to draw your attention to a few lines.
Part 3.0 Trax LSHSelfAttention
Part 3.1 Description
The larger the matrix multiply in the previous section is, the more context can be taken into account when making the next decision. However, the self attention dot product grows as the size of the input squared. For example, if one wished to have an input size of 1024, that would result in or over a million dot products for each head! As a result, there has been significant research related to reducing the compute requirements. One such approach is Locality Sensitive Hashing(LSH) Self Attention.
You may recall, earlier in the course you utilized LSH to find similar tweets without resorting to calculating cosine similarity for each pair of embeddings. We will use a similar approach here. It may be best described with an example.
LSH Self attention uses Queries only, no Keys. Attention then generates a metric of the similarity of each value of Q relative to all the other values in Q. An earlier assignment demonstrated that values which hash to the same bucket are likely to be similar. Further, multiple random hashes can improve the chances of finding entries which are similar. This is the approach taken here, though the hash is implemented a bit differently. The values of Q are hashed into buckets using a randomly generated set of hash vectors. Multiple sets of hash vectors are used, generating multiple hash tables. In the figure above, we have 3 hash tables with 4 buckets in each table. Notionally, following the hash, the values of Q have been replicated 3 times and distributed to their appropriate bucket in each of the 3 tables. To find similarity then, one generates dot-products only between members of the buckets. The result of this operation provides information on which entries are similar. As the operation has been distributed over multiple hash tables, these results need to be combined to form a complete picture and this can be used to generate a reduced dot-product attention array. Its clear that because we do not do a compare of every value vs every other value, the size of Dots will be reduced.
The challenge in this approach is getting it to operate efficiently. You may recall from the earlier assignments the buckets were lists of entries and had varying length. This will operate poorly on a vector processing machine such as a GPU or TPU. Ideally, operations are done in large blocks with uniform sizes. While it is straightforward to implement the hash algorithm this way, it is challenging to managed buckets and variable sized dot-products. This will be discussed further below. For now, we will examine and implement the hash function.
our_hash_vectors, is a reimplementation of Trax hashvector. It takes in an array of vectors, hashes the entries and returns and array assigning each input vector to n_hash buckets. Hashing is described as creating random rotations, see Practical and Optimal LSH for Angular Distance.
Note, in the diagram, sizes relate to our expected input while our_hash_vectors is written assuming a generic input vector
Instructions Step 1 create an array of random normal vectors which will be our hash vectors. Each vector will be hashed into a hash table and into rot_size//2
buckets. We use rot_size//2
to reduce computation. Later in the routine we will form the negative rotations with a simple negation and concatenate to get a full rot_size
number of rotations. * use fastmath.random.normal and create an array of random vectors of shape (vec.shape[-1],n_hashes, rot_size//2)
Step 2 In this step we simply do the matrix multiply. jax
has an accelerated version of einsum. Here we will utilize more conventional routines.
Step 2x * 2a: np.reshape random_rotations into a 2 dimensional array ([-1, n_hashes * (rot_size // 2)]) * 2b: np.dot vecs and random_rotations forming our rotated_vecs * 2c: back to 3 dimension with np.reshape [-1, n_hashes, rot_size//2] * 2d: prepare for concatenating by swapping dimensions np.transpose (1, 0, 2) Step 3 Here we concatenate our rotation vectors getting a fullrot_size number of buckets (note, n_buckets = rotsize) * use np.concatenate, [rotated_vecs, -rotated_vecs], axis=-1 Step 4 This is the exciting step! You have no doubt been wondering how we will turn these vectors into bucket indexes. By performing np.argmax over the rotations for a given entry, you get the index to the best match! We will use this as a bucket index. * np.argmax(...).astype(np.int32); be sure to use the correct axis! Step 5 In this style of hashing, items which land in bucket 0 of hash table 0 are not necessarily similar to those landing in bucket 0 of hash table 1, so we keep them separate. We do this by offsetting the bucket numbers by 'n_buckets'.
add buckets and offsets and reshape into a one dimensional array This will return a 1D array of size n_hashes * vec.shape[0].
Great! Now that we have a hash function, we can work on sorting our buckets and performing our matrix operations. We'll walk through this algorithm in small steps:
sort_buckets - we'll perform the sort
softmax
dotandv - do the matrix math to form the dotproduct and output These routines will demonstrate a simplified version of the algorithm. We won't address masking and variable bucket sizes but will consider how they would be handled.
sort_buckets
At this point, we have called the hash function and were returned the associated buckets. For example, if we started with q[n_seq,n_q]
, with n_hash = 2; n_buckets = 4; n_seq = 8
we might be returned: bucket = [0,1,2,3,0,1,2,3, 4,5,6,7,4,5,6,7]
Note that it is n_hash*n_seq long and that the bucket values for each hash have been offset by n_hash so the numbers do not overlap. Going forward, we going to sort this array of buckets to group together members of the same (hash,bucket) pair.
Instructions Step 1 Our goal is to sort rather than the bucket list, so we will need to track the association of the buckets to their elements in .
using np.arange, create
ticker
, just a sequence of numbers (0..n_hashed * seqlen) associating members of q with their bucket.
Step 2 This step is provided to you as it is a bit difficult to describe. We want to disambiguate elements that map to the same bucket. When a sorting routine encounters a situation where multiple entries have the same value, it can correctly choose any entry to go first. This makes testing ambiguous. This prevents that. We multiply all the buckets by seqlen
and then add ticker % seqlen
Step 3 Here we are! Ready to sort. This is the exciting part.
Utilize fastmath.sort_key_val and sort
buckets_and_t
andticker
.
Step 4 We need to be able to undo the sort at the end to get things back into their correct locations
sort
sticker
andticker
to for the reverse map
Step 5 create our sorted q and sorted v
use np.take and
st
to grab correct values inq
for the sorted values,sq
. Use axis=0.
Use the example code below the routine to check and help debug your results.
Now let's create the dot product attention. We have sorted so that elements that the hash has determined are likely to be similar are adjacent to each other. We now want to perform the dot-product within those limited regions - in 'chunks'.
The example we have been working on is shown above, with sequences of 8, 2 hashes, 4 buckets and, conveniently, the content of Q was such that when sorted, there were 2 entries in each bucket. If we reshape Q into a (8,2,n_q), we can use numpy matmul to perform the operation. Numpy matmul will treat the inputs as a stack of matrices residing in the last two indexes. This will allow us to matrix multiply Q with itself in chunks and later can also be used to perform the matrix multiply with v.
We will perform a softmax on the output of the dot product of Q and Q, but in this case, there is a bit more to the story. Recall the output of the hash had multiple hash tables. We will perform softmax on those separately and then must combine them. This is where the form of softmax we defined at the top of the notebook comes into play. The routines below will utilize the logsumexp values that the our_softmax
routine calculates.
There is a good deal of reshaping to get things into the right formats. The code has many print statements that match the expected values below. You can use those to check your work as you go along. If you don't do a lot of 3-dimensional matrix multiplications in your daily life, it might be worthwhile to open a spare cell and practice a few simple examples to get the hang of it! Here is one to start with:
Instructions Step 1 Reshaping Q
np.reshape
sq
(sorted q) to be 3 dimensions. The middle dimension is the size of the 'chunk' specified bykv_chunk_len
np.swapaxes to perform a 'transpose' on the reshaped
sq
, but only on the last two dimensionnp.matmul the two values.
Step 2
use our_softmax to perform the softmax on the dot product. Don't forget
passthrough
Step 3
np.reshape
sv
. Likesq
, the middle dimension is the size of the 'chunk' specified bykv_chunk_len
np.matmul dotlike and the reshaped
sv
np.reshape so to a two dimensional array with the last dimension stays the same (
so.shape[-1]
)logits
also needs reshaping, we'll do that.
Step 4 Now we can undo the sort.
use np.take and
undo_sort
and axis = 0 to unsort sodo the same with
slogits
.
Step 5 This step combines the results of multiple hashes. Recall, the softmax was only over the values in one hash, this extends it to all the hashes. Read through it, the code is provided. Note this is taking place after the matrix multiply with v while the softmax output is used before the multiply. How does this achieve the correct result?
Great! You have now done examples code for most of the operation that are unique to the LSH version of self-attention. I'm sure at this point you are wondering what happens if the number of entries in a bucket is not evenly distributed the way our example is. It is possible, for example for all of the seqlen
entries to land in one bucket. Further, since the buckets are not aligned, our 'chunks' may be misaligned with the start of the bucket. The implementation addresses this by attending to adjacent chunks as was described in the lecture:
Hopefully, having implemented parts of this, you will appreciate this diagram more fully.
Part 3.5 OurLSHSelfAttention
You can examine the full implementations below. Area's we did not 'attend to' in our implementations above include variable bucket sizes and masking. We will instantiate a layer of the full implementation below. We tried to use the same variable names above to make it easier to decipher the full version. Note that some of the functionality we implemented in our routines is split between attend
and forward_unbatched
. We've inserted our version of hash below, but use the original version of attend
.
Congratuations! you have created a custom layer and have become familiar with LSHSelfAttention.