CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!

GitHub Repository: pytorch/tutorials
Path: blob/main/beginner_source/nlp/deep_learning_tutorial.py
Views: 494
1
# -*- coding: utf-8 -*-
2
r"""
3
Deep Learning with PyTorch
4
**************************
5
6
Deep Learning Building Blocks: Affine maps, non-linearities and objectives
7
==========================================================================
8
9
Deep learning consists of composing linearities with non-linearities in
10
clever ways. The introduction of non-linearities allows for powerful
11
models. In this section, we will play with these core components, make
12
up an objective function, and see how the model is trained.
13
14
15
Affine Maps
16
~~~~~~~~~~~
17
18
One of the core workhorses of deep learning is the affine map, which is
19
a function :math:`f(x)` where
20
21
.. math:: f(x) = Ax + b
22
23
for a matrix :math:`A` and vectors :math:`x, b`. The parameters to be
24
learned here are :math:`A` and :math:`b`. Often, :math:`b` is refered to
25
as the *bias* term.
26
27
28
PyTorch and most other deep learning frameworks do things a little
29
differently than traditional linear algebra. It maps the rows of the
30
input instead of the columns. That is, the :math:`i`'th row of the
31
output below is the mapping of the :math:`i`'th row of the input under
32
:math:`A`, plus the bias term. Look at the example below.
33
34
"""
35
36
# Author: Robert Guthrie
37
38
import torch
39
import torch.nn as nn
40
import torch.nn.functional as F
41
import torch.optim as optim
42
43
torch.manual_seed(1)
44
45
46
######################################################################
47
48
lin = nn.Linear(5, 3) # maps from R^5 to R^3, parameters A, b
49
# data is 2x5. A maps from 5 to 3... can we map "data" under A?
50
data = torch.randn(2, 5)
51
print(lin(data)) # yes
52
53
54
######################################################################
55
# Non-Linearities
56
# ~~~~~~~~~~~~~~~
57
#
58
# First, note the following fact, which will explain why we need
59
# non-linearities in the first place. Suppose we have two affine maps
60
# :math:`f(x) = Ax + b` and :math:`g(x) = Cx + d`. What is
61
# :math:`f(g(x))`?
62
#
63
# .. math:: f(g(x)) = A(Cx + d) + b = ACx + (Ad + b)
64
#
65
# :math:`AC` is a matrix and :math:`Ad + b` is a vector, so we see that
66
# composing affine maps gives you an affine map.
67
#
68
# From this, you can see that if you wanted your neural network to be long
69
# chains of affine compositions, that this adds no new power to your model
70
# than just doing a single affine map.
71
#
72
# If we introduce non-linearities in between the affine layers, this is no
73
# longer the case, and we can build much more powerful models.
74
#
75
# There are a few core non-linearities.
76
# :math:`\tanh(x), \sigma(x), \text{ReLU}(x)` are the most common. You are
77
# probably wondering: "why these functions? I can think of plenty of other
78
# non-linearities." The reason for this is that they have gradients that
79
# are easy to compute, and computing gradients is essential for learning.
80
# For example
81
#
82
# .. math:: \frac{d\sigma}{dx} = \sigma(x)(1 - \sigma(x))
83
#
84
# A quick note: although you may have learned some neural networks in your
85
# intro to AI class where :math:`\sigma(x)` was the default non-linearity,
86
# typically people shy away from it in practice. This is because the
87
# gradient *vanishes* very quickly as the absolute value of the argument
88
# grows. Small gradients means it is hard to learn. Most people default to
89
# tanh or ReLU.
90
#
91
92
# In pytorch, most non-linearities are in torch.functional (we have it imported as F)
93
# Note that non-linearites typically don't have parameters like affine maps do.
94
# That is, they don't have weights that are updated during training.
95
data = torch.randn(2, 2)
96
print(data)
97
print(F.relu(data))
98
99
100
######################################################################
101
# Softmax and Probabilities
102
# ~~~~~~~~~~~~~~~~~~~~~~~~~
103
#
104
# The function :math:`\text{Softmax}(x)` is also just a non-linearity, but
105
# it is special in that it usually is the last operation done in a
106
# network. This is because it takes in a vector of real numbers and
107
# returns a probability distribution. Its definition is as follows. Let
108
# :math:`x` be a vector of real numbers (positive, negative, whatever,
109
# there are no constraints). Then the i'th component of
110
# :math:`\text{Softmax}(x)` is
111
#
112
# .. math:: \frac{\exp(x_i)}{\sum_j \exp(x_j)}
113
#
114
# It should be clear that the output is a probability distribution: each
115
# element is non-negative and the sum over all components is 1.
116
#
117
# You could also think of it as just applying an element-wise
118
# exponentiation operator to the input to make everything non-negative and
119
# then dividing by the normalization constant.
120
#
121
122
# Softmax is also in torch.nn.functional
123
data = torch.randn(5)
124
print(data)
125
print(F.softmax(data, dim=0))
126
print(F.softmax(data, dim=0).sum()) # Sums to 1 because it is a distribution!
127
print(F.log_softmax(data, dim=0)) # theres also log_softmax
128
129
130
######################################################################
131
# Objective Functions
132
# ~~~~~~~~~~~~~~~~~~~
133
#
134
# The objective function is the function that your network is being
135
# trained to minimize (in which case it is often called a *loss function*
136
# or *cost function*). This proceeds by first choosing a training
137
# instance, running it through your neural network, and then computing the
138
# loss of the output. The parameters of the model are then updated by
139
# taking the derivative of the loss function. Intuitively, if your model
140
# is completely confident in its answer, and its answer is wrong, your
141
# loss will be high. If it is very confident in its answer, and its answer
142
# is correct, the loss will be low.
143
#
144
# The idea behind minimizing the loss function on your training examples
145
# is that your network will hopefully generalize well and have small loss
146
# on unseen examples in your dev set, test set, or in production. An
147
# example loss function is the *negative log likelihood loss*, which is a
148
# very common objective for multi-class classification. For supervised
149
# multi-class classification, this means training the network to minimize
150
# the negative log probability of the correct output (or equivalently,
151
# maximize the log probability of the correct output).
152
#
153
154
155
######################################################################
156
# Optimization and Training
157
# =========================
158
#
159
# So what we can compute a loss function for an instance? What do we do
160
# with that? We saw earlier that Tensors know how to compute gradients
161
# with respect to the things that were used to compute it. Well,
162
# since our loss is an Tensor, we can compute gradients with
163
# respect to all of the parameters used to compute it! Then we can perform
164
# standard gradient updates. Let :math:`\theta` be our parameters,
165
# :math:`L(\theta)` the loss function, and :math:`\eta` a positive
166
# learning rate. Then:
167
#
168
# .. math:: \theta^{(t+1)} = \theta^{(t)} - \eta \nabla_\theta L(\theta)
169
#
170
# There are a huge collection of algorithms and active research in
171
# attempting to do something more than just this vanilla gradient update.
172
# Many attempt to vary the learning rate based on what is happening at
173
# train time. You don't need to worry about what specifically these
174
# algorithms are doing unless you are really interested. Torch provides
175
# many in the torch.optim package, and they are all completely
176
# transparent. Using the simplest gradient update is the same as the more
177
# complicated algorithms. Trying different update algorithms and different
178
# parameters for the update algorithms (like different initial learning
179
# rates) is important in optimizing your network's performance. Often,
180
# just replacing vanilla SGD with an optimizer like Adam or RMSProp will
181
# boost performance noticably.
182
#
183
184
185
######################################################################
186
# Creating Network Components in PyTorch
187
# ======================================
188
#
189
# Before we move on to our focus on NLP, lets do an annotated example of
190
# building a network in PyTorch using only affine maps and
191
# non-linearities. We will also see how to compute a loss function, using
192
# PyTorch's built in negative log likelihood, and update parameters by
193
# backpropagation.
194
#
195
# All network components should inherit from nn.Module and override the
196
# forward() method. That is about it, as far as the boilerplate is
197
# concerned. Inheriting from nn.Module provides functionality to your
198
# component. For example, it makes it keep track of its trainable
199
# parameters, you can swap it between CPU and GPU with the ``.to(device)``
200
# method, where device can be a CPU device ``torch.device("cpu")`` or CUDA
201
# device ``torch.device("cuda:0")``.
202
#
203
# Let's write an annotated example of a network that takes in a sparse
204
# bag-of-words representation and outputs a probability distribution over
205
# two labels: "English" and "Spanish". This model is just logistic
206
# regression.
207
#
208
209
210
######################################################################
211
# Example: Logistic Regression Bag-of-Words classifier
212
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
213
#
214
# Our model will map a sparse BoW representation to log probabilities over
215
# labels. We assign each word in the vocab an index. For example, say our
216
# entire vocab is two words "hello" and "world", with indices 0 and 1
217
# respectively. The BoW vector for the sentence "hello hello hello hello"
218
# is
219
#
220
# .. math:: \left[ 4, 0 \right]
221
#
222
# For "hello world world hello", it is
223
#
224
# .. math:: \left[ 2, 2 \right]
225
#
226
# etc. In general, it is
227
#
228
# .. math:: \left[ \text{Count}(\text{hello}), \text{Count}(\text{world}) \right]
229
#
230
# Denote this BOW vector as :math:`x`. The output of our network is:
231
#
232
# .. math:: \log \text{Softmax}(Ax + b)
233
#
234
# That is, we pass the input through an affine map and then do log
235
# softmax.
236
#
237
238
data = [("me gusta comer en la cafeteria".split(), "SPANISH"),
239
("Give it to me".split(), "ENGLISH"),
240
("No creo que sea una buena idea".split(), "SPANISH"),
241
("No it is not a good idea to get lost at sea".split(), "ENGLISH")]
242
243
test_data = [("Yo creo que si".split(), "SPANISH"),
244
("it is lost on me".split(), "ENGLISH")]
245
246
# word_to_ix maps each word in the vocab to a unique integer, which will be its
247
# index into the Bag of words vector
248
word_to_ix = {}
249
for sent, _ in data + test_data:
250
for word in sent:
251
if word not in word_to_ix:
252
word_to_ix[word] = len(word_to_ix)
253
print(word_to_ix)
254
255
VOCAB_SIZE = len(word_to_ix)
256
NUM_LABELS = 2
257
258
259
class BoWClassifier(nn.Module): # inheriting from nn.Module!
260
261
def __init__(self, num_labels, vocab_size):
262
# calls the init function of nn.Module. Dont get confused by syntax,
263
# just always do it in an nn.Module
264
super(BoWClassifier, self).__init__()
265
266
# Define the parameters that you will need. In this case, we need A and b,
267
# the parameters of the affine mapping.
268
# Torch defines nn.Linear(), which provides the affine map.
269
# Make sure you understand why the input dimension is vocab_size
270
# and the output is num_labels!
271
self.linear = nn.Linear(vocab_size, num_labels)
272
273
# NOTE! The non-linearity log softmax does not have parameters! So we don't need
274
# to worry about that here
275
276
def forward(self, bow_vec):
277
# Pass the input through the linear layer,
278
# then pass that through log_softmax.
279
# Many non-linearities and other functions are in torch.nn.functional
280
return F.log_softmax(self.linear(bow_vec), dim=1)
281
282
283
def make_bow_vector(sentence, word_to_ix):
284
vec = torch.zeros(len(word_to_ix))
285
for word in sentence:
286
vec[word_to_ix[word]] += 1
287
return vec.view(1, -1)
288
289
290
def make_target(label, label_to_ix):
291
return torch.LongTensor([label_to_ix[label]])
292
293
294
model = BoWClassifier(NUM_LABELS, VOCAB_SIZE)
295
296
# the model knows its parameters. The first output below is A, the second is b.
297
# Whenever you assign a component to a class variable in the __init__ function
298
# of a module, which was done with the line
299
# self.linear = nn.Linear(...)
300
# Then through some Python magic from the PyTorch devs, your module
301
# (in this case, BoWClassifier) will store knowledge of the nn.Linear's parameters
302
for param in model.parameters():
303
print(param)
304
305
# To run the model, pass in a BoW vector
306
# Here we don't need to train, so the code is wrapped in torch.no_grad()
307
with torch.no_grad():
308
sample = data[0]
309
bow_vector = make_bow_vector(sample[0], word_to_ix)
310
log_probs = model(bow_vector)
311
print(log_probs)
312
313
314
######################################################################
315
# Which of the above values corresponds to the log probability of ENGLISH,
316
# and which to SPANISH? We never defined it, but we need to if we want to
317
# train the thing.
318
#
319
320
label_to_ix = {"SPANISH": 0, "ENGLISH": 1}
321
322
323
######################################################################
324
# So lets train! To do this, we pass instances through to get log
325
# probabilities, compute a loss function, compute the gradient of the loss
326
# function, and then update the parameters with a gradient step. Loss
327
# functions are provided by Torch in the nn package. nn.NLLLoss() is the
328
# negative log likelihood loss we want. It also defines optimization
329
# functions in torch.optim. Here, we will just use SGD.
330
#
331
# Note that the *input* to NLLLoss is a vector of log probabilities, and a
332
# target label. It doesn't compute the log probabilities for us. This is
333
# why the last layer of our network is log softmax. The loss function
334
# nn.CrossEntropyLoss() is the same as NLLLoss(), except it does the log
335
# softmax for you.
336
#
337
338
# Run on test data before we train, just to see a before-and-after
339
with torch.no_grad():
340
for instance, label in test_data:
341
bow_vec = make_bow_vector(instance, word_to_ix)
342
log_probs = model(bow_vec)
343
print(log_probs)
344
345
# Print the matrix column corresponding to "creo"
346
print(next(model.parameters())[:, word_to_ix["creo"]])
347
348
loss_function = nn.NLLLoss()
349
optimizer = optim.SGD(model.parameters(), lr=0.1)
350
351
# Usually you want to pass over the training data several times.
352
# 100 is much bigger than on a real data set, but real datasets have more than
353
# two instances. Usually, somewhere between 5 and 30 epochs is reasonable.
354
for epoch in range(100):
355
for instance, label in data:
356
# Step 1. Remember that PyTorch accumulates gradients.
357
# We need to clear them out before each instance
358
model.zero_grad()
359
360
# Step 2. Make our BOW vector and also we must wrap the target in a
361
# Tensor as an integer. For example, if the target is SPANISH, then
362
# we wrap the integer 0. The loss function then knows that the 0th
363
# element of the log probabilities is the log probability
364
# corresponding to SPANISH
365
bow_vec = make_bow_vector(instance, word_to_ix)
366
target = make_target(label, label_to_ix)
367
368
# Step 3. Run our forward pass.
369
log_probs = model(bow_vec)
370
371
# Step 4. Compute the loss, gradients, and update the parameters by
372
# calling optimizer.step()
373
loss = loss_function(log_probs, target)
374
loss.backward()
375
optimizer.step()
376
377
with torch.no_grad():
378
for instance, label in test_data:
379
bow_vec = make_bow_vector(instance, word_to_ix)
380
log_probs = model(bow_vec)
381
print(log_probs)
382
383
# Index corresponding to Spanish goes up, English goes down!
384
print(next(model.parameters())[:, word_to_ix["creo"]])
385
386
387
######################################################################
388
# We got the right answer! You can see that the log probability for
389
# Spanish is much higher in the first example, and the log probability for
390
# English is much higher in the second for the test data, as it should be.
391
#
392
# Now you see how to make a PyTorch component, pass some data through it
393
# and do gradient updates. We are ready to dig deeper into what deep NLP
394
# has to offer.
395
#
396
397