CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!

GitHub Repository: pytorch/tutorials
Path: blob/main/beginner_source/nlp/advanced_tutorial.py
Views: 494
1
# -*- coding: utf-8 -*-
2
r"""
3
Advanced: Making Dynamic Decisions and the Bi-LSTM CRF
4
======================================================
5
6
Dynamic versus Static Deep Learning Toolkits
7
--------------------------------------------
8
9
Pytorch is a *dynamic* neural network kit. Another example of a dynamic
10
kit is `Dynet <https://github.com/clab/dynet>`__ (I mention this because
11
working with Pytorch and Dynet is similar. If you see an example in
12
Dynet, it will probably help you implement it in Pytorch). The opposite
13
is the *static* tool kit, which includes Theano, Keras, TensorFlow, etc.
14
The core difference is the following:
15
16
* In a static toolkit, you define
17
a computation graph once, compile it, and then stream instances to it.
18
* In a dynamic toolkit, you define a computation graph *for each
19
instance*. It is never compiled and is executed on-the-fly
20
21
Without a lot of experience, it is difficult to appreciate the
22
difference. One example is to suppose we want to build a deep
23
constituent parser. Suppose our model involves roughly the following
24
steps:
25
26
* We build the tree bottom up
27
* Tag the root nodes (the words of the sentence)
28
* From there, use a neural network and the embeddings
29
of the words to find combinations that form constituents. Whenever you
30
form a new constituent, use some sort of technique to get an embedding
31
of the constituent. In this case, our network architecture will depend
32
completely on the input sentence. In the sentence "The green cat
33
scratched the wall", at some point in the model, we will want to combine
34
the span :math:`(i,j,r) = (1, 3, \text{NP})` (that is, an NP constituent
35
spans word 1 to word 3, in this case "The green cat").
36
37
However, another sentence might be "Somewhere, the big fat cat scratched
38
the wall". In this sentence, we will want to form the constituent
39
:math:`(2, 4, NP)` at some point. The constituents we will want to form
40
will depend on the instance. If we just compile the computation graph
41
once, as in a static toolkit, it will be exceptionally difficult or
42
impossible to program this logic. In a dynamic toolkit though, there
43
isn't just 1 pre-defined computation graph. There can be a new
44
computation graph for each instance, so this problem goes away.
45
46
Dynamic toolkits also have the advantage of being easier to debug and
47
the code more closely resembling the host language (by that I mean that
48
Pytorch and Dynet look more like actual Python code than Keras or
49
Theano).
50
51
Bi-LSTM Conditional Random Field Discussion
52
-------------------------------------------
53
54
For this section, we will see a full, complicated example of a Bi-LSTM
55
Conditional Random Field for named-entity recognition. The LSTM tagger
56
above is typically sufficient for part-of-speech tagging, but a sequence
57
model like the CRF is really essential for strong performance on NER.
58
Familiarity with CRF's is assumed. Although this name sounds scary, all
59
the model is a CRF but where an LSTM provides the features. This is
60
an advanced model though, far more complicated than any earlier model in
61
this tutorial. If you want to skip it, that is fine. To see if you're
62
ready, see if you can:
63
64
- Write the recurrence for the viterbi variable at step i for tag k.
65
- Modify the above recurrence to compute the forward variables instead.
66
- Modify again the above recurrence to compute the forward variables in
67
log-space (hint: log-sum-exp)
68
69
If you can do those three things, you should be able to understand the
70
code below. Recall that the CRF computes a conditional probability. Let
71
:math:`y` be a tag sequence and :math:`x` an input sequence of words.
72
Then we compute
73
74
.. math:: P(y|x) = \frac{\exp{(\text{Score}(x, y)})}{\sum_{y'} \exp{(\text{Score}(x, y')})}
75
76
Where the score is determined by defining some log potentials
77
:math:`\log \psi_i(x,y)` such that
78
79
.. math:: \text{Score}(x,y) = \sum_i \log \psi_i(x,y)
80
81
To make the partition function tractable, the potentials must look only
82
at local features.
83
84
In the Bi-LSTM CRF, we define two kinds of potentials: emission and
85
transition. The emission potential for the word at index :math:`i` comes
86
from the hidden state of the Bi-LSTM at timestep :math:`i`. The
87
transition scores are stored in a :math:`|T|x|T|` matrix
88
:math:`\textbf{P}`, where :math:`T` is the tag set. In my
89
implementation, :math:`\textbf{P}_{j,k}` is the score of transitioning
90
to tag :math:`j` from tag :math:`k`. So:
91
92
.. math:: \text{Score}(x,y) = \sum_i \log \psi_\text{EMIT}(y_i \rightarrow x_i) + \log \psi_\text{TRANS}(y_{i-1} \rightarrow y_i)
93
94
.. math:: = \sum_i h_i[y_i] + \textbf{P}_{y_i, y_{i-1}}
95
96
where in this second expression, we think of the tags as being assigned
97
unique non-negative indices.
98
99
If the above discussion was too brief, you can check out
100
`this <http://www.cs.columbia.edu/%7Emcollins/crf.pdf>`__ write up from
101
Michael Collins on CRFs.
102
103
Implementation Notes
104
--------------------
105
106
The example below implements the forward algorithm in log space to
107
compute the partition function, and the viterbi algorithm to decode.
108
Backpropagation will compute the gradients automatically for us. We
109
don't have to do anything by hand.
110
111
The implementation is not optimized. If you understand what is going on,
112
you'll probably quickly see that iterating over the next tag in the
113
forward algorithm could probably be done in one big operation. I wanted
114
to code to be more readable. If you want to make the relevant change,
115
you could probably use this tagger for real tasks.
116
"""
117
# Author: Robert Guthrie
118
119
import torch
120
import torch.autograd as autograd
121
import torch.nn as nn
122
import torch.optim as optim
123
124
torch.manual_seed(1)
125
126
#####################################################################
127
# Helper functions to make the code more readable.
128
129
130
def argmax(vec):
131
# return the argmax as a python int
132
_, idx = torch.max(vec, 1)
133
return idx.item()
134
135
136
def prepare_sequence(seq, to_ix):
137
idxs = [to_ix[w] for w in seq]
138
return torch.tensor(idxs, dtype=torch.long)
139
140
141
# Compute log sum exp in a numerically stable way for the forward algorithm
142
def log_sum_exp(vec):
143
max_score = vec[0, argmax(vec)]
144
max_score_broadcast = max_score.view(1, -1).expand(1, vec.size()[1])
145
return max_score + \
146
torch.log(torch.sum(torch.exp(vec - max_score_broadcast)))
147
148
#####################################################################
149
# Create model
150
151
152
class BiLSTM_CRF(nn.Module):
153
154
def __init__(self, vocab_size, tag_to_ix, embedding_dim, hidden_dim):
155
super(BiLSTM_CRF, self).__init__()
156
self.embedding_dim = embedding_dim
157
self.hidden_dim = hidden_dim
158
self.vocab_size = vocab_size
159
self.tag_to_ix = tag_to_ix
160
self.tagset_size = len(tag_to_ix)
161
162
self.word_embeds = nn.Embedding(vocab_size, embedding_dim)
163
self.lstm = nn.LSTM(embedding_dim, hidden_dim // 2,
164
num_layers=1, bidirectional=True)
165
166
# Maps the output of the LSTM into tag space.
167
self.hidden2tag = nn.Linear(hidden_dim, self.tagset_size)
168
169
# Matrix of transition parameters. Entry i,j is the score of
170
# transitioning *to* i *from* j.
171
self.transitions = nn.Parameter(
172
torch.randn(self.tagset_size, self.tagset_size))
173
174
# These two statements enforce the constraint that we never transfer
175
# to the start tag and we never transfer from the stop tag
176
self.transitions.data[tag_to_ix[START_TAG], :] = -10000
177
self.transitions.data[:, tag_to_ix[STOP_TAG]] = -10000
178
179
self.hidden = self.init_hidden()
180
181
def init_hidden(self):
182
return (torch.randn(2, 1, self.hidden_dim // 2),
183
torch.randn(2, 1, self.hidden_dim // 2))
184
185
def _forward_alg(self, feats):
186
# Do the forward algorithm to compute the partition function
187
init_alphas = torch.full((1, self.tagset_size), -10000.)
188
# START_TAG has all of the score.
189
init_alphas[0][self.tag_to_ix[START_TAG]] = 0.
190
191
# Wrap in a variable so that we will get automatic backprop
192
forward_var = init_alphas
193
194
# Iterate through the sentence
195
for feat in feats:
196
alphas_t = [] # The forward tensors at this timestep
197
for next_tag in range(self.tagset_size):
198
# broadcast the emission score: it is the same regardless of
199
# the previous tag
200
emit_score = feat[next_tag].view(
201
1, -1).expand(1, self.tagset_size)
202
# the ith entry of trans_score is the score of transitioning to
203
# next_tag from i
204
trans_score = self.transitions[next_tag].view(1, -1)
205
# The ith entry of next_tag_var is the value for the
206
# edge (i -> next_tag) before we do log-sum-exp
207
next_tag_var = forward_var + trans_score + emit_score
208
# The forward variable for this tag is log-sum-exp of all the
209
# scores.
210
alphas_t.append(log_sum_exp(next_tag_var).view(1))
211
forward_var = torch.cat(alphas_t).view(1, -1)
212
terminal_var = forward_var + self.transitions[self.tag_to_ix[STOP_TAG]]
213
alpha = log_sum_exp(terminal_var)
214
return alpha
215
216
def _get_lstm_features(self, sentence):
217
self.hidden = self.init_hidden()
218
embeds = self.word_embeds(sentence).view(len(sentence), 1, -1)
219
lstm_out, self.hidden = self.lstm(embeds, self.hidden)
220
lstm_out = lstm_out.view(len(sentence), self.hidden_dim)
221
lstm_feats = self.hidden2tag(lstm_out)
222
return lstm_feats
223
224
def _score_sentence(self, feats, tags):
225
# Gives the score of a provided tag sequence
226
score = torch.zeros(1)
227
tags = torch.cat([torch.tensor([self.tag_to_ix[START_TAG]], dtype=torch.long), tags])
228
for i, feat in enumerate(feats):
229
score = score + \
230
self.transitions[tags[i + 1], tags[i]] + feat[tags[i + 1]]
231
score = score + self.transitions[self.tag_to_ix[STOP_TAG], tags[-1]]
232
return score
233
234
def _viterbi_decode(self, feats):
235
backpointers = []
236
237
# Initialize the viterbi variables in log space
238
init_vvars = torch.full((1, self.tagset_size), -10000.)
239
init_vvars[0][self.tag_to_ix[START_TAG]] = 0
240
241
# forward_var at step i holds the viterbi variables for step i-1
242
forward_var = init_vvars
243
for feat in feats:
244
bptrs_t = [] # holds the backpointers for this step
245
viterbivars_t = [] # holds the viterbi variables for this step
246
247
for next_tag in range(self.tagset_size):
248
# next_tag_var[i] holds the viterbi variable for tag i at the
249
# previous step, plus the score of transitioning
250
# from tag i to next_tag.
251
# We don't include the emission scores here because the max
252
# does not depend on them (we add them in below)
253
next_tag_var = forward_var + self.transitions[next_tag]
254
best_tag_id = argmax(next_tag_var)
255
bptrs_t.append(best_tag_id)
256
viterbivars_t.append(next_tag_var[0][best_tag_id].view(1))
257
# Now add in the emission scores, and assign forward_var to the set
258
# of viterbi variables we just computed
259
forward_var = (torch.cat(viterbivars_t) + feat).view(1, -1)
260
backpointers.append(bptrs_t)
261
262
# Transition to STOP_TAG
263
terminal_var = forward_var + self.transitions[self.tag_to_ix[STOP_TAG]]
264
best_tag_id = argmax(terminal_var)
265
path_score = terminal_var[0][best_tag_id]
266
267
# Follow the back pointers to decode the best path.
268
best_path = [best_tag_id]
269
for bptrs_t in reversed(backpointers):
270
best_tag_id = bptrs_t[best_tag_id]
271
best_path.append(best_tag_id)
272
# Pop off the start tag (we dont want to return that to the caller)
273
start = best_path.pop()
274
assert start == self.tag_to_ix[START_TAG] # Sanity check
275
best_path.reverse()
276
return path_score, best_path
277
278
def neg_log_likelihood(self, sentence, tags):
279
feats = self._get_lstm_features(sentence)
280
forward_score = self._forward_alg(feats)
281
gold_score = self._score_sentence(feats, tags)
282
return forward_score - gold_score
283
284
def forward(self, sentence): # dont confuse this with _forward_alg above.
285
# Get the emission scores from the BiLSTM
286
lstm_feats = self._get_lstm_features(sentence)
287
288
# Find the best path, given the features.
289
score, tag_seq = self._viterbi_decode(lstm_feats)
290
return score, tag_seq
291
292
#####################################################################
293
# Run training
294
295
296
START_TAG = "<START>"
297
STOP_TAG = "<STOP>"
298
EMBEDDING_DIM = 5
299
HIDDEN_DIM = 4
300
301
# Make up some training data
302
training_data = [(
303
"the wall street journal reported today that apple corporation made money".split(),
304
"B I I I O O O B I O O".split()
305
), (
306
"georgia tech is a university in georgia".split(),
307
"B I O O O O B".split()
308
)]
309
310
word_to_ix = {}
311
for sentence, tags in training_data:
312
for word in sentence:
313
if word not in word_to_ix:
314
word_to_ix[word] = len(word_to_ix)
315
316
tag_to_ix = {"B": 0, "I": 1, "O": 2, START_TAG: 3, STOP_TAG: 4}
317
318
model = BiLSTM_CRF(len(word_to_ix), tag_to_ix, EMBEDDING_DIM, HIDDEN_DIM)
319
optimizer = optim.SGD(model.parameters(), lr=0.01, weight_decay=1e-4)
320
321
# Check predictions before training
322
with torch.no_grad():
323
precheck_sent = prepare_sequence(training_data[0][0], word_to_ix)
324
precheck_tags = torch.tensor([tag_to_ix[t] for t in training_data[0][1]], dtype=torch.long)
325
print(model(precheck_sent))
326
327
# Make sure prepare_sequence from earlier in the LSTM section is loaded
328
for epoch in range(
329
300): # again, normally you would NOT do 300 epochs, it is toy data
330
for sentence, tags in training_data:
331
# Step 1. Remember that Pytorch accumulates gradients.
332
# We need to clear them out before each instance
333
model.zero_grad()
334
335
# Step 2. Get our inputs ready for the network, that is,
336
# turn them into Tensors of word indices.
337
sentence_in = prepare_sequence(sentence, word_to_ix)
338
targets = torch.tensor([tag_to_ix[t] for t in tags], dtype=torch.long)
339
340
# Step 3. Run our forward pass.
341
loss = model.neg_log_likelihood(sentence_in, targets)
342
343
# Step 4. Compute the loss, gradients, and update the parameters by
344
# calling optimizer.step()
345
loss.backward()
346
optimizer.step()
347
348
# Check predictions after training
349
with torch.no_grad():
350
precheck_sent = prepare_sequence(training_data[0][0], word_to_ix)
351
print(model(precheck_sent))
352
# We got it!
353
354
355
######################################################################
356
# Exercise: A new loss function for discriminative tagging
357
# --------------------------------------------------------
358
#
359
# It wasn't really necessary for us to create a computation graph when
360
# doing decoding, since we do not backpropagate from the viterbi path
361
# score. Since we have it anyway, try training the tagger where the loss
362
# function is the difference between the Viterbi path score and the score
363
# of the gold-standard path. It should be clear that this function is
364
# non-negative and 0 when the predicted tag sequence is the correct tag
365
# sequence. This is essentially *structured perceptron*.
366
#
367
# This modification should be short, since Viterbi and score\_sentence are
368
# already implemented. This is an example of the shape of the computation
369
# graph *depending on the training instance*. Although I haven't tried
370
# implementing this in a static toolkit, I imagine that it is possible but
371
# much less straightforward.
372
#
373
# Pick up some real data and do a comparison!
374
#
375
376