CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.

GitHub Repository: pytorch/tutorials
Path: blob/main/intermediate_source/char_rnn_classification_tutorial.py
Views: 712
1
# -*- coding: utf-8 -*-
2
"""
3
NLP From Scratch: Classifying Names with a Character-Level RNN
4
**************************************************************
5
**Author**: `Sean Robertson <https://github.com/spro>`_
6
7
This tutorials is part of a three-part series:
8
9
* `NLP From Scratch: Classifying Names with a Character-Level RNN <https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html>`__
10
* `NLP From Scratch: Generating Names with a Character-Level RNN <https://pytorch.org/tutorials/intermediate/char_rnn_generation_tutorial.html>`__
11
* `NLP From Scratch: Translation with a Sequence to Sequence Network and Attention <https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html>`__
12
13
We will be building and training a basic character-level Recurrent Neural
14
Network (RNN) to classify words. This tutorial, along with two other
15
Natural Language Processing (NLP) "from scratch" tutorials
16
:doc:`/intermediate/char_rnn_generation_tutorial` and
17
:doc:`/intermediate/seq2seq_translation_tutorial`, show how to
18
preprocess data to model NLP. In particular, these tutorials show how
19
preprocessing to model NLP works at a low level.
20
21
A character-level RNN reads words as a series of characters -
22
outputting a prediction and "hidden state" at each step, feeding its
23
previous hidden state into each next step. We take the final prediction
24
to be the output, i.e. which class the word belongs to.
25
26
Specifically, we'll train on a few thousand surnames from 18 languages
27
of origin, and predict which language a name is from based on the
28
spelling:
29
30
.. code-block:: sh
31
32
$ python predict.py Hinton
33
(-0.47) Scottish
34
(-1.52) English
35
(-3.57) Irish
36
37
$ python predict.py Schmidhuber
38
(-0.19) German
39
(-2.48) Czech
40
(-2.68) Dutch
41
42
43
Recommended Preparation
44
=======================
45
46
Before starting this tutorial it is recommended that you have installed PyTorch,
47
and have a basic understanding of Python programming language and Tensors:
48
49
- https://pytorch.org/ For installation instructions
50
- :doc:`/beginner/deep_learning_60min_blitz` to get started with PyTorch in general
51
and learn the basics of Tensors
52
- :doc:`/beginner/pytorch_with_examples` for a wide and deep overview
53
- :doc:`/beginner/former_torchies_tutorial` if you are former Lua Torch user
54
55
It would also be useful to know about RNNs and how they work:
56
57
- `The Unreasonable Effectiveness of Recurrent Neural
58
Networks <https://karpathy.github.io/2015/05/21/rnn-effectiveness/>`__
59
shows a bunch of real life examples
60
- `Understanding LSTM
61
Networks <https://colah.github.io/posts/2015-08-Understanding-LSTMs/>`__
62
is about LSTMs specifically but also informative about RNNs in
63
general
64
65
Preparing the Data
66
==================
67
68
.. note::
69
Download the data from
70
`here <https://download.pytorch.org/tutorial/data.zip>`_
71
and extract it to the current directory.
72
73
Included in the ``data/names`` directory are 18 text files named as
74
``[Language].txt``. Each file contains a bunch of names, one name per
75
line, mostly romanized (but we still need to convert from Unicode to
76
ASCII).
77
78
We'll end up with a dictionary of lists of names per language,
79
``{language: [names ...]}``. The generic variables "category" and "line"
80
(for language and name in our case) are used for later extensibility.
81
"""
82
from io import open
83
import glob
84
import os
85
86
def findFiles(path): return glob.glob(path)
87
88
print(findFiles('data/names/*.txt'))
89
90
import unicodedata
91
import string
92
93
all_letters = string.ascii_letters + " .,;'"
94
n_letters = len(all_letters)
95
96
# Turn a Unicode string to plain ASCII, thanks to https://stackoverflow.com/a/518232/2809427
97
def unicodeToAscii(s):
98
return ''.join(
99
c for c in unicodedata.normalize('NFD', s)
100
if unicodedata.category(c) != 'Mn'
101
and c in all_letters
102
)
103
104
print(unicodeToAscii('Ślusàrski'))
105
106
# Build the category_lines dictionary, a list of names per language
107
category_lines = {}
108
all_categories = []
109
110
# Read a file and split into lines
111
def readLines(filename):
112
lines = open(filename, encoding='utf-8').read().strip().split('\n')
113
return [unicodeToAscii(line) for line in lines]
114
115
for filename in findFiles('data/names/*.txt'):
116
category = os.path.splitext(os.path.basename(filename))[0]
117
all_categories.append(category)
118
lines = readLines(filename)
119
category_lines[category] = lines
120
121
n_categories = len(all_categories)
122
123
124
######################################################################
125
# Now we have ``category_lines``, a dictionary mapping each category
126
# (language) to a list of lines (names). We also kept track of
127
# ``all_categories`` (just a list of languages) and ``n_categories`` for
128
# later reference.
129
#
130
131
print(category_lines['Italian'][:5])
132
133
134
######################################################################
135
# Turning Names into Tensors
136
# --------------------------
137
#
138
# Now that we have all the names organized, we need to turn them into
139
# Tensors to make any use of them.
140
#
141
# To represent a single letter, we use a "one-hot vector" of size
142
# ``<1 x n_letters>``. A one-hot vector is filled with 0s except for a 1
143
# at index of the current letter, e.g. ``"b" = <0 1 0 0 0 ...>``.
144
#
145
# To make a word we join a bunch of those into a 2D matrix
146
# ``<line_length x 1 x n_letters>``.
147
#
148
# That extra 1 dimension is because PyTorch assumes everything is in
149
# batches - we're just using a batch size of 1 here.
150
#
151
152
import torch
153
154
# Find letter index from all_letters, e.g. "a" = 0
155
def letterToIndex(letter):
156
return all_letters.find(letter)
157
158
# Just for demonstration, turn a letter into a <1 x n_letters> Tensor
159
def letterToTensor(letter):
160
tensor = torch.zeros(1, n_letters)
161
tensor[0][letterToIndex(letter)] = 1
162
return tensor
163
164
# Turn a line into a <line_length x 1 x n_letters>,
165
# or an array of one-hot letter vectors
166
def lineToTensor(line):
167
tensor = torch.zeros(len(line), 1, n_letters)
168
for li, letter in enumerate(line):
169
tensor[li][0][letterToIndex(letter)] = 1
170
return tensor
171
172
print(letterToTensor('J'))
173
174
print(lineToTensor('Jones').size())
175
176
177
######################################################################
178
# Creating the Network
179
# ====================
180
#
181
# Before autograd, creating a recurrent neural network in Torch involved
182
# cloning the parameters of a layer over several timesteps. The layers
183
# held hidden state and gradients which are now entirely handled by the
184
# graph itself. This means you can implement a RNN in a very "pure" way,
185
# as regular feed-forward layers.
186
#
187
# This RNN module implements a "vanilla RNN" an is just 3 linear layers
188
# which operate on an input and hidden state, with a ``LogSoftmax`` layer
189
# after the output.
190
#
191
192
import torch.nn as nn
193
import torch.nn.functional as F
194
195
class RNN(nn.Module):
196
def __init__(self, input_size, hidden_size, output_size):
197
super(RNN, self).__init__()
198
199
self.hidden_size = hidden_size
200
201
self.i2h = nn.Linear(input_size, hidden_size)
202
self.h2h = nn.Linear(hidden_size, hidden_size)
203
self.h2o = nn.Linear(hidden_size, output_size)
204
self.softmax = nn.LogSoftmax(dim=1)
205
206
def forward(self, input, hidden):
207
hidden = F.tanh(self.i2h(input) + self.h2h(hidden))
208
output = self.h2o(hidden)
209
output = self.softmax(output)
210
return output, hidden
211
212
def initHidden(self):
213
return torch.zeros(1, self.hidden_size)
214
215
n_hidden = 128
216
rnn = RNN(n_letters, n_hidden, n_categories)
217
218
219
######################################################################
220
# To run a step of this network we need to pass an input (in our case, the
221
# Tensor for the current letter) and a previous hidden state (which we
222
# initialize as zeros at first). We'll get back the output (probability of
223
# each language) and a next hidden state (which we keep for the next
224
# step).
225
#
226
227
input = letterToTensor('A')
228
hidden = torch.zeros(1, n_hidden)
229
230
output, next_hidden = rnn(input, hidden)
231
232
233
######################################################################
234
# For the sake of efficiency we don't want to be creating a new Tensor for
235
# every step, so we will use ``lineToTensor`` instead of
236
# ``letterToTensor`` and use slices. This could be further optimized by
237
# precomputing batches of Tensors.
238
#
239
240
input = lineToTensor('Albert')
241
hidden = torch.zeros(1, n_hidden)
242
243
output, next_hidden = rnn(input[0], hidden)
244
print(output)
245
246
247
######################################################################
248
# As you can see the output is a ``<1 x n_categories>`` Tensor, where
249
# every item is the likelihood of that category (higher is more likely).
250
#
251
252
253
######################################################################
254
#
255
# Training
256
# ========
257
# Preparing for Training
258
# ----------------------
259
#
260
# Before going into training we should make a few helper functions. The
261
# first is to interpret the output of the network, which we know to be a
262
# likelihood of each category. We can use ``Tensor.topk`` to get the index
263
# of the greatest value:
264
#
265
266
def categoryFromOutput(output):
267
top_n, top_i = output.topk(1)
268
category_i = top_i[0].item()
269
return all_categories[category_i], category_i
270
271
print(categoryFromOutput(output))
272
273
274
######################################################################
275
# We will also want a quick way to get a training example (a name and its
276
# language):
277
#
278
279
import random
280
281
def randomChoice(l):
282
return l[random.randint(0, len(l) - 1)]
283
284
def randomTrainingExample():
285
category = randomChoice(all_categories)
286
line = randomChoice(category_lines[category])
287
category_tensor = torch.tensor([all_categories.index(category)], dtype=torch.long)
288
line_tensor = lineToTensor(line)
289
return category, line, category_tensor, line_tensor
290
291
for i in range(10):
292
category, line, category_tensor, line_tensor = randomTrainingExample()
293
print('category =', category, '/ line =', line)
294
295
296
######################################################################
297
# Training the Network
298
# --------------------
299
#
300
# Now all it takes to train this network is show it a bunch of examples,
301
# have it make guesses, and tell it if it's wrong.
302
#
303
# For the loss function ``nn.NLLLoss`` is appropriate, since the last
304
# layer of the RNN is ``nn.LogSoftmax``.
305
#
306
307
criterion = nn.NLLLoss()
308
309
310
######################################################################
311
# Each loop of training will:
312
#
313
# - Create input and target tensors
314
# - Create a zeroed initial hidden state
315
# - Read each letter in and
316
#
317
# - Keep hidden state for next letter
318
#
319
# - Compare final output to target
320
# - Back-propagate
321
# - Return the output and loss
322
#
323
324
learning_rate = 0.005 # If you set this too high, it might explode. If too low, it might not learn
325
326
def train(category_tensor, line_tensor):
327
hidden = rnn.initHidden()
328
329
rnn.zero_grad()
330
331
for i in range(line_tensor.size()[0]):
332
output, hidden = rnn(line_tensor[i], hidden)
333
334
loss = criterion(output, category_tensor)
335
loss.backward()
336
337
# Add parameters' gradients to their values, multiplied by learning rate
338
for p in rnn.parameters():
339
p.data.add_(p.grad.data, alpha=-learning_rate)
340
341
return output, loss.item()
342
343
344
######################################################################
345
# Now we just have to run that with a bunch of examples. Since the
346
# ``train`` function returns both the output and loss we can print its
347
# guesses and also keep track of loss for plotting. Since there are 1000s
348
# of examples we print only every ``print_every`` examples, and take an
349
# average of the loss.
350
#
351
352
import time
353
import math
354
355
n_iters = 100000
356
print_every = 5000
357
plot_every = 1000
358
359
360
361
# Keep track of losses for plotting
362
current_loss = 0
363
all_losses = []
364
365
def timeSince(since):
366
now = time.time()
367
s = now - since
368
m = math.floor(s / 60)
369
s -= m * 60
370
return '%dm %ds' % (m, s)
371
372
start = time.time()
373
374
for iter in range(1, n_iters + 1):
375
category, line, category_tensor, line_tensor = randomTrainingExample()
376
output, loss = train(category_tensor, line_tensor)
377
current_loss += loss
378
379
# Print ``iter`` number, loss, name and guess
380
if iter % print_every == 0:
381
guess, guess_i = categoryFromOutput(output)
382
correct = '✓' if guess == category else '✗ (%s)' % category
383
print('%d %d%% (%s) %.4f %s / %s %s' % (iter, iter / n_iters * 100, timeSince(start), loss, line, guess, correct))
384
385
# Add current loss avg to list of losses
386
if iter % plot_every == 0:
387
all_losses.append(current_loss / plot_every)
388
current_loss = 0
389
390
391
######################################################################
392
# Plotting the Results
393
# --------------------
394
#
395
# Plotting the historical loss from ``all_losses`` shows the network
396
# learning:
397
#
398
399
import matplotlib.pyplot as plt
400
import matplotlib.ticker as ticker
401
402
plt.figure()
403
plt.plot(all_losses)
404
405
406
######################################################################
407
# Evaluating the Results
408
# ======================
409
#
410
# To see how well the network performs on different categories, we will
411
# create a confusion matrix, indicating for every actual language (rows)
412
# which language the network guesses (columns). To calculate the confusion
413
# matrix a bunch of samples are run through the network with
414
# ``evaluate()``, which is the same as ``train()`` minus the backprop.
415
#
416
417
# Keep track of correct guesses in a confusion matrix
418
confusion = torch.zeros(n_categories, n_categories)
419
n_confusion = 10000
420
421
# Just return an output given a line
422
def evaluate(line_tensor):
423
hidden = rnn.initHidden()
424
425
for i in range(line_tensor.size()[0]):
426
output, hidden = rnn(line_tensor[i], hidden)
427
428
return output
429
430
# Go through a bunch of examples and record which are correctly guessed
431
for i in range(n_confusion):
432
category, line, category_tensor, line_tensor = randomTrainingExample()
433
output = evaluate(line_tensor)
434
guess, guess_i = categoryFromOutput(output)
435
category_i = all_categories.index(category)
436
confusion[category_i][guess_i] += 1
437
438
# Normalize by dividing every row by its sum
439
for i in range(n_categories):
440
confusion[i] = confusion[i] / confusion[i].sum()
441
442
# Set up plot
443
fig = plt.figure()
444
ax = fig.add_subplot(111)
445
cax = ax.matshow(confusion.numpy())
446
fig.colorbar(cax)
447
448
# Set up axes
449
ax.set_xticklabels([''] + all_categories, rotation=90)
450
ax.set_yticklabels([''] + all_categories)
451
452
# Force label at every tick
453
ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
454
ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
455
456
# sphinx_gallery_thumbnail_number = 2
457
plt.show()
458
459
460
######################################################################
461
# You can pick out bright spots off the main axis that show which
462
# languages it guesses incorrectly, e.g. Chinese for Korean, and Spanish
463
# for Italian. It seems to do very well with Greek, and very poorly with
464
# English (perhaps because of overlap with other languages).
465
#
466
467
468
######################################################################
469
# Running on User Input
470
# ---------------------
471
#
472
473
def predict(input_line, n_predictions=3):
474
print('\n> %s' % input_line)
475
with torch.no_grad():
476
output = evaluate(lineToTensor(input_line))
477
478
# Get top N categories
479
topv, topi = output.topk(n_predictions, 1, True)
480
predictions = []
481
482
for i in range(n_predictions):
483
value = topv[0][i].item()
484
category_index = topi[0][i].item()
485
print('(%.2f) %s' % (value, all_categories[category_index]))
486
predictions.append([value, all_categories[category_index]])
487
488
predict('Dovesky')
489
predict('Jackson')
490
predict('Satoshi')
491
492
493
######################################################################
494
# The final versions of the scripts `in the Practical PyTorch
495
# repo <https://github.com/spro/practical-pytorch/tree/master/char-rnn-classification>`__
496
# split the above code into a few files:
497
#
498
# - ``data.py`` (loads files)
499
# - ``model.py`` (defines the RNN)
500
# - ``train.py`` (runs training)
501
# - ``predict.py`` (runs ``predict()`` with command line arguments)
502
# - ``server.py`` (serve prediction as a JSON API with ``bottle.py``)
503
#
504
# Run ``train.py`` to train and save the network.
505
#
506
# Run ``predict.py`` with a name to view predictions:
507
#
508
# .. code-block:: sh
509
#
510
# $ python predict.py Hazaki
511
# (-0.42) Japanese
512
# (-1.39) Polish
513
# (-3.51) Czech
514
#
515
# Run ``server.py`` and visit http://localhost:5533/Yourname to get JSON
516
# output of predictions.
517
#
518
519
520
######################################################################
521
# Exercises
522
# =========
523
#
524
# - Try with a different dataset of line -> category, for example:
525
#
526
# - Any word -> language
527
# - First name -> gender
528
# - Character name -> writer
529
# - Page title -> blog or subreddit
530
#
531
# - Get better results with a bigger and/or better shaped network
532
#
533
# - Add more linear layers
534
# - Try the ``nn.LSTM`` and ``nn.GRU`` layers
535
# - Combine multiple of these RNNs as a higher level network
536
#
537
538