Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/master/Natural Language Processing with Attention Models/Week 2 - Text Summarization/C4_W2_lecture_notebook_Transformer_Decoder.ipynb
Views: 13373
The Transformer Decoder: Ungraded Lab Notebook
In this notebook, you'll explore the transformer decoder and how to implement it with Trax.
Background
In the last lecture notebook, you saw how to translate the mathematics of attention into NumPy code. Here, you'll see how multi-head causal attention fits into a GPT-2 transformer decoder, and how to build one with Trax layers. In the assignment notebook, you'll implement causal attention from scratch, but here, you'll exploit the handy-dandy tl.CausalAttention()
layer.
The schematic below illustrates the components and flow of a transformer decoder. Note that while the algorithm diagram flows from the bottom to the top, the overview and subsequent Trax layer codes are top-down.
Imports
Sentence gets embedded, add positional encoding
Embed the words, then create vectors representing each word's position in each sentence = range(max_len)
, where max_len
= )
Multi-head causal attention
The layers and array dimensions involved in multi-head causal attention (which looks at previous words in the input text) are summarized in the figure below:
tl.CausalAttention()
does all of this for you! You might be wondering, though, whether you need to pass in your input text 3 times, since for causal attention, the queries Q, keys K, and values V all come from the same source. Fortunately, tl.CausalAttention()
handles this as well by making use of the tl.Branch()
combinator layer. In general, each branch within a tl.Branch()
layer performs parallel operations on copies of the layer's inputs. For causal attention, each branch (representing Q, K, and V) applies a linear transformation (i.e. a dense layer without a subsequent activation) to its copy of the input, then splits that result into heads. You can see the syntax for this in the screenshot from the trax.layers.attention.py
source code below:
Feed-forward layer
Typically ends with a ReLU activation, but we'll leave open the possibility of a different activation
Most of the parameters are here
Decoder block
Here, we return a list containing two residual blocks. The first wraps around the causal attention layer, whose inputs are normalized and to which we apply dropout regulation. The second wraps around the feed-forward layer. You may notice that the second call to tl.Residual()
doesn't call a normalization layer before calling the feed-forward layer. This is because the normalization layer is included in the feed-forward layer.
The transformer decoder: putting it all together
A.k.a. repeat N times, dense layer and softmax for output
Concluding remarks
In this week's assignment, you'll see how to train a transformer decoder on the cnn_dailymail dataset, available from TensorFlow Datasets (part of TensorFlow Data Services). Because training such a model from scratch is time-intensive, you'll use a pre-trained model to summarize documents later in the assignment. Due to time and storage concerns, we will also not train the decoder on a different summarization dataset in this lab. If you have the time and space, we encourage you to explore the other summarization datasets at TensorFlow Datasets. Which of them might suit your purposes better than the cnn_dailymail
dataset? Where else can you find datasets for text summarization models?