Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Path: blob/main/intermediate_source/char_rnn_generation_tutorial.py
Views: 712
# -*- coding: utf-8 -*-1"""2NLP From Scratch: Generating Names with a Character-Level RNN3*************************************************************4**Author**: `Sean Robertson <https://github.com/spro>`_56This tutorials is part of a three-part series:78* `NLP From Scratch: Classifying Names with a Character-Level RNN <https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html>`__9* `NLP From Scratch: Generating Names with a Character-Level RNN <https://pytorch.org/tutorials/intermediate/char_rnn_generation_tutorial.html>`__10* `NLP From Scratch: Translation with a Sequence to Sequence Network and Attention <https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html>`__1112This is our second of three tutorials on "NLP From Scratch".13In the `first tutorial </tutorials/intermediate/char_rnn_classification_tutorial>`_14we used a RNN to classify names into their language of origin. This time15we'll turn around and generate names from languages.1617.. code-block:: sh1819> python sample.py Russian RUS20Rovakov21Uantov22Shavakov2324> python sample.py German GER25Gerren26Ereng27Rosher2829> python sample.py Spanish SPA30Salla31Parer32Allan3334> python sample.py Chinese CHI35Chan36Hang37Iun3839We are still hand-crafting a small RNN with a few linear layers. The big40difference is instead of predicting a category after reading in all the41letters of a name, we input a category and output one letter at a time.42Recurrently predicting characters to form language (this could also be43done with words or other higher order constructs) is often referred to44as a "language model".4546**Recommended Reading:**4748I assume you have at least installed PyTorch, know Python, and49understand Tensors:5051- https://pytorch.org/ For installation instructions52- :doc:`/beginner/deep_learning_60min_blitz` to get started with PyTorch in general53- :doc:`/beginner/pytorch_with_examples` for a wide and deep overview54- :doc:`/beginner/former_torchies_tutorial` if you are former Lua Torch user5556It would also be useful to know about RNNs and how they work:5758- `The Unreasonable Effectiveness of Recurrent Neural59Networks <https://karpathy.github.io/2015/05/21/rnn-effectiveness/>`__60shows a bunch of real life examples61- `Understanding LSTM62Networks <https://colah.github.io/posts/2015-08-Understanding-LSTMs/>`__63is about LSTMs specifically but also informative about RNNs in64general6566I also suggest the previous tutorial, :doc:`/intermediate/char_rnn_classification_tutorial`676869Preparing the Data70==================7172.. note::73Download the data from74`here <https://download.pytorch.org/tutorial/data.zip>`_75and extract it to the current directory.7677See the last tutorial for more detail of this process. In short, there78are a bunch of plain text files ``data/names/[Language].txt`` with a79name per line. We split lines into an array, convert Unicode to ASCII,80and end up with a dictionary ``{language: [names ...]}``.8182"""83from io import open84import glob85import os86import unicodedata87import string8889all_letters = string.ascii_letters + " .,;'-"90n_letters = len(all_letters) + 1 # Plus EOS marker9192def findFiles(path): return glob.glob(path)9394# Turn a Unicode string to plain ASCII, thanks to https://stackoverflow.com/a/518232/280942795def unicodeToAscii(s):96return ''.join(97c for c in unicodedata.normalize('NFD', s)98if unicodedata.category(c) != 'Mn'99and c in all_letters100)101102# Read a file and split into lines103def readLines(filename):104with open(filename, encoding='utf-8') as some_file:105return [unicodeToAscii(line.strip()) for line in some_file]106107# Build the category_lines dictionary, a list of lines per category108category_lines = {}109all_categories = []110for filename in findFiles('data/names/*.txt'):111category = os.path.splitext(os.path.basename(filename))[0]112all_categories.append(category)113lines = readLines(filename)114category_lines[category] = lines115116n_categories = len(all_categories)117118if n_categories == 0:119raise RuntimeError('Data not found. Make sure that you downloaded data '120'from https://download.pytorch.org/tutorial/data.zip and extract it to '121'the current directory.')122123print('# categories:', n_categories, all_categories)124print(unicodeToAscii("O'Néàl"))125126127######################################################################128# Creating the Network129# ====================130#131# This network extends `the last tutorial's RNN <#Creating-the-Network>`__132# with an extra argument for the category tensor, which is concatenated133# along with the others. The category tensor is a one-hot vector just like134# the letter input.135#136# We will interpret the output as the probability of the next letter. When137# sampling, the most likely output letter is used as the next input138# letter.139#140# I added a second linear layer ``o2o`` (after combining hidden and141# output) to give it more muscle to work with. There's also a dropout142# layer, which `randomly zeros parts of its143# input <https://arxiv.org/abs/1207.0580>`__ with a given probability144# (here 0.1) and is usually used to fuzz inputs to prevent overfitting.145# Here we're using it towards the end of the network to purposely add some146# chaos and increase sampling variety.147#148# .. figure:: https://i.imgur.com/jzVrf7f.png149# :alt:150#151#152153import torch154import torch.nn as nn155156class RNN(nn.Module):157def __init__(self, input_size, hidden_size, output_size):158super(RNN, self).__init__()159self.hidden_size = hidden_size160161self.i2h = nn.Linear(n_categories + input_size + hidden_size, hidden_size)162self.i2o = nn.Linear(n_categories + input_size + hidden_size, output_size)163self.o2o = nn.Linear(hidden_size + output_size, output_size)164self.dropout = nn.Dropout(0.1)165self.softmax = nn.LogSoftmax(dim=1)166167def forward(self, category, input, hidden):168input_combined = torch.cat((category, input, hidden), 1)169hidden = self.i2h(input_combined)170output = self.i2o(input_combined)171output_combined = torch.cat((hidden, output), 1)172output = self.o2o(output_combined)173output = self.dropout(output)174output = self.softmax(output)175return output, hidden176177def initHidden(self):178return torch.zeros(1, self.hidden_size)179180181######################################################################182# Training183# =========184# Preparing for Training185# ----------------------186#187# First of all, helper functions to get random pairs of (category, line):188#189190import random191192# Random item from a list193def randomChoice(l):194return l[random.randint(0, len(l) - 1)]195196# Get a random category and random line from that category197def randomTrainingPair():198category = randomChoice(all_categories)199line = randomChoice(category_lines[category])200return category, line201202203######################################################################204# For each timestep (that is, for each letter in a training word) the205# inputs of the network will be206# ``(category, current letter, hidden state)`` and the outputs will be207# ``(next letter, next hidden state)``. So for each training set, we'll208# need the category, a set of input letters, and a set of output/target209# letters.210#211# Since we are predicting the next letter from the current letter for each212# timestep, the letter pairs are groups of consecutive letters from the213# line - e.g. for ``"ABCD<EOS>"`` we would create ("A", "B"), ("B", "C"),214# ("C", "D"), ("D", "EOS").215#216# .. figure:: https://i.imgur.com/JH58tXY.png217# :alt:218#219# The category tensor is a `one-hot220# tensor <https://en.wikipedia.org/wiki/One-hot>`__ of size221# ``<1 x n_categories>``. When training we feed it to the network at every222# timestep - this is a design choice, it could have been included as part223# of initial hidden state or some other strategy.224#225226# One-hot vector for category227def categoryTensor(category):228li = all_categories.index(category)229tensor = torch.zeros(1, n_categories)230tensor[0][li] = 1231return tensor232233# One-hot matrix of first to last letters (not including EOS) for input234def inputTensor(line):235tensor = torch.zeros(len(line), 1, n_letters)236for li in range(len(line)):237letter = line[li]238tensor[li][0][all_letters.find(letter)] = 1239return tensor240241# ``LongTensor`` of second letter to end (EOS) for target242def targetTensor(line):243letter_indexes = [all_letters.find(line[li]) for li in range(1, len(line))]244letter_indexes.append(n_letters - 1) # EOS245return torch.LongTensor(letter_indexes)246247248######################################################################249# For convenience during training we'll make a ``randomTrainingExample``250# function that fetches a random (category, line) pair and turns them into251# the required (category, input, target) tensors.252#253254# Make category, input, and target tensors from a random category, line pair255def randomTrainingExample():256category, line = randomTrainingPair()257category_tensor = categoryTensor(category)258input_line_tensor = inputTensor(line)259target_line_tensor = targetTensor(line)260return category_tensor, input_line_tensor, target_line_tensor261262263######################################################################264# Training the Network265# --------------------266#267# In contrast to classification, where only the last output is used, we268# are making a prediction at every step, so we are calculating loss at269# every step.270#271# The magic of autograd allows you to simply sum these losses at each step272# and call backward at the end.273#274275criterion = nn.NLLLoss()276277learning_rate = 0.0005278279def train(category_tensor, input_line_tensor, target_line_tensor):280target_line_tensor.unsqueeze_(-1)281hidden = rnn.initHidden()282283rnn.zero_grad()284285loss = torch.Tensor([0]) # you can also just simply use ``loss = 0``286287for i in range(input_line_tensor.size(0)):288output, hidden = rnn(category_tensor, input_line_tensor[i], hidden)289l = criterion(output, target_line_tensor[i])290loss += l291292loss.backward()293294for p in rnn.parameters():295p.data.add_(p.grad.data, alpha=-learning_rate)296297return output, loss.item() / input_line_tensor.size(0)298299300######################################################################301# To keep track of how long training takes I am adding a302# ``timeSince(timestamp)`` function which returns a human readable string:303#304305import time306import math307308def timeSince(since):309now = time.time()310s = now - since311m = math.floor(s / 60)312s -= m * 60313return '%dm %ds' % (m, s)314315316######################################################################317# Training is business as usual - call train a bunch of times and wait a318# few minutes, printing the current time and loss every ``print_every``319# examples, and keeping store of an average loss per ``plot_every`` examples320# in ``all_losses`` for plotting later.321#322323rnn = RNN(n_letters, 128, n_letters)324325n_iters = 100000326print_every = 5000327plot_every = 500328all_losses = []329total_loss = 0 # Reset every ``plot_every`` ``iters``330331start = time.time()332333for iter in range(1, n_iters + 1):334output, loss = train(*randomTrainingExample())335total_loss += loss336337if iter % print_every == 0:338print('%s (%d %d%%) %.4f' % (timeSince(start), iter, iter / n_iters * 100, loss))339340if iter % plot_every == 0:341all_losses.append(total_loss / plot_every)342total_loss = 0343344345######################################################################346# Plotting the Losses347# -------------------348#349# Plotting the historical loss from all\_losses shows the network350# learning:351#352353import matplotlib.pyplot as plt354355plt.figure()356plt.plot(all_losses)357358359######################################################################360# Sampling the Network361# ====================362#363# To sample we give the network a letter and ask what the next one is,364# feed that in as the next letter, and repeat until the EOS token.365#366# - Create tensors for input category, starting letter, and empty hidden367# state368# - Create a string ``output_name`` with the starting letter369# - Up to a maximum output length,370#371# - Feed the current letter to the network372# - Get the next letter from highest output, and next hidden state373# - If the letter is EOS, stop here374# - If a regular letter, add to ``output_name`` and continue375#376# - Return the final name377#378# .. note::379# Rather than having to give it a starting letter, another380# strategy would have been to include a "start of string" token in381# training and have the network choose its own starting letter.382#383384max_length = 20385386# Sample from a category and starting letter387def sample(category, start_letter='A'):388with torch.no_grad(): # no need to track history in sampling389category_tensor = categoryTensor(category)390input = inputTensor(start_letter)391hidden = rnn.initHidden()392393output_name = start_letter394395for i in range(max_length):396output, hidden = rnn(category_tensor, input[0], hidden)397topv, topi = output.topk(1)398topi = topi[0][0]399if topi == n_letters - 1:400break401else:402letter = all_letters[topi]403output_name += letter404input = inputTensor(letter)405406return output_name407408# Get multiple samples from one category and multiple starting letters409def samples(category, start_letters='ABC'):410for start_letter in start_letters:411print(sample(category, start_letter))412413samples('Russian', 'RUS')414415samples('German', 'GER')416417samples('Spanish', 'SPA')418419samples('Chinese', 'CHI')420421422######################################################################423# Exercises424# =========425#426# - Try with a different dataset of category -> line, for example:427#428# - Fictional series -> Character name429# - Part of speech -> Word430# - Country -> City431#432# - Use a "start of sentence" token so that sampling can be done without433# choosing a start letter434# - Get better results with a bigger and/or better shaped network435#436# - Try the ``nn.LSTM`` and ``nn.GRU`` layers437# - Combine multiple of these RNNs as a higher level network438#439440441