Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Path: blob/main/beginner_source/nlp/sequence_models_tutorial.py
Views: 713
# -*- coding: utf-8 -*-1r"""2Sequence Models and Long Short-Term Memory Networks3===================================================45At this point, we have seen various feed-forward networks. That is,6there is no state maintained by the network at all. This might not be7the behavior we want. Sequence models are central to NLP: they are8models where there is some sort of dependence through time between your9inputs. The classical example of a sequence model is the Hidden Markov10Model for part-of-speech tagging. Another example is the conditional11random field.1213A recurrent neural network is a network that maintains some kind of14state. For example, its output could be used as part of the next input,15so that information can propagate along as the network passes over the16sequence. In the case of an LSTM, for each element in the sequence,17there is a corresponding *hidden state* :math:`h_t`, which in principle18can contain information from arbitrary points earlier in the sequence.19We can use the hidden state to predict words in a language model,20part-of-speech tags, and a myriad of other things.212223LSTMs in Pytorch24~~~~~~~~~~~~~~~~~2526Before getting to the example, note a few things. Pytorch's LSTM expects27all of its inputs to be 3D tensors. The semantics of the axes of these28tensors is important. The first axis is the sequence itself, the second29indexes instances in the mini-batch, and the third indexes elements of30the input. We haven't discussed mini-batching, so let's just ignore that31and assume we will always have just 1 dimension on the second axis. If32we want to run the sequence model over the sentence "The cow jumped",33our input should look like3435.. math::363738\begin{bmatrix}39\overbrace{q_\text{The}}^\text{row vector} \\40q_\text{cow} \\41q_\text{jumped}42\end{bmatrix}4344Except remember there is an additional 2nd dimension with size 1.4546In addition, you could go through the sequence one at a time, in which47case the 1st axis will have size 1 also.4849Let's see a quick example.50"""5152# Author: Robert Guthrie5354import torch55import torch.nn as nn56import torch.nn.functional as F57import torch.optim as optim5859torch.manual_seed(1)6061######################################################################6263lstm = nn.LSTM(3, 3) # Input dim is 3, output dim is 364inputs = [torch.randn(1, 3) for _ in range(5)] # make a sequence of length 56566# initialize the hidden state.67hidden = (torch.randn(1, 1, 3),68torch.randn(1, 1, 3))69for i in inputs:70# Step through the sequence one element at a time.71# after each step, hidden contains the hidden state.72out, hidden = lstm(i.view(1, 1, -1), hidden)7374# alternatively, we can do the entire sequence all at once.75# the first value returned by LSTM is all of the hidden states throughout76# the sequence. the second is just the most recent hidden state77# (compare the last slice of "out" with "hidden" below, they are the same)78# The reason for this is that:79# "out" will give you access to all hidden states in the sequence80# "hidden" will allow you to continue the sequence and backpropagate,81# by passing it as an argument to the lstm at a later time82# Add the extra 2nd dimension83inputs = torch.cat(inputs).view(len(inputs), 1, -1)84hidden = (torch.randn(1, 1, 3), torch.randn(1, 1, 3)) # clean out hidden state85out, hidden = lstm(inputs, hidden)86print(out)87print(hidden)888990######################################################################91# Example: An LSTM for Part-of-Speech Tagging92# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~93#94# In this section, we will use an LSTM to get part of speech tags. We will95# not use Viterbi or Forward-Backward or anything like that, but as a96# (challenging) exercise to the reader, think about how Viterbi could be97# used after you have seen what is going on. In this example, we also refer98# to embeddings. If you are unfamiliar with embeddings, you can read up99# about them `here <https://pytorch.org/tutorials/beginner/nlp/word_embeddings_tutorial.html>`__.100#101# The model is as follows: let our input sentence be102# :math:`w_1, \dots, w_M`, where :math:`w_i \in V`, our vocab. Also, let103# :math:`T` be our tag set, and :math:`y_i` the tag of word :math:`w_i`.104# Denote our prediction of the tag of word :math:`w_i` by105# :math:`\hat{y}_i`.106#107# This is a structure prediction, model, where our output is a sequence108# :math:`\hat{y}_1, \dots, \hat{y}_M`, where :math:`\hat{y}_i \in T`.109#110# To do the prediction, pass an LSTM over the sentence. Denote the hidden111# state at timestep :math:`i` as :math:`h_i`. Also, assign each tag a112# unique index (like how we had word\_to\_ix in the word embeddings113# section). Then our prediction rule for :math:`\hat{y}_i` is114#115# .. math:: \hat{y}_i = \text{argmax}_j \ (\log \text{Softmax}(Ah_i + b))_j116#117# That is, take the log softmax of the affine map of the hidden state,118# and the predicted tag is the tag that has the maximum value in this119# vector. Note this implies immediately that the dimensionality of the120# target space of :math:`A` is :math:`|T|`.121#122#123# Prepare data:124125def prepare_sequence(seq, to_ix):126idxs = [to_ix[w] for w in seq]127return torch.tensor(idxs, dtype=torch.long)128129130training_data = [131# Tags are: DET - determiner; NN - noun; V - verb132# For example, the word "The" is a determiner133("The dog ate the apple".split(), ["DET", "NN", "V", "DET", "NN"]),134("Everybody read that book".split(), ["NN", "V", "DET", "NN"])135]136word_to_ix = {}137# For each words-list (sentence) and tags-list in each tuple of training_data138for sent, tags in training_data:139for word in sent:140if word not in word_to_ix: # word has not been assigned an index yet141word_to_ix[word] = len(word_to_ix) # Assign each word with a unique index142print(word_to_ix)143tag_to_ix = {"DET": 0, "NN": 1, "V": 2} # Assign each tag with a unique index144145# These will usually be more like 32 or 64 dimensional.146# We will keep them small, so we can see how the weights change as we train.147EMBEDDING_DIM = 6148HIDDEN_DIM = 6149150######################################################################151# Create the model:152153154class LSTMTagger(nn.Module):155156def __init__(self, embedding_dim, hidden_dim, vocab_size, tagset_size):157super(LSTMTagger, self).__init__()158self.hidden_dim = hidden_dim159160self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)161162# The LSTM takes word embeddings as inputs, and outputs hidden states163# with dimensionality hidden_dim.164self.lstm = nn.LSTM(embedding_dim, hidden_dim)165166# The linear layer that maps from hidden state space to tag space167self.hidden2tag = nn.Linear(hidden_dim, tagset_size)168169def forward(self, sentence):170embeds = self.word_embeddings(sentence)171lstm_out, _ = self.lstm(embeds.view(len(sentence), 1, -1))172tag_space = self.hidden2tag(lstm_out.view(len(sentence), -1))173tag_scores = F.log_softmax(tag_space, dim=1)174return tag_scores175176######################################################################177# Train the model:178179180model = LSTMTagger(EMBEDDING_DIM, HIDDEN_DIM, len(word_to_ix), len(tag_to_ix))181loss_function = nn.NLLLoss()182optimizer = optim.SGD(model.parameters(), lr=0.1)183184# See what the scores are before training185# Note that element i,j of the output is the score for tag j for word i.186# Here we don't need to train, so the code is wrapped in torch.no_grad()187with torch.no_grad():188inputs = prepare_sequence(training_data[0][0], word_to_ix)189tag_scores = model(inputs)190print(tag_scores)191192for epoch in range(300): # again, normally you would NOT do 300 epochs, it is toy data193for sentence, tags in training_data:194# Step 1. Remember that Pytorch accumulates gradients.195# We need to clear them out before each instance196model.zero_grad()197198# Step 2. Get our inputs ready for the network, that is, turn them into199# Tensors of word indices.200sentence_in = prepare_sequence(sentence, word_to_ix)201targets = prepare_sequence(tags, tag_to_ix)202203# Step 3. Run our forward pass.204tag_scores = model(sentence_in)205206# Step 4. Compute the loss, gradients, and update the parameters by207# calling optimizer.step()208loss = loss_function(tag_scores, targets)209loss.backward()210optimizer.step()211212# See what the scores are after training213with torch.no_grad():214inputs = prepare_sequence(training_data[0][0], word_to_ix)215tag_scores = model(inputs)216217# The sentence is "the dog ate the apple". i,j corresponds to score for tag j218# for word i. The predicted tag is the maximum scoring tag.219# Here, we can see the predicted sequence below is 0 1 2 0 1220# since 0 is index of the maximum value of row 1,221# 1 is the index of maximum value of row 2, etc.222# Which is DET NOUN VERB DET NOUN, the correct sequence!223print(tag_scores)224225226######################################################################227# Exercise: Augmenting the LSTM part-of-speech tagger with character-level features228# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~229#230# In the example above, each word had an embedding, which served as the231# inputs to our sequence model. Let's augment the word embeddings with a232# representation derived from the characters of the word. We expect that233# this should help significantly, since character-level information like234# affixes have a large bearing on part-of-speech. For example, words with235# the affix *-ly* are almost always tagged as adverbs in English.236#237# To do this, let :math:`c_w` be the character-level representation of238# word :math:`w`. Let :math:`x_w` be the word embedding as before. Then239# the input to our sequence model is the concatenation of :math:`x_w` and240# :math:`c_w`. So if :math:`x_w` has dimension 5, and :math:`c_w`241# dimension 3, then our LSTM should accept an input of dimension 8.242#243# To get the character level representation, do an LSTM over the244# characters of a word, and let :math:`c_w` be the final hidden state of245# this LSTM. Hints:246#247# * There are going to be two LSTM's in your new model.248# The original one that outputs POS tag scores, and the new one that249# outputs a character-level representation of each word.250# * To do a sequence model over characters, you will have to embed characters.251# The character embeddings will be the input to the character LSTM.252#253254255