Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Download

📚 The CoCalc Library - books, templates and other resources

132928 views
License: OTHER
1
""" A neural chatbot using sequence to sequence model with
2
attentional decoder.
3
4
This is based on Google Translate Tensorflow model
5
https://github.com/tensorflow/models/blob/master/tutorials/rnn/translate/
6
7
Sequence to sequence model by Cho et al.(2014)
8
9
Created by Chip Huyen as the starter code for assignment 3,
10
class CS 20SI: "TensorFlow for Deep Learning Research"
11
cs20si.stanford.edu
12
13
This file contains the code to run the model.
14
15
See readme.md for instruction on how to run the starter code.
16
"""
17
from __future__ import division
18
from __future__ import print_function
19
20
import argparse
21
import os
22
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
23
import random
24
import sys
25
import time
26
27
import numpy as np
28
import tensorflow as tf
29
30
from model import ChatBotModel
31
import config
32
import data
33
34
def _get_random_bucket(train_buckets_scale):
35
""" Get a random bucket from which to choose a training sample """
36
rand = random.random()
37
return min([i for i in range(len(train_buckets_scale))
38
if train_buckets_scale[i] > rand])
39
40
def _assert_lengths(encoder_size, decoder_size, encoder_inputs, decoder_inputs, decoder_masks):
41
""" Assert that the encoder inputs, decoder inputs, and decoder masks are
42
of the expected lengths """
43
if len(encoder_inputs) != encoder_size:
44
raise ValueError("Encoder length must be equal to the one in bucket,"
45
" %d != %d." % (len(encoder_inputs), encoder_size))
46
if len(decoder_inputs) != decoder_size:
47
raise ValueError("Decoder length must be equal to the one in bucket,"
48
" %d != %d." % (len(decoder_inputs), decoder_size))
49
if len(decoder_masks) != decoder_size:
50
raise ValueError("Weights length must be equal to the one in bucket,"
51
" %d != %d." % (len(decoder_masks), decoder_size))
52
53
def run_step(sess, model, encoder_inputs, decoder_inputs, decoder_masks, bucket_id, forward_only):
54
""" Run one step in training.
55
@forward_only: boolean value to decide whether a backward path should be created
56
forward_only is set to True when you just want to evaluate on the test set,
57
or when you want to the bot to be in chat mode. """
58
encoder_size, decoder_size = config.BUCKETS[bucket_id]
59
_assert_lengths(encoder_size, decoder_size, encoder_inputs, decoder_inputs, decoder_masks)
60
61
# input feed: encoder inputs, decoder inputs, target_weights, as provided.
62
input_feed = {}
63
for step in range(encoder_size):
64
input_feed[model.encoder_inputs[step].name] = encoder_inputs[step]
65
for step in range(decoder_size):
66
input_feed[model.decoder_inputs[step].name] = decoder_inputs[step]
67
input_feed[model.decoder_masks[step].name] = decoder_masks[step]
68
69
last_target = model.decoder_inputs[decoder_size].name
70
input_feed[last_target] = np.zeros([model.batch_size], dtype=np.int32)
71
72
# output feed: depends on whether we do a backward step or not.
73
if not forward_only:
74
output_feed = [model.train_ops[bucket_id], # update op that does SGD.
75
model.gradient_norms[bucket_id], # gradient norm.
76
model.losses[bucket_id]] # loss for this batch.
77
else:
78
output_feed = [model.losses[bucket_id]] # loss for this batch.
79
for step in range(decoder_size): # output logits.
80
output_feed.append(model.outputs[bucket_id][step])
81
82
outputs = sess.run(output_feed, input_feed)
83
if not forward_only:
84
return outputs[1], outputs[2], None # Gradient norm, loss, no outputs.
85
else:
86
return None, outputs[0], outputs[1:] # No gradient norm, loss, outputs.
87
88
def _get_buckets():
89
""" Load the dataset into buckets based on their lengths.
90
train_buckets_scale is the inverval that'll help us
91
choose a random bucket later on.
92
"""
93
test_buckets = data.load_data('test_ids.enc', 'test_ids.dec')
94
data_buckets = data.load_data('train_ids.enc', 'train_ids.dec')
95
train_bucket_sizes = [len(data_buckets[b]) for b in range(len(config.BUCKETS))]
96
print("Number of samples in each bucket:\n", train_bucket_sizes)
97
train_total_size = sum(train_bucket_sizes)
98
# list of increasing numbers from 0 to 1 that we'll use to select a bucket.
99
train_buckets_scale = [sum(train_bucket_sizes[:i + 1]) / train_total_size
100
for i in range(len(train_bucket_sizes))]
101
print("Bucket scale:\n", train_buckets_scale)
102
return test_buckets, data_buckets, train_buckets_scale
103
104
def _get_skip_step(iteration):
105
""" How many steps should the model train before it saves all the weights. """
106
if iteration < 100:
107
return 30
108
return 100
109
110
def _check_restore_parameters(sess, saver):
111
""" Restore the previously trained parameters if there are any. """
112
ckpt = tf.train.get_checkpoint_state(os.path.dirname(config.CPT_PATH + '/checkpoint'))
113
if ckpt and ckpt.model_checkpoint_path:
114
print("Loading parameters for the Chatbot")
115
saver.restore(sess, ckpt.model_checkpoint_path)
116
else:
117
print("Initializing fresh parameters for the Chatbot")
118
119
def _eval_test_set(sess, model, test_buckets):
120
""" Evaluate on the test set. """
121
for bucket_id in range(len(config.BUCKETS)):
122
if len(test_buckets[bucket_id]) == 0:
123
print(" Test: empty bucket %d" % (bucket_id))
124
continue
125
start = time.time()
126
encoder_inputs, decoder_inputs, decoder_masks = data.get_batch(test_buckets[bucket_id],
127
bucket_id,
128
batch_size=config.BATCH_SIZE)
129
_, step_loss, _ = run_step(sess, model, encoder_inputs, decoder_inputs,
130
decoder_masks, bucket_id, True)
131
print('Test bucket {}: loss {}, time {}'.format(bucket_id, step_loss, time.time() - start))
132
133
def train():
134
""" Train the bot """
135
test_buckets, data_buckets, train_buckets_scale = _get_buckets()
136
# in train mode, we need to create the backward path, so forwrad_only is False
137
model = ChatBotModel(False, config.BATCH_SIZE)
138
model.build_graph()
139
140
saver = tf.train.Saver()
141
142
with tf.Session() as sess:
143
print('Running session')
144
sess.run(tf.global_variables_initializer())
145
_check_restore_parameters(sess, saver)
146
147
iteration = model.global_step.eval()
148
total_loss = 0
149
while True:
150
skip_step = _get_skip_step(iteration)
151
bucket_id = _get_random_bucket(train_buckets_scale)
152
encoder_inputs, decoder_inputs, decoder_masks = data.get_batch(data_buckets[bucket_id],
153
bucket_id,
154
batch_size=config.BATCH_SIZE)
155
start = time.time()
156
_, step_loss, _ = run_step(sess, model, encoder_inputs, decoder_inputs, decoder_masks, bucket_id, False)
157
total_loss += step_loss
158
iteration += 1
159
160
if iteration % skip_step == 0:
161
print('Iter {}: loss {}, time {}'.format(iteration, total_loss/skip_step, time.time() - start))
162
start = time.time()
163
total_loss = 0
164
saver.save(sess, os.path.join(config.CPT_PATH, 'chatbot'), global_step=model.global_step)
165
if iteration % (10 * skip_step) == 0:
166
# Run evals on development set and print their loss
167
_eval_test_set(sess, model, test_buckets)
168
start = time.time()
169
sys.stdout.flush()
170
171
def _get_user_input():
172
""" Get user's input, which will be transformed into encoder input later """
173
print("> ", end="")
174
sys.stdout.flush()
175
return sys.stdin.readline()
176
177
def _find_right_bucket(length):
178
""" Find the proper bucket for an encoder input based on its length """
179
return min([b for b in range(len(config.BUCKETS))
180
if config.BUCKETS[b][0] >= length])
181
182
def _construct_response(output_logits, inv_dec_vocab):
183
""" Construct a response to the user's encoder input.
184
@output_logits: the outputs from sequence to sequence wrapper.
185
output_logits is decoder_size np array, each of dim 1 x DEC_VOCAB
186
187
This is a greedy decoder - outputs are just argmaxes of output_logits.
188
"""
189
print(output_logits[0])
190
outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]
191
# If there is an EOS symbol in outputs, cut them at that point.
192
if config.EOS_ID in outputs:
193
outputs = outputs[:outputs.index(config.EOS_ID)]
194
# Print out sentence corresponding to outputs.
195
return " ".join([tf.compat.as_str(inv_dec_vocab[output]) for output in outputs])
196
197
def chat():
198
""" in test mode, we don't to create the backward path
199
"""
200
_, enc_vocab = data.load_vocab(os.path.join(config.PROCESSED_PATH, 'vocab.enc'))
201
inv_dec_vocab, _ = data.load_vocab(os.path.join(config.PROCESSED_PATH, 'vocab.dec'))
202
203
model = ChatBotModel(True, batch_size=1)
204
model.build_graph()
205
206
saver = tf.train.Saver()
207
208
with tf.Session() as sess:
209
sess.run(tf.global_variables_initializer())
210
_check_restore_parameters(sess, saver)
211
output_file = open(os.path.join(config.PROCESSED_PATH, config.OUTPUT_FILE), 'a+')
212
# Decode from standard input.
213
max_length = config.BUCKETS[-1][0]
214
print('Welcome to TensorBro. Say something. Enter to exit. Max length is', max_length)
215
while True:
216
line = _get_user_input()
217
if len(line) > 0 and line[-1] == '\n':
218
line = line[:-1]
219
if line == '':
220
break
221
output_file.write('HUMAN ++++ ' + line + '\n')
222
# Get token-ids for the input sentence.
223
token_ids = data.sentence2id(enc_vocab, str(line))
224
if (len(token_ids) > max_length):
225
print('Max length I can handle is:', max_length)
226
line = _get_user_input()
227
continue
228
# Which bucket does it belong to?
229
bucket_id = _find_right_bucket(len(token_ids))
230
# Get a 1-element batch to feed the sentence to the model.
231
encoder_inputs, decoder_inputs, decoder_masks = data.get_batch([(token_ids, [])],
232
bucket_id,
233
batch_size=1)
234
# Get output logits for the sentence.
235
_, _, output_logits = run_step(sess, model, encoder_inputs, decoder_inputs,
236
decoder_masks, bucket_id, True)
237
response = _construct_response(output_logits, inv_dec_vocab)
238
print(response)
239
output_file.write('BOT ++++ ' + response + '\n')
240
output_file.write('=============================================\n')
241
output_file.close()
242
243
def main():
244
parser = argparse.ArgumentParser()
245
parser.add_argument('--mode', choices={'train', 'chat'},
246
default='train', help="mode. if not specified, it's in the train mode")
247
args = parser.parse_args()
248
249
if not os.path.isdir(config.PROCESSED_PATH):
250
data.prepare_raw_data()
251
data.process_data()
252
print('Data ready!')
253
# create checkpoints folder if there isn't one already
254
data.make_dir(config.CPT_PATH)
255
256
if args.mode == 'train':
257
train()
258
elif args.mode == 'chat':
259
chat()
260
261
if __name__ == '__main__':
262
main()
263
264