CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.

GitHub Repository: pytorch/tutorials
Path: blob/main/beginner_source/chatbot_tutorial.py
Views: 712
1
# -*- coding: utf-8 -*-
2
3
"""
4
Chatbot Tutorial
5
================
6
**Author:** `Matthew Inkawhich <https://github.com/MatthewInkawhich>`_
7
"""
8
9
10
######################################################################
11
# In this tutorial, we explore a fun and interesting use-case of recurrent
12
# sequence-to-sequence models. We will train a simple chatbot using movie
13
# scripts from the `Cornell Movie-Dialogs
14
# Corpus <https://www.cs.cornell.edu/~cristian/Cornell_Movie-Dialogs_Corpus.html>`__.
15
#
16
# Conversational models are a hot topic in artificial intelligence
17
# research. Chatbots can be found in a variety of settings, including
18
# customer service applications and online helpdesks. These bots are often
19
# powered by retrieval-based models, which output predefined responses to
20
# questions of certain forms. In a highly restricted domain like a
21
# company’s IT helpdesk, these models may be sufficient, however, they are
22
# not robust enough for more general use-cases. Teaching a machine to
23
# carry out a meaningful conversation with a human in multiple domains is
24
# a research question that is far from solved. Recently, the deep learning
25
# boom has allowed for powerful generative models like Google’s `Neural
26
# Conversational Model <https://arxiv.org/abs/1506.05869>`__, which marks
27
# a large step towards multi-domain generative conversational models. In
28
# this tutorial, we will implement this kind of model in PyTorch.
29
#
30
# .. figure:: /_static/img/chatbot/bot.png
31
# :align: center
32
# :alt: bot
33
#
34
# .. code-block:: python
35
#
36
# > hello?
37
# Bot: hello .
38
# > where am I?
39
# Bot: you re in a hospital .
40
# > who are you?
41
# Bot: i m a lawyer .
42
# > how are you doing?
43
# Bot: i m fine .
44
# > are you my friend?
45
# Bot: no .
46
# > you're under arrest
47
# Bot: i m trying to help you !
48
# > i'm just kidding
49
# Bot: i m sorry .
50
# > where are you from?
51
# Bot: san francisco .
52
# > it's time for me to leave
53
# Bot: i know .
54
# > goodbye
55
# Bot: goodbye .
56
#
57
# **Tutorial Highlights**
58
#
59
# - Handle loading and preprocessing of `Cornell Movie-Dialogs
60
# Corpus <https://www.cs.cornell.edu/~cristian/Cornell_Movie-Dialogs_Corpus.html>`__
61
# dataset
62
# - Implement a sequence-to-sequence model with `Luong attention
63
# mechanism(s) <https://arxiv.org/abs/1508.04025>`__
64
# - Jointly train encoder and decoder models using mini-batches
65
# - Implement greedy-search decoding module
66
# - Interact with trained chatbot
67
#
68
# **Acknowledgments**
69
#
70
# This tutorial borrows code from the following sources:
71
#
72
# 1) Yuan-Kuei Wu’s pytorch-chatbot implementation:
73
# https://github.com/ywk991112/pytorch-chatbot
74
#
75
# 2) Sean Robertson’s practical-pytorch seq2seq-translation example:
76
# https://github.com/spro/practical-pytorch/tree/master/seq2seq-translation
77
#
78
# 3) FloydHub Cornell Movie Corpus preprocessing code:
79
# https://github.com/floydhub/textutil-preprocess-cornell-movie-corpus
80
#
81
82
83
######################################################################
84
# Preparations
85
# ------------
86
#
87
# To get started, `download <https://zissou.infosci.cornell.edu/convokit/datasets/movie-corpus/movie-corpus.zip>`__ the Movie-Dialogs Corpus zip file.
88
89
# and put in a ``data/`` directory under the current directory.
90
#
91
# After that, let’s import some necessities.
92
#
93
94
import torch
95
from torch.jit import script, trace
96
import torch.nn as nn
97
from torch import optim
98
import torch.nn.functional as F
99
import csv
100
import random
101
import re
102
import os
103
import unicodedata
104
import codecs
105
from io import open
106
import itertools
107
import math
108
import json
109
110
111
USE_CUDA = torch.cuda.is_available()
112
device = torch.device("cuda" if USE_CUDA else "cpu")
113
114
115
######################################################################
116
# Load & Preprocess Data
117
# ----------------------
118
#
119
# The next step is to reformat our data file and load the data into
120
# structures that we can work with.
121
#
122
# The `Cornell Movie-Dialogs
123
# Corpus <https://www.cs.cornell.edu/~cristian/Cornell_Movie-Dialogs_Corpus.html>`__
124
# is a rich dataset of movie character dialog:
125
#
126
# - 220,579 conversational exchanges between 10,292 pairs of movie
127
# characters
128
# - 9,035 characters from 617 movies
129
# - 304,713 total utterances
130
#
131
# This dataset is large and diverse, and there is a great variation of
132
# language formality, time periods, sentiment, etc. Our hope is that this
133
# diversity makes our model robust to many forms of inputs and queries.
134
#
135
# First, we’ll take a look at some lines of our datafile to see the
136
# original format.
137
#
138
139
corpus_name = "movie-corpus"
140
corpus = os.path.join("data", corpus_name)
141
142
def printLines(file, n=10):
143
with open(file, 'rb') as datafile:
144
lines = datafile.readlines()
145
for line in lines[:n]:
146
print(line)
147
148
printLines(os.path.join(corpus, "utterances.jsonl"))
149
150
151
######################################################################
152
# Create formatted data file
153
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
154
#
155
# For convenience, we'll create a nicely formatted data file in which each line
156
# contains a tab-separated *query sentence* and a *response sentence* pair.
157
#
158
# The following functions facilitate the parsing of the raw
159
# ``utterances.jsonl`` data file.
160
#
161
# - ``loadLinesAndConversations`` splits each line of the file into a dictionary of
162
# lines with fields: ``lineID``, ``characterID``, and text and then groups them
163
# into conversations with fields: ``conversationID``, ``movieID``, and lines.
164
# - ``extractSentencePairs`` extracts pairs of sentences from
165
# conversations
166
#
167
168
# Splits each line of the file to create lines and conversations
169
def loadLinesAndConversations(fileName):
170
lines = {}
171
conversations = {}
172
with open(fileName, 'r', encoding='iso-8859-1') as f:
173
for line in f:
174
lineJson = json.loads(line)
175
# Extract fields for line object
176
lineObj = {}
177
lineObj["lineID"] = lineJson["id"]
178
lineObj["characterID"] = lineJson["speaker"]
179
lineObj["text"] = lineJson["text"]
180
lines[lineObj['lineID']] = lineObj
181
182
# Extract fields for conversation object
183
if lineJson["conversation_id"] not in conversations:
184
convObj = {}
185
convObj["conversationID"] = lineJson["conversation_id"]
186
convObj["movieID"] = lineJson["meta"]["movie_id"]
187
convObj["lines"] = [lineObj]
188
else:
189
convObj = conversations[lineJson["conversation_id"]]
190
convObj["lines"].insert(0, lineObj)
191
conversations[convObj["conversationID"]] = convObj
192
193
return lines, conversations
194
195
196
# Extracts pairs of sentences from conversations
197
def extractSentencePairs(conversations):
198
qa_pairs = []
199
for conversation in conversations.values():
200
# Iterate over all the lines of the conversation
201
for i in range(len(conversation["lines"]) - 1): # We ignore the last line (no answer for it)
202
inputLine = conversation["lines"][i]["text"].strip()
203
targetLine = conversation["lines"][i+1]["text"].strip()
204
# Filter wrong samples (if one of the lists is empty)
205
if inputLine and targetLine:
206
qa_pairs.append([inputLine, targetLine])
207
return qa_pairs
208
209
210
######################################################################
211
# Now we’ll call these functions and create the file. We’ll call it
212
# ``formatted_movie_lines.txt``.
213
#
214
215
# Define path to new file
216
datafile = os.path.join(corpus, "formatted_movie_lines.txt")
217
218
delimiter = '\t'
219
# Unescape the delimiter
220
delimiter = str(codecs.decode(delimiter, "unicode_escape"))
221
222
# Initialize lines dict and conversations dict
223
lines = {}
224
conversations = {}
225
# Load lines and conversations
226
print("\nProcessing corpus into lines and conversations...")
227
lines, conversations = loadLinesAndConversations(os.path.join(corpus, "utterances.jsonl"))
228
229
# Write new csv file
230
print("\nWriting newly formatted file...")
231
with open(datafile, 'w', encoding='utf-8') as outputfile:
232
writer = csv.writer(outputfile, delimiter=delimiter, lineterminator='\n')
233
for pair in extractSentencePairs(conversations):
234
writer.writerow(pair)
235
236
# Print a sample of lines
237
print("\nSample lines from file:")
238
printLines(datafile)
239
240
241
######################################################################
242
# Load and trim data
243
# ~~~~~~~~~~~~~~~~~~
244
#
245
# Our next order of business is to create a vocabulary and load
246
# query/response sentence pairs into memory.
247
#
248
# Note that we are dealing with sequences of **words**, which do not have
249
# an implicit mapping to a discrete numerical space. Thus, we must create
250
# one by mapping each unique word that we encounter in our dataset to an
251
# index value.
252
#
253
# For this we define a ``Voc`` class, which keeps a mapping from words to
254
# indexes, a reverse mapping of indexes to words, a count of each word and
255
# a total word count. The class provides methods for adding a word to the
256
# vocabulary (``addWord``), adding all words in a sentence
257
# (``addSentence``) and trimming infrequently seen words (``trim``). More
258
# on trimming later.
259
#
260
261
# Default word tokens
262
PAD_token = 0 # Used for padding short sentences
263
SOS_token = 1 # Start-of-sentence token
264
EOS_token = 2 # End-of-sentence token
265
266
class Voc:
267
def __init__(self, name):
268
self.name = name
269
self.trimmed = False
270
self.word2index = {}
271
self.word2count = {}
272
self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
273
self.num_words = 3 # Count SOS, EOS, PAD
274
275
def addSentence(self, sentence):
276
for word in sentence.split(' '):
277
self.addWord(word)
278
279
def addWord(self, word):
280
if word not in self.word2index:
281
self.word2index[word] = self.num_words
282
self.word2count[word] = 1
283
self.index2word[self.num_words] = word
284
self.num_words += 1
285
else:
286
self.word2count[word] += 1
287
288
# Remove words below a certain count threshold
289
def trim(self, min_count):
290
if self.trimmed:
291
return
292
self.trimmed = True
293
294
keep_words = []
295
296
for k, v in self.word2count.items():
297
if v >= min_count:
298
keep_words.append(k)
299
300
print('keep_words {} / {} = {:.4f}'.format(
301
len(keep_words), len(self.word2index), len(keep_words) / len(self.word2index)
302
))
303
304
# Reinitialize dictionaries
305
self.word2index = {}
306
self.word2count = {}
307
self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
308
self.num_words = 3 # Count default tokens
309
310
for word in keep_words:
311
self.addWord(word)
312
313
314
######################################################################
315
# Now we can assemble our vocabulary and query/response sentence pairs.
316
# Before we are ready to use this data, we must perform some
317
# preprocessing.
318
#
319
# First, we must convert the Unicode strings to ASCII using
320
# ``unicodeToAscii``. Next, we should convert all letters to lowercase and
321
# trim all non-letter characters except for basic punctuation
322
# (``normalizeString``). Finally, to aid in training convergence, we will
323
# filter out sentences with length greater than the ``MAX_LENGTH``
324
# threshold (``filterPairs``).
325
#
326
327
MAX_LENGTH = 10 # Maximum sentence length to consider
328
329
# Turn a Unicode string to plain ASCII, thanks to
330
# https://stackoverflow.com/a/518232/2809427
331
def unicodeToAscii(s):
332
return ''.join(
333
c for c in unicodedata.normalize('NFD', s)
334
if unicodedata.category(c) != 'Mn'
335
)
336
337
# Lowercase, trim, and remove non-letter characters
338
def normalizeString(s):
339
s = unicodeToAscii(s.lower().strip())
340
s = re.sub(r"([.!?])", r" \1", s)
341
s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
342
s = re.sub(r"\s+", r" ", s).strip()
343
return s
344
345
# Read query/response pairs and return a voc object
346
def readVocs(datafile, corpus_name):
347
print("Reading lines...")
348
# Read the file and split into lines
349
lines = open(datafile, encoding='utf-8').\
350
read().strip().split('\n')
351
# Split every line into pairs and normalize
352
pairs = [[normalizeString(s) for s in l.split('\t')] for l in lines]
353
voc = Voc(corpus_name)
354
return voc, pairs
355
356
# Returns True if both sentences in a pair 'p' are under the MAX_LENGTH threshold
357
def filterPair(p):
358
# Input sequences need to preserve the last word for EOS token
359
return len(p[0].split(' ')) < MAX_LENGTH and len(p[1].split(' ')) < MAX_LENGTH
360
361
# Filter pairs using the ``filterPair`` condition
362
def filterPairs(pairs):
363
return [pair for pair in pairs if filterPair(pair)]
364
365
# Using the functions defined above, return a populated voc object and pairs list
366
def loadPrepareData(corpus, corpus_name, datafile, save_dir):
367
print("Start preparing training data ...")
368
voc, pairs = readVocs(datafile, corpus_name)
369
print("Read {!s} sentence pairs".format(len(pairs)))
370
pairs = filterPairs(pairs)
371
print("Trimmed to {!s} sentence pairs".format(len(pairs)))
372
print("Counting words...")
373
for pair in pairs:
374
voc.addSentence(pair[0])
375
voc.addSentence(pair[1])
376
print("Counted words:", voc.num_words)
377
return voc, pairs
378
379
380
# Load/Assemble voc and pairs
381
save_dir = os.path.join("data", "save")
382
voc, pairs = loadPrepareData(corpus, corpus_name, datafile, save_dir)
383
# Print some pairs to validate
384
print("\npairs:")
385
for pair in pairs[:10]:
386
print(pair)
387
388
389
######################################################################
390
# Another tactic that is beneficial to achieving faster convergence during
391
# training is trimming rarely used words out of our vocabulary. Decreasing
392
# the feature space will also soften the difficulty of the function that
393
# the model must learn to approximate. We will do this as a two-step
394
# process:
395
#
396
# 1) Trim words used under ``MIN_COUNT`` threshold using the ``voc.trim``
397
# function.
398
#
399
# 2) Filter out pairs with trimmed words.
400
#
401
402
MIN_COUNT = 3 # Minimum word count threshold for trimming
403
404
def trimRareWords(voc, pairs, MIN_COUNT):
405
# Trim words used under the MIN_COUNT from the voc
406
voc.trim(MIN_COUNT)
407
# Filter out pairs with trimmed words
408
keep_pairs = []
409
for pair in pairs:
410
input_sentence = pair[0]
411
output_sentence = pair[1]
412
keep_input = True
413
keep_output = True
414
# Check input sentence
415
for word in input_sentence.split(' '):
416
if word not in voc.word2index:
417
keep_input = False
418
break
419
# Check output sentence
420
for word in output_sentence.split(' '):
421
if word not in voc.word2index:
422
keep_output = False
423
break
424
425
# Only keep pairs that do not contain trimmed word(s) in their input or output sentence
426
if keep_input and keep_output:
427
keep_pairs.append(pair)
428
429
print("Trimmed from {} pairs to {}, {:.4f} of total".format(len(pairs), len(keep_pairs), len(keep_pairs) / len(pairs)))
430
return keep_pairs
431
432
433
# Trim voc and pairs
434
pairs = trimRareWords(voc, pairs, MIN_COUNT)
435
436
437
######################################################################
438
# Prepare Data for Models
439
# -----------------------
440
#
441
# Although we have put a great deal of effort into preparing and massaging our
442
# data into a nice vocabulary object and list of sentence pairs, our models
443
# will ultimately expect numerical torch tensors as inputs. One way to
444
# prepare the processed data for the models can be found in the `seq2seq
445
# translation
446
# tutorial <https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html>`__.
447
# In that tutorial, we use a batch size of 1, meaning that all we have to
448
# do is convert the words in our sentence pairs to their corresponding
449
# indexes from the vocabulary and feed this to the models.
450
#
451
# However, if you’re interested in speeding up training and/or would like
452
# to leverage GPU parallelization capabilities, you will need to train
453
# with mini-batches.
454
#
455
# Using mini-batches also means that we must be mindful of the variation
456
# of sentence length in our batches. To accommodate sentences of different
457
# sizes in the same batch, we will make our batched input tensor of shape
458
# *(max_length, batch_size)*, where sentences shorter than the
459
# *max_length* are zero padded after an *EOS_token*.
460
#
461
# If we simply convert our English sentences to tensors by converting
462
# words to their indexes(\ ``indexesFromSentence``) and zero-pad, our
463
# tensor would have shape *(batch_size, max_length)* and indexing the
464
# first dimension would return a full sequence across all time-steps.
465
# However, we need to be able to index our batch along time, and across
466
# all sequences in the batch. Therefore, we transpose our input batch
467
# shape to *(max_length, batch_size)*, so that indexing across the first
468
# dimension returns a time step across all sentences in the batch. We
469
# handle this transpose implicitly in the ``zeroPadding`` function.
470
#
471
# .. figure:: /_static/img/chatbot/seq2seq_batches.png
472
# :align: center
473
# :alt: batches
474
#
475
# The ``inputVar`` function handles the process of converting sentences to
476
# tensor, ultimately creating a correctly shaped zero-padded tensor. It
477
# also returns a tensor of ``lengths`` for each of the sequences in the
478
# batch which will be passed to our decoder later.
479
#
480
# The ``outputVar`` function performs a similar function to ``inputVar``,
481
# but instead of returning a ``lengths`` tensor, it returns a binary mask
482
# tensor and a maximum target sentence length. The binary mask tensor has
483
# the same shape as the output target tensor, but every element that is a
484
# *PAD_token* is 0 and all others are 1.
485
#
486
# ``batch2TrainData`` simply takes a bunch of pairs and returns the input
487
# and target tensors using the aforementioned functions.
488
#
489
490
def indexesFromSentence(voc, sentence):
491
return [voc.word2index[word] for word in sentence.split(' ')] + [EOS_token]
492
493
494
def zeroPadding(l, fillvalue=PAD_token):
495
return list(itertools.zip_longest(*l, fillvalue=fillvalue))
496
497
def binaryMatrix(l, value=PAD_token):
498
m = []
499
for i, seq in enumerate(l):
500
m.append([])
501
for token in seq:
502
if token == PAD_token:
503
m[i].append(0)
504
else:
505
m[i].append(1)
506
return m
507
508
# Returns padded input sequence tensor and lengths
509
def inputVar(l, voc):
510
indexes_batch = [indexesFromSentence(voc, sentence) for sentence in l]
511
lengths = torch.tensor([len(indexes) for indexes in indexes_batch])
512
padList = zeroPadding(indexes_batch)
513
padVar = torch.LongTensor(padList)
514
return padVar, lengths
515
516
# Returns padded target sequence tensor, padding mask, and max target length
517
def outputVar(l, voc):
518
indexes_batch = [indexesFromSentence(voc, sentence) for sentence in l]
519
max_target_len = max([len(indexes) for indexes in indexes_batch])
520
padList = zeroPadding(indexes_batch)
521
mask = binaryMatrix(padList)
522
mask = torch.BoolTensor(mask)
523
padVar = torch.LongTensor(padList)
524
return padVar, mask, max_target_len
525
526
# Returns all items for a given batch of pairs
527
def batch2TrainData(voc, pair_batch):
528
pair_batch.sort(key=lambda x: len(x[0].split(" ")), reverse=True)
529
input_batch, output_batch = [], []
530
for pair in pair_batch:
531
input_batch.append(pair[0])
532
output_batch.append(pair[1])
533
inp, lengths = inputVar(input_batch, voc)
534
output, mask, max_target_len = outputVar(output_batch, voc)
535
return inp, lengths, output, mask, max_target_len
536
537
538
# Example for validation
539
small_batch_size = 5
540
batches = batch2TrainData(voc, [random.choice(pairs) for _ in range(small_batch_size)])
541
input_variable, lengths, target_variable, mask, max_target_len = batches
542
543
print("input_variable:", input_variable)
544
print("lengths:", lengths)
545
print("target_variable:", target_variable)
546
print("mask:", mask)
547
print("max_target_len:", max_target_len)
548
549
550
######################################################################
551
# Define Models
552
# -------------
553
#
554
# Seq2Seq Model
555
# ~~~~~~~~~~~~~
556
#
557
# The brains of our chatbot is a sequence-to-sequence (seq2seq) model. The
558
# goal of a seq2seq model is to take a variable-length sequence as an
559
# input, and return a variable-length sequence as an output using a
560
# fixed-sized model.
561
#
562
# `Sutskever et al. <https://arxiv.org/abs/1409.3215>`__ discovered that
563
# by using two separate recurrent neural nets together, we can accomplish
564
# this task. One RNN acts as an **encoder**, which encodes a variable
565
# length input sequence to a fixed-length context vector. In theory, this
566
# context vector (the final hidden layer of the RNN) will contain semantic
567
# information about the query sentence that is input to the bot. The
568
# second RNN is a **decoder**, which takes an input word and the context
569
# vector, and returns a guess for the next word in the sequence and a
570
# hidden state to use in the next iteration.
571
#
572
# .. figure:: /_static/img/chatbot/seq2seq_ts.png
573
# :align: center
574
# :alt: model
575
#
576
# Image source:
577
# https://jeddy92.github.io/JEddy92.github.io/ts_seq2seq_intro/
578
#
579
580
581
######################################################################
582
# Encoder
583
# ~~~~~~~
584
#
585
# The encoder RNN iterates through the input sentence one token
586
# (e.g. word) at a time, at each time step outputting an “output” vector
587
# and a “hidden state” vector. The hidden state vector is then passed to
588
# the next time step, while the output vector is recorded. The encoder
589
# transforms the context it saw at each point in the sequence into a set
590
# of points in a high-dimensional space, which the decoder will use to
591
# generate a meaningful output for the given task.
592
#
593
# At the heart of our encoder is a multi-layered Gated Recurrent Unit,
594
# invented by `Cho et al. <https://arxiv.org/pdf/1406.1078v3.pdf>`__ in
595
# 2014. We will use a bidirectional variant of the GRU, meaning that there
596
# are essentially two independent RNNs: one that is fed the input sequence
597
# in normal sequential order, and one that is fed the input sequence in
598
# reverse order. The outputs of each network are summed at each time step.
599
# Using a bidirectional GRU will give us the advantage of encoding both
600
# past and future contexts.
601
#
602
# Bidirectional RNN:
603
#
604
# .. figure:: /_static/img/chatbot/RNN-bidirectional.png
605
# :width: 70%
606
# :align: center
607
# :alt: rnn_bidir
608
#
609
# Image source: https://colah.github.io/posts/2015-09-NN-Types-FP/
610
#
611
# Note that an ``embedding`` layer is used to encode our word indices in
612
# an arbitrarily sized feature space. For our models, this layer will map
613
# each word to a feature space of size *hidden_size*. When trained, these
614
# values should encode semantic similarity between similar meaning words.
615
#
616
# Finally, if passing a padded batch of sequences to an RNN module, we
617
# must pack and unpack padding around the RNN pass using
618
# ``nn.utils.rnn.pack_padded_sequence`` and
619
# ``nn.utils.rnn.pad_packed_sequence`` respectively.
620
#
621
# **Computation Graph:**
622
#
623
# 1) Convert word indexes to embeddings.
624
# 2) Pack padded batch of sequences for RNN module.
625
# 3) Forward pass through GRU.
626
# 4) Unpack padding.
627
# 5) Sum bidirectional GRU outputs.
628
# 6) Return output and final hidden state.
629
#
630
# **Inputs:**
631
#
632
# - ``input_seq``: batch of input sentences; shape=\ *(max_length,
633
# batch_size)*
634
# - ``input_lengths``: list of sentence lengths corresponding to each
635
# sentence in the batch; shape=\ *(batch_size)*
636
# - ``hidden``: hidden state; shape=\ *(n_layers x num_directions,
637
# batch_size, hidden_size)*
638
#
639
# **Outputs:**
640
#
641
# - ``outputs``: output features from the last hidden layer of the GRU
642
# (sum of bidirectional outputs); shape=\ *(max_length, batch_size,
643
# hidden_size)*
644
# - ``hidden``: updated hidden state from GRU; shape=\ *(n_layers x
645
# num_directions, batch_size, hidden_size)*
646
#
647
#
648
649
class EncoderRNN(nn.Module):
650
def __init__(self, hidden_size, embedding, n_layers=1, dropout=0):
651
super(EncoderRNN, self).__init__()
652
self.n_layers = n_layers
653
self.hidden_size = hidden_size
654
self.embedding = embedding
655
656
# Initialize GRU; the input_size and hidden_size parameters are both set to 'hidden_size'
657
# because our input size is a word embedding with number of features == hidden_size
658
self.gru = nn.GRU(hidden_size, hidden_size, n_layers,
659
dropout=(0 if n_layers == 1 else dropout), bidirectional=True)
660
661
def forward(self, input_seq, input_lengths, hidden=None):
662
# Convert word indexes to embeddings
663
embedded = self.embedding(input_seq)
664
# Pack padded batch of sequences for RNN module
665
packed = nn.utils.rnn.pack_padded_sequence(embedded, input_lengths)
666
# Forward pass through GRU
667
outputs, hidden = self.gru(packed, hidden)
668
# Unpack padding
669
outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs)
670
# Sum bidirectional GRU outputs
671
outputs = outputs[:, :, :self.hidden_size] + outputs[:, : ,self.hidden_size:]
672
# Return output and final hidden state
673
return outputs, hidden
674
675
676
######################################################################
677
# Decoder
678
# ~~~~~~~
679
#
680
# The decoder RNN generates the response sentence in a token-by-token
681
# fashion. It uses the encoder’s context vectors, and internal hidden
682
# states to generate the next word in the sequence. It continues
683
# generating words until it outputs an *EOS_token*, representing the end
684
# of the sentence. A common problem with a vanilla seq2seq decoder is that
685
# if we rely solely on the context vector to encode the entire input
686
# sequence’s meaning, it is likely that we will have information loss.
687
# This is especially the case when dealing with long input sequences,
688
# greatly limiting the capability of our decoder.
689
#
690
# To combat this, `Bahdanau et al. <https://arxiv.org/abs/1409.0473>`__
691
# created an “attention mechanism” that allows the decoder to pay
692
# attention to certain parts of the input sequence, rather than using the
693
# entire fixed context at every step.
694
#
695
# At a high level, attention is calculated using the decoder’s current
696
# hidden state and the encoder’s outputs. The output attention weights
697
# have the same shape as the input sequence, allowing us to multiply them
698
# by the encoder outputs, giving us a weighted sum which indicates the
699
# parts of encoder output to pay attention to. `Sean
700
# Robertson’s <https://github.com/spro>`__ figure describes this very
701
# well:
702
#
703
# .. figure:: /_static/img/chatbot/attn2.png
704
# :align: center
705
# :alt: attn2
706
#
707
# `Luong et al. <https://arxiv.org/abs/1508.04025>`__ improved upon
708
# Bahdanau et al.’s groundwork by creating “Global attention”. The key
709
# difference is that with “Global attention”, we consider all of the
710
# encoder’s hidden states, as opposed to Bahdanau et al.’s “Local
711
# attention”, which only considers the encoder’s hidden state from the
712
# current time step. Another difference is that with “Global attention”,
713
# we calculate attention weights, or energies, using the hidden state of
714
# the decoder from the current time step only. Bahdanau et al.’s attention
715
# calculation requires knowledge of the decoder’s state from the previous
716
# time step. Also, Luong et al. provides various methods to calculate the
717
# attention energies between the encoder output and decoder output which
718
# are called “score functions”:
719
#
720
# .. figure:: /_static/img/chatbot/scores.png
721
# :width: 60%
722
# :align: center
723
# :alt: scores
724
#
725
# where :math:`h_t` = current target decoder state and :math:`\bar{h}_s` =
726
# all encoder states.
727
#
728
# Overall, the Global attention mechanism can be summarized by the
729
# following figure. Note that we will implement the “Attention Layer” as a
730
# separate ``nn.Module`` called ``Attn``. The output of this module is a
731
# softmax normalized weights tensor of shape *(batch_size, 1,
732
# max_length)*.
733
#
734
# .. figure:: /_static/img/chatbot/global_attn.png
735
# :align: center
736
# :width: 60%
737
# :alt: global_attn
738
#
739
740
# Luong attention layer
741
class Attn(nn.Module):
742
def __init__(self, method, hidden_size):
743
super(Attn, self).__init__()
744
self.method = method
745
if self.method not in ['dot', 'general', 'concat']:
746
raise ValueError(self.method, "is not an appropriate attention method.")
747
self.hidden_size = hidden_size
748
if self.method == 'general':
749
self.attn = nn.Linear(self.hidden_size, hidden_size)
750
elif self.method == 'concat':
751
self.attn = nn.Linear(self.hidden_size * 2, hidden_size)
752
self.v = nn.Parameter(torch.FloatTensor(hidden_size))
753
754
def dot_score(self, hidden, encoder_output):
755
return torch.sum(hidden * encoder_output, dim=2)
756
757
def general_score(self, hidden, encoder_output):
758
energy = self.attn(encoder_output)
759
return torch.sum(hidden * energy, dim=2)
760
761
def concat_score(self, hidden, encoder_output):
762
energy = self.attn(torch.cat((hidden.expand(encoder_output.size(0), -1, -1), encoder_output), 2)).tanh()
763
return torch.sum(self.v * energy, dim=2)
764
765
def forward(self, hidden, encoder_outputs):
766
# Calculate the attention weights (energies) based on the given method
767
if self.method == 'general':
768
attn_energies = self.general_score(hidden, encoder_outputs)
769
elif self.method == 'concat':
770
attn_energies = self.concat_score(hidden, encoder_outputs)
771
elif self.method == 'dot':
772
attn_energies = self.dot_score(hidden, encoder_outputs)
773
774
# Transpose max_length and batch_size dimensions
775
attn_energies = attn_energies.t()
776
777
# Return the softmax normalized probability scores (with added dimension)
778
return F.softmax(attn_energies, dim=1).unsqueeze(1)
779
780
781
######################################################################
782
# Now that we have defined our attention submodule, we can implement the
783
# actual decoder model. For the decoder, we will manually feed our batch
784
# one time step at a time. This means that our embedded word tensor and
785
# GRU output will both have shape *(1, batch_size, hidden_size)*.
786
#
787
# **Computation Graph:**
788
#
789
# 1) Get embedding of current input word.
790
# 2) Forward through unidirectional GRU.
791
# 3) Calculate attention weights from the current GRU output from (2).
792
# 4) Multiply attention weights to encoder outputs to get new "weighted sum" context vector.
793
# 5) Concatenate weighted context vector and GRU output using Luong eq. 5.
794
# 6) Predict next word using Luong eq. 6 (without softmax).
795
# 7) Return output and final hidden state.
796
#
797
# **Inputs:**
798
#
799
# - ``input_step``: one time step (one word) of input sequence batch;
800
# shape=\ *(1, batch_size)*
801
# - ``last_hidden``: final hidden layer of GRU; shape=\ *(n_layers x
802
# num_directions, batch_size, hidden_size)*
803
# - ``encoder_outputs``: encoder model’s output; shape=\ *(max_length,
804
# batch_size, hidden_size)*
805
#
806
# **Outputs:**
807
#
808
# - ``output``: softmax normalized tensor giving probabilities of each
809
# word being the correct next word in the decoded sequence;
810
# shape=\ *(batch_size, voc.num_words)*
811
# - ``hidden``: final hidden state of GRU; shape=\ *(n_layers x
812
# num_directions, batch_size, hidden_size)*
813
#
814
815
class LuongAttnDecoderRNN(nn.Module):
816
def __init__(self, attn_model, embedding, hidden_size, output_size, n_layers=1, dropout=0.1):
817
super(LuongAttnDecoderRNN, self).__init__()
818
819
# Keep for reference
820
self.attn_model = attn_model
821
self.hidden_size = hidden_size
822
self.output_size = output_size
823
self.n_layers = n_layers
824
self.dropout = dropout
825
826
# Define layers
827
self.embedding = embedding
828
self.embedding_dropout = nn.Dropout(dropout)
829
self.gru = nn.GRU(hidden_size, hidden_size, n_layers, dropout=(0 if n_layers == 1 else dropout))
830
self.concat = nn.Linear(hidden_size * 2, hidden_size)
831
self.out = nn.Linear(hidden_size, output_size)
832
833
self.attn = Attn(attn_model, hidden_size)
834
835
def forward(self, input_step, last_hidden, encoder_outputs):
836
# Note: we run this one step (word) at a time
837
# Get embedding of current input word
838
embedded = self.embedding(input_step)
839
embedded = self.embedding_dropout(embedded)
840
# Forward through unidirectional GRU
841
rnn_output, hidden = self.gru(embedded, last_hidden)
842
# Calculate attention weights from the current GRU output
843
attn_weights = self.attn(rnn_output, encoder_outputs)
844
# Multiply attention weights to encoder outputs to get new "weighted sum" context vector
845
context = attn_weights.bmm(encoder_outputs.transpose(0, 1))
846
# Concatenate weighted context vector and GRU output using Luong eq. 5
847
rnn_output = rnn_output.squeeze(0)
848
context = context.squeeze(1)
849
concat_input = torch.cat((rnn_output, context), 1)
850
concat_output = torch.tanh(self.concat(concat_input))
851
# Predict next word using Luong eq. 6
852
output = self.out(concat_output)
853
output = F.softmax(output, dim=1)
854
# Return output and final hidden state
855
return output, hidden
856
857
858
######################################################################
859
# Define Training Procedure
860
# -------------------------
861
#
862
# Masked loss
863
# ~~~~~~~~~~~
864
#
865
# Since we are dealing with batches of padded sequences, we cannot simply
866
# consider all elements of the tensor when calculating loss. We define
867
# ``maskNLLLoss`` to calculate our loss based on our decoder’s output
868
# tensor, the target tensor, and a binary mask tensor describing the
869
# padding of the target tensor. This loss function calculates the average
870
# negative log likelihood of the elements that correspond to a *1* in the
871
# mask tensor.
872
#
873
874
def maskNLLLoss(inp, target, mask):
875
nTotal = mask.sum()
876
crossEntropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1)).squeeze(1))
877
loss = crossEntropy.masked_select(mask).mean()
878
loss = loss.to(device)
879
return loss, nTotal.item()
880
881
882
######################################################################
883
# Single training iteration
884
# ~~~~~~~~~~~~~~~~~~~~~~~~~
885
#
886
# The ``train`` function contains the algorithm for a single training
887
# iteration (a single batch of inputs).
888
#
889
# We will use a couple of clever tricks to aid in convergence:
890
#
891
# - The first trick is using **teacher forcing**. This means that at some
892
# probability, set by ``teacher_forcing_ratio``, we use the current
893
# target word as the decoder’s next input rather than using the
894
# decoder’s current guess. This technique acts as training wheels for
895
# the decoder, aiding in more efficient training. However, teacher
896
# forcing can lead to model instability during inference, as the
897
# decoder may not have a sufficient chance to truly craft its own
898
# output sequences during training. Thus, we must be mindful of how we
899
# are setting the ``teacher_forcing_ratio``, and not be fooled by fast
900
# convergence.
901
#
902
# - The second trick that we implement is **gradient clipping**. This is
903
# a commonly used technique for countering the “exploding gradient”
904
# problem. In essence, by clipping or thresholding gradients to a
905
# maximum value, we prevent the gradients from growing exponentially
906
# and either overflow (NaN), or overshoot steep cliffs in the cost
907
# function.
908
#
909
# .. figure:: /_static/img/chatbot/grad_clip.png
910
# :align: center
911
# :width: 60%
912
# :alt: grad_clip
913
#
914
# Image source: Goodfellow et al. *Deep Learning*. 2016. https://www.deeplearningbook.org/
915
#
916
# **Sequence of Operations:**
917
#
918
# 1) Forward pass entire input batch through encoder.
919
# 2) Initialize decoder inputs as SOS_token, and hidden state as the encoder's final hidden state.
920
# 3) Forward input batch sequence through decoder one time step at a time.
921
# 4) If teacher forcing: set next decoder input as the current target; else: set next decoder input as current decoder output.
922
# 5) Calculate and accumulate loss.
923
# 6) Perform backpropagation.
924
# 7) Clip gradients.
925
# 8) Update encoder and decoder model parameters.
926
#
927
#
928
# .. Note ::
929
#
930
# PyTorch’s RNN modules (``RNN``, ``LSTM``, ``GRU``) can be used like any
931
# other non-recurrent layers by simply passing them the entire input
932
# sequence (or batch of sequences). We use the ``GRU`` layer like this in
933
# the ``encoder``. The reality is that under the hood, there is an
934
# iterative process looping over each time step calculating hidden states.
935
# Alternatively, you can run these modules one time-step at a time. In
936
# this case, we manually loop over the sequences during the training
937
# process like we must do for the ``decoder`` model. As long as you
938
# maintain the correct conceptual model of these modules, implementing
939
# sequential models can be very straightforward.
940
#
941
#
942
943
944
def train(input_variable, lengths, target_variable, mask, max_target_len, encoder, decoder, embedding,
945
encoder_optimizer, decoder_optimizer, batch_size, clip, max_length=MAX_LENGTH):
946
947
# Zero gradients
948
encoder_optimizer.zero_grad()
949
decoder_optimizer.zero_grad()
950
951
# Set device options
952
input_variable = input_variable.to(device)
953
target_variable = target_variable.to(device)
954
mask = mask.to(device)
955
# Lengths for RNN packing should always be on the CPU
956
lengths = lengths.to("cpu")
957
958
# Initialize variables
959
loss = 0
960
print_losses = []
961
n_totals = 0
962
963
# Forward pass through encoder
964
encoder_outputs, encoder_hidden = encoder(input_variable, lengths)
965
966
# Create initial decoder input (start with SOS tokens for each sentence)
967
decoder_input = torch.LongTensor([[SOS_token for _ in range(batch_size)]])
968
decoder_input = decoder_input.to(device)
969
970
# Set initial decoder hidden state to the encoder's final hidden state
971
decoder_hidden = encoder_hidden[:decoder.n_layers]
972
973
# Determine if we are using teacher forcing this iteration
974
use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False
975
976
# Forward batch of sequences through decoder one time step at a time
977
if use_teacher_forcing:
978
for t in range(max_target_len):
979
decoder_output, decoder_hidden = decoder(
980
decoder_input, decoder_hidden, encoder_outputs
981
)
982
# Teacher forcing: next input is current target
983
decoder_input = target_variable[t].view(1, -1)
984
# Calculate and accumulate loss
985
mask_loss, nTotal = maskNLLLoss(decoder_output, target_variable[t], mask[t])
986
loss += mask_loss
987
print_losses.append(mask_loss.item() * nTotal)
988
n_totals += nTotal
989
else:
990
for t in range(max_target_len):
991
decoder_output, decoder_hidden = decoder(
992
decoder_input, decoder_hidden, encoder_outputs
993
)
994
# No teacher forcing: next input is decoder's own current output
995
_, topi = decoder_output.topk(1)
996
decoder_input = torch.LongTensor([[topi[i][0] for i in range(batch_size)]])
997
decoder_input = decoder_input.to(device)
998
# Calculate and accumulate loss
999
mask_loss, nTotal = maskNLLLoss(decoder_output, target_variable[t], mask[t])
1000
loss += mask_loss
1001
print_losses.append(mask_loss.item() * nTotal)
1002
n_totals += nTotal
1003
1004
# Perform backpropagation
1005
loss.backward()
1006
1007
# Clip gradients: gradients are modified in place
1008
_ = nn.utils.clip_grad_norm_(encoder.parameters(), clip)
1009
_ = nn.utils.clip_grad_norm_(decoder.parameters(), clip)
1010
1011
# Adjust model weights
1012
encoder_optimizer.step()
1013
decoder_optimizer.step()
1014
1015
return sum(print_losses) / n_totals
1016
1017
1018
######################################################################
1019
# Training iterations
1020
# ~~~~~~~~~~~~~~~~~~~
1021
#
1022
# It is finally time to tie the full training procedure together with the
1023
# data. The ``trainIters`` function is responsible for running
1024
# ``n_iterations`` of training given the passed models, optimizers, data,
1025
# etc. This function is quite self explanatory, as we have done the heavy
1026
# lifting with the ``train`` function.
1027
#
1028
# One thing to note is that when we save our model, we save a tarball
1029
# containing the encoder and decoder ``state_dicts`` (parameters), the
1030
# optimizers’ ``state_dicts``, the loss, the iteration, etc. Saving the model
1031
# in this way will give us the ultimate flexibility with the checkpoint.
1032
# After loading a checkpoint, we will be able to use the model parameters
1033
# to run inference, or we can continue training right where we left off.
1034
#
1035
1036
def trainIters(model_name, voc, pairs, encoder, decoder, encoder_optimizer, decoder_optimizer, embedding, encoder_n_layers, decoder_n_layers, save_dir, n_iteration, batch_size, print_every, save_every, clip, corpus_name, loadFilename):
1037
1038
# Load batches for each iteration
1039
training_batches = [batch2TrainData(voc, [random.choice(pairs) for _ in range(batch_size)])
1040
for _ in range(n_iteration)]
1041
1042
# Initializations
1043
print('Initializing ...')
1044
start_iteration = 1
1045
print_loss = 0
1046
if loadFilename:
1047
start_iteration = checkpoint['iteration'] + 1
1048
1049
# Training loop
1050
print("Training...")
1051
for iteration in range(start_iteration, n_iteration + 1):
1052
training_batch = training_batches[iteration - 1]
1053
# Extract fields from batch
1054
input_variable, lengths, target_variable, mask, max_target_len = training_batch
1055
1056
# Run a training iteration with batch
1057
loss = train(input_variable, lengths, target_variable, mask, max_target_len, encoder,
1058
decoder, embedding, encoder_optimizer, decoder_optimizer, batch_size, clip)
1059
print_loss += loss
1060
1061
# Print progress
1062
if iteration % print_every == 0:
1063
print_loss_avg = print_loss / print_every
1064
print("Iteration: {}; Percent complete: {:.1f}%; Average loss: {:.4f}".format(iteration, iteration / n_iteration * 100, print_loss_avg))
1065
print_loss = 0
1066
1067
# Save checkpoint
1068
if (iteration % save_every == 0):
1069
directory = os.path.join(save_dir, model_name, corpus_name, '{}-{}_{}'.format(encoder_n_layers, decoder_n_layers, hidden_size))
1070
if not os.path.exists(directory):
1071
os.makedirs(directory)
1072
torch.save({
1073
'iteration': iteration,
1074
'en': encoder.state_dict(),
1075
'de': decoder.state_dict(),
1076
'en_opt': encoder_optimizer.state_dict(),
1077
'de_opt': decoder_optimizer.state_dict(),
1078
'loss': loss,
1079
'voc_dict': voc.__dict__,
1080
'embedding': embedding.state_dict()
1081
}, os.path.join(directory, '{}_{}.tar'.format(iteration, 'checkpoint')))
1082
1083
1084
######################################################################
1085
# Define Evaluation
1086
# -----------------
1087
#
1088
# After training a model, we want to be able to talk to the bot ourselves.
1089
# First, we must define how we want the model to decode the encoded input.
1090
#
1091
# Greedy decoding
1092
# ~~~~~~~~~~~~~~~
1093
#
1094
# Greedy decoding is the decoding method that we use during training when
1095
# we are **NOT** using teacher forcing. In other words, for each time
1096
# step, we simply choose the word from ``decoder_output`` with the highest
1097
# softmax value. This decoding method is optimal on a single time-step
1098
# level.
1099
#
1100
# To facilitate the greedy decoding operation, we define a
1101
# ``GreedySearchDecoder`` class. When run, an object of this class takes
1102
# an input sequence (``input_seq``) of shape *(input_seq length, 1)*, a
1103
# scalar input length (``input_length``) tensor, and a ``max_length`` to
1104
# bound the response sentence length. The input sentence is evaluated
1105
# using the following computational graph:
1106
#
1107
# **Computation Graph:**
1108
#
1109
# 1) Forward input through encoder model.
1110
# 2) Prepare encoder's final hidden layer to be first hidden input to the decoder.
1111
# 3) Initialize decoder's first input as SOS_token.
1112
# 4) Initialize tensors to append decoded words to.
1113
# 5) Iteratively decode one word token at a time:
1114
# a) Forward pass through decoder.
1115
# b) Obtain most likely word token and its softmax score.
1116
# c) Record token and score.
1117
# d) Prepare current token to be next decoder input.
1118
# 6) Return collections of word tokens and scores.
1119
#
1120
1121
class GreedySearchDecoder(nn.Module):
1122
def __init__(self, encoder, decoder):
1123
super(GreedySearchDecoder, self).__init__()
1124
self.encoder = encoder
1125
self.decoder = decoder
1126
1127
def forward(self, input_seq, input_length, max_length):
1128
# Forward input through encoder model
1129
encoder_outputs, encoder_hidden = self.encoder(input_seq, input_length)
1130
# Prepare encoder's final hidden layer to be first hidden input to the decoder
1131
decoder_hidden = encoder_hidden[:decoder.n_layers]
1132
# Initialize decoder input with SOS_token
1133
decoder_input = torch.ones(1, 1, device=device, dtype=torch.long) * SOS_token
1134
# Initialize tensors to append decoded words to
1135
all_tokens = torch.zeros([0], device=device, dtype=torch.long)
1136
all_scores = torch.zeros([0], device=device)
1137
# Iteratively decode one word token at a time
1138
for _ in range(max_length):
1139
# Forward pass through decoder
1140
decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden, encoder_outputs)
1141
# Obtain most likely word token and its softmax score
1142
decoder_scores, decoder_input = torch.max(decoder_output, dim=1)
1143
# Record token and score
1144
all_tokens = torch.cat((all_tokens, decoder_input), dim=0)
1145
all_scores = torch.cat((all_scores, decoder_scores), dim=0)
1146
# Prepare current token to be next decoder input (add a dimension)
1147
decoder_input = torch.unsqueeze(decoder_input, 0)
1148
# Return collections of word tokens and scores
1149
return all_tokens, all_scores
1150
1151
1152
######################################################################
1153
# Evaluate my text
1154
# ~~~~~~~~~~~~~~~~
1155
#
1156
# Now that we have our decoding method defined, we can write functions for
1157
# evaluating a string input sentence. The ``evaluate`` function manages
1158
# the low-level process of handling the input sentence. We first format
1159
# the sentence as an input batch of word indexes with *batch_size==1*. We
1160
# do this by converting the words of the sentence to their corresponding
1161
# indexes, and transposing the dimensions to prepare the tensor for our
1162
# models. We also create a ``lengths`` tensor which contains the length of
1163
# our input sentence. In this case, ``lengths`` is scalar because we are
1164
# only evaluating one sentence at a time (batch_size==1). Next, we obtain
1165
# the decoded response sentence tensor using our ``GreedySearchDecoder``
1166
# object (``searcher``). Finally, we convert the response’s indexes to
1167
# words and return the list of decoded words.
1168
#
1169
# ``evaluateInput`` acts as the user interface for our chatbot. When
1170
# called, an input text field will spawn in which we can enter our query
1171
# sentence. After typing our input sentence and pressing *Enter*, our text
1172
# is normalized in the same way as our training data, and is ultimately
1173
# fed to the ``evaluate`` function to obtain a decoded output sentence. We
1174
# loop this process, so we can keep chatting with our bot until we enter
1175
# either “q” or “quit”.
1176
#
1177
# Finally, if a sentence is entered that contains a word that is not in
1178
# the vocabulary, we handle this gracefully by printing an error message
1179
# and prompting the user to enter another sentence.
1180
#
1181
1182
def evaluate(encoder, decoder, searcher, voc, sentence, max_length=MAX_LENGTH):
1183
### Format input sentence as a batch
1184
# words -> indexes
1185
indexes_batch = [indexesFromSentence(voc, sentence)]
1186
# Create lengths tensor
1187
lengths = torch.tensor([len(indexes) for indexes in indexes_batch])
1188
# Transpose dimensions of batch to match models' expectations
1189
input_batch = torch.LongTensor(indexes_batch).transpose(0, 1)
1190
# Use appropriate device
1191
input_batch = input_batch.to(device)
1192
lengths = lengths.to("cpu")
1193
# Decode sentence with searcher
1194
tokens, scores = searcher(input_batch, lengths, max_length)
1195
# indexes -> words
1196
decoded_words = [voc.index2word[token.item()] for token in tokens]
1197
return decoded_words
1198
1199
1200
def evaluateInput(encoder, decoder, searcher, voc):
1201
input_sentence = ''
1202
while(1):
1203
try:
1204
# Get input sentence
1205
input_sentence = input('> ')
1206
# Check if it is quit case
1207
if input_sentence == 'q' or input_sentence == 'quit': break
1208
# Normalize sentence
1209
input_sentence = normalizeString(input_sentence)
1210
# Evaluate sentence
1211
output_words = evaluate(encoder, decoder, searcher, voc, input_sentence)
1212
# Format and print response sentence
1213
output_words[:] = [x for x in output_words if not (x == 'EOS' or x == 'PAD')]
1214
print('Bot:', ' '.join(output_words))
1215
1216
except KeyError:
1217
print("Error: Encountered unknown word.")
1218
1219
1220
######################################################################
1221
# Run Model
1222
# ---------
1223
#
1224
# Finally, it is time to run our model!
1225
#
1226
# Regardless of whether we want to train or test the chatbot model, we
1227
# must initialize the individual encoder and decoder models. In the
1228
# following block, we set our desired configurations, choose to start from
1229
# scratch or set a checkpoint to load from, and build and initialize the
1230
# models. Feel free to play with different model configurations to
1231
# optimize performance.
1232
#
1233
1234
# Configure models
1235
model_name = 'cb_model'
1236
attn_model = 'dot'
1237
#``attn_model = 'general'``
1238
#``attn_model = 'concat'``
1239
hidden_size = 500
1240
encoder_n_layers = 2
1241
decoder_n_layers = 2
1242
dropout = 0.1
1243
batch_size = 64
1244
1245
# Set checkpoint to load from; set to None if starting from scratch
1246
loadFilename = None
1247
checkpoint_iter = 4000
1248
1249
#############################################################
1250
# Sample code to load from a checkpoint:
1251
#
1252
# .. code-block:: python
1253
#
1254
# loadFilename = os.path.join(save_dir, model_name, corpus_name,
1255
# '{}-{}_{}'.format(encoder_n_layers, decoder_n_layers, hidden_size),
1256
# '{}_checkpoint.tar'.format(checkpoint_iter))
1257
1258
# Load model if a ``loadFilename`` is provided
1259
if loadFilename:
1260
# If loading on same machine the model was trained on
1261
checkpoint = torch.load(loadFilename)
1262
# If loading a model trained on GPU to CPU
1263
#checkpoint = torch.load(loadFilename, map_location=torch.device('cpu'))
1264
encoder_sd = checkpoint['en']
1265
decoder_sd = checkpoint['de']
1266
encoder_optimizer_sd = checkpoint['en_opt']
1267
decoder_optimizer_sd = checkpoint['de_opt']
1268
embedding_sd = checkpoint['embedding']
1269
voc.__dict__ = checkpoint['voc_dict']
1270
1271
1272
print('Building encoder and decoder ...')
1273
# Initialize word embeddings
1274
embedding = nn.Embedding(voc.num_words, hidden_size)
1275
if loadFilename:
1276
embedding.load_state_dict(embedding_sd)
1277
# Initialize encoder & decoder models
1278
encoder = EncoderRNN(hidden_size, embedding, encoder_n_layers, dropout)
1279
decoder = LuongAttnDecoderRNN(attn_model, embedding, hidden_size, voc.num_words, decoder_n_layers, dropout)
1280
if loadFilename:
1281
encoder.load_state_dict(encoder_sd)
1282
decoder.load_state_dict(decoder_sd)
1283
# Use appropriate device
1284
encoder = encoder.to(device)
1285
decoder = decoder.to(device)
1286
print('Models built and ready to go!')
1287
1288
1289
######################################################################
1290
# Run Training
1291
# ~~~~~~~~~~~~
1292
#
1293
# Run the following block if you want to train the model.
1294
#
1295
# First we set training parameters, then we initialize our optimizers, and
1296
# finally we call the ``trainIters`` function to run our training
1297
# iterations.
1298
#
1299
1300
# Configure training/optimization
1301
clip = 50.0
1302
teacher_forcing_ratio = 1.0
1303
learning_rate = 0.0001
1304
decoder_learning_ratio = 5.0
1305
n_iteration = 4000
1306
print_every = 1
1307
save_every = 500
1308
1309
# Ensure dropout layers are in train mode
1310
encoder.train()
1311
decoder.train()
1312
1313
# Initialize optimizers
1314
print('Building optimizers ...')
1315
encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
1316
decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate * decoder_learning_ratio)
1317
if loadFilename:
1318
encoder_optimizer.load_state_dict(encoder_optimizer_sd)
1319
decoder_optimizer.load_state_dict(decoder_optimizer_sd)
1320
1321
# If you have CUDA, configure CUDA to call
1322
for state in encoder_optimizer.state.values():
1323
for k, v in state.items():
1324
if isinstance(v, torch.Tensor):
1325
state[k] = v.cuda()
1326
1327
for state in decoder_optimizer.state.values():
1328
for k, v in state.items():
1329
if isinstance(v, torch.Tensor):
1330
state[k] = v.cuda()
1331
1332
# Run training iterations
1333
print("Starting Training!")
1334
trainIters(model_name, voc, pairs, encoder, decoder, encoder_optimizer, decoder_optimizer,
1335
embedding, encoder_n_layers, decoder_n_layers, save_dir, n_iteration, batch_size,
1336
print_every, save_every, clip, corpus_name, loadFilename)
1337
1338
1339
######################################################################
1340
# Run Evaluation
1341
# ~~~~~~~~~~~~~~
1342
#
1343
# To chat with your model, run the following block.
1344
#
1345
1346
# Set dropout layers to ``eval`` mode
1347
encoder.eval()
1348
decoder.eval()
1349
1350
# Initialize search module
1351
searcher = GreedySearchDecoder(encoder, decoder)
1352
1353
# Begin chatting (uncomment and run the following line to begin)
1354
# evaluateInput(encoder, decoder, searcher, voc)
1355
1356
1357
######################################################################
1358
# Conclusion
1359
# ----------
1360
#
1361
# That’s all for this one, folks. Congratulations, you now know the
1362
# fundamentals to building a generative chatbot model! If you’re
1363
# interested, you can try tailoring the chatbot’s behavior by tweaking the
1364
# model and training parameters and customizing the data that you train
1365
# the model on.
1366
#
1367
# Check out the other tutorials for more cool deep learning applications
1368
# in PyTorch!
1369
#
1370
1371