CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!

GitHub Repository: pytorch/tutorials
Path: blob/main/beginner_source/deploy_seq2seq_hybrid_frontend_tutorial.py
Views: 494
1
# -*- coding: utf-8 -*-
2
"""
3
Deploying a Seq2Seq Model with TorchScript
4
==================================================
5
**Author:** `Matthew Inkawhich <https://github.com/MatthewInkawhich>`_
6
"""
7
8
9
######################################################################
10
# This tutorial will walk through the process of transitioning a
11
# sequence-to-sequence model to TorchScript using the TorchScript
12
# API. The model that we will convert is the chatbot model from the
13
# `Chatbot tutorial <https://pytorch.org/tutorials/beginner/chatbot_tutorial.html>`__.
14
# You can either treat this tutorial as a “Part 2” to the Chatbot tutorial
15
# and deploy your own pretrained model, or you can start with this
16
# document and use a pretrained model that we host. In the latter case,
17
# you can reference the original Chatbot tutorial for details
18
# regarding data preprocessing, model theory and definition, and model
19
# training.
20
#
21
# What is TorchScript?
22
# ----------------------------
23
#
24
# During the research and development phase of a deep learning-based
25
# project, it is advantageous to interact with an **eager**, imperative
26
# interface like PyTorch’s. This gives users the ability to write
27
# familiar, idiomatic Python, allowing for the use of Python data
28
# structures, control flow operations, print statements, and debugging
29
# utilities. Although the eager interface is a beneficial tool for
30
# research and experimentation applications, when it comes time to deploy
31
# the model in a production environment, having a **graph**-based model
32
# representation is very beneficial. A deferred graph representation
33
# allows for optimizations such as out-of-order execution, and the ability
34
# to target highly optimized hardware architectures. Also, a graph-based
35
# representation enables framework-agnostic model exportation. PyTorch
36
# provides mechanisms for incrementally converting eager-mode code into
37
# TorchScript, a statically analyzable and optimizable subset of Python
38
# that Torch uses to represent deep learning programs independently from
39
# the Python runtime.
40
#
41
# The API for converting eager-mode PyTorch programs into TorchScript is
42
# found in the ``torch.jit`` module. This module has two core modalities for
43
# converting an eager-mode model to a TorchScript graph representation:
44
# **tracing** and **scripting**. The ``torch.jit.trace`` function takes a
45
# module or function and a set of example inputs. It then runs the example
46
# input through the function or module while tracing the computational
47
# steps that are encountered, and outputs a graph-based function that
48
# performs the traced operations. **Tracing** is great for straightforward
49
# modules and functions that do not involve data-dependent control flow,
50
# such as standard convolutional neural networks. However, if a function
51
# with data-dependent if statements and loops is traced, only the
52
# operations called along the execution route taken by the example input
53
# will be recorded. In other words, the control flow itself is not
54
# captured. To convert modules and functions containing data-dependent
55
# control flow, a **scripting** mechanism is provided. The
56
# ``torch.jit.script`` function/decorator takes a module or function and
57
# does not requires example inputs. Scripting then explicitly converts
58
# the module or function code to TorchScript, including all control flows.
59
# One caveat with using scripting is that it only supports a subset of
60
# Python, so you might need to rewrite the code to make it compatible
61
# with the TorchScript syntax.
62
#
63
# For all details relating to the supported features, see the `TorchScript
64
# language reference <https://pytorch.org/docs/master/jit.html>`__.
65
# To provide the maximum flexibility, you can also mix tracing and scripting
66
# modes together to represent your whole program, and these techniques can
67
# be applied incrementally.
68
#
69
# .. figure:: /_static/img/chatbot/pytorch_workflow.png
70
# :align: center
71
# :alt: workflow
72
#
73
74
75
76
######################################################################
77
# Acknowledgments
78
# ----------------
79
#
80
# This tutorial was inspired by the following sources:
81
#
82
# 1) Yuan-Kuei Wu's pytorch-chatbot implementation:
83
# https://github.com/ywk991112/pytorch-chatbot
84
#
85
# 2) Sean Robertson's practical-pytorch seq2seq-translation example:
86
# https://github.com/spro/practical-pytorch/tree/master/seq2seq-translation
87
#
88
# 3) FloydHub's Cornell Movie Corpus preprocessing code:
89
# https://github.com/floydhub/textutil-preprocess-cornell-movie-corpus
90
#
91
92
93
######################################################################
94
# Prepare Environment
95
# -------------------
96
#
97
# First, we will import the required modules and set some constants. If
98
# you are planning on using your own model, be sure that the
99
# ``MAX_LENGTH`` constant is set correctly. As a reminder, this constant
100
# defines the maximum allowed sentence length during training and the
101
# maximum length output that the model is capable of producing.
102
#
103
104
import torch
105
import torch.nn as nn
106
import torch.nn.functional as F
107
import re
108
import os
109
import unicodedata
110
import numpy as np
111
112
device = torch.device("cpu")
113
114
115
MAX_LENGTH = 10 # Maximum sentence length
116
117
# Default word tokens
118
PAD_token = 0 # Used for padding short sentences
119
SOS_token = 1 # Start-of-sentence token
120
EOS_token = 2 # End-of-sentence token
121
122
123
######################################################################
124
# Model Overview
125
# --------------
126
#
127
# As mentioned, the model that we are using is a
128
# `sequence-to-sequence <https://arxiv.org/abs/1409.3215>`__ (seq2seq)
129
# model. This type of model is used in cases when our input is a
130
# variable-length sequence, and our output is also a variable length
131
# sequence that is not necessarily a one-to-one mapping of the input. A
132
# seq2seq model is comprised of two recurrent neural networks (RNNs) that
133
# work cooperatively: an **encoder** and a **decoder**.
134
#
135
# .. figure:: /_static/img/chatbot/seq2seq_ts.png
136
# :align: center
137
# :alt: model
138
#
139
#
140
# Image source:
141
# https://jeddy92.github.io/JEddy92.github.io/ts_seq2seq_intro/
142
#
143
# Encoder
144
# ~~~~~~~
145
#
146
# The encoder RNN iterates through the input sentence one token
147
# (e.g. word) at a time, at each time step outputting an “output” vector
148
# and a “hidden state” vector. The hidden state vector is then passed to
149
# the next time step, while the output vector is recorded. The encoder
150
# transforms the context it saw at each point in the sequence into a set
151
# of points in a high-dimensional space, which the decoder will use to
152
# generate a meaningful output for the given task.
153
#
154
# Decoder
155
# ~~~~~~~
156
#
157
# The decoder RNN generates the response sentence in a token-by-token
158
# fashion. It uses the encoder’s context vectors, and internal hidden
159
# states to generate the next word in the sequence. It continues
160
# generating words until it outputs an *EOS_token*, representing the end
161
# of the sentence. We use an `attention
162
# mechanism <https://arxiv.org/abs/1409.0473>`__ in our decoder to help it
163
# to “pay attention” to certain parts of the input when generating the
164
# output. For our model, we implement `Luong et
165
# al. <https://arxiv.org/abs/1508.04025>`__\ ’s “Global attention” module,
166
# and use it as a submodule in our decode model.
167
#
168
169
170
######################################################################
171
# Data Handling
172
# -------------
173
#
174
# Although our models conceptually deal with sequences of tokens, in
175
# reality, they deal with numbers like all machine learning models do. In
176
# this case, every word in the model’s vocabulary, which was established
177
# before training, is mapped to an integer index. We use a ``Voc`` object
178
# to contain the mappings from word to index, as well as the total number
179
# of words in the vocabulary. We will load the object later before we run
180
# the model.
181
#
182
# Also, in order for us to be able to run evaluations, we must provide a
183
# tool for processing our string inputs. The ``normalizeString`` function
184
# converts all characters in a string to lowercase and removes all
185
# non-letter characters. The ``indexesFromSentence`` function takes a
186
# sentence of words and returns the corresponding sequence of word
187
# indexes.
188
#
189
190
class Voc:
191
def __init__(self, name):
192
self.name = name
193
self.trimmed = False
194
self.word2index = {}
195
self.word2count = {}
196
self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
197
self.num_words = 3 # Count SOS, EOS, PAD
198
199
def addSentence(self, sentence):
200
for word in sentence.split(' '):
201
self.addWord(word)
202
203
def addWord(self, word):
204
if word not in self.word2index:
205
self.word2index[word] = self.num_words
206
self.word2count[word] = 1
207
self.index2word[self.num_words] = word
208
self.num_words += 1
209
else:
210
self.word2count[word] += 1
211
212
# Remove words below a certain count threshold
213
def trim(self, min_count):
214
if self.trimmed:
215
return
216
self.trimmed = True
217
keep_words = []
218
for k, v in self.word2count.items():
219
if v >= min_count:
220
keep_words.append(k)
221
222
print('keep_words {} / {} = {:.4f}'.format(
223
len(keep_words), len(self.word2index), len(keep_words) / len(self.word2index)
224
))
225
# Reinitialize dictionaries
226
self.word2index = {}
227
self.word2count = {}
228
self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
229
self.num_words = 3 # Count default tokens
230
for word in keep_words:
231
self.addWord(word)
232
233
234
# Lowercase and remove non-letter characters
235
def normalizeString(s):
236
s = s.lower()
237
s = re.sub(r"([.!?])", r" \1", s)
238
s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
239
return s
240
241
242
# Takes string sentence, returns sentence of word indexes
243
def indexesFromSentence(voc, sentence):
244
return [voc.word2index[word] for word in sentence.split(' ')] + [EOS_token]
245
246
247
######################################################################
248
# Define Encoder
249
# --------------
250
#
251
# We implement our encoder’s RNN with the ``torch.nn.GRU`` module which we
252
# feed a batch of sentences (vectors of word embeddings) and it internally
253
# iterates through the sentences one token at a time calculating the
254
# hidden states. We initialize this module to be bidirectional, meaning
255
# that we have two independent GRUs: one that iterates through the
256
# sequences in chronological order, and another that iterates in reverse
257
# order. We ultimately return the sum of these two GRUs’ outputs. Since
258
# our model was trained using batching, our ``EncoderRNN`` model’s
259
# ``forward`` function expects a padded input batch. To batch
260
# variable-length sentences, we allow a maximum of *MAX_LENGTH* tokens in
261
# a sentence, and all sentences in the batch that have less than
262
# *MAX_LENGTH* tokens are padded at the end with our dedicated *PAD_token*
263
# tokens. To use padded batches with a PyTorch RNN module, we must wrap
264
# the forward pass call with ``torch.nn.utils.rnn.pack_padded_sequence``
265
# and ``torch.nn.utils.rnn.pad_packed_sequence`` data transformations.
266
# Note that the ``forward`` function also takes an ``input_lengths`` list,
267
# which contains the length of each sentence in the batch. This input is
268
# used by the ``torch.nn.utils.rnn.pack_padded_sequence`` function when
269
# padding.
270
#
271
# TorchScript Notes:
272
# ~~~~~~~~~~~~~~~~~~~~~~
273
#
274
# Since the encoder’s ``forward`` function does not contain any
275
# data-dependent control flow, we will use **tracing** to convert it to
276
# script mode. When tracing a module, we can leave the module definition
277
# as-is. We will initialize all models towards the end of this document
278
# before we run evaluations.
279
#
280
281
class EncoderRNN(nn.Module):
282
def __init__(self, hidden_size, embedding, n_layers=1, dropout=0):
283
super(EncoderRNN, self).__init__()
284
self.n_layers = n_layers
285
self.hidden_size = hidden_size
286
self.embedding = embedding
287
288
# Initialize GRU; the ``input_size`` and ``hidden_size`` parameters are both set to 'hidden_size'
289
# because our input size is a word embedding with number of features == hidden_size
290
self.gru = nn.GRU(hidden_size, hidden_size, n_layers,
291
dropout=(0 if n_layers == 1 else dropout), bidirectional=True)
292
293
def forward(self, input_seq, input_lengths, hidden=None):
294
# type: (Tensor, Tensor, Optional[Tensor]) -> Tuple[Tensor, Tensor]
295
# Convert word indexes to embeddings
296
embedded = self.embedding(input_seq)
297
# Pack padded batch of sequences for RNN module
298
packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, input_lengths)
299
# Forward pass through GRU
300
outputs, hidden = self.gru(packed, hidden)
301
# Unpack padding
302
outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(outputs)
303
# Sum bidirectional GRU outputs
304
outputs = outputs[:, :, :self.hidden_size] + outputs[:, : ,self.hidden_size:]
305
# Return output and final hidden state
306
return outputs, hidden
307
308
309
######################################################################
310
# Define Decoder’s Attention Module
311
# ---------------------------------
312
#
313
# Next, we’ll define our attention module (``Attn``). Note that this
314
# module will be used as a submodule in our decoder model. Luong et
315
# al. consider various “score functions”, which take the current decoder
316
# RNN output and the entire encoder output, and return attention
317
# “energies”. This attention energies tensor is the same size as the
318
# encoder output, and the two are ultimately multiplied, resulting in a
319
# weighted tensor whose largest values represent the most important parts
320
# of the query sentence at a particular time-step of decoding.
321
#
322
323
# Luong attention layer
324
class Attn(nn.Module):
325
def __init__(self, method, hidden_size):
326
super(Attn, self).__init__()
327
self.method = method
328
if self.method not in ['dot', 'general', 'concat']:
329
raise ValueError(self.method, "is not an appropriate attention method.")
330
self.hidden_size = hidden_size
331
if self.method == 'general':
332
self.attn = nn.Linear(self.hidden_size, hidden_size)
333
elif self.method == 'concat':
334
self.attn = nn.Linear(self.hidden_size * 2, hidden_size)
335
self.v = nn.Parameter(torch.FloatTensor(hidden_size))
336
337
def dot_score(self, hidden, encoder_output):
338
return torch.sum(hidden * encoder_output, dim=2)
339
340
def general_score(self, hidden, encoder_output):
341
energy = self.attn(encoder_output)
342
return torch.sum(hidden * energy, dim=2)
343
344
def concat_score(self, hidden, encoder_output):
345
energy = self.attn(torch.cat((hidden.expand(encoder_output.size(0), -1, -1), encoder_output), 2)).tanh()
346
return torch.sum(self.v * energy, dim=2)
347
348
def forward(self, hidden, encoder_outputs):
349
# Calculate the attention weights (energies) based on the given method
350
if self.method == 'general':
351
attn_energies = self.general_score(hidden, encoder_outputs)
352
elif self.method == 'concat':
353
attn_energies = self.concat_score(hidden, encoder_outputs)
354
elif self.method == 'dot':
355
attn_energies = self.dot_score(hidden, encoder_outputs)
356
357
# Transpose max_length and batch_size dimensions
358
attn_energies = attn_energies.t()
359
360
# Return the softmax normalized probability scores (with added dimension)
361
return F.softmax(attn_energies, dim=1).unsqueeze(1)
362
363
364
######################################################################
365
# Define Decoder
366
# --------------
367
#
368
# Similarly to the ``EncoderRNN``, we use the ``torch.nn.GRU`` module for
369
# our decoder’s RNN. This time, however, we use a unidirectional GRU. It
370
# is important to note that unlike the encoder, we will feed the decoder
371
# RNN one word at a time. We start by getting the embedding of the current
372
# word and applying a
373
# `dropout <https://pytorch.org/docs/stable/nn.html?highlight=dropout#torch.nn.Dropout>`__.
374
# Next, we forward the embedding and the last hidden state to the GRU and
375
# obtain a current GRU output and hidden state. We then use our ``Attn``
376
# module as a layer to obtain the attention weights, which we multiply by
377
# the encoder’s output to obtain our attended encoder output. We use this
378
# attended encoder output as our ``context`` tensor, which represents a
379
# weighted sum indicating what parts of the encoder’s output to pay
380
# attention to. From here, we use a linear layer and softmax normalization
381
# to select the next word in the output sequence.
382
383
# TorchScript Notes:
384
# ~~~~~~~~~~~~~~~~~~~~~~
385
#
386
# Similarly to the ``EncoderRNN``, this module does not contain any
387
# data-dependent control flow. Therefore, we can once again use
388
# **tracing** to convert this model to TorchScript after it
389
# is initialized and its parameters are loaded.
390
#
391
392
class LuongAttnDecoderRNN(nn.Module):
393
def __init__(self, attn_model, embedding, hidden_size, output_size, n_layers=1, dropout=0.1):
394
super(LuongAttnDecoderRNN, self).__init__()
395
396
# Keep for reference
397
self.attn_model = attn_model
398
self.hidden_size = hidden_size
399
self.output_size = output_size
400
self.n_layers = n_layers
401
self.dropout = dropout
402
403
# Define layers
404
self.embedding = embedding
405
self.embedding_dropout = nn.Dropout(dropout)
406
self.gru = nn.GRU(hidden_size, hidden_size, n_layers, dropout=(0 if n_layers == 1 else dropout))
407
self.concat = nn.Linear(hidden_size * 2, hidden_size)
408
self.out = nn.Linear(hidden_size, output_size)
409
410
self.attn = Attn(attn_model, hidden_size)
411
412
def forward(self, input_step, last_hidden, encoder_outputs):
413
# Note: we run this one step (word) at a time
414
# Get embedding of current input word
415
embedded = self.embedding(input_step)
416
embedded = self.embedding_dropout(embedded)
417
# Forward through unidirectional GRU
418
rnn_output, hidden = self.gru(embedded, last_hidden)
419
# Calculate attention weights from the current GRU output
420
attn_weights = self.attn(rnn_output, encoder_outputs)
421
# Multiply attention weights to encoder outputs to get new "weighted sum" context vector
422
context = attn_weights.bmm(encoder_outputs.transpose(0, 1))
423
# Concatenate weighted context vector and GRU output using Luong eq. 5
424
rnn_output = rnn_output.squeeze(0)
425
context = context.squeeze(1)
426
concat_input = torch.cat((rnn_output, context), 1)
427
concat_output = torch.tanh(self.concat(concat_input))
428
# Predict next word using Luong eq. 6
429
output = self.out(concat_output)
430
output = F.softmax(output, dim=1)
431
# Return output and final hidden state
432
return output, hidden
433
434
435
######################################################################
436
# Define Evaluation
437
# -----------------
438
#
439
# Greedy Search Decoder
440
# ~~~~~~~~~~~~~~~~~~~~~
441
#
442
# As in the chatbot tutorial, we use a ``GreedySearchDecoder`` module to
443
# facilitate the actual decoding process. This module has the trained
444
# encoder and decoder models as attributes, and drives the process of
445
# encoding an input sentence (a vector of word indexes), and iteratively
446
# decoding an output response sequence one word (word index) at a time.
447
#
448
# Encoding the input sequence is straightforward: simply forward the
449
# entire sequence tensor and its corresponding lengths vector to the
450
# ``encoder``. It is important to note that this module only deals with
451
# one input sequence at a time, **NOT** batches of sequences. Therefore,
452
# when the constant **1** is used for declaring tensor sizes, this
453
# corresponds to a batch size of 1. To decode a given decoder output, we
454
# must iteratively run forward passes through our decoder model, which
455
# outputs softmax scores corresponding to the probability of each word
456
# being the correct next word in the decoded sequence. We initialize the
457
# ``decoder_input`` to a tensor containing an *SOS_token*. After each pass
458
# through the ``decoder``, we *greedily* append the word with the highest
459
# softmax probability to the ``decoded_words`` list. We also use this word
460
# as the ``decoder_input`` for the next iteration. The decoding process
461
# terminates either if the ``decoded_words`` list has reached a length of
462
# *MAX_LENGTH* or if the predicted word is the *EOS_token*.
463
#
464
# TorchScript Notes:
465
# ~~~~~~~~~~~~~~~~~~~~~~
466
#
467
# The ``forward`` method of this module involves iterating over the range
468
# of :math:`[0, max\_length)` when decoding an output sequence one word at
469
# a time. Because of this, we should use **scripting** to convert this
470
# module to TorchScript. Unlike with our encoder and decoder models,
471
# which we can trace, we must make some necessary changes to the
472
# ``GreedySearchDecoder`` module in order to initialize an object without
473
# error. In other words, we must ensure that our module adheres to the
474
# rules of the TorchScript mechanism, and does not utilize any language
475
# features outside of the subset of Python that TorchScript includes.
476
#
477
# To get an idea of some manipulations that may be required, we will go
478
# over the diffs between the ``GreedySearchDecoder`` implementation from
479
# the chatbot tutorial and the implementation that we use in the cell
480
# below. Note that the lines highlighted in red are lines removed from the
481
# original implementation and the lines highlighted in green are new.
482
#
483
# .. figure:: /_static/img/chatbot/diff.png
484
# :align: center
485
# :alt: diff
486
#
487
# Changes:
488
# ^^^^^^^^
489
#
490
# - Added ``decoder_n_layers`` to the constructor arguments
491
#
492
# - This change stems from the fact that the encoder and decoder
493
# models that we pass to this module will be a child of
494
# ``TracedModule`` (not ``Module``). Therefore, we cannot access the
495
# decoder’s number of layers with ``decoder.n_layers``. Instead, we
496
# plan for this, and pass this value in during module construction.
497
#
498
#
499
# - Store away new attributes as constants
500
#
501
# - In the original implementation, we were free to use variables from
502
# the surrounding (global) scope in our ``GreedySearchDecoder``\ ’s
503
# ``forward`` method. However, now that we are using scripting, we
504
# do not have this freedom, as the assumption with scripting is that
505
# we cannot necessarily hold on to Python objects, especially when
506
# exporting. An easy solution to this is to store these values from
507
# the global scope as attributes to the module in the constructor,
508
# and add them to a special list called ``__constants__`` so that
509
# they can be used as literal values when constructing the graph in
510
# the ``forward`` method. An example of this usage is on NEW line
511
# 19, where instead of using the ``device`` and ``SOS_token`` global
512
# values, we use our constant attributes ``self._device`` and
513
# ``self._SOS_token``.
514
#
515
#
516
# - Enforce types of ``forward`` method arguments
517
#
518
# - By default, all parameters to a TorchScript function are assumed
519
# to be Tensor. If we need to pass an argument of a different type,
520
# we can use function type annotations as introduced in `PEP
521
# 3107 <https://www.python.org/dev/peps/pep-3107/>`__. In addition,
522
# it is possible to declare arguments of different types using
523
# Mypy-style type annotations (see
524
# `doc <https://pytorch.org/docs/master/jit.html#types>`__).
525
#
526
#
527
# - Change initialization of ``decoder_input``
528
#
529
# - In the original implementation, we initialized our
530
# ``decoder_input`` tensor with ``torch.LongTensor([[SOS_token]])``.
531
# When scripting, we are not allowed to initialize tensors in a
532
# literal fashion like this. Instead, we can initialize our tensor
533
# with an explicit torch function such as ``torch.ones``. In this
534
# case, we can easily replicate the scalar ``decoder_input`` tensor
535
# by multiplying 1 by our SOS_token value stored in the constant
536
# ``self._SOS_token``.
537
#
538
539
class GreedySearchDecoder(nn.Module):
540
def __init__(self, encoder, decoder, decoder_n_layers):
541
super(GreedySearchDecoder, self).__init__()
542
self.encoder = encoder
543
self.decoder = decoder
544
self._device = device
545
self._SOS_token = SOS_token
546
self._decoder_n_layers = decoder_n_layers
547
548
__constants__ = ['_device', '_SOS_token', '_decoder_n_layers']
549
550
def forward(self, input_seq : torch.Tensor, input_length : torch.Tensor, max_length : int):
551
# Forward input through encoder model
552
encoder_outputs, encoder_hidden = self.encoder(input_seq, input_length)
553
# Prepare encoder's final hidden layer to be first hidden input to the decoder
554
decoder_hidden = encoder_hidden[:self._decoder_n_layers]
555
# Initialize decoder input with SOS_token
556
decoder_input = torch.ones(1, 1, device=self._device, dtype=torch.long) * self._SOS_token
557
# Initialize tensors to append decoded words to
558
all_tokens = torch.zeros([0], device=self._device, dtype=torch.long)
559
all_scores = torch.zeros([0], device=self._device)
560
# Iteratively decode one word token at a time
561
for _ in range(max_length):
562
# Forward pass through decoder
563
decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden, encoder_outputs)
564
# Obtain most likely word token and its softmax score
565
decoder_scores, decoder_input = torch.max(decoder_output, dim=1)
566
# Record token and score
567
all_tokens = torch.cat((all_tokens, decoder_input), dim=0)
568
all_scores = torch.cat((all_scores, decoder_scores), dim=0)
569
# Prepare current token to be next decoder input (add a dimension)
570
decoder_input = torch.unsqueeze(decoder_input, 0)
571
# Return collections of word tokens and scores
572
return all_tokens, all_scores
573
574
575
576
######################################################################
577
# Evaluating an Input
578
# ~~~~~~~~~~~~~~~~~~~
579
#
580
# Next, we define some functions for evaluating an input. The ``evaluate``
581
# function takes a normalized string sentence, processes it to a tensor of
582
# its corresponding word indexes (with batch size of 1), and passes this
583
# tensor to a ``GreedySearchDecoder`` instance called ``searcher`` to
584
# handle the encoding/decoding process. The searcher returns the output
585
# word index vector and a scores tensor corresponding to the softmax
586
# scores for each decoded word token. The final step is to convert each
587
# word index back to its string representation using ``voc.index2word``.
588
#
589
# We also define two functions for evaluating an input sentence. The
590
# ``evaluateInput`` function prompts a user for an input, and evaluates
591
# it. It will continue to ask for another input until the user enters ‘q’
592
# or ‘quit’.
593
#
594
# The ``evaluateExample`` function simply takes a string input sentence as
595
# an argument, normalizes it, evaluates it, and prints the response.
596
#
597
598
def evaluate(searcher, voc, sentence, max_length=MAX_LENGTH):
599
### Format input sentence as a batch
600
# words -> indexes
601
indexes_batch = [indexesFromSentence(voc, sentence)]
602
# Create lengths tensor
603
lengths = torch.tensor([len(indexes) for indexes in indexes_batch])
604
# Transpose dimensions of batch to match models' expectations
605
input_batch = torch.LongTensor(indexes_batch).transpose(0, 1)
606
# Use appropriate device
607
input_batch = input_batch.to(device)
608
lengths = lengths.to(device)
609
# Decode sentence with searcher
610
tokens, scores = searcher(input_batch, lengths, max_length)
611
# indexes -> words
612
decoded_words = [voc.index2word[token.item()] for token in tokens]
613
return decoded_words
614
615
616
# Evaluate inputs from user input (``stdin``)
617
def evaluateInput(searcher, voc):
618
input_sentence = ''
619
while(1):
620
try:
621
# Get input sentence
622
input_sentence = input('> ')
623
# Check if it is quit case
624
if input_sentence == 'q' or input_sentence == 'quit': break
625
# Normalize sentence
626
input_sentence = normalizeString(input_sentence)
627
# Evaluate sentence
628
output_words = evaluate(searcher, voc, input_sentence)
629
# Format and print response sentence
630
output_words[:] = [x for x in output_words if not (x == 'EOS' or x == 'PAD')]
631
print('Bot:', ' '.join(output_words))
632
633
except KeyError:
634
print("Error: Encountered unknown word.")
635
636
# Normalize input sentence and call ``evaluate()``
637
def evaluateExample(sentence, searcher, voc):
638
print("> " + sentence)
639
# Normalize sentence
640
input_sentence = normalizeString(sentence)
641
# Evaluate sentence
642
output_words = evaluate(searcher, voc, input_sentence)
643
output_words[:] = [x for x in output_words if not (x == 'EOS' or x == 'PAD')]
644
print('Bot:', ' '.join(output_words))
645
646
647
######################################################################
648
# Load Pretrained Parameters
649
# --------------------------
650
#
651
# No, let's load our model!
652
#
653
# Use hosted model
654
# ~~~~~~~~~~~~~~~~
655
#
656
# To load the hosted model:
657
#
658
# 1) Download the model `here <https://download.pytorch.org/models/tutorials/4000_checkpoint.tar>`__.
659
#
660
# 2) Set the ``loadFilename`` variable to the path to the downloaded
661
# checkpoint file.
662
#
663
# 3) Leave the ``checkpoint = torch.load(loadFilename)`` line uncommented,
664
# as the hosted model was trained on CPU.
665
#
666
# Use your own model
667
# ~~~~~~~~~~~~~~~~~~
668
#
669
# To load your own pretrained model:
670
#
671
# 1) Set the ``loadFilename`` variable to the path to the checkpoint file
672
# that you wish to load. Note that if you followed the convention for
673
# saving the model from the chatbot tutorial, this may involve changing
674
# the ``model_name``, ``encoder_n_layers``, ``decoder_n_layers``,
675
# ``hidden_size``, and ``checkpoint_iter`` (as these values are used in
676
# the model path).
677
#
678
# 2) If you trained the model on a CPU, make sure that you are opening the
679
# checkpoint with the ``checkpoint = torch.load(loadFilename)`` line.
680
# If you trained the model on a GPU and are running this tutorial on a
681
# CPU, uncomment the
682
# ``checkpoint = torch.load(loadFilename, map_location=torch.device('cpu'))``
683
# line.
684
#
685
# TorchScript Notes:
686
# ~~~~~~~~~~~~~~~~~~~~~~
687
#
688
# Notice that we initialize and load parameters into our encoder and
689
# decoder models as usual. If you are using tracing mode(``torch.jit.trace``)
690
# for some part of your models, you must call ``.to(device)`` to set the device
691
# options of the models and ``.eval()`` to set the dropout layers to test mode
692
# **before** tracing the models. `TracedModule` objects do not inherit the
693
# ``to`` or ``eval`` methods. Since in this tutorial we are only using
694
# scripting instead of tracing, we only need to do this before we do
695
# evaluation (which is the same as we normally do in eager mode).
696
#
697
698
save_dir = os.path.join("data", "save")
699
corpus_name = "cornell movie-dialogs corpus"
700
701
# Configure models
702
model_name = 'cb_model'
703
attn_model = 'dot'
704
#attn_model = 'general'``
705
#attn_model = 'concat'
706
hidden_size = 500
707
encoder_n_layers = 2
708
decoder_n_layers = 2
709
dropout = 0.1
710
batch_size = 64
711
712
# If you're loading your own model
713
# Set checkpoint to load from
714
checkpoint_iter = 4000
715
716
#############################################################
717
# Sample code to load from a checkpoint:
718
#
719
# .. code-block:: python
720
#
721
# loadFilename = os.path.join(save_dir, model_name, corpus_name,
722
# '{}-{}_{}'.format(encoder_n_layers, decoder_n_layers, hidden_size),
723
# '{}_checkpoint.tar'.format(checkpoint_iter))
724
725
# If you're loading the hosted model
726
loadFilename = 'data/4000_checkpoint.tar'
727
728
# Load model
729
# Force CPU device options (to match tensors in this tutorial)
730
checkpoint = torch.load(loadFilename, map_location=torch.device('cpu'))
731
encoder_sd = checkpoint['en']
732
decoder_sd = checkpoint['de']
733
encoder_optimizer_sd = checkpoint['en_opt']
734
decoder_optimizer_sd = checkpoint['de_opt']
735
embedding_sd = checkpoint['embedding']
736
voc = Voc(corpus_name)
737
voc.__dict__ = checkpoint['voc_dict']
738
739
740
print('Building encoder and decoder ...')
741
# Initialize word embeddings
742
embedding = nn.Embedding(voc.num_words, hidden_size)
743
embedding.load_state_dict(embedding_sd)
744
# Initialize encoder & decoder models
745
encoder = EncoderRNN(hidden_size, embedding, encoder_n_layers, dropout)
746
decoder = LuongAttnDecoderRNN(attn_model, embedding, hidden_size, voc.num_words, decoder_n_layers, dropout)
747
# Load trained model parameters
748
encoder.load_state_dict(encoder_sd)
749
decoder.load_state_dict(decoder_sd)
750
# Use appropriate device
751
encoder = encoder.to(device)
752
decoder = decoder.to(device)
753
# Set dropout layers to ``eval`` mode
754
encoder.eval()
755
decoder.eval()
756
print('Models built and ready to go!')
757
758
759
######################################################################
760
# Convert Model to TorchScript
761
# -----------------------------
762
#
763
# Encoder
764
# ~~~~~~~
765
#
766
# As previously mentioned, to convert the encoder model to TorchScript,
767
# we use **scripting**. The encoder model takes an input sequence and
768
# a corresponding lengths tensor. Therefore, we create an example input
769
# sequence tensor ``test_seq``, which is of appropriate size (MAX_LENGTH,
770
# 1), contains numbers in the appropriate range
771
# :math:`[0, voc.num\_words)`, and is of the appropriate type (int64). We
772
# also create a ``test_seq_length`` scalar which realistically contains
773
# the value corresponding to how many words are in the ``test_seq``. The
774
# next step is to use the ``torch.jit.trace`` function to trace the model.
775
# Notice that the first argument we pass is the module that we want to
776
# trace, and the second is a tuple of arguments to the module’s
777
# ``forward`` method.
778
#
779
# Decoder
780
# ~~~~~~~
781
#
782
# We perform the same process for tracing the decoder as we did for the
783
# encoder. Notice that we call forward on a set of random inputs to the
784
# traced_encoder to get the output that we need for the decoder. This is
785
# not required, as we could also simply manufacture a tensor of the
786
# correct shape, type, and value range. This method is possible because in
787
# our case we do not have any constraints on the values of the tensors
788
# because we do not have any operations that could fault on out-of-range
789
# inputs.
790
#
791
# GreedySearchDecoder
792
# ~~~~~~~~~~~~~~~~~~~
793
#
794
# Recall that we scripted our searcher module due to the presence of
795
# data-dependent control flow. In the case of scripting, we do necessary
796
# language changes to make sure the implementation complies with
797
# TorchScript. We initialize the scripted searcher the same way that we
798
# would initialize an unscripted variant.
799
#
800
801
### Compile the whole greedy search model to TorchScript model
802
# Create artificial inputs
803
test_seq = torch.LongTensor(MAX_LENGTH, 1).random_(0, voc.num_words).to(device)
804
test_seq_length = torch.LongTensor([test_seq.size()[0]]).to(device)
805
# Trace the model
806
traced_encoder = torch.jit.trace(encoder, (test_seq, test_seq_length))
807
808
### Convert decoder model
809
# Create and generate artificial inputs
810
test_encoder_outputs, test_encoder_hidden = traced_encoder(test_seq, test_seq_length)
811
test_decoder_hidden = test_encoder_hidden[:decoder.n_layers]
812
test_decoder_input = torch.LongTensor(1, 1).random_(0, voc.num_words)
813
# Trace the model
814
traced_decoder = torch.jit.trace(decoder, (test_decoder_input, test_decoder_hidden, test_encoder_outputs))
815
816
### Initialize searcher module by wrapping ``torch.jit.script`` call
817
scripted_searcher = torch.jit.script(GreedySearchDecoder(traced_encoder, traced_decoder, decoder.n_layers))
818
819
820
821
822
######################################################################
823
# Print Graphs
824
# ------------
825
#
826
# Now that our models are in TorchScript form, we can print the graphs of
827
# each to ensure that we captured the computational graph appropriately.
828
# Since TorchScript allow us to recursively compile the whole model
829
# hierarchy and inline the ``encoder`` and ``decoder`` graph into a single
830
# graph, we just need to print the `scripted_searcher` graph
831
832
print('scripted_searcher graph:\n', scripted_searcher.graph)
833
834
835
######################################################################
836
# Run Evaluation
837
# --------------
838
#
839
# Finally, we will run evaluation of the chatbot model using the TorchScript
840
# models. If converted correctly, the models will behave exactly as they
841
# would in their eager-mode representation.
842
#
843
# By default, we evaluate a few common query sentences. If you want to
844
# chat with the bot yourself, uncomment the ``evaluateInput`` line and
845
# give it a spin.
846
#
847
848
849
# Use appropriate device
850
scripted_searcher.to(device)
851
# Set dropout layers to ``eval`` mode
852
scripted_searcher.eval()
853
854
# Evaluate examples
855
sentences = ["hello", "what's up?", "who are you?", "where am I?", "where are you from?"]
856
for s in sentences:
857
evaluateExample(s, scripted_searcher, voc)
858
859
# Evaluate your input by running
860
# ``evaluateInput(traced_encoder, traced_decoder, scripted_searcher, voc)``
861
862
863
######################################################################
864
# Save Model
865
# ----------
866
#
867
# Now that we have successfully converted our model to TorchScript, we
868
# will serialize it for use in a non-Python deployment environment. To do
869
# this, we can simply save our ``scripted_searcher`` module, as this is
870
# the user-facing interface for running inference against the chatbot
871
# model. When saving a Script module, use script_module.save(PATH) instead
872
# of torch.save(model, PATH).
873
#
874
875
scripted_searcher.save("scripted_chatbot.pth")
876
877