Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pytorch
GitHub Repository: pytorch/tutorials
Path: blob/main/intermediate_source/char_rnn_classification_tutorial.py
1384 views
1
# -*- coding: utf-8 -*-
2
"""
3
NLP From Scratch: Classifying Names with a Character-Level RNN
4
**************************************************************
5
**Author**: `Sean Robertson <https://github.com/spro>`_
6
7
This tutorials is part of a three-part series:
8
9
* `NLP From Scratch: Classifying Names with a Character-Level RNN <https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html>`__
10
* `NLP From Scratch: Generating Names with a Character-Level RNN <https://pytorch.org/tutorials/intermediate/char_rnn_generation_tutorial.html>`__
11
* `NLP From Scratch: Translation with a Sequence to Sequence Network and Attention <https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html>`__
12
13
We will be building and training a basic character-level Recurrent Neural
14
Network (RNN) to classify words. This tutorial, along with two other
15
Natural Language Processing (NLP) "from scratch" tutorials
16
:doc:`/intermediate/char_rnn_generation_tutorial` and
17
:doc:`/intermediate/seq2seq_translation_tutorial`, show how to
18
preprocess data to model NLP. In particular, these tutorials show how
19
preprocessing to model NLP works at a low level.
20
21
A character-level RNN reads words as a series of characters -
22
outputting a prediction and "hidden state" at each step, feeding its
23
previous hidden state into each next step. We take the final prediction
24
to be the output, i.e. which class the word belongs to.
25
26
Specifically, we'll train on a few thousand surnames from 18 languages
27
of origin, and predict which language a name is from based on the
28
spelling.
29
30
Recommended Preparation
31
=======================
32
33
Before starting this tutorial it is recommended that you have installed PyTorch,
34
and have a basic understanding of Python programming language and Tensors:
35
36
- https://pytorch.org/ For installation instructions
37
- :doc:`/beginner/deep_learning_60min_blitz` to get started with PyTorch in general
38
and learn the basics of Tensors
39
- :doc:`/beginner/pytorch_with_examples` for a wide and deep overview
40
- :doc:`/beginner/former_torchies_tutorial` if you are former Lua Torch user
41
42
It would also be useful to know about RNNs and how they work:
43
44
- `The Unreasonable Effectiveness of Recurrent Neural
45
Networks <https://karpathy.github.io/2015/05/21/rnn-effectiveness/>`__
46
shows a bunch of real life examples
47
- `Understanding LSTM
48
Networks <https://colah.github.io/posts/2015-08-Understanding-LSTMs/>`__
49
is about LSTMs specifically but also informative about RNNs in
50
general
51
"""
52
######################################################################
53
# Preparing Torch
54
# ==========================
55
#
56
# Set up torch to default to the right device use GPU acceleration depending on your hardware (CPU or CUDA).
57
#
58
59
import torch
60
61
# Check if CUDA is available
62
device = torch.device('cpu')
63
if torch.cuda.is_available():
64
device = torch.device('cuda')
65
66
torch.set_default_device(device)
67
print(f"Using device = {torch.get_default_device()}")
68
69
######################################################################
70
# Preparing the Data
71
# ==================
72
#
73
# Download the data from `here <https://download.pytorch.org/tutorial/data.zip>`__
74
# and extract it to the current directory.
75
#
76
# Included in the ``data/names`` directory are 18 text files named as
77
# ``[Language].txt``. Each file contains a bunch of names, one name per
78
# line, mostly romanized (but we still need to convert from Unicode to
79
# ASCII).
80
#
81
# The first step is to define and clean our data. Initially, we need to convert Unicode to plain ASCII to
82
# limit the RNN input layers. This is accomplished by converting Unicode strings to ASCII and allowing only a small set of allowed characters.
83
84
import string
85
import unicodedata
86
87
# We can use "_" to represent an out-of-vocabulary character, that is, any character we are not handling in our model
88
allowed_characters = string.ascii_letters + " .,;'" + "_"
89
n_letters = len(allowed_characters)
90
91
# Turn a Unicode string to plain ASCII, thanks to https://stackoverflow.com/a/518232/2809427
92
def unicodeToAscii(s):
93
return ''.join(
94
c for c in unicodedata.normalize('NFD', s)
95
if unicodedata.category(c) != 'Mn'
96
and c in allowed_characters
97
)
98
99
#########################
100
# Here's an example of converting a unicode alphabet name to plain ASCII. This simplifies the input layer
101
#
102
103
print (f"converting 'Ślusàrski' to {unicodeToAscii('Ślusàrski')}")
104
105
######################################################################
106
# Turning Names into Tensors
107
# ==========================
108
#
109
# Now that we have all the names organized, we need to turn them into
110
# Tensors to make any use of them.
111
#
112
# To represent a single letter, we use a "one-hot vector" of size
113
# ``<1 x n_letters>``. A one-hot vector is filled with 0s except for a 1
114
# at index of the current letter, e.g. ``"b" = <0 1 0 0 0 ...>``.
115
#
116
# To make a word we join a bunch of those into a 2D matrix
117
# ``<line_length x 1 x n_letters>``.
118
#
119
# That extra 1 dimension is because PyTorch assumes everything is in
120
# batches - we're just using a batch size of 1 here.
121
122
# Find letter index from all_letters, e.g. "a" = 0
123
def letterToIndex(letter):
124
# return our out-of-vocabulary character if we encounter a letter unknown to our model
125
if letter not in allowed_characters:
126
return allowed_characters.find("_")
127
else:
128
return allowed_characters.find(letter)
129
130
# Turn a line into a <line_length x 1 x n_letters>,
131
# or an array of one-hot letter vectors
132
def lineToTensor(line):
133
tensor = torch.zeros(len(line), 1, n_letters)
134
for li, letter in enumerate(line):
135
tensor[li][0][letterToIndex(letter)] = 1
136
return tensor
137
138
#########################
139
# Here are some examples of how to use ``lineToTensor()`` for a single and multiple character string.
140
141
print (f"The letter 'a' becomes {lineToTensor('a')}") #notice that the first position in the tensor = 1
142
print (f"The name 'Ahn' becomes {lineToTensor('Ahn')}") #notice 'A' sets the 27th index to 1
143
144
#########################
145
# Congratulations, you have built the foundational tensor objects for this learning task! You can use a similar approach
146
# for other RNN tasks with text.
147
#
148
# Next, we need to combine all our examples into a dataset so we can train, test and validate our models. For this,
149
# we will use the `Dataset and DataLoader <https://pytorch.org/tutorials/beginner/basics/data_tutorial.html>`__ classes
150
# to hold our dataset. Each Dataset needs to implement three functions: ``__init__``, ``__len__``, and ``__getitem__``.
151
from io import open
152
import glob
153
import os
154
import time
155
156
import torch
157
from torch.utils.data import Dataset
158
159
class NamesDataset(Dataset):
160
161
def __init__(self, data_dir):
162
self.data_dir = data_dir #for provenance of the dataset
163
self.load_time = time.localtime #for provenance of the dataset
164
labels_set = set() #set of all classes
165
166
self.data = []
167
self.data_tensors = []
168
self.labels = []
169
self.labels_tensors = []
170
171
#read all the ``.txt`` files in the specified directory
172
text_files = glob.glob(os.path.join(data_dir, '*.txt'))
173
for filename in text_files:
174
label = os.path.splitext(os.path.basename(filename))[0]
175
labels_set.add(label)
176
lines = open(filename, encoding='utf-8').read().strip().split('\n')
177
for name in lines:
178
self.data.append(name)
179
self.data_tensors.append(lineToTensor(name))
180
self.labels.append(label)
181
182
#Cache the tensor representation of the labels
183
self.labels_uniq = list(labels_set)
184
for idx in range(len(self.labels)):
185
temp_tensor = torch.tensor([self.labels_uniq.index(self.labels[idx])], dtype=torch.long)
186
self.labels_tensors.append(temp_tensor)
187
188
def __len__(self):
189
return len(self.data)
190
191
def __getitem__(self, idx):
192
data_item = self.data[idx]
193
data_label = self.labels[idx]
194
data_tensor = self.data_tensors[idx]
195
label_tensor = self.labels_tensors[idx]
196
197
return label_tensor, data_tensor, data_label, data_item
198
199
200
#########################
201
#Here we can load our example data into the ``NamesDataset``
202
203
alldata = NamesDataset("data/names")
204
print(f"loaded {len(alldata)} items of data")
205
print(f"example = {alldata[0]}")
206
207
#########################
208
#Using the dataset object allows us to easily split the data into train and test sets. Here we create a 80/20
209
# split but the ``torch.utils.data`` has more useful utilities. Here we specify a generator since we need to use the
210
#same device as PyTorch defaults to above.
211
212
train_set, test_set = torch.utils.data.random_split(alldata, [.85, .15], generator=torch.Generator(device=device).manual_seed(2024))
213
214
print(f"train examples = {len(train_set)}, validation examples = {len(test_set)}")
215
216
#########################
217
# Now we have a basic dataset containing **20074** examples where each example is a pairing of label and name. We have also
218
#split the dataset into training and testing so we can validate the model that we build.
219
220
221
######################################################################
222
# Creating the Network
223
# ====================
224
#
225
# Before autograd, creating a recurrent neural network in Torch involved
226
# cloning the parameters of a layer over several timesteps. The layers
227
# held hidden state and gradients which are now entirely handled by the
228
# graph itself. This means you can implement a RNN in a very "pure" way,
229
# as regular feed-forward layers.
230
#
231
# This CharRNN class implements an RNN with three components.
232
# First, we use the `nn.RNN implementation <https://pytorch.org/docs/stable/generated/torch.nn.RNN.html>`__.
233
# Next, we define a layer that maps the RNN hidden layers to our output. And finally, we apply a ``softmax`` function. Using ``nn.RNN``
234
# leads to a significant improvement in performance, such as cuDNN-accelerated kernels, versus implementing
235
# each layer as a ``nn.Linear``. It also simplifies the implementation in ``forward()``.
236
#
237
238
import torch.nn as nn
239
import torch.nn.functional as F
240
241
class CharRNN(nn.Module):
242
def __init__(self, input_size, hidden_size, output_size):
243
super(CharRNN, self).__init__()
244
245
self.rnn = nn.RNN(input_size, hidden_size)
246
self.h2o = nn.Linear(hidden_size, output_size)
247
self.softmax = nn.LogSoftmax(dim=1)
248
249
def forward(self, line_tensor):
250
rnn_out, hidden = self.rnn(line_tensor)
251
output = self.h2o(hidden[0])
252
output = self.softmax(output)
253
254
return output
255
256
257
###########################
258
# We can then create an RNN with 58 input nodes, 128 hidden nodes, and 18 outputs:
259
260
n_hidden = 128
261
rnn = CharRNN(n_letters, n_hidden, len(alldata.labels_uniq))
262
print(rnn)
263
264
######################################################################
265
# After that we can pass our Tensor to the RNN to obtain a predicted output. Subsequently,
266
# we use a helper function, ``label_from_output``, to derive a text label for the class.
267
268
def label_from_output(output, output_labels):
269
top_n, top_i = output.topk(1)
270
label_i = top_i[0].item()
271
return output_labels[label_i], label_i
272
273
input = lineToTensor('Albert')
274
output = rnn(input) #this is equivalent to ``output = rnn.forward(input)``
275
print(output)
276
print(label_from_output(output, alldata.labels_uniq))
277
278
######################################################################
279
#
280
# Training
281
# ========
282
283
284
######################################################################
285
# Training the Network
286
# --------------------
287
#
288
# Now all it takes to train this network is show it a bunch of examples,
289
# have it make guesses, and tell it if it's wrong.
290
#
291
# We do this by defining a ``train()`` function which trains the model on a given dataset using minibatches. RNNs
292
# RNNs are trained similarly to other networks; therefore, for completeness, we include a batched training method here.
293
# The loop (``for i in batch``) computes the losses for each of the items in the batch before adjusting the
294
# weights. This operation is repeated until the number of epochs is reached.
295
296
import random
297
import numpy as np
298
299
def train(rnn, training_data, n_epoch = 10, n_batch_size = 64, report_every = 50, learning_rate = 0.2, criterion = nn.NLLLoss()):
300
"""
301
Learn on a batch of training_data for a specified number of iterations and reporting thresholds
302
"""
303
# Keep track of losses for plotting
304
current_loss = 0
305
all_losses = []
306
rnn.train()
307
optimizer = torch.optim.SGD(rnn.parameters(), lr=learning_rate)
308
309
start = time.time()
310
print(f"training on data set with n = {len(training_data)}")
311
312
for iter in range(1, n_epoch + 1):
313
rnn.zero_grad() # clear the gradients
314
315
# create some minibatches
316
# we cannot use dataloaders because each of our names is a different length
317
batches = list(range(len(training_data)))
318
random.shuffle(batches)
319
batches = np.array_split(batches, len(batches) //n_batch_size )
320
321
for idx, batch in enumerate(batches):
322
batch_loss = 0
323
for i in batch: #for each example in this batch
324
(label_tensor, text_tensor, label, text) = training_data[i]
325
output = rnn.forward(text_tensor)
326
loss = criterion(output, label_tensor)
327
batch_loss += loss
328
329
# optimize parameters
330
batch_loss.backward()
331
nn.utils.clip_grad_norm_(rnn.parameters(), 3)
332
optimizer.step()
333
optimizer.zero_grad()
334
335
current_loss += batch_loss.item() / len(batch)
336
337
all_losses.append(current_loss / len(batches) )
338
if iter % report_every == 0:
339
print(f"{iter} ({iter / n_epoch:.0%}): \t average batch loss = {all_losses[-1]}")
340
current_loss = 0
341
342
return all_losses
343
344
##########################################################################
345
# We can now train a dataset with minibatches for a specified number of epochs. The number of epochs for this
346
# example is reduced to speed up the build. You can get better results with different parameters.
347
348
start = time.time()
349
all_losses = train(rnn, train_set, n_epoch=27, learning_rate=0.15, report_every=5)
350
end = time.time()
351
print(f"training took {end-start}s")
352
353
######################################################################
354
# Plotting the Results
355
# --------------------
356
#
357
# Plotting the historical loss from ``all_losses`` shows the network
358
# learning:
359
#
360
361
import matplotlib.pyplot as plt
362
import matplotlib.ticker as ticker
363
364
plt.figure()
365
plt.plot(all_losses)
366
plt.show()
367
368
######################################################################
369
# Evaluating the Results
370
# ======================
371
#
372
# To see how well the network performs on different categories, we will
373
# create a confusion matrix, indicating for every actual language (rows)
374
# which language the network guesses (columns). To calculate the confusion
375
# matrix a bunch of samples are run through the network with
376
# ``evaluate()``, which is the same as ``train()`` minus the backprop.
377
#
378
379
def evaluate(rnn, testing_data, classes):
380
confusion = torch.zeros(len(classes), len(classes))
381
382
rnn.eval() #set to eval mode
383
with torch.no_grad(): # do not record the gradients during eval phase
384
for i in range(len(testing_data)):
385
(label_tensor, text_tensor, label, text) = testing_data[i]
386
output = rnn(text_tensor)
387
guess, guess_i = label_from_output(output, classes)
388
label_i = classes.index(label)
389
confusion[label_i][guess_i] += 1
390
391
# Normalize by dividing every row by its sum
392
for i in range(len(classes)):
393
denom = confusion[i].sum()
394
if denom > 0:
395
confusion[i] = confusion[i] / denom
396
397
# Set up plot
398
fig = plt.figure()
399
ax = fig.add_subplot(111)
400
cax = ax.matshow(confusion.cpu().numpy()) #numpy uses cpu here so we need to use a cpu version
401
fig.colorbar(cax)
402
403
# Set up axes
404
ax.set_xticks(np.arange(len(classes)), labels=classes, rotation=90)
405
ax.set_yticks(np.arange(len(classes)), labels=classes)
406
407
# Force label at every tick
408
ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
409
ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
410
411
# sphinx_gallery_thumbnail_number = 2
412
plt.show()
413
414
415
416
evaluate(rnn, test_set, classes=alldata.labels_uniq)
417
418
419
######################################################################
420
# You can pick out bright spots off the main axis that show which
421
# languages it guesses incorrectly, e.g. Chinese for Korean, and Spanish
422
# for Italian. It seems to do very well with Greek, and very poorly with
423
# English (perhaps because of overlap with other languages).
424
#
425
426
427
######################################################################
428
# Exercises
429
# =========
430
#
431
# - Get better results with a bigger and/or better shaped network
432
#
433
# - Adjust the hyperparameters to enhance performance, such as changing the number of epochs, batch size, and learning rate
434
# - Try the ``nn.LSTM`` and ``nn.GRU`` layers
435
# - Modify the size of the layers, such as increasing or decreasing the number of hidden nodes or adding additional linear layers
436
# - Combine multiple of these RNNs as a higher level network
437
#
438
# - Try with a different dataset of line -> label, for example:
439
#
440
# - Any word -> language
441
# - First name -> gender
442
# - Character name -> writer
443
# - Page title -> blog or subreddit
444