Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Path: blob/main/intermediate_source/char_rnn_classification_tutorial.py
Views: 712
# -*- coding: utf-8 -*-1"""2NLP From Scratch: Classifying Names with a Character-Level RNN3**************************************************************4**Author**: `Sean Robertson <https://github.com/spro>`_56This tutorials is part of a three-part series:78* `NLP From Scratch: Classifying Names with a Character-Level RNN <https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html>`__9* `NLP From Scratch: Generating Names with a Character-Level RNN <https://pytorch.org/tutorials/intermediate/char_rnn_generation_tutorial.html>`__10* `NLP From Scratch: Translation with a Sequence to Sequence Network and Attention <https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html>`__1112We will be building and training a basic character-level Recurrent Neural13Network (RNN) to classify words. This tutorial, along with two other14Natural Language Processing (NLP) "from scratch" tutorials15:doc:`/intermediate/char_rnn_generation_tutorial` and16:doc:`/intermediate/seq2seq_translation_tutorial`, show how to17preprocess data to model NLP. In particular, these tutorials show how18preprocessing to model NLP works at a low level.1920A character-level RNN reads words as a series of characters -21outputting a prediction and "hidden state" at each step, feeding its22previous hidden state into each next step. We take the final prediction23to be the output, i.e. which class the word belongs to.2425Specifically, we'll train on a few thousand surnames from 18 languages26of origin, and predict which language a name is from based on the27spelling:2829.. code-block:: sh3031$ python predict.py Hinton32(-0.47) Scottish33(-1.52) English34(-3.57) Irish3536$ python predict.py Schmidhuber37(-0.19) German38(-2.48) Czech39(-2.68) Dutch404142Recommended Preparation43=======================4445Before starting this tutorial it is recommended that you have installed PyTorch,46and have a basic understanding of Python programming language and Tensors:4748- https://pytorch.org/ For installation instructions49- :doc:`/beginner/deep_learning_60min_blitz` to get started with PyTorch in general50and learn the basics of Tensors51- :doc:`/beginner/pytorch_with_examples` for a wide and deep overview52- :doc:`/beginner/former_torchies_tutorial` if you are former Lua Torch user5354It would also be useful to know about RNNs and how they work:5556- `The Unreasonable Effectiveness of Recurrent Neural57Networks <https://karpathy.github.io/2015/05/21/rnn-effectiveness/>`__58shows a bunch of real life examples59- `Understanding LSTM60Networks <https://colah.github.io/posts/2015-08-Understanding-LSTMs/>`__61is about LSTMs specifically but also informative about RNNs in62general6364Preparing the Data65==================6667.. note::68Download the data from69`here <https://download.pytorch.org/tutorial/data.zip>`_70and extract it to the current directory.7172Included in the ``data/names`` directory are 18 text files named as73``[Language].txt``. Each file contains a bunch of names, one name per74line, mostly romanized (but we still need to convert from Unicode to75ASCII).7677We'll end up with a dictionary of lists of names per language,78``{language: [names ...]}``. The generic variables "category" and "line"79(for language and name in our case) are used for later extensibility.80"""81from io import open82import glob83import os8485def findFiles(path): return glob.glob(path)8687print(findFiles('data/names/*.txt'))8889import unicodedata90import string9192all_letters = string.ascii_letters + " .,;'"93n_letters = len(all_letters)9495# Turn a Unicode string to plain ASCII, thanks to https://stackoverflow.com/a/518232/280942796def unicodeToAscii(s):97return ''.join(98c for c in unicodedata.normalize('NFD', s)99if unicodedata.category(c) != 'Mn'100and c in all_letters101)102103print(unicodeToAscii('Ślusàrski'))104105# Build the category_lines dictionary, a list of names per language106category_lines = {}107all_categories = []108109# Read a file and split into lines110def readLines(filename):111lines = open(filename, encoding='utf-8').read().strip().split('\n')112return [unicodeToAscii(line) for line in lines]113114for filename in findFiles('data/names/*.txt'):115category = os.path.splitext(os.path.basename(filename))[0]116all_categories.append(category)117lines = readLines(filename)118category_lines[category] = lines119120n_categories = len(all_categories)121122123######################################################################124# Now we have ``category_lines``, a dictionary mapping each category125# (language) to a list of lines (names). We also kept track of126# ``all_categories`` (just a list of languages) and ``n_categories`` for127# later reference.128#129130print(category_lines['Italian'][:5])131132133######################################################################134# Turning Names into Tensors135# --------------------------136#137# Now that we have all the names organized, we need to turn them into138# Tensors to make any use of them.139#140# To represent a single letter, we use a "one-hot vector" of size141# ``<1 x n_letters>``. A one-hot vector is filled with 0s except for a 1142# at index of the current letter, e.g. ``"b" = <0 1 0 0 0 ...>``.143#144# To make a word we join a bunch of those into a 2D matrix145# ``<line_length x 1 x n_letters>``.146#147# That extra 1 dimension is because PyTorch assumes everything is in148# batches - we're just using a batch size of 1 here.149#150151import torch152153# Find letter index from all_letters, e.g. "a" = 0154def letterToIndex(letter):155return all_letters.find(letter)156157# Just for demonstration, turn a letter into a <1 x n_letters> Tensor158def letterToTensor(letter):159tensor = torch.zeros(1, n_letters)160tensor[0][letterToIndex(letter)] = 1161return tensor162163# Turn a line into a <line_length x 1 x n_letters>,164# or an array of one-hot letter vectors165def lineToTensor(line):166tensor = torch.zeros(len(line), 1, n_letters)167for li, letter in enumerate(line):168tensor[li][0][letterToIndex(letter)] = 1169return tensor170171print(letterToTensor('J'))172173print(lineToTensor('Jones').size())174175176######################################################################177# Creating the Network178# ====================179#180# Before autograd, creating a recurrent neural network in Torch involved181# cloning the parameters of a layer over several timesteps. The layers182# held hidden state and gradients which are now entirely handled by the183# graph itself. This means you can implement a RNN in a very "pure" way,184# as regular feed-forward layers.185#186# This RNN module implements a "vanilla RNN" an is just 3 linear layers187# which operate on an input and hidden state, with a ``LogSoftmax`` layer188# after the output.189#190191import torch.nn as nn192import torch.nn.functional as F193194class RNN(nn.Module):195def __init__(self, input_size, hidden_size, output_size):196super(RNN, self).__init__()197198self.hidden_size = hidden_size199200self.i2h = nn.Linear(input_size, hidden_size)201self.h2h = nn.Linear(hidden_size, hidden_size)202self.h2o = nn.Linear(hidden_size, output_size)203self.softmax = nn.LogSoftmax(dim=1)204205def forward(self, input, hidden):206hidden = F.tanh(self.i2h(input) + self.h2h(hidden))207output = self.h2o(hidden)208output = self.softmax(output)209return output, hidden210211def initHidden(self):212return torch.zeros(1, self.hidden_size)213214n_hidden = 128215rnn = RNN(n_letters, n_hidden, n_categories)216217218######################################################################219# To run a step of this network we need to pass an input (in our case, the220# Tensor for the current letter) and a previous hidden state (which we221# initialize as zeros at first). We'll get back the output (probability of222# each language) and a next hidden state (which we keep for the next223# step).224#225226input = letterToTensor('A')227hidden = torch.zeros(1, n_hidden)228229output, next_hidden = rnn(input, hidden)230231232######################################################################233# For the sake of efficiency we don't want to be creating a new Tensor for234# every step, so we will use ``lineToTensor`` instead of235# ``letterToTensor`` and use slices. This could be further optimized by236# precomputing batches of Tensors.237#238239input = lineToTensor('Albert')240hidden = torch.zeros(1, n_hidden)241242output, next_hidden = rnn(input[0], hidden)243print(output)244245246######################################################################247# As you can see the output is a ``<1 x n_categories>`` Tensor, where248# every item is the likelihood of that category (higher is more likely).249#250251252######################################################################253#254# Training255# ========256# Preparing for Training257# ----------------------258#259# Before going into training we should make a few helper functions. The260# first is to interpret the output of the network, which we know to be a261# likelihood of each category. We can use ``Tensor.topk`` to get the index262# of the greatest value:263#264265def categoryFromOutput(output):266top_n, top_i = output.topk(1)267category_i = top_i[0].item()268return all_categories[category_i], category_i269270print(categoryFromOutput(output))271272273######################################################################274# We will also want a quick way to get a training example (a name and its275# language):276#277278import random279280def randomChoice(l):281return l[random.randint(0, len(l) - 1)]282283def randomTrainingExample():284category = randomChoice(all_categories)285line = randomChoice(category_lines[category])286category_tensor = torch.tensor([all_categories.index(category)], dtype=torch.long)287line_tensor = lineToTensor(line)288return category, line, category_tensor, line_tensor289290for i in range(10):291category, line, category_tensor, line_tensor = randomTrainingExample()292print('category =', category, '/ line =', line)293294295######################################################################296# Training the Network297# --------------------298#299# Now all it takes to train this network is show it a bunch of examples,300# have it make guesses, and tell it if it's wrong.301#302# For the loss function ``nn.NLLLoss`` is appropriate, since the last303# layer of the RNN is ``nn.LogSoftmax``.304#305306criterion = nn.NLLLoss()307308309######################################################################310# Each loop of training will:311#312# - Create input and target tensors313# - Create a zeroed initial hidden state314# - Read each letter in and315#316# - Keep hidden state for next letter317#318# - Compare final output to target319# - Back-propagate320# - Return the output and loss321#322323learning_rate = 0.005 # If you set this too high, it might explode. If too low, it might not learn324325def train(category_tensor, line_tensor):326hidden = rnn.initHidden()327328rnn.zero_grad()329330for i in range(line_tensor.size()[0]):331output, hidden = rnn(line_tensor[i], hidden)332333loss = criterion(output, category_tensor)334loss.backward()335336# Add parameters' gradients to their values, multiplied by learning rate337for p in rnn.parameters():338p.data.add_(p.grad.data, alpha=-learning_rate)339340return output, loss.item()341342343######################################################################344# Now we just have to run that with a bunch of examples. Since the345# ``train`` function returns both the output and loss we can print its346# guesses and also keep track of loss for plotting. Since there are 1000s347# of examples we print only every ``print_every`` examples, and take an348# average of the loss.349#350351import time352import math353354n_iters = 100000355print_every = 5000356plot_every = 1000357358359360# Keep track of losses for plotting361current_loss = 0362all_losses = []363364def timeSince(since):365now = time.time()366s = now - since367m = math.floor(s / 60)368s -= m * 60369return '%dm %ds' % (m, s)370371start = time.time()372373for iter in range(1, n_iters + 1):374category, line, category_tensor, line_tensor = randomTrainingExample()375output, loss = train(category_tensor, line_tensor)376current_loss += loss377378# Print ``iter`` number, loss, name and guess379if iter % print_every == 0:380guess, guess_i = categoryFromOutput(output)381correct = '✓' if guess == category else '✗ (%s)' % category382print('%d %d%% (%s) %.4f %s / %s %s' % (iter, iter / n_iters * 100, timeSince(start), loss, line, guess, correct))383384# Add current loss avg to list of losses385if iter % plot_every == 0:386all_losses.append(current_loss / plot_every)387current_loss = 0388389390######################################################################391# Plotting the Results392# --------------------393#394# Plotting the historical loss from ``all_losses`` shows the network395# learning:396#397398import matplotlib.pyplot as plt399import matplotlib.ticker as ticker400401plt.figure()402plt.plot(all_losses)403404405######################################################################406# Evaluating the Results407# ======================408#409# To see how well the network performs on different categories, we will410# create a confusion matrix, indicating for every actual language (rows)411# which language the network guesses (columns). To calculate the confusion412# matrix a bunch of samples are run through the network with413# ``evaluate()``, which is the same as ``train()`` minus the backprop.414#415416# Keep track of correct guesses in a confusion matrix417confusion = torch.zeros(n_categories, n_categories)418n_confusion = 10000419420# Just return an output given a line421def evaluate(line_tensor):422hidden = rnn.initHidden()423424for i in range(line_tensor.size()[0]):425output, hidden = rnn(line_tensor[i], hidden)426427return output428429# Go through a bunch of examples and record which are correctly guessed430for i in range(n_confusion):431category, line, category_tensor, line_tensor = randomTrainingExample()432output = evaluate(line_tensor)433guess, guess_i = categoryFromOutput(output)434category_i = all_categories.index(category)435confusion[category_i][guess_i] += 1436437# Normalize by dividing every row by its sum438for i in range(n_categories):439confusion[i] = confusion[i] / confusion[i].sum()440441# Set up plot442fig = plt.figure()443ax = fig.add_subplot(111)444cax = ax.matshow(confusion.numpy())445fig.colorbar(cax)446447# Set up axes448ax.set_xticklabels([''] + all_categories, rotation=90)449ax.set_yticklabels([''] + all_categories)450451# Force label at every tick452ax.xaxis.set_major_locator(ticker.MultipleLocator(1))453ax.yaxis.set_major_locator(ticker.MultipleLocator(1))454455# sphinx_gallery_thumbnail_number = 2456plt.show()457458459######################################################################460# You can pick out bright spots off the main axis that show which461# languages it guesses incorrectly, e.g. Chinese for Korean, and Spanish462# for Italian. It seems to do very well with Greek, and very poorly with463# English (perhaps because of overlap with other languages).464#465466467######################################################################468# Running on User Input469# ---------------------470#471472def predict(input_line, n_predictions=3):473print('\n> %s' % input_line)474with torch.no_grad():475output = evaluate(lineToTensor(input_line))476477# Get top N categories478topv, topi = output.topk(n_predictions, 1, True)479predictions = []480481for i in range(n_predictions):482value = topv[0][i].item()483category_index = topi[0][i].item()484print('(%.2f) %s' % (value, all_categories[category_index]))485predictions.append([value, all_categories[category_index]])486487predict('Dovesky')488predict('Jackson')489predict('Satoshi')490491492######################################################################493# The final versions of the scripts `in the Practical PyTorch494# repo <https://github.com/spro/practical-pytorch/tree/master/char-rnn-classification>`__495# split the above code into a few files:496#497# - ``data.py`` (loads files)498# - ``model.py`` (defines the RNN)499# - ``train.py`` (runs training)500# - ``predict.py`` (runs ``predict()`` with command line arguments)501# - ``server.py`` (serve prediction as a JSON API with ``bottle.py``)502#503# Run ``train.py`` to train and save the network.504#505# Run ``predict.py`` with a name to view predictions:506#507# .. code-block:: sh508#509# $ python predict.py Hazaki510# (-0.42) Japanese511# (-1.39) Polish512# (-3.51) Czech513#514# Run ``server.py`` and visit http://localhost:5533/Yourname to get JSON515# output of predictions.516#517518519######################################################################520# Exercises521# =========522#523# - Try with a different dataset of line -> category, for example:524#525# - Any word -> language526# - First name -> gender527# - Character name -> writer528# - Page title -> blog or subreddit529#530# - Get better results with a bigger and/or better shaped network531#532# - Add more linear layers533# - Try the ``nn.LSTM`` and ``nn.GRU`` layers534# - Combine multiple of these RNNs as a higher level network535#536537538