Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/master/Natural Language Processing with Attention Models/Week 4 - Chatbot/C4_W4_Ungraded_Lab_Revnet.ipynb
Views: 13373
Putting the "Re" in Reformer: Ungraded Lab
This ungraded lab will explore Reversible Residual Networks. You will use these networks in this week's assignment that utilizes the Reformer model. It is based on on the Transformer model you already know, but with two unique features.
Locality Sensitive Hashing (LSH) Attention to reduce the compute cost of the dot product attention and
Reversible Residual Networks (RevNets) organization to reduce the storage requirements when doing backpropagation in training.
In this ungraded lab we'll start with a quick review of Residual Networks and their implementation in Trax. Then we will discuss the Revnet architecture and its use in Reformer.
Outline
Part 1.0 Residual Networks
Deep Residual Networks (Resnets) were introduced to improve convergence in deep networks. Residual Networks introduce a shortcut connection around one or more layers in a deep network as shown in the diagram below from the original paper.
The Trax documentation describes an implementation of Resnets using branch
. We'll explore that here by implementing a simple resnet built from simple function based layers. Specifically, we'll build a 4 layer network based on two functions, 'F' and 'G'.
Part 1.1 Branch
Trax branch
figures prominently in the residual network layer so we will first examine it. You can see from the figure above that we will need a function that will copy an input and send it down multiple paths. This is accomplished with a branch layer, one of the Trax 'combinators'. Branch is a combinator that applies a list of layers in parallel to copies of inputs. Lets try it out! First we will need some layers to play with. Let's build some from functions.
Trax uses the concept of a 'stack' to transfer data between layers. For Branch, for each of its layer arguments, it copies the n_in
inputs from the stack and provides them to the layer, tracking the max_n_in, or the largest n_in required. It then pops the max_n_in elements from the stack.
Each layer in the input list copies as many inputs from the stack as it needs, and their outputs are successively combined on stack. Put another way, each element of the branch can have differing numbers of inputs and outputs. Let's try a more complex example.
In this case, the number if inputs being copied from the stack varies with the layer
Branch has a special feature to support Residual Network. If an argument is 'None', it will pull the top of stack and push it (at its location in the sequence) onto the output stack
Part 1.2 Residual Model
OK, your turn. Write a function 'MyResidual', that uses tl.Branch
and tl.Add
to build a residual layer. If you are curious about the Trax implementation, you can see the code here.
Expected Result (array([3]), 'n', 'm')
Great! Now, let's build the 4 layer residual Network in Figure 2. You can use MyResidual
, or if you prefer, the tl.Residual in Trax, or a combination!
Expected Results (array([1089]), 'n', 'm')
Part 2.0 Reversible Residual Networks
The Reformer utilized RevNets to reduce the storage requirements for performing backpropagation.
One thing to note is that the forward functions produced by two networks are similar, but they are not equivalent. Note for example the asymmetry in the output equations after two stages of operation.
Part 2.1 Trax Reversible Layers
Let's take a look at how this is used in the Reformer.
Eliminating some of the detail, we can see the structure of the network.
We'll review the Trax layers used to implement the Reversible section of the Reformer. First we can note that not all of the reformer is reversible. Only the section in the ReversibleSerial layer is reversible. In a large Reformer model, that section is repeated many times making up the majority of the model.
The implementation starts by duplicating the input to allow the two paths that are part of the reversible residual organization with Dup. Note that this is accomplished by copying the top of stack and pushing two copies of it onto the stack. This then feeds into the ReversibleHalfResidual layer which we'll review in more detail below. This is followed by ReversibleSwap. As the name implies, this performs a swap, in this case, the two topmost entries in the stack. This pattern is repeated until we reach the end of the ReversibleSerial section. At that point, the topmost 2 entries of the stack represent the two paths through the network. These are concatenated and pushed onto the stack. The result is an entry that is twice the size of the non-reversible version.
Let's look more closely at the ReversibleHalfResidual. This layer is responsible for executing the layer or layers provided as arguments and adding the output of those layers, the 'residual', to the top of the stack. Below is the 'forward' routine which implements this.
Unlike the previous residual function, the value that is added is from the second path rather than the input to the set of sublayers in this layer. Note that the Layers called by the ReversibleHalfResidual forward function are not modified to support reverse functionality. This layer provides them a 'normal' view of the stack and takes care of reverse operation.
Let's try out some of these layers! We'll start with the ones that just operate on the stack, Dup() and Swap().
You are no doubt wondering "How is ReversibleSwap different from Swap?". Good question! Lets look:
Let's try ReversibleHalfResidual, First we'll need some layers..
Just a note about ReversibleHalfResidual. As this is written, it resides in the Reformer model and is a layer. It is invoked a bit differently that other layers. Rather than tl.XYZ, it is just ReversibleHalfResidual(layers..) as shown below. This may change in the future.
Notice the output: (DeviceArray([3], dtype=int32), array([1])). The first value, (DeviceArray([3], dtype=int32) is the output of the "Fl" layer and has been converted to a 'Jax' DeviceArray. The second array([1]) is just passed through (recall the diagram of ReversibleHalfResidual above).
The final layer we need is the ReversibleSerial Layer. This is the reversible equivalent of the Serial layer and is used in the same manner to build a sequence of layers.
Expected Output
Expected Result DeviceArray([ 65, 681], dtype=int32)
OK, now you have had a chance to try all the 'Reversible' functions in Trax. On to the Assignment!