CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!

GitHub Repository: pytorch/tutorials
Path: blob/main/intermediate_source/seq2seq_translation_tutorial.py
Views: 494
1
# -*- coding: utf-8 -*-
2
"""
3
NLP From Scratch: Translation with a Sequence to Sequence Network and Attention
4
*******************************************************************************
5
**Author**: `Sean Robertson <https://github.com/spro>`_
6
7
This is the third and final tutorial on doing "NLP From Scratch", where we
8
write our own classes and functions to preprocess the data to do our NLP
9
modeling tasks. We hope after you complete this tutorial that you'll proceed to
10
learn how `torchtext` can handle much of this preprocessing for you in the
11
three tutorials immediately following this one.
12
13
In this project we will be teaching a neural network to translate from
14
French to English.
15
16
.. code-block:: sh
17
18
[KEY: > input, = target, < output]
19
20
> il est en train de peindre un tableau .
21
= he is painting a picture .
22
< he is painting a picture .
23
24
> pourquoi ne pas essayer ce vin delicieux ?
25
= why not try that delicious wine ?
26
< why not try that delicious wine ?
27
28
> elle n est pas poete mais romanciere .
29
= she is not a poet but a novelist .
30
< she not not a poet but a novelist .
31
32
> vous etes trop maigre .
33
= you re too skinny .
34
< you re all alone .
35
36
... to varying degrees of success.
37
38
This is made possible by the simple but powerful idea of the `sequence
39
to sequence network <https://arxiv.org/abs/1409.3215>`__, in which two
40
recurrent neural networks work together to transform one sequence to
41
another. An encoder network condenses an input sequence into a vector,
42
and a decoder network unfolds that vector into a new sequence.
43
44
.. figure:: /_static/img/seq-seq-images/seq2seq.png
45
:alt:
46
47
To improve upon this model we'll use an `attention
48
mechanism <https://arxiv.org/abs/1409.0473>`__, which lets the decoder
49
learn to focus over a specific range of the input sequence.
50
51
**Recommended Reading:**
52
53
I assume you have at least installed PyTorch, know Python, and
54
understand Tensors:
55
56
- https://pytorch.org/ For installation instructions
57
- :doc:`/beginner/deep_learning_60min_blitz` to get started with PyTorch in general
58
- :doc:`/beginner/pytorch_with_examples` for a wide and deep overview
59
- :doc:`/beginner/former_torchies_tutorial` if you are former Lua Torch user
60
61
62
It would also be useful to know about Sequence to Sequence networks and
63
how they work:
64
65
- `Learning Phrase Representations using RNN Encoder-Decoder for
66
Statistical Machine Translation <https://arxiv.org/abs/1406.1078>`__
67
- `Sequence to Sequence Learning with Neural
68
Networks <https://arxiv.org/abs/1409.3215>`__
69
- `Neural Machine Translation by Jointly Learning to Align and
70
Translate <https://arxiv.org/abs/1409.0473>`__
71
- `A Neural Conversational Model <https://arxiv.org/abs/1506.05869>`__
72
73
You will also find the previous tutorials on
74
:doc:`/intermediate/char_rnn_classification_tutorial`
75
and :doc:`/intermediate/char_rnn_generation_tutorial`
76
helpful as those concepts are very similar to the Encoder and Decoder
77
models, respectively.
78
79
**Requirements**
80
"""
81
from __future__ import unicode_literals, print_function, division
82
from io import open
83
import unicodedata
84
import re
85
import random
86
87
import torch
88
import torch.nn as nn
89
from torch import optim
90
import torch.nn.functional as F
91
92
import numpy as np
93
from torch.utils.data import TensorDataset, DataLoader, RandomSampler
94
95
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
96
97
######################################################################
98
# Loading data files
99
# ==================
100
#
101
# The data for this project is a set of many thousands of English to
102
# French translation pairs.
103
#
104
# `This question on Open Data Stack
105
# Exchange <https://opendata.stackexchange.com/questions/3888/dataset-of-sentences-translated-into-many-languages>`__
106
# pointed me to the open translation site https://tatoeba.org/ which has
107
# downloads available at https://tatoeba.org/eng/downloads - and better
108
# yet, someone did the extra work of splitting language pairs into
109
# individual text files here: https://www.manythings.org/anki/
110
#
111
# The English to French pairs are too big to include in the repository, so
112
# download to ``data/eng-fra.txt`` before continuing. The file is a tab
113
# separated list of translation pairs:
114
#
115
# .. code-block:: sh
116
#
117
# I am cold. J'ai froid.
118
#
119
# .. note::
120
# Download the data from
121
# `here <https://download.pytorch.org/tutorial/data.zip>`_
122
# and extract it to the current directory.
123
124
######################################################################
125
# Similar to the character encoding used in the character-level RNN
126
# tutorials, we will be representing each word in a language as a one-hot
127
# vector, or giant vector of zeros except for a single one (at the index
128
# of the word). Compared to the dozens of characters that might exist in a
129
# language, there are many many more words, so the encoding vector is much
130
# larger. We will however cheat a bit and trim the data to only use a few
131
# thousand words per language.
132
#
133
# .. figure:: /_static/img/seq-seq-images/word-encoding.png
134
# :alt:
135
#
136
#
137
138
139
######################################################################
140
# We'll need a unique index per word to use as the inputs and targets of
141
# the networks later. To keep track of all this we will use a helper class
142
# called ``Lang`` which has word → index (``word2index``) and index → word
143
# (``index2word``) dictionaries, as well as a count of each word
144
# ``word2count`` which will be used to replace rare words later.
145
#
146
147
SOS_token = 0
148
EOS_token = 1
149
150
class Lang:
151
def __init__(self, name):
152
self.name = name
153
self.word2index = {}
154
self.word2count = {}
155
self.index2word = {0: "SOS", 1: "EOS"}
156
self.n_words = 2 # Count SOS and EOS
157
158
def addSentence(self, sentence):
159
for word in sentence.split(' '):
160
self.addWord(word)
161
162
def addWord(self, word):
163
if word not in self.word2index:
164
self.word2index[word] = self.n_words
165
self.word2count[word] = 1
166
self.index2word[self.n_words] = word
167
self.n_words += 1
168
else:
169
self.word2count[word] += 1
170
171
172
######################################################################
173
# The files are all in Unicode, to simplify we will turn Unicode
174
# characters to ASCII, make everything lowercase, and trim most
175
# punctuation.
176
#
177
178
# Turn a Unicode string to plain ASCII, thanks to
179
# https://stackoverflow.com/a/518232/2809427
180
def unicodeToAscii(s):
181
return ''.join(
182
c for c in unicodedata.normalize('NFD', s)
183
if unicodedata.category(c) != 'Mn'
184
)
185
186
# Lowercase, trim, and remove non-letter characters
187
def normalizeString(s):
188
s = unicodeToAscii(s.lower().strip())
189
s = re.sub(r"([.!?])", r" \1", s)
190
s = re.sub(r"[^a-zA-Z!?]+", r" ", s)
191
return s.strip()
192
193
194
######################################################################
195
# To read the data file we will split the file into lines, and then split
196
# lines into pairs. The files are all English → Other Language, so if we
197
# want to translate from Other Language → English I added the ``reverse``
198
# flag to reverse the pairs.
199
#
200
201
def readLangs(lang1, lang2, reverse=False):
202
print("Reading lines...")
203
204
# Read the file and split into lines
205
lines = open('data/%s-%s.txt' % (lang1, lang2), encoding='utf-8').\
206
read().strip().split('\n')
207
208
# Split every line into pairs and normalize
209
pairs = [[normalizeString(s) for s in l.split('\t')] for l in lines]
210
211
# Reverse pairs, make Lang instances
212
if reverse:
213
pairs = [list(reversed(p)) for p in pairs]
214
input_lang = Lang(lang2)
215
output_lang = Lang(lang1)
216
else:
217
input_lang = Lang(lang1)
218
output_lang = Lang(lang2)
219
220
return input_lang, output_lang, pairs
221
222
223
######################################################################
224
# Since there are a *lot* of example sentences and we want to train
225
# something quickly, we'll trim the data set to only relatively short and
226
# simple sentences. Here the maximum length is 10 words (that includes
227
# ending punctuation) and we're filtering to sentences that translate to
228
# the form "I am" or "He is" etc. (accounting for apostrophes replaced
229
# earlier).
230
#
231
232
MAX_LENGTH = 10
233
234
eng_prefixes = (
235
"i am ", "i m ",
236
"he is", "he s ",
237
"she is", "she s ",
238
"you are", "you re ",
239
"we are", "we re ",
240
"they are", "they re "
241
)
242
243
def filterPair(p):
244
return len(p[0].split(' ')) < MAX_LENGTH and \
245
len(p[1].split(' ')) < MAX_LENGTH and \
246
p[1].startswith(eng_prefixes)
247
248
249
def filterPairs(pairs):
250
return [pair for pair in pairs if filterPair(pair)]
251
252
253
######################################################################
254
# The full process for preparing the data is:
255
#
256
# - Read text file and split into lines, split lines into pairs
257
# - Normalize text, filter by length and content
258
# - Make word lists from sentences in pairs
259
#
260
261
def prepareData(lang1, lang2, reverse=False):
262
input_lang, output_lang, pairs = readLangs(lang1, lang2, reverse)
263
print("Read %s sentence pairs" % len(pairs))
264
pairs = filterPairs(pairs)
265
print("Trimmed to %s sentence pairs" % len(pairs))
266
print("Counting words...")
267
for pair in pairs:
268
input_lang.addSentence(pair[0])
269
output_lang.addSentence(pair[1])
270
print("Counted words:")
271
print(input_lang.name, input_lang.n_words)
272
print(output_lang.name, output_lang.n_words)
273
return input_lang, output_lang, pairs
274
275
input_lang, output_lang, pairs = prepareData('eng', 'fra', True)
276
print(random.choice(pairs))
277
278
279
######################################################################
280
# The Seq2Seq Model
281
# =================
282
#
283
# A Recurrent Neural Network, or RNN, is a network that operates on a
284
# sequence and uses its own output as input for subsequent steps.
285
#
286
# A `Sequence to Sequence network <https://arxiv.org/abs/1409.3215>`__, or
287
# seq2seq network, or `Encoder Decoder
288
# network <https://arxiv.org/pdf/1406.1078v3.pdf>`__, is a model
289
# consisting of two RNNs called the encoder and decoder. The encoder reads
290
# an input sequence and outputs a single vector, and the decoder reads
291
# that vector to produce an output sequence.
292
#
293
# .. figure:: /_static/img/seq-seq-images/seq2seq.png
294
# :alt:
295
#
296
# Unlike sequence prediction with a single RNN, where every input
297
# corresponds to an output, the seq2seq model frees us from sequence
298
# length and order, which makes it ideal for translation between two
299
# languages.
300
#
301
# Consider the sentence ``Je ne suis pas le chat noir`` → ``I am not the
302
# black cat``. Most of the words in the input sentence have a direct
303
# translation in the output sentence, but are in slightly different
304
# orders, e.g. ``chat noir`` and ``black cat``. Because of the ``ne/pas``
305
# construction there is also one more word in the input sentence. It would
306
# be difficult to produce a correct translation directly from the sequence
307
# of input words.
308
#
309
# With a seq2seq model the encoder creates a single vector which, in the
310
# ideal case, encodes the "meaning" of the input sequence into a single
311
# vector — a single point in some N dimensional space of sentences.
312
#
313
314
315
######################################################################
316
# The Encoder
317
# -----------
318
#
319
# The encoder of a seq2seq network is a RNN that outputs some value for
320
# every word from the input sentence. For every input word the encoder
321
# outputs a vector and a hidden state, and uses the hidden state for the
322
# next input word.
323
#
324
# .. figure:: /_static/img/seq-seq-images/encoder-network.png
325
# :alt:
326
#
327
#
328
329
class EncoderRNN(nn.Module):
330
def __init__(self, input_size, hidden_size, dropout_p=0.1):
331
super(EncoderRNN, self).__init__()
332
self.hidden_size = hidden_size
333
334
self.embedding = nn.Embedding(input_size, hidden_size)
335
self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)
336
self.dropout = nn.Dropout(dropout_p)
337
338
def forward(self, input):
339
embedded = self.dropout(self.embedding(input))
340
output, hidden = self.gru(embedded)
341
return output, hidden
342
343
######################################################################
344
# The Decoder
345
# -----------
346
#
347
# The decoder is another RNN that takes the encoder output vector(s) and
348
# outputs a sequence of words to create the translation.
349
#
350
351
352
######################################################################
353
# Simple Decoder
354
# ^^^^^^^^^^^^^^
355
#
356
# In the simplest seq2seq decoder we use only last output of the encoder.
357
# This last output is sometimes called the *context vector* as it encodes
358
# context from the entire sequence. This context vector is used as the
359
# initial hidden state of the decoder.
360
#
361
# At every step of decoding, the decoder is given an input token and
362
# hidden state. The initial input token is the start-of-string ``<SOS>``
363
# token, and the first hidden state is the context vector (the encoder's
364
# last hidden state).
365
#
366
# .. figure:: /_static/img/seq-seq-images/decoder-network.png
367
# :alt:
368
#
369
#
370
371
class DecoderRNN(nn.Module):
372
def __init__(self, hidden_size, output_size):
373
super(DecoderRNN, self).__init__()
374
self.embedding = nn.Embedding(output_size, hidden_size)
375
self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)
376
self.out = nn.Linear(hidden_size, output_size)
377
378
def forward(self, encoder_outputs, encoder_hidden, target_tensor=None):
379
batch_size = encoder_outputs.size(0)
380
decoder_input = torch.empty(batch_size, 1, dtype=torch.long, device=device).fill_(SOS_token)
381
decoder_hidden = encoder_hidden
382
decoder_outputs = []
383
384
for i in range(MAX_LENGTH):
385
decoder_output, decoder_hidden = self.forward_step(decoder_input, decoder_hidden)
386
decoder_outputs.append(decoder_output)
387
388
if target_tensor is not None:
389
# Teacher forcing: Feed the target as the next input
390
decoder_input = target_tensor[:, i].unsqueeze(1) # Teacher forcing
391
else:
392
# Without teacher forcing: use its own predictions as the next input
393
_, topi = decoder_output.topk(1)
394
decoder_input = topi.squeeze(-1).detach() # detach from history as input
395
396
decoder_outputs = torch.cat(decoder_outputs, dim=1)
397
decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)
398
return decoder_outputs, decoder_hidden, None # We return `None` for consistency in the training loop
399
400
def forward_step(self, input, hidden):
401
output = self.embedding(input)
402
output = F.relu(output)
403
output, hidden = self.gru(output, hidden)
404
output = self.out(output)
405
return output, hidden
406
407
######################################################################
408
# I encourage you to train and observe the results of this model, but to
409
# save space we'll be going straight for the gold and introducing the
410
# Attention Mechanism.
411
#
412
413
414
######################################################################
415
# Attention Decoder
416
# ^^^^^^^^^^^^^^^^^
417
#
418
# If only the context vector is passed between the encoder and decoder,
419
# that single vector carries the burden of encoding the entire sentence.
420
#
421
# Attention allows the decoder network to "focus" on a different part of
422
# the encoder's outputs for every step of the decoder's own outputs. First
423
# we calculate a set of *attention weights*. These will be multiplied by
424
# the encoder output vectors to create a weighted combination. The result
425
# (called ``attn_applied`` in the code) should contain information about
426
# that specific part of the input sequence, and thus help the decoder
427
# choose the right output words.
428
#
429
# .. figure:: https://i.imgur.com/1152PYf.png
430
# :alt:
431
#
432
# Calculating the attention weights is done with another feed-forward
433
# layer ``attn``, using the decoder's input and hidden state as inputs.
434
# Because there are sentences of all sizes in the training data, to
435
# actually create and train this layer we have to choose a maximum
436
# sentence length (input length, for encoder outputs) that it can apply
437
# to. Sentences of the maximum length will use all the attention weights,
438
# while shorter sentences will only use the first few.
439
#
440
# .. figure:: /_static/img/seq-seq-images/attention-decoder-network.png
441
# :alt:
442
#
443
#
444
# Bahdanau attention, also known as additive attention, is a commonly used
445
# attention mechanism in sequence-to-sequence models, particularly in neural
446
# machine translation tasks. It was introduced by Bahdanau et al. in their
447
# paper titled `Neural Machine Translation by Jointly Learning to Align and Translate <https://arxiv.org/pdf/1409.0473.pdf>`__.
448
# This attention mechanism employs a learned alignment model to compute attention
449
# scores between the encoder and decoder hidden states. It utilizes a feed-forward
450
# neural network to calculate alignment scores.
451
#
452
# However, there are alternative attention mechanisms available, such as Luong attention,
453
# which computes attention scores by taking the dot product between the decoder hidden
454
# state and the encoder hidden states. It does not involve the non-linear transformation
455
# used in Bahdanau attention.
456
#
457
# In this tutorial, we will be using Bahdanau attention. However, it would be a valuable
458
# exercise to explore modifying the attention mechanism to use Luong attention.
459
460
class BahdanauAttention(nn.Module):
461
def __init__(self, hidden_size):
462
super(BahdanauAttention, self).__init__()
463
self.Wa = nn.Linear(hidden_size, hidden_size)
464
self.Ua = nn.Linear(hidden_size, hidden_size)
465
self.Va = nn.Linear(hidden_size, 1)
466
467
def forward(self, query, keys):
468
scores = self.Va(torch.tanh(self.Wa(query) + self.Ua(keys)))
469
scores = scores.squeeze(2).unsqueeze(1)
470
471
weights = F.softmax(scores, dim=-1)
472
context = torch.bmm(weights, keys)
473
474
return context, weights
475
476
class AttnDecoderRNN(nn.Module):
477
def __init__(self, hidden_size, output_size, dropout_p=0.1):
478
super(AttnDecoderRNN, self).__init__()
479
self.embedding = nn.Embedding(output_size, hidden_size)
480
self.attention = BahdanauAttention(hidden_size)
481
self.gru = nn.GRU(2 * hidden_size, hidden_size, batch_first=True)
482
self.out = nn.Linear(hidden_size, output_size)
483
self.dropout = nn.Dropout(dropout_p)
484
485
def forward(self, encoder_outputs, encoder_hidden, target_tensor=None):
486
batch_size = encoder_outputs.size(0)
487
decoder_input = torch.empty(batch_size, 1, dtype=torch.long, device=device).fill_(SOS_token)
488
decoder_hidden = encoder_hidden
489
decoder_outputs = []
490
attentions = []
491
492
for i in range(MAX_LENGTH):
493
decoder_output, decoder_hidden, attn_weights = self.forward_step(
494
decoder_input, decoder_hidden, encoder_outputs
495
)
496
decoder_outputs.append(decoder_output)
497
attentions.append(attn_weights)
498
499
if target_tensor is not None:
500
# Teacher forcing: Feed the target as the next input
501
decoder_input = target_tensor[:, i].unsqueeze(1) # Teacher forcing
502
else:
503
# Without teacher forcing: use its own predictions as the next input
504
_, topi = decoder_output.topk(1)
505
decoder_input = topi.squeeze(-1).detach() # detach from history as input
506
507
decoder_outputs = torch.cat(decoder_outputs, dim=1)
508
decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)
509
attentions = torch.cat(attentions, dim=1)
510
511
return decoder_outputs, decoder_hidden, attentions
512
513
514
def forward_step(self, input, hidden, encoder_outputs):
515
embedded = self.dropout(self.embedding(input))
516
517
query = hidden.permute(1, 0, 2)
518
context, attn_weights = self.attention(query, encoder_outputs)
519
input_gru = torch.cat((embedded, context), dim=2)
520
521
output, hidden = self.gru(input_gru, hidden)
522
output = self.out(output)
523
524
return output, hidden, attn_weights
525
526
527
######################################################################
528
# .. note:: There are other forms of attention that work around the length
529
# limitation by using a relative position approach. Read about "local
530
# attention" in `Effective Approaches to Attention-based Neural Machine
531
# Translation <https://arxiv.org/abs/1508.04025>`__.
532
#
533
# Training
534
# ========
535
#
536
# Preparing Training Data
537
# -----------------------
538
#
539
# To train, for each pair we will need an input tensor (indexes of the
540
# words in the input sentence) and target tensor (indexes of the words in
541
# the target sentence). While creating these vectors we will append the
542
# EOS token to both sequences.
543
#
544
545
def indexesFromSentence(lang, sentence):
546
return [lang.word2index[word] for word in sentence.split(' ')]
547
548
def tensorFromSentence(lang, sentence):
549
indexes = indexesFromSentence(lang, sentence)
550
indexes.append(EOS_token)
551
return torch.tensor(indexes, dtype=torch.long, device=device).view(1, -1)
552
553
def tensorsFromPair(pair):
554
input_tensor = tensorFromSentence(input_lang, pair[0])
555
target_tensor = tensorFromSentence(output_lang, pair[1])
556
return (input_tensor, target_tensor)
557
558
def get_dataloader(batch_size):
559
input_lang, output_lang, pairs = prepareData('eng', 'fra', True)
560
561
n = len(pairs)
562
input_ids = np.zeros((n, MAX_LENGTH), dtype=np.int32)
563
target_ids = np.zeros((n, MAX_LENGTH), dtype=np.int32)
564
565
for idx, (inp, tgt) in enumerate(pairs):
566
inp_ids = indexesFromSentence(input_lang, inp)
567
tgt_ids = indexesFromSentence(output_lang, tgt)
568
inp_ids.append(EOS_token)
569
tgt_ids.append(EOS_token)
570
input_ids[idx, :len(inp_ids)] = inp_ids
571
target_ids[idx, :len(tgt_ids)] = tgt_ids
572
573
train_data = TensorDataset(torch.LongTensor(input_ids).to(device),
574
torch.LongTensor(target_ids).to(device))
575
576
train_sampler = RandomSampler(train_data)
577
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)
578
return input_lang, output_lang, train_dataloader
579
580
581
######################################################################
582
# Training the Model
583
# ------------------
584
#
585
# To train we run the input sentence through the encoder, and keep track
586
# of every output and the latest hidden state. Then the decoder is given
587
# the ``<SOS>`` token as its first input, and the last hidden state of the
588
# encoder as its first hidden state.
589
#
590
# "Teacher forcing" is the concept of using the real target outputs as
591
# each next input, instead of using the decoder's guess as the next input.
592
# Using teacher forcing causes it to converge faster but `when the trained
593
# network is exploited, it may exhibit
594
# instability <http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.378.4095&rep=rep1&type=pdf>`__.
595
#
596
# You can observe outputs of teacher-forced networks that read with
597
# coherent grammar but wander far from the correct translation -
598
# intuitively it has learned to represent the output grammar and can "pick
599
# up" the meaning once the teacher tells it the first few words, but it
600
# has not properly learned how to create the sentence from the translation
601
# in the first place.
602
#
603
# Because of the freedom PyTorch's autograd gives us, we can randomly
604
# choose to use teacher forcing or not with a simple if statement. Turn
605
# ``teacher_forcing_ratio`` up to use more of it.
606
#
607
608
def train_epoch(dataloader, encoder, decoder, encoder_optimizer,
609
decoder_optimizer, criterion):
610
611
total_loss = 0
612
for data in dataloader:
613
input_tensor, target_tensor = data
614
615
encoder_optimizer.zero_grad()
616
decoder_optimizer.zero_grad()
617
618
encoder_outputs, encoder_hidden = encoder(input_tensor)
619
decoder_outputs, _, _ = decoder(encoder_outputs, encoder_hidden, target_tensor)
620
621
loss = criterion(
622
decoder_outputs.view(-1, decoder_outputs.size(-1)),
623
target_tensor.view(-1)
624
)
625
loss.backward()
626
627
encoder_optimizer.step()
628
decoder_optimizer.step()
629
630
total_loss += loss.item()
631
632
return total_loss / len(dataloader)
633
634
635
######################################################################
636
# This is a helper function to print time elapsed and estimated time
637
# remaining given the current time and progress %.
638
#
639
640
import time
641
import math
642
643
def asMinutes(s):
644
m = math.floor(s / 60)
645
s -= m * 60
646
return '%dm %ds' % (m, s)
647
648
def timeSince(since, percent):
649
now = time.time()
650
s = now - since
651
es = s / (percent)
652
rs = es - s
653
return '%s (- %s)' % (asMinutes(s), asMinutes(rs))
654
655
656
######################################################################
657
# The whole training process looks like this:
658
#
659
# - Start a timer
660
# - Initialize optimizers and criterion
661
# - Create set of training pairs
662
# - Start empty losses array for plotting
663
#
664
# Then we call ``train`` many times and occasionally print the progress (%
665
# of examples, time so far, estimated time) and average loss.
666
#
667
668
def train(train_dataloader, encoder, decoder, n_epochs, learning_rate=0.001,
669
print_every=100, plot_every=100):
670
start = time.time()
671
plot_losses = []
672
print_loss_total = 0 # Reset every print_every
673
plot_loss_total = 0 # Reset every plot_every
674
675
encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
676
decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate)
677
criterion = nn.NLLLoss()
678
679
for epoch in range(1, n_epochs + 1):
680
loss = train_epoch(train_dataloader, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion)
681
print_loss_total += loss
682
plot_loss_total += loss
683
684
if epoch % print_every == 0:
685
print_loss_avg = print_loss_total / print_every
686
print_loss_total = 0
687
print('%s (%d %d%%) %.4f' % (timeSince(start, epoch / n_epochs),
688
epoch, epoch / n_epochs * 100, print_loss_avg))
689
690
if epoch % plot_every == 0:
691
plot_loss_avg = plot_loss_total / plot_every
692
plot_losses.append(plot_loss_avg)
693
plot_loss_total = 0
694
695
showPlot(plot_losses)
696
697
######################################################################
698
# Plotting results
699
# ----------------
700
#
701
# Plotting is done with matplotlib, using the array of loss values
702
# ``plot_losses`` saved while training.
703
#
704
705
import matplotlib.pyplot as plt
706
plt.switch_backend('agg')
707
import matplotlib.ticker as ticker
708
import numpy as np
709
710
def showPlot(points):
711
plt.figure()
712
fig, ax = plt.subplots()
713
# this locator puts ticks at regular intervals
714
loc = ticker.MultipleLocator(base=0.2)
715
ax.yaxis.set_major_locator(loc)
716
plt.plot(points)
717
718
719
######################################################################
720
# Evaluation
721
# ==========
722
#
723
# Evaluation is mostly the same as training, but there are no targets so
724
# we simply feed the decoder's predictions back to itself for each step.
725
# Every time it predicts a word we add it to the output string, and if it
726
# predicts the EOS token we stop there. We also store the decoder's
727
# attention outputs for display later.
728
#
729
730
def evaluate(encoder, decoder, sentence, input_lang, output_lang):
731
with torch.no_grad():
732
input_tensor = tensorFromSentence(input_lang, sentence)
733
734
encoder_outputs, encoder_hidden = encoder(input_tensor)
735
decoder_outputs, decoder_hidden, decoder_attn = decoder(encoder_outputs, encoder_hidden)
736
737
_, topi = decoder_outputs.topk(1)
738
decoded_ids = topi.squeeze()
739
740
decoded_words = []
741
for idx in decoded_ids:
742
if idx.item() == EOS_token:
743
decoded_words.append('<EOS>')
744
break
745
decoded_words.append(output_lang.index2word[idx.item()])
746
return decoded_words, decoder_attn
747
748
749
######################################################################
750
# We can evaluate random sentences from the training set and print out the
751
# input, target, and output to make some subjective quality judgements:
752
#
753
754
def evaluateRandomly(encoder, decoder, n=10):
755
for i in range(n):
756
pair = random.choice(pairs)
757
print('>', pair[0])
758
print('=', pair[1])
759
output_words, _ = evaluate(encoder, decoder, pair[0], input_lang, output_lang)
760
output_sentence = ' '.join(output_words)
761
print('<', output_sentence)
762
print('')
763
764
765
######################################################################
766
# Training and Evaluating
767
# =======================
768
#
769
# With all these helper functions in place (it looks like extra work, but
770
# it makes it easier to run multiple experiments) we can actually
771
# initialize a network and start training.
772
#
773
# Remember that the input sentences were heavily filtered. For this small
774
# dataset we can use relatively small networks of 256 hidden nodes and a
775
# single GRU layer. After about 40 minutes on a MacBook CPU we'll get some
776
# reasonable results.
777
#
778
# .. note::
779
# If you run this notebook you can train, interrupt the kernel,
780
# evaluate, and continue training later. Comment out the lines where the
781
# encoder and decoder are initialized and run ``trainIters`` again.
782
#
783
784
hidden_size = 128
785
batch_size = 32
786
787
input_lang, output_lang, train_dataloader = get_dataloader(batch_size)
788
789
encoder = EncoderRNN(input_lang.n_words, hidden_size).to(device)
790
decoder = AttnDecoderRNN(hidden_size, output_lang.n_words).to(device)
791
792
train(train_dataloader, encoder, decoder, 80, print_every=5, plot_every=5)
793
794
######################################################################
795
#
796
# Set dropout layers to ``eval`` mode
797
encoder.eval()
798
decoder.eval()
799
evaluateRandomly(encoder, decoder)
800
801
802
######################################################################
803
# Visualizing Attention
804
# ---------------------
805
#
806
# A useful property of the attention mechanism is its highly interpretable
807
# outputs. Because it is used to weight specific encoder outputs of the
808
# input sequence, we can imagine looking where the network is focused most
809
# at each time step.
810
#
811
# You could simply run ``plt.matshow(attentions)`` to see attention output
812
# displayed as a matrix. For a better viewing experience we will do the
813
# extra work of adding axes and labels:
814
#
815
816
def showAttention(input_sentence, output_words, attentions):
817
fig = plt.figure()
818
ax = fig.add_subplot(111)
819
cax = ax.matshow(attentions.cpu().numpy(), cmap='bone')
820
fig.colorbar(cax)
821
822
# Set up axes
823
ax.set_xticklabels([''] + input_sentence.split(' ') +
824
['<EOS>'], rotation=90)
825
ax.set_yticklabels([''] + output_words)
826
827
# Show label at every tick
828
ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
829
ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
830
831
plt.show()
832
833
834
def evaluateAndShowAttention(input_sentence):
835
output_words, attentions = evaluate(encoder, decoder, input_sentence, input_lang, output_lang)
836
print('input =', input_sentence)
837
print('output =', ' '.join(output_words))
838
showAttention(input_sentence, output_words, attentions[0, :len(output_words), :])
839
840
841
evaluateAndShowAttention('il n est pas aussi grand que son pere')
842
843
evaluateAndShowAttention('je suis trop fatigue pour conduire')
844
845
evaluateAndShowAttention('je suis desole si c est une question idiote')
846
847
evaluateAndShowAttention('je suis reellement fiere de vous')
848
849
850
######################################################################
851
# Exercises
852
# =========
853
#
854
# - Try with a different dataset
855
#
856
# - Another language pair
857
# - Human → Machine (e.g. IOT commands)
858
# - Chat → Response
859
# - Question → Answer
860
#
861
# - Replace the embeddings with pretrained word embeddings such as ``word2vec`` or
862
# ``GloVe``
863
# - Try with more layers, more hidden units, and more sentences. Compare
864
# the training time and results.
865
# - If you use a translation file where pairs have two of the same phrase
866
# (``I am test \t I am test``), you can use this as an autoencoder. Try
867
# this:
868
#
869
# - Train as an autoencoder
870
# - Save only the Encoder network
871
# - Train a new Decoder for translation from there
872
#
873
874