CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!

GitHub Repository: pytorch/tutorials
Path: blob/main/beginner_source/nlp/sequence_models_tutorial.py
Views: 494
1
# -*- coding: utf-8 -*-
2
r"""
3
Sequence Models and Long Short-Term Memory Networks
4
===================================================
5
6
At this point, we have seen various feed-forward networks. That is,
7
there is no state maintained by the network at all. This might not be
8
the behavior we want. Sequence models are central to NLP: they are
9
models where there is some sort of dependence through time between your
10
inputs. The classical example of a sequence model is the Hidden Markov
11
Model for part-of-speech tagging. Another example is the conditional
12
random field.
13
14
A recurrent neural network is a network that maintains some kind of
15
state. For example, its output could be used as part of the next input,
16
so that information can propagate along as the network passes over the
17
sequence. In the case of an LSTM, for each element in the sequence,
18
there is a corresponding *hidden state* :math:`h_t`, which in principle
19
can contain information from arbitrary points earlier in the sequence.
20
We can use the hidden state to predict words in a language model,
21
part-of-speech tags, and a myriad of other things.
22
23
24
LSTMs in Pytorch
25
~~~~~~~~~~~~~~~~~
26
27
Before getting to the example, note a few things. Pytorch's LSTM expects
28
all of its inputs to be 3D tensors. The semantics of the axes of these
29
tensors is important. The first axis is the sequence itself, the second
30
indexes instances in the mini-batch, and the third indexes elements of
31
the input. We haven't discussed mini-batching, so let's just ignore that
32
and assume we will always have just 1 dimension on the second axis. If
33
we want to run the sequence model over the sentence "The cow jumped",
34
our input should look like
35
36
.. math::
37
38
39
\begin{bmatrix}
40
\overbrace{q_\text{The}}^\text{row vector} \\
41
q_\text{cow} \\
42
q_\text{jumped}
43
\end{bmatrix}
44
45
Except remember there is an additional 2nd dimension with size 1.
46
47
In addition, you could go through the sequence one at a time, in which
48
case the 1st axis will have size 1 also.
49
50
Let's see a quick example.
51
"""
52
53
# Author: Robert Guthrie
54
55
import torch
56
import torch.nn as nn
57
import torch.nn.functional as F
58
import torch.optim as optim
59
60
torch.manual_seed(1)
61
62
######################################################################
63
64
lstm = nn.LSTM(3, 3) # Input dim is 3, output dim is 3
65
inputs = [torch.randn(1, 3) for _ in range(5)] # make a sequence of length 5
66
67
# initialize the hidden state.
68
hidden = (torch.randn(1, 1, 3),
69
torch.randn(1, 1, 3))
70
for i in inputs:
71
# Step through the sequence one element at a time.
72
# after each step, hidden contains the hidden state.
73
out, hidden = lstm(i.view(1, 1, -1), hidden)
74
75
# alternatively, we can do the entire sequence all at once.
76
# the first value returned by LSTM is all of the hidden states throughout
77
# the sequence. the second is just the most recent hidden state
78
# (compare the last slice of "out" with "hidden" below, they are the same)
79
# The reason for this is that:
80
# "out" will give you access to all hidden states in the sequence
81
# "hidden" will allow you to continue the sequence and backpropagate,
82
# by passing it as an argument to the lstm at a later time
83
# Add the extra 2nd dimension
84
inputs = torch.cat(inputs).view(len(inputs), 1, -1)
85
hidden = (torch.randn(1, 1, 3), torch.randn(1, 1, 3)) # clean out hidden state
86
out, hidden = lstm(inputs, hidden)
87
print(out)
88
print(hidden)
89
90
91
######################################################################
92
# Example: An LSTM for Part-of-Speech Tagging
93
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
94
#
95
# In this section, we will use an LSTM to get part of speech tags. We will
96
# not use Viterbi or Forward-Backward or anything like that, but as a
97
# (challenging) exercise to the reader, think about how Viterbi could be
98
# used after you have seen what is going on. In this example, we also refer
99
# to embeddings. If you are unfamiliar with embeddings, you can read up
100
# about them `here <https://pytorch.org/tutorials/beginner/nlp/word_embeddings_tutorial.html>`__.
101
#
102
# The model is as follows: let our input sentence be
103
# :math:`w_1, \dots, w_M`, where :math:`w_i \in V`, our vocab. Also, let
104
# :math:`T` be our tag set, and :math:`y_i` the tag of word :math:`w_i`.
105
# Denote our prediction of the tag of word :math:`w_i` by
106
# :math:`\hat{y}_i`.
107
#
108
# This is a structure prediction, model, where our output is a sequence
109
# :math:`\hat{y}_1, \dots, \hat{y}_M`, where :math:`\hat{y}_i \in T`.
110
#
111
# To do the prediction, pass an LSTM over the sentence. Denote the hidden
112
# state at timestep :math:`i` as :math:`h_i`. Also, assign each tag a
113
# unique index (like how we had word\_to\_ix in the word embeddings
114
# section). Then our prediction rule for :math:`\hat{y}_i` is
115
#
116
# .. math:: \hat{y}_i = \text{argmax}_j \ (\log \text{Softmax}(Ah_i + b))_j
117
#
118
# That is, take the log softmax of the affine map of the hidden state,
119
# and the predicted tag is the tag that has the maximum value in this
120
# vector. Note this implies immediately that the dimensionality of the
121
# target space of :math:`A` is :math:`|T|`.
122
#
123
#
124
# Prepare data:
125
126
def prepare_sequence(seq, to_ix):
127
idxs = [to_ix[w] for w in seq]
128
return torch.tensor(idxs, dtype=torch.long)
129
130
131
training_data = [
132
# Tags are: DET - determiner; NN - noun; V - verb
133
# For example, the word "The" is a determiner
134
("The dog ate the apple".split(), ["DET", "NN", "V", "DET", "NN"]),
135
("Everybody read that book".split(), ["NN", "V", "DET", "NN"])
136
]
137
word_to_ix = {}
138
# For each words-list (sentence) and tags-list in each tuple of training_data
139
for sent, tags in training_data:
140
for word in sent:
141
if word not in word_to_ix: # word has not been assigned an index yet
142
word_to_ix[word] = len(word_to_ix) # Assign each word with a unique index
143
print(word_to_ix)
144
tag_to_ix = {"DET": 0, "NN": 1, "V": 2} # Assign each tag with a unique index
145
146
# These will usually be more like 32 or 64 dimensional.
147
# We will keep them small, so we can see how the weights change as we train.
148
EMBEDDING_DIM = 6
149
HIDDEN_DIM = 6
150
151
######################################################################
152
# Create the model:
153
154
155
class LSTMTagger(nn.Module):
156
157
def __init__(self, embedding_dim, hidden_dim, vocab_size, tagset_size):
158
super(LSTMTagger, self).__init__()
159
self.hidden_dim = hidden_dim
160
161
self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)
162
163
# The LSTM takes word embeddings as inputs, and outputs hidden states
164
# with dimensionality hidden_dim.
165
self.lstm = nn.LSTM(embedding_dim, hidden_dim)
166
167
# The linear layer that maps from hidden state space to tag space
168
self.hidden2tag = nn.Linear(hidden_dim, tagset_size)
169
170
def forward(self, sentence):
171
embeds = self.word_embeddings(sentence)
172
lstm_out, _ = self.lstm(embeds.view(len(sentence), 1, -1))
173
tag_space = self.hidden2tag(lstm_out.view(len(sentence), -1))
174
tag_scores = F.log_softmax(tag_space, dim=1)
175
return tag_scores
176
177
######################################################################
178
# Train the model:
179
180
181
model = LSTMTagger(EMBEDDING_DIM, HIDDEN_DIM, len(word_to_ix), len(tag_to_ix))
182
loss_function = nn.NLLLoss()
183
optimizer = optim.SGD(model.parameters(), lr=0.1)
184
185
# See what the scores are before training
186
# Note that element i,j of the output is the score for tag j for word i.
187
# Here we don't need to train, so the code is wrapped in torch.no_grad()
188
with torch.no_grad():
189
inputs = prepare_sequence(training_data[0][0], word_to_ix)
190
tag_scores = model(inputs)
191
print(tag_scores)
192
193
for epoch in range(300): # again, normally you would NOT do 300 epochs, it is toy data
194
for sentence, tags in training_data:
195
# Step 1. Remember that Pytorch accumulates gradients.
196
# We need to clear them out before each instance
197
model.zero_grad()
198
199
# Step 2. Get our inputs ready for the network, that is, turn them into
200
# Tensors of word indices.
201
sentence_in = prepare_sequence(sentence, word_to_ix)
202
targets = prepare_sequence(tags, tag_to_ix)
203
204
# Step 3. Run our forward pass.
205
tag_scores = model(sentence_in)
206
207
# Step 4. Compute the loss, gradients, and update the parameters by
208
# calling optimizer.step()
209
loss = loss_function(tag_scores, targets)
210
loss.backward()
211
optimizer.step()
212
213
# See what the scores are after training
214
with torch.no_grad():
215
inputs = prepare_sequence(training_data[0][0], word_to_ix)
216
tag_scores = model(inputs)
217
218
# The sentence is "the dog ate the apple". i,j corresponds to score for tag j
219
# for word i. The predicted tag is the maximum scoring tag.
220
# Here, we can see the predicted sequence below is 0 1 2 0 1
221
# since 0 is index of the maximum value of row 1,
222
# 1 is the index of maximum value of row 2, etc.
223
# Which is DET NOUN VERB DET NOUN, the correct sequence!
224
print(tag_scores)
225
226
227
######################################################################
228
# Exercise: Augmenting the LSTM part-of-speech tagger with character-level features
229
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
230
#
231
# In the example above, each word had an embedding, which served as the
232
# inputs to our sequence model. Let's augment the word embeddings with a
233
# representation derived from the characters of the word. We expect that
234
# this should help significantly, since character-level information like
235
# affixes have a large bearing on part-of-speech. For example, words with
236
# the affix *-ly* are almost always tagged as adverbs in English.
237
#
238
# To do this, let :math:`c_w` be the character-level representation of
239
# word :math:`w`. Let :math:`x_w` be the word embedding as before. Then
240
# the input to our sequence model is the concatenation of :math:`x_w` and
241
# :math:`c_w`. So if :math:`x_w` has dimension 5, and :math:`c_w`
242
# dimension 3, then our LSTM should accept an input of dimension 8.
243
#
244
# To get the character level representation, do an LSTM over the
245
# characters of a word, and let :math:`c_w` be the final hidden state of
246
# this LSTM. Hints:
247
#
248
# * There are going to be two LSTM's in your new model.
249
# The original one that outputs POS tag scores, and the new one that
250
# outputs a character-level representation of each word.
251
# * To do a sequence model over characters, you will have to embed characters.
252
# The character embeddings will be the input to the character LSTM.
253
#
254
255