Path: blob/main/intermediate_source/char_rnn_classification_tutorial.py
1384 views
# -*- coding: utf-8 -*-1"""2NLP From Scratch: Classifying Names with a Character-Level RNN3**************************************************************4**Author**: `Sean Robertson <https://github.com/spro>`_56This tutorials is part of a three-part series:78* `NLP From Scratch: Classifying Names with a Character-Level RNN <https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html>`__9* `NLP From Scratch: Generating Names with a Character-Level RNN <https://pytorch.org/tutorials/intermediate/char_rnn_generation_tutorial.html>`__10* `NLP From Scratch: Translation with a Sequence to Sequence Network and Attention <https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html>`__1112We will be building and training a basic character-level Recurrent Neural13Network (RNN) to classify words. This tutorial, along with two other14Natural Language Processing (NLP) "from scratch" tutorials15:doc:`/intermediate/char_rnn_generation_tutorial` and16:doc:`/intermediate/seq2seq_translation_tutorial`, show how to17preprocess data to model NLP. In particular, these tutorials show how18preprocessing to model NLP works at a low level.1920A character-level RNN reads words as a series of characters -21outputting a prediction and "hidden state" at each step, feeding its22previous hidden state into each next step. We take the final prediction23to be the output, i.e. which class the word belongs to.2425Specifically, we'll train on a few thousand surnames from 18 languages26of origin, and predict which language a name is from based on the27spelling.2829Recommended Preparation30=======================3132Before starting this tutorial it is recommended that you have installed PyTorch,33and have a basic understanding of Python programming language and Tensors:3435- https://pytorch.org/ For installation instructions36- :doc:`/beginner/deep_learning_60min_blitz` to get started with PyTorch in general37and learn the basics of Tensors38- :doc:`/beginner/pytorch_with_examples` for a wide and deep overview39- :doc:`/beginner/former_torchies_tutorial` if you are former Lua Torch user4041It would also be useful to know about RNNs and how they work:4243- `The Unreasonable Effectiveness of Recurrent Neural44Networks <https://karpathy.github.io/2015/05/21/rnn-effectiveness/>`__45shows a bunch of real life examples46- `Understanding LSTM47Networks <https://colah.github.io/posts/2015-08-Understanding-LSTMs/>`__48is about LSTMs specifically but also informative about RNNs in49general50"""51######################################################################52# Preparing Torch53# ==========================54#55# Set up torch to default to the right device use GPU acceleration depending on your hardware (CPU or CUDA).56#5758import torch5960# Check if CUDA is available61device = torch.device('cpu')62if torch.cuda.is_available():63device = torch.device('cuda')6465torch.set_default_device(device)66print(f"Using device = {torch.get_default_device()}")6768######################################################################69# Preparing the Data70# ==================71#72# Download the data from `here <https://download.pytorch.org/tutorial/data.zip>`__73# and extract it to the current directory.74#75# Included in the ``data/names`` directory are 18 text files named as76# ``[Language].txt``. Each file contains a bunch of names, one name per77# line, mostly romanized (but we still need to convert from Unicode to78# ASCII).79#80# The first step is to define and clean our data. Initially, we need to convert Unicode to plain ASCII to81# limit the RNN input layers. This is accomplished by converting Unicode strings to ASCII and allowing only a small set of allowed characters.8283import string84import unicodedata8586# We can use "_" to represent an out-of-vocabulary character, that is, any character we are not handling in our model87allowed_characters = string.ascii_letters + " .,;'" + "_"88n_letters = len(allowed_characters)8990# Turn a Unicode string to plain ASCII, thanks to https://stackoverflow.com/a/518232/280942791def unicodeToAscii(s):92return ''.join(93c for c in unicodedata.normalize('NFD', s)94if unicodedata.category(c) != 'Mn'95and c in allowed_characters96)9798#########################99# Here's an example of converting a unicode alphabet name to plain ASCII. This simplifies the input layer100#101102print (f"converting 'Ślusàrski' to {unicodeToAscii('Ślusàrski')}")103104######################################################################105# Turning Names into Tensors106# ==========================107#108# Now that we have all the names organized, we need to turn them into109# Tensors to make any use of them.110#111# To represent a single letter, we use a "one-hot vector" of size112# ``<1 x n_letters>``. A one-hot vector is filled with 0s except for a 1113# at index of the current letter, e.g. ``"b" = <0 1 0 0 0 ...>``.114#115# To make a word we join a bunch of those into a 2D matrix116# ``<line_length x 1 x n_letters>``.117#118# That extra 1 dimension is because PyTorch assumes everything is in119# batches - we're just using a batch size of 1 here.120121# Find letter index from all_letters, e.g. "a" = 0122def letterToIndex(letter):123# return our out-of-vocabulary character if we encounter a letter unknown to our model124if letter not in allowed_characters:125return allowed_characters.find("_")126else:127return allowed_characters.find(letter)128129# Turn a line into a <line_length x 1 x n_letters>,130# or an array of one-hot letter vectors131def lineToTensor(line):132tensor = torch.zeros(len(line), 1, n_letters)133for li, letter in enumerate(line):134tensor[li][0][letterToIndex(letter)] = 1135return tensor136137#########################138# Here are some examples of how to use ``lineToTensor()`` for a single and multiple character string.139140print (f"The letter 'a' becomes {lineToTensor('a')}") #notice that the first position in the tensor = 1141print (f"The name 'Ahn' becomes {lineToTensor('Ahn')}") #notice 'A' sets the 27th index to 1142143#########################144# Congratulations, you have built the foundational tensor objects for this learning task! You can use a similar approach145# for other RNN tasks with text.146#147# Next, we need to combine all our examples into a dataset so we can train, test and validate our models. For this,148# we will use the `Dataset and DataLoader <https://pytorch.org/tutorials/beginner/basics/data_tutorial.html>`__ classes149# to hold our dataset. Each Dataset needs to implement three functions: ``__init__``, ``__len__``, and ``__getitem__``.150from io import open151import glob152import os153import time154155import torch156from torch.utils.data import Dataset157158class NamesDataset(Dataset):159160def __init__(self, data_dir):161self.data_dir = data_dir #for provenance of the dataset162self.load_time = time.localtime #for provenance of the dataset163labels_set = set() #set of all classes164165self.data = []166self.data_tensors = []167self.labels = []168self.labels_tensors = []169170#read all the ``.txt`` files in the specified directory171text_files = glob.glob(os.path.join(data_dir, '*.txt'))172for filename in text_files:173label = os.path.splitext(os.path.basename(filename))[0]174labels_set.add(label)175lines = open(filename, encoding='utf-8').read().strip().split('\n')176for name in lines:177self.data.append(name)178self.data_tensors.append(lineToTensor(name))179self.labels.append(label)180181#Cache the tensor representation of the labels182self.labels_uniq = list(labels_set)183for idx in range(len(self.labels)):184temp_tensor = torch.tensor([self.labels_uniq.index(self.labels[idx])], dtype=torch.long)185self.labels_tensors.append(temp_tensor)186187def __len__(self):188return len(self.data)189190def __getitem__(self, idx):191data_item = self.data[idx]192data_label = self.labels[idx]193data_tensor = self.data_tensors[idx]194label_tensor = self.labels_tensors[idx]195196return label_tensor, data_tensor, data_label, data_item197198199#########################200#Here we can load our example data into the ``NamesDataset``201202alldata = NamesDataset("data/names")203print(f"loaded {len(alldata)} items of data")204print(f"example = {alldata[0]}")205206#########################207#Using the dataset object allows us to easily split the data into train and test sets. Here we create a 80/20208# split but the ``torch.utils.data`` has more useful utilities. Here we specify a generator since we need to use the209#same device as PyTorch defaults to above.210211train_set, test_set = torch.utils.data.random_split(alldata, [.85, .15], generator=torch.Generator(device=device).manual_seed(2024))212213print(f"train examples = {len(train_set)}, validation examples = {len(test_set)}")214215#########################216# Now we have a basic dataset containing **20074** examples where each example is a pairing of label and name. We have also217#split the dataset into training and testing so we can validate the model that we build.218219220######################################################################221# Creating the Network222# ====================223#224# Before autograd, creating a recurrent neural network in Torch involved225# cloning the parameters of a layer over several timesteps. The layers226# held hidden state and gradients which are now entirely handled by the227# graph itself. This means you can implement a RNN in a very "pure" way,228# as regular feed-forward layers.229#230# This CharRNN class implements an RNN with three components.231# First, we use the `nn.RNN implementation <https://pytorch.org/docs/stable/generated/torch.nn.RNN.html>`__.232# Next, we define a layer that maps the RNN hidden layers to our output. And finally, we apply a ``softmax`` function. Using ``nn.RNN``233# leads to a significant improvement in performance, such as cuDNN-accelerated kernels, versus implementing234# each layer as a ``nn.Linear``. It also simplifies the implementation in ``forward()``.235#236237import torch.nn as nn238import torch.nn.functional as F239240class CharRNN(nn.Module):241def __init__(self, input_size, hidden_size, output_size):242super(CharRNN, self).__init__()243244self.rnn = nn.RNN(input_size, hidden_size)245self.h2o = nn.Linear(hidden_size, output_size)246self.softmax = nn.LogSoftmax(dim=1)247248def forward(self, line_tensor):249rnn_out, hidden = self.rnn(line_tensor)250output = self.h2o(hidden[0])251output = self.softmax(output)252253return output254255256###########################257# We can then create an RNN with 58 input nodes, 128 hidden nodes, and 18 outputs:258259n_hidden = 128260rnn = CharRNN(n_letters, n_hidden, len(alldata.labels_uniq))261print(rnn)262263######################################################################264# After that we can pass our Tensor to the RNN to obtain a predicted output. Subsequently,265# we use a helper function, ``label_from_output``, to derive a text label for the class.266267def label_from_output(output, output_labels):268top_n, top_i = output.topk(1)269label_i = top_i[0].item()270return output_labels[label_i], label_i271272input = lineToTensor('Albert')273output = rnn(input) #this is equivalent to ``output = rnn.forward(input)``274print(output)275print(label_from_output(output, alldata.labels_uniq))276277######################################################################278#279# Training280# ========281282283######################################################################284# Training the Network285# --------------------286#287# Now all it takes to train this network is show it a bunch of examples,288# have it make guesses, and tell it if it's wrong.289#290# We do this by defining a ``train()`` function which trains the model on a given dataset using minibatches. RNNs291# RNNs are trained similarly to other networks; therefore, for completeness, we include a batched training method here.292# The loop (``for i in batch``) computes the losses for each of the items in the batch before adjusting the293# weights. This operation is repeated until the number of epochs is reached.294295import random296import numpy as np297298def train(rnn, training_data, n_epoch = 10, n_batch_size = 64, report_every = 50, learning_rate = 0.2, criterion = nn.NLLLoss()):299"""300Learn on a batch of training_data for a specified number of iterations and reporting thresholds301"""302# Keep track of losses for plotting303current_loss = 0304all_losses = []305rnn.train()306optimizer = torch.optim.SGD(rnn.parameters(), lr=learning_rate)307308start = time.time()309print(f"training on data set with n = {len(training_data)}")310311for iter in range(1, n_epoch + 1):312rnn.zero_grad() # clear the gradients313314# create some minibatches315# we cannot use dataloaders because each of our names is a different length316batches = list(range(len(training_data)))317random.shuffle(batches)318batches = np.array_split(batches, len(batches) //n_batch_size )319320for idx, batch in enumerate(batches):321batch_loss = 0322for i in batch: #for each example in this batch323(label_tensor, text_tensor, label, text) = training_data[i]324output = rnn.forward(text_tensor)325loss = criterion(output, label_tensor)326batch_loss += loss327328# optimize parameters329batch_loss.backward()330nn.utils.clip_grad_norm_(rnn.parameters(), 3)331optimizer.step()332optimizer.zero_grad()333334current_loss += batch_loss.item() / len(batch)335336all_losses.append(current_loss / len(batches) )337if iter % report_every == 0:338print(f"{iter} ({iter / n_epoch:.0%}): \t average batch loss = {all_losses[-1]}")339current_loss = 0340341return all_losses342343##########################################################################344# We can now train a dataset with minibatches for a specified number of epochs. The number of epochs for this345# example is reduced to speed up the build. You can get better results with different parameters.346347start = time.time()348all_losses = train(rnn, train_set, n_epoch=27, learning_rate=0.15, report_every=5)349end = time.time()350print(f"training took {end-start}s")351352######################################################################353# Plotting the Results354# --------------------355#356# Plotting the historical loss from ``all_losses`` shows the network357# learning:358#359360import matplotlib.pyplot as plt361import matplotlib.ticker as ticker362363plt.figure()364plt.plot(all_losses)365plt.show()366367######################################################################368# Evaluating the Results369# ======================370#371# To see how well the network performs on different categories, we will372# create a confusion matrix, indicating for every actual language (rows)373# which language the network guesses (columns). To calculate the confusion374# matrix a bunch of samples are run through the network with375# ``evaluate()``, which is the same as ``train()`` minus the backprop.376#377378def evaluate(rnn, testing_data, classes):379confusion = torch.zeros(len(classes), len(classes))380381rnn.eval() #set to eval mode382with torch.no_grad(): # do not record the gradients during eval phase383for i in range(len(testing_data)):384(label_tensor, text_tensor, label, text) = testing_data[i]385output = rnn(text_tensor)386guess, guess_i = label_from_output(output, classes)387label_i = classes.index(label)388confusion[label_i][guess_i] += 1389390# Normalize by dividing every row by its sum391for i in range(len(classes)):392denom = confusion[i].sum()393if denom > 0:394confusion[i] = confusion[i] / denom395396# Set up plot397fig = plt.figure()398ax = fig.add_subplot(111)399cax = ax.matshow(confusion.cpu().numpy()) #numpy uses cpu here so we need to use a cpu version400fig.colorbar(cax)401402# Set up axes403ax.set_xticks(np.arange(len(classes)), labels=classes, rotation=90)404ax.set_yticks(np.arange(len(classes)), labels=classes)405406# Force label at every tick407ax.xaxis.set_major_locator(ticker.MultipleLocator(1))408ax.yaxis.set_major_locator(ticker.MultipleLocator(1))409410# sphinx_gallery_thumbnail_number = 2411plt.show()412413414415evaluate(rnn, test_set, classes=alldata.labels_uniq)416417418######################################################################419# You can pick out bright spots off the main axis that show which420# languages it guesses incorrectly, e.g. Chinese for Korean, and Spanish421# for Italian. It seems to do very well with Greek, and very poorly with422# English (perhaps because of overlap with other languages).423#424425426######################################################################427# Exercises428# =========429#430# - Get better results with a bigger and/or better shaped network431#432# - Adjust the hyperparameters to enhance performance, such as changing the number of epochs, batch size, and learning rate433# - Try the ``nn.LSTM`` and ``nn.GRU`` layers434# - Modify the size of the layers, such as increasing or decreasing the number of hidden nodes or adding additional linear layers435# - Combine multiple of these RNNs as a higher level network436#437# - Try with a different dataset of line -> label, for example:438#439# - Any word -> language440# - First name -> gender441# - Character name -> writer442# - Page title -> blog or subreddit443444