Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/intermediate_source/char_rnn_classification_tutorial.py
Views: 1017
# -*- coding: utf-8 -*-1"""2NLP From Scratch: Classifying Names with a Character-Level RNN3**************************************************************4**Author**: `Sean Robertson <https://github.com/spro>`_56This tutorials is part of a three-part series:78* `NLP From Scratch: Classifying Names with a Character-Level RNN <https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html>`__9* `NLP From Scratch: Generating Names with a Character-Level RNN <https://pytorch.org/tutorials/intermediate/char_rnn_generation_tutorial.html>`__10* `NLP From Scratch: Translation with a Sequence to Sequence Network and Attention <https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html>`__1112We will be building and training a basic character-level Recurrent Neural13Network (RNN) to classify words. This tutorial, along with two other14Natural Language Processing (NLP) "from scratch" tutorials15:doc:`/intermediate/char_rnn_generation_tutorial` and16:doc:`/intermediate/seq2seq_translation_tutorial`, show how to17preprocess data to model NLP. In particular, these tutorials show how18preprocessing to model NLP works at a low level.1920A character-level RNN reads words as a series of characters -21outputting a prediction and "hidden state" at each step, feeding its22previous hidden state into each next step. We take the final prediction23to be the output, i.e. which class the word belongs to.2425Specifically, we'll train on a few thousand surnames from 18 languages26of origin, and predict which language a name is from based on the27spelling.2829Recommended Preparation30=======================3132Before starting this tutorial it is recommended that you have installed PyTorch,33and have a basic understanding of Python programming language and Tensors:3435- https://pytorch.org/ For installation instructions36- :doc:`/beginner/deep_learning_60min_blitz` to get started with PyTorch in general37and learn the basics of Tensors38- :doc:`/beginner/pytorch_with_examples` for a wide and deep overview39- :doc:`/beginner/former_torchies_tutorial` if you are former Lua Torch user4041It would also be useful to know about RNNs and how they work:4243- `The Unreasonable Effectiveness of Recurrent Neural44Networks <https://karpathy.github.io/2015/05/21/rnn-effectiveness/>`__45shows a bunch of real life examples46- `Understanding LSTM47Networks <https://colah.github.io/posts/2015-08-Understanding-LSTMs/>`__48is about LSTMs specifically but also informative about RNNs in49general50"""51######################################################################52# Preparing Torch53# ==========================54#55# Set up torch to default to the right device use GPU acceleration depending on your hardware (CPU or CUDA).56#5758import torch5960# Check if CUDA is available61device = torch.device('cpu')62if torch.cuda.is_available():63device = torch.device('cuda')6465torch.set_default_device(device)66print(f"Using device = {torch.get_default_device()}")6768######################################################################69# Preparing the Data70# ==================71#72# Download the data from `here <https://download.pytorch.org/tutorial/data.zip>`__73# and extract it to the current directory.74#75# Included in the ``data/names`` directory are 18 text files named as76# ``[Language].txt``. Each file contains a bunch of names, one name per77# line, mostly romanized (but we still need to convert from Unicode to78# ASCII).79#80# The first step is to define and clean our data. Initially, we need to convert Unicode to plain ASCII to81# limit the RNN input layers. This is accomplished by converting Unicode strings to ASCII and allowing only a small set of allowed characters.8283import string84import unicodedata8586allowed_characters = string.ascii_letters + " .,;'"87n_letters = len(allowed_characters)8889# Turn a Unicode string to plain ASCII, thanks to https://stackoverflow.com/a/518232/280942790def unicodeToAscii(s):91return ''.join(92c for c in unicodedata.normalize('NFD', s)93if unicodedata.category(c) != 'Mn'94and c in allowed_characters95)9697#########################98# Here's an example of converting a unicode alphabet name to plain ASCII. This simplifies the input layer99#100101print (f"converting 'Ślusàrski' to {unicodeToAscii('Ślusàrski')}")102103######################################################################104# Turning Names into Tensors105# ==========================106#107# Now that we have all the names organized, we need to turn them into108# Tensors to make any use of them.109#110# To represent a single letter, we use a "one-hot vector" of size111# ``<1 x n_letters>``. A one-hot vector is filled with 0s except for a 1112# at index of the current letter, e.g. ``"b" = <0 1 0 0 0 ...>``.113#114# To make a word we join a bunch of those into a 2D matrix115# ``<line_length x 1 x n_letters>``.116#117# That extra 1 dimension is because PyTorch assumes everything is in118# batches - we're just using a batch size of 1 here.119120# Find letter index from all_letters, e.g. "a" = 0121def letterToIndex(letter):122return allowed_characters.find(letter)123124# Turn a line into a <line_length x 1 x n_letters>,125# or an array of one-hot letter vectors126def lineToTensor(line):127tensor = torch.zeros(len(line), 1, n_letters)128for li, letter in enumerate(line):129tensor[li][0][letterToIndex(letter)] = 1130return tensor131132#########################133# Here are some examples of how to use ``lineToTensor()`` for a single and multiple character string.134135print (f"The letter 'a' becomes {lineToTensor('a')}") #notice that the first position in the tensor = 1136print (f"The name 'Ahn' becomes {lineToTensor('Ahn')}") #notice 'A' sets the 27th index to 1137138#########################139# Congratulations, you have built the foundational tensor objects for this learning task! You can use a similar approach140# for other RNN tasks with text.141#142# Next, we need to combine all our examples into a dataset so we can train, test and validate our models. For this,143# we will use the `Dataset and DataLoader <https://pytorch.org/tutorials/beginner/basics/data_tutorial.html>`__ classes144# to hold our dataset. Each Dataset needs to implement three functions: ``__init__``, ``__len__``, and ``__getitem__``.145from io import open146import glob147import os148import time149150import torch151from torch.utils.data import Dataset152153class NamesDataset(Dataset):154155def __init__(self, data_dir):156self.data_dir = data_dir #for provenance of the dataset157self.load_time = time.localtime #for provenance of the dataset158labels_set = set() #set of all classes159160self.data = []161self.data_tensors = []162self.labels = []163self.labels_tensors = []164165#read all the ``.txt`` files in the specified directory166text_files = glob.glob(os.path.join(data_dir, '*.txt'))167for filename in text_files:168label = os.path.splitext(os.path.basename(filename))[0]169labels_set.add(label)170lines = open(filename, encoding='utf-8').read().strip().split('\n')171for name in lines:172self.data.append(name)173self.data_tensors.append(lineToTensor(name))174self.labels.append(label)175176#Cache the tensor representation of the labels177self.labels_uniq = list(labels_set)178for idx in range(len(self.labels)):179temp_tensor = torch.tensor([self.labels_uniq.index(self.labels[idx])], dtype=torch.long)180self.labels_tensors.append(temp_tensor)181182def __len__(self):183return len(self.data)184185def __getitem__(self, idx):186data_item = self.data[idx]187data_label = self.labels[idx]188data_tensor = self.data_tensors[idx]189label_tensor = self.labels_tensors[idx]190191return label_tensor, data_tensor, data_label, data_item192193194#########################195#Here we can load our example data into the ``NamesDataset``196197alldata = NamesDataset("data/names")198print(f"loaded {len(alldata)} items of data")199print(f"example = {alldata[0]}")200201#########################202#Using the dataset object allows us to easily split the data into train and test sets. Here we create a 80/20203# split but the ``torch.utils.data`` has more useful utilities. Here we specify a generator since we need to use the204#same device as PyTorch defaults to above.205206train_set, test_set = torch.utils.data.random_split(alldata, [.85, .15], generator=torch.Generator(device=device).manual_seed(2024))207208print(f"train examples = {len(train_set)}, validation examples = {len(test_set)}")209210#########################211# Now we have a basic dataset containing **20074** examples where each example is a pairing of label and name. We have also212#split the dataset into training and testing so we can validate the model that we build.213214215######################################################################216# Creating the Network217# ====================218#219# Before autograd, creating a recurrent neural network in Torch involved220# cloning the parameters of a layer over several timesteps. The layers221# held hidden state and gradients which are now entirely handled by the222# graph itself. This means you can implement a RNN in a very "pure" way,223# as regular feed-forward layers.224#225# This CharRNN class implements an RNN with three components.226# First, we use the `nn.RNN implementation <https://pytorch.org/docs/stable/generated/torch.nn.RNN.html>`__.227# Next, we define a layer that maps the RNN hidden layers to our output. And finally, we apply a ``softmax`` function. Using ``nn.RNN``228# leads to a significant improvement in performance, such as cuDNN-accelerated kernels, versus implementing229# each layer as a ``nn.Linear``. It also simplifies the implementation in ``forward()``.230#231232import torch.nn as nn233import torch.nn.functional as F234235class CharRNN(nn.Module):236def __init__(self, input_size, hidden_size, output_size):237super(CharRNN, self).__init__()238239self.rnn = nn.RNN(input_size, hidden_size)240self.h2o = nn.Linear(hidden_size, output_size)241self.softmax = nn.LogSoftmax(dim=1)242243def forward(self, line_tensor):244rnn_out, hidden = self.rnn(line_tensor)245output = self.h2o(hidden[0])246output = self.softmax(output)247248return output249250251###########################252# We can then create an RNN with 57 input nodes, 128 hidden nodes, and 18 outputs:253254n_hidden = 128255rnn = CharRNN(n_letters, n_hidden, len(alldata.labels_uniq))256print(rnn)257258######################################################################259# After that we can pass our Tensor to the RNN to obtain a predicted output. Subsequently,260# we use a helper function, ``label_from_output``, to derive a text label for the class.261262def label_from_output(output, output_labels):263top_n, top_i = output.topk(1)264label_i = top_i[0].item()265return output_labels[label_i], label_i266267input = lineToTensor('Albert')268output = rnn(input) #this is equivalent to ``output = rnn.forward(input)``269print(output)270print(label_from_output(output, alldata.labels_uniq))271272######################################################################273#274# Training275# ========276277278######################################################################279# Training the Network280# --------------------281#282# Now all it takes to train this network is show it a bunch of examples,283# have it make guesses, and tell it if it's wrong.284#285# We do this by defining a ``train()`` function which trains the model on a given dataset using minibatches. RNNs286# RNNs are trained similarly to other networks; therefore, for completeness, we include a batched training method here.287# The loop (``for i in batch``) computes the losses for each of the items in the batch before adjusting the288# weights. This operation is repeated until the number of epochs is reached.289290import random291import numpy as np292293def train(rnn, training_data, n_epoch = 10, n_batch_size = 64, report_every = 50, learning_rate = 0.2, criterion = nn.NLLLoss()):294"""295Learn on a batch of training_data for a specified number of iterations and reporting thresholds296"""297# Keep track of losses for plotting298current_loss = 0299all_losses = []300rnn.train()301optimizer = torch.optim.SGD(rnn.parameters(), lr=learning_rate)302303start = time.time()304print(f"training on data set with n = {len(training_data)}")305306for iter in range(1, n_epoch + 1):307rnn.zero_grad() # clear the gradients308309# create some minibatches310# we cannot use dataloaders because each of our names is a different length311batches = list(range(len(training_data)))312random.shuffle(batches)313batches = np.array_split(batches, len(batches) //n_batch_size )314315for idx, batch in enumerate(batches):316batch_loss = 0317for i in batch: #for each example in this batch318(label_tensor, text_tensor, label, text) = training_data[i]319output = rnn.forward(text_tensor)320loss = criterion(output, label_tensor)321batch_loss += loss322323# optimize parameters324batch_loss.backward()325nn.utils.clip_grad_norm_(rnn.parameters(), 3)326optimizer.step()327optimizer.zero_grad()328329current_loss += batch_loss.item() / len(batch)330331all_losses.append(current_loss / len(batches) )332if iter % report_every == 0:333print(f"{iter} ({iter / n_epoch:.0%}): \t average batch loss = {all_losses[-1]}")334current_loss = 0335336return all_losses337338##########################################################################339# We can now train a dataset with minibatches for a specified number of epochs. The number of epochs for this340# example is reduced to speed up the build. You can get better results with different parameters.341342start = time.time()343all_losses = train(rnn, train_set, n_epoch=27, learning_rate=0.15, report_every=5)344end = time.time()345print(f"training took {end-start}s")346347######################################################################348# Plotting the Results349# --------------------350#351# Plotting the historical loss from ``all_losses`` shows the network352# learning:353#354355import matplotlib.pyplot as plt356import matplotlib.ticker as ticker357358plt.figure()359plt.plot(all_losses)360plt.show()361362######################################################################363# Evaluating the Results364# ======================365#366# To see how well the network performs on different categories, we will367# create a confusion matrix, indicating for every actual language (rows)368# which language the network guesses (columns). To calculate the confusion369# matrix a bunch of samples are run through the network with370# ``evaluate()``, which is the same as ``train()`` minus the backprop.371#372373def evaluate(rnn, testing_data, classes):374confusion = torch.zeros(len(classes), len(classes))375376rnn.eval() #set to eval mode377with torch.no_grad(): # do not record the gradients during eval phase378for i in range(len(testing_data)):379(label_tensor, text_tensor, label, text) = testing_data[i]380output = rnn(text_tensor)381guess, guess_i = label_from_output(output, classes)382label_i = classes.index(label)383confusion[label_i][guess_i] += 1384385# Normalize by dividing every row by its sum386for i in range(len(classes)):387denom = confusion[i].sum()388if denom > 0:389confusion[i] = confusion[i] / denom390391# Set up plot392fig = plt.figure()393ax = fig.add_subplot(111)394cax = ax.matshow(confusion.cpu().numpy()) #numpy uses cpu here so we need to use a cpu version395fig.colorbar(cax)396397# Set up axes398ax.set_xticks(np.arange(len(classes)), labels=classes, rotation=90)399ax.set_yticks(np.arange(len(classes)), labels=classes)400401# Force label at every tick402ax.xaxis.set_major_locator(ticker.MultipleLocator(1))403ax.yaxis.set_major_locator(ticker.MultipleLocator(1))404405# sphinx_gallery_thumbnail_number = 2406plt.show()407408409410evaluate(rnn, test_set, classes=alldata.labels_uniq)411412413######################################################################414# You can pick out bright spots off the main axis that show which415# languages it guesses incorrectly, e.g. Chinese for Korean, and Spanish416# for Italian. It seems to do very well with Greek, and very poorly with417# English (perhaps because of overlap with other languages).418#419420421######################################################################422# Exercises423# =========424#425# - Get better results with a bigger and/or better shaped network426#427# - Adjust the hyperparameters to enhance performance, such as changing the number of epochs, batch size, and learning rate428# - Try the ``nn.LSTM`` and ``nn.GRU`` layers429# - Modify the size of the layers, such as increasing or decreasing the number of hidden nodes or adding additional linear layers430# - Combine multiple of these RNNs as a higher level network431#432# - Try with a different dataset of line -> label, for example:433#434# - Any word -> language435# - First name -> gender436# - Character name -> writer437# - Page title -> blog or subreddit438439