CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: pytorch/tutorials
Path: blob/main/intermediate_source/char_rnn_classification_tutorial.py
Views: 1017
1
# -*- coding: utf-8 -*-
2
"""
3
NLP From Scratch: Classifying Names with a Character-Level RNN
4
**************************************************************
5
**Author**: `Sean Robertson <https://github.com/spro>`_
6
7
This tutorials is part of a three-part series:
8
9
* `NLP From Scratch: Classifying Names with a Character-Level RNN <https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html>`__
10
* `NLP From Scratch: Generating Names with a Character-Level RNN <https://pytorch.org/tutorials/intermediate/char_rnn_generation_tutorial.html>`__
11
* `NLP From Scratch: Translation with a Sequence to Sequence Network and Attention <https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html>`__
12
13
We will be building and training a basic character-level Recurrent Neural
14
Network (RNN) to classify words. This tutorial, along with two other
15
Natural Language Processing (NLP) "from scratch" tutorials
16
:doc:`/intermediate/char_rnn_generation_tutorial` and
17
:doc:`/intermediate/seq2seq_translation_tutorial`, show how to
18
preprocess data to model NLP. In particular, these tutorials show how
19
preprocessing to model NLP works at a low level.
20
21
A character-level RNN reads words as a series of characters -
22
outputting a prediction and "hidden state" at each step, feeding its
23
previous hidden state into each next step. We take the final prediction
24
to be the output, i.e. which class the word belongs to.
25
26
Specifically, we'll train on a few thousand surnames from 18 languages
27
of origin, and predict which language a name is from based on the
28
spelling.
29
30
Recommended Preparation
31
=======================
32
33
Before starting this tutorial it is recommended that you have installed PyTorch,
34
and have a basic understanding of Python programming language and Tensors:
35
36
- https://pytorch.org/ For installation instructions
37
- :doc:`/beginner/deep_learning_60min_blitz` to get started with PyTorch in general
38
and learn the basics of Tensors
39
- :doc:`/beginner/pytorch_with_examples` for a wide and deep overview
40
- :doc:`/beginner/former_torchies_tutorial` if you are former Lua Torch user
41
42
It would also be useful to know about RNNs and how they work:
43
44
- `The Unreasonable Effectiveness of Recurrent Neural
45
Networks <https://karpathy.github.io/2015/05/21/rnn-effectiveness/>`__
46
shows a bunch of real life examples
47
- `Understanding LSTM
48
Networks <https://colah.github.io/posts/2015-08-Understanding-LSTMs/>`__
49
is about LSTMs specifically but also informative about RNNs in
50
general
51
"""
52
######################################################################
53
# Preparing Torch
54
# ==========================
55
#
56
# Set up torch to default to the right device use GPU acceleration depending on your hardware (CPU or CUDA).
57
#
58
59
import torch
60
61
# Check if CUDA is available
62
device = torch.device('cpu')
63
if torch.cuda.is_available():
64
device = torch.device('cuda')
65
66
torch.set_default_device(device)
67
print(f"Using device = {torch.get_default_device()}")
68
69
######################################################################
70
# Preparing the Data
71
# ==================
72
#
73
# Download the data from `here <https://download.pytorch.org/tutorial/data.zip>`__
74
# and extract it to the current directory.
75
#
76
# Included in the ``data/names`` directory are 18 text files named as
77
# ``[Language].txt``. Each file contains a bunch of names, one name per
78
# line, mostly romanized (but we still need to convert from Unicode to
79
# ASCII).
80
#
81
# The first step is to define and clean our data. Initially, we need to convert Unicode to plain ASCII to
82
# limit the RNN input layers. This is accomplished by converting Unicode strings to ASCII and allowing only a small set of allowed characters.
83
84
import string
85
import unicodedata
86
87
allowed_characters = string.ascii_letters + " .,;'"
88
n_letters = len(allowed_characters)
89
90
# Turn a Unicode string to plain ASCII, thanks to https://stackoverflow.com/a/518232/2809427
91
def unicodeToAscii(s):
92
return ''.join(
93
c for c in unicodedata.normalize('NFD', s)
94
if unicodedata.category(c) != 'Mn'
95
and c in allowed_characters
96
)
97
98
#########################
99
# Here's an example of converting a unicode alphabet name to plain ASCII. This simplifies the input layer
100
#
101
102
print (f"converting 'Ślusàrski' to {unicodeToAscii('Ślusàrski')}")
103
104
######################################################################
105
# Turning Names into Tensors
106
# ==========================
107
#
108
# Now that we have all the names organized, we need to turn them into
109
# Tensors to make any use of them.
110
#
111
# To represent a single letter, we use a "one-hot vector" of size
112
# ``<1 x n_letters>``. A one-hot vector is filled with 0s except for a 1
113
# at index of the current letter, e.g. ``"b" = <0 1 0 0 0 ...>``.
114
#
115
# To make a word we join a bunch of those into a 2D matrix
116
# ``<line_length x 1 x n_letters>``.
117
#
118
# That extra 1 dimension is because PyTorch assumes everything is in
119
# batches - we're just using a batch size of 1 here.
120
121
# Find letter index from all_letters, e.g. "a" = 0
122
def letterToIndex(letter):
123
return allowed_characters.find(letter)
124
125
# Turn a line into a <line_length x 1 x n_letters>,
126
# or an array of one-hot letter vectors
127
def lineToTensor(line):
128
tensor = torch.zeros(len(line), 1, n_letters)
129
for li, letter in enumerate(line):
130
tensor[li][0][letterToIndex(letter)] = 1
131
return tensor
132
133
#########################
134
# Here are some examples of how to use ``lineToTensor()`` for a single and multiple character string.
135
136
print (f"The letter 'a' becomes {lineToTensor('a')}") #notice that the first position in the tensor = 1
137
print (f"The name 'Ahn' becomes {lineToTensor('Ahn')}") #notice 'A' sets the 27th index to 1
138
139
#########################
140
# Congratulations, you have built the foundational tensor objects for this learning task! You can use a similar approach
141
# for other RNN tasks with text.
142
#
143
# Next, we need to combine all our examples into a dataset so we can train, test and validate our models. For this,
144
# we will use the `Dataset and DataLoader <https://pytorch.org/tutorials/beginner/basics/data_tutorial.html>`__ classes
145
# to hold our dataset. Each Dataset needs to implement three functions: ``__init__``, ``__len__``, and ``__getitem__``.
146
from io import open
147
import glob
148
import os
149
import time
150
151
import torch
152
from torch.utils.data import Dataset
153
154
class NamesDataset(Dataset):
155
156
def __init__(self, data_dir):
157
self.data_dir = data_dir #for provenance of the dataset
158
self.load_time = time.localtime #for provenance of the dataset
159
labels_set = set() #set of all classes
160
161
self.data = []
162
self.data_tensors = []
163
self.labels = []
164
self.labels_tensors = []
165
166
#read all the ``.txt`` files in the specified directory
167
text_files = glob.glob(os.path.join(data_dir, '*.txt'))
168
for filename in text_files:
169
label = os.path.splitext(os.path.basename(filename))[0]
170
labels_set.add(label)
171
lines = open(filename, encoding='utf-8').read().strip().split('\n')
172
for name in lines:
173
self.data.append(name)
174
self.data_tensors.append(lineToTensor(name))
175
self.labels.append(label)
176
177
#Cache the tensor representation of the labels
178
self.labels_uniq = list(labels_set)
179
for idx in range(len(self.labels)):
180
temp_tensor = torch.tensor([self.labels_uniq.index(self.labels[idx])], dtype=torch.long)
181
self.labels_tensors.append(temp_tensor)
182
183
def __len__(self):
184
return len(self.data)
185
186
def __getitem__(self, idx):
187
data_item = self.data[idx]
188
data_label = self.labels[idx]
189
data_tensor = self.data_tensors[idx]
190
label_tensor = self.labels_tensors[idx]
191
192
return label_tensor, data_tensor, data_label, data_item
193
194
195
#########################
196
#Here we can load our example data into the ``NamesDataset``
197
198
alldata = NamesDataset("data/names")
199
print(f"loaded {len(alldata)} items of data")
200
print(f"example = {alldata[0]}")
201
202
#########################
203
#Using the dataset object allows us to easily split the data into train and test sets. Here we create a 80/20
204
# split but the ``torch.utils.data`` has more useful utilities. Here we specify a generator since we need to use the
205
#same device as PyTorch defaults to above.
206
207
train_set, test_set = torch.utils.data.random_split(alldata, [.85, .15], generator=torch.Generator(device=device).manual_seed(2024))
208
209
print(f"train examples = {len(train_set)}, validation examples = {len(test_set)}")
210
211
#########################
212
# Now we have a basic dataset containing **20074** examples where each example is a pairing of label and name. We have also
213
#split the dataset into training and testing so we can validate the model that we build.
214
215
216
######################################################################
217
# Creating the Network
218
# ====================
219
#
220
# Before autograd, creating a recurrent neural network in Torch involved
221
# cloning the parameters of a layer over several timesteps. The layers
222
# held hidden state and gradients which are now entirely handled by the
223
# graph itself. This means you can implement a RNN in a very "pure" way,
224
# as regular feed-forward layers.
225
#
226
# This CharRNN class implements an RNN with three components.
227
# First, we use the `nn.RNN implementation <https://pytorch.org/docs/stable/generated/torch.nn.RNN.html>`__.
228
# Next, we define a layer that maps the RNN hidden layers to our output. And finally, we apply a ``softmax`` function. Using ``nn.RNN``
229
# leads to a significant improvement in performance, such as cuDNN-accelerated kernels, versus implementing
230
# each layer as a ``nn.Linear``. It also simplifies the implementation in ``forward()``.
231
#
232
233
import torch.nn as nn
234
import torch.nn.functional as F
235
236
class CharRNN(nn.Module):
237
def __init__(self, input_size, hidden_size, output_size):
238
super(CharRNN, self).__init__()
239
240
self.rnn = nn.RNN(input_size, hidden_size)
241
self.h2o = nn.Linear(hidden_size, output_size)
242
self.softmax = nn.LogSoftmax(dim=1)
243
244
def forward(self, line_tensor):
245
rnn_out, hidden = self.rnn(line_tensor)
246
output = self.h2o(hidden[0])
247
output = self.softmax(output)
248
249
return output
250
251
252
###########################
253
# We can then create an RNN with 57 input nodes, 128 hidden nodes, and 18 outputs:
254
255
n_hidden = 128
256
rnn = CharRNN(n_letters, n_hidden, len(alldata.labels_uniq))
257
print(rnn)
258
259
######################################################################
260
# After that we can pass our Tensor to the RNN to obtain a predicted output. Subsequently,
261
# we use a helper function, ``label_from_output``, to derive a text label for the class.
262
263
def label_from_output(output, output_labels):
264
top_n, top_i = output.topk(1)
265
label_i = top_i[0].item()
266
return output_labels[label_i], label_i
267
268
input = lineToTensor('Albert')
269
output = rnn(input) #this is equivalent to ``output = rnn.forward(input)``
270
print(output)
271
print(label_from_output(output, alldata.labels_uniq))
272
273
######################################################################
274
#
275
# Training
276
# ========
277
278
279
######################################################################
280
# Training the Network
281
# --------------------
282
#
283
# Now all it takes to train this network is show it a bunch of examples,
284
# have it make guesses, and tell it if it's wrong.
285
#
286
# We do this by defining a ``train()`` function which trains the model on a given dataset using minibatches. RNNs
287
# RNNs are trained similarly to other networks; therefore, for completeness, we include a batched training method here.
288
# The loop (``for i in batch``) computes the losses for each of the items in the batch before adjusting the
289
# weights. This operation is repeated until the number of epochs is reached.
290
291
import random
292
import numpy as np
293
294
def train(rnn, training_data, n_epoch = 10, n_batch_size = 64, report_every = 50, learning_rate = 0.2, criterion = nn.NLLLoss()):
295
"""
296
Learn on a batch of training_data for a specified number of iterations and reporting thresholds
297
"""
298
# Keep track of losses for plotting
299
current_loss = 0
300
all_losses = []
301
rnn.train()
302
optimizer = torch.optim.SGD(rnn.parameters(), lr=learning_rate)
303
304
start = time.time()
305
print(f"training on data set with n = {len(training_data)}")
306
307
for iter in range(1, n_epoch + 1):
308
rnn.zero_grad() # clear the gradients
309
310
# create some minibatches
311
# we cannot use dataloaders because each of our names is a different length
312
batches = list(range(len(training_data)))
313
random.shuffle(batches)
314
batches = np.array_split(batches, len(batches) //n_batch_size )
315
316
for idx, batch in enumerate(batches):
317
batch_loss = 0
318
for i in batch: #for each example in this batch
319
(label_tensor, text_tensor, label, text) = training_data[i]
320
output = rnn.forward(text_tensor)
321
loss = criterion(output, label_tensor)
322
batch_loss += loss
323
324
# optimize parameters
325
batch_loss.backward()
326
nn.utils.clip_grad_norm_(rnn.parameters(), 3)
327
optimizer.step()
328
optimizer.zero_grad()
329
330
current_loss += batch_loss.item() / len(batch)
331
332
all_losses.append(current_loss / len(batches) )
333
if iter % report_every == 0:
334
print(f"{iter} ({iter / n_epoch:.0%}): \t average batch loss = {all_losses[-1]}")
335
current_loss = 0
336
337
return all_losses
338
339
##########################################################################
340
# We can now train a dataset with minibatches for a specified number of epochs. The number of epochs for this
341
# example is reduced to speed up the build. You can get better results with different parameters.
342
343
start = time.time()
344
all_losses = train(rnn, train_set, n_epoch=27, learning_rate=0.15, report_every=5)
345
end = time.time()
346
print(f"training took {end-start}s")
347
348
######################################################################
349
# Plotting the Results
350
# --------------------
351
#
352
# Plotting the historical loss from ``all_losses`` shows the network
353
# learning:
354
#
355
356
import matplotlib.pyplot as plt
357
import matplotlib.ticker as ticker
358
359
plt.figure()
360
plt.plot(all_losses)
361
plt.show()
362
363
######################################################################
364
# Evaluating the Results
365
# ======================
366
#
367
# To see how well the network performs on different categories, we will
368
# create a confusion matrix, indicating for every actual language (rows)
369
# which language the network guesses (columns). To calculate the confusion
370
# matrix a bunch of samples are run through the network with
371
# ``evaluate()``, which is the same as ``train()`` minus the backprop.
372
#
373
374
def evaluate(rnn, testing_data, classes):
375
confusion = torch.zeros(len(classes), len(classes))
376
377
rnn.eval() #set to eval mode
378
with torch.no_grad(): # do not record the gradients during eval phase
379
for i in range(len(testing_data)):
380
(label_tensor, text_tensor, label, text) = testing_data[i]
381
output = rnn(text_tensor)
382
guess, guess_i = label_from_output(output, classes)
383
label_i = classes.index(label)
384
confusion[label_i][guess_i] += 1
385
386
# Normalize by dividing every row by its sum
387
for i in range(len(classes)):
388
denom = confusion[i].sum()
389
if denom > 0:
390
confusion[i] = confusion[i] / denom
391
392
# Set up plot
393
fig = plt.figure()
394
ax = fig.add_subplot(111)
395
cax = ax.matshow(confusion.cpu().numpy()) #numpy uses cpu here so we need to use a cpu version
396
fig.colorbar(cax)
397
398
# Set up axes
399
ax.set_xticks(np.arange(len(classes)), labels=classes, rotation=90)
400
ax.set_yticks(np.arange(len(classes)), labels=classes)
401
402
# Force label at every tick
403
ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
404
ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
405
406
# sphinx_gallery_thumbnail_number = 2
407
plt.show()
408
409
410
411
evaluate(rnn, test_set, classes=alldata.labels_uniq)
412
413
414
######################################################################
415
# You can pick out bright spots off the main axis that show which
416
# languages it guesses incorrectly, e.g. Chinese for Korean, and Spanish
417
# for Italian. It seems to do very well with Greek, and very poorly with
418
# English (perhaps because of overlap with other languages).
419
#
420
421
422
######################################################################
423
# Exercises
424
# =========
425
#
426
# - Get better results with a bigger and/or better shaped network
427
#
428
# - Adjust the hyperparameters to enhance performance, such as changing the number of epochs, batch size, and learning rate
429
# - Try the ``nn.LSTM`` and ``nn.GRU`` layers
430
# - Modify the size of the layers, such as increasing or decreasing the number of hidden nodes or adding additional linear layers
431
# - Combine multiple of these RNNs as a higher level network
432
#
433
# - Try with a different dataset of line -> label, for example:
434
#
435
# - Any word -> language
436
# - First name -> gender
437
# - Character name -> writer
438
# - Page title -> blog or subreddit
439