Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Download

📚 The CoCalc Library - books, templates and other resources

132928 views
License: OTHER
1
""" A neural chatbot using sequence to sequence model with
2
attentional decoder.
3
4
This is based on Google Translate Tensorflow model
5
https://github.com/tensorflow/models/blob/master/tutorials/rnn/translate/
6
7
Sequence to sequence model by Cho et al.(2014)
8
9
Created by Chip Huyen ([email protected])
10
CS20: "TensorFlow for Deep Learning Research"
11
cs20.stanford.edu
12
13
This file contains the code to run the model.
14
15
See README.md for instruction on how to run the starter code.
16
"""
17
import argparse
18
import os
19
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
20
import random
21
import sys
22
import time
23
24
import numpy as np
25
import tensorflow as tf
26
27
from model import ChatBotModel
28
import config
29
import data
30
31
def _get_random_bucket(train_buckets_scale):
32
""" Get a random bucket from which to choose a training sample """
33
rand = random.random()
34
return min([i for i in range(len(train_buckets_scale))
35
if train_buckets_scale[i] > rand])
36
37
def _assert_lengths(encoder_size, decoder_size, encoder_inputs, decoder_inputs, decoder_masks):
38
""" Assert that the encoder inputs, decoder inputs, and decoder masks are
39
of the expected lengths """
40
if len(encoder_inputs) != encoder_size:
41
raise ValueError("Encoder length must be equal to the one in bucket,"
42
" %d != %d." % (len(encoder_inputs), encoder_size))
43
if len(decoder_inputs) != decoder_size:
44
raise ValueError("Decoder length must be equal to the one in bucket,"
45
" %d != %d." % (len(decoder_inputs), decoder_size))
46
if len(decoder_masks) != decoder_size:
47
raise ValueError("Weights length must be equal to the one in bucket,"
48
" %d != %d." % (len(decoder_masks), decoder_size))
49
50
def run_step(sess, model, encoder_inputs, decoder_inputs, decoder_masks, bucket_id, forward_only):
51
""" Run one step in training.
52
@forward_only: boolean value to decide whether a backward path should be created
53
forward_only is set to True when you just want to evaluate on the test set,
54
or when you want to the bot to be in chat mode. """
55
encoder_size, decoder_size = config.BUCKETS[bucket_id]
56
_assert_lengths(encoder_size, decoder_size, encoder_inputs, decoder_inputs, decoder_masks)
57
58
# input feed: encoder inputs, decoder inputs, target_weights, as provided.
59
input_feed = {}
60
for step in range(encoder_size):
61
input_feed[model.encoder_inputs[step].name] = encoder_inputs[step]
62
for step in range(decoder_size):
63
input_feed[model.decoder_inputs[step].name] = decoder_inputs[step]
64
input_feed[model.decoder_masks[step].name] = decoder_masks[step]
65
66
last_target = model.decoder_inputs[decoder_size].name
67
input_feed[last_target] = np.zeros([model.batch_size], dtype=np.int32)
68
69
# output feed: depends on whether we do a backward step or not.
70
if not forward_only:
71
output_feed = [model.train_ops[bucket_id], # update op that does SGD.
72
model.gradient_norms[bucket_id], # gradient norm.
73
model.losses[bucket_id]] # loss for this batch.
74
else:
75
output_feed = [model.losses[bucket_id]] # loss for this batch.
76
for step in range(decoder_size): # output logits.
77
output_feed.append(model.outputs[bucket_id][step])
78
79
outputs = sess.run(output_feed, input_feed)
80
if not forward_only:
81
return outputs[1], outputs[2], None # Gradient norm, loss, no outputs.
82
else:
83
return None, outputs[0], outputs[1:] # No gradient norm, loss, outputs.
84
85
def _get_buckets():
86
""" Load the dataset into buckets based on their lengths.
87
train_buckets_scale is the inverval that'll help us
88
choose a random bucket later on.
89
"""
90
test_buckets = data.load_data('test_ids.enc', 'test_ids.dec')
91
data_buckets = data.load_data('train_ids.enc', 'train_ids.dec')
92
train_bucket_sizes = [len(data_buckets[b]) for b in range(len(config.BUCKETS))]
93
print("Number of samples in each bucket:\n", train_bucket_sizes)
94
train_total_size = sum(train_bucket_sizes)
95
# list of increasing numbers from 0 to 1 that we'll use to select a bucket.
96
train_buckets_scale = [sum(train_bucket_sizes[:i + 1]) / train_total_size
97
for i in range(len(train_bucket_sizes))]
98
print("Bucket scale:\n", train_buckets_scale)
99
return test_buckets, data_buckets, train_buckets_scale
100
101
def _get_skip_step(iteration):
102
""" How many steps should the model train before it saves all the weights. """
103
if iteration < 100:
104
return 30
105
return 100
106
107
def _check_restore_parameters(sess, saver):
108
""" Restore the previously trained parameters if there are any. """
109
ckpt = tf.train.get_checkpoint_state(os.path.dirname(config.CPT_PATH + '/checkpoint'))
110
if ckpt and ckpt.model_checkpoint_path:
111
print("Loading parameters for the Chatbot")
112
saver.restore(sess, ckpt.model_checkpoint_path)
113
else:
114
print("Initializing fresh parameters for the Chatbot")
115
116
def _eval_test_set(sess, model, test_buckets):
117
""" Evaluate on the test set. """
118
for bucket_id in range(len(config.BUCKETS)):
119
if len(test_buckets[bucket_id]) == 0:
120
print(" Test: empty bucket %d" % (bucket_id))
121
continue
122
start = time.time()
123
encoder_inputs, decoder_inputs, decoder_masks = data.get_batch(test_buckets[bucket_id],
124
bucket_id,
125
batch_size=config.BATCH_SIZE)
126
_, step_loss, _ = run_step(sess, model, encoder_inputs, decoder_inputs,
127
decoder_masks, bucket_id, True)
128
print('Test bucket {}: loss {}, time {}'.format(bucket_id, step_loss, time.time() - start))
129
130
def train():
131
""" Train the bot """
132
test_buckets, data_buckets, train_buckets_scale = _get_buckets()
133
# in train mode, we need to create the backward path, so forwrad_only is False
134
model = ChatBotModel(False, config.BATCH_SIZE)
135
model.build_graph()
136
137
saver = tf.train.Saver()
138
139
with tf.Session() as sess:
140
print('Running session')
141
sess.run(tf.global_variables_initializer())
142
_check_restore_parameters(sess, saver)
143
144
iteration = model.global_step.eval()
145
total_loss = 0
146
while True:
147
skip_step = _get_skip_step(iteration)
148
bucket_id = _get_random_bucket(train_buckets_scale)
149
encoder_inputs, decoder_inputs, decoder_masks = data.get_batch(data_buckets[bucket_id],
150
bucket_id,
151
batch_size=config.BATCH_SIZE)
152
start = time.time()
153
_, step_loss, _ = run_step(sess, model, encoder_inputs, decoder_inputs, decoder_masks, bucket_id, False)
154
total_loss += step_loss
155
iteration += 1
156
157
if iteration % skip_step == 0:
158
print('Iter {}: loss {}, time {}'.format(iteration, total_loss/skip_step, time.time() - start))
159
start = time.time()
160
total_loss = 0
161
saver.save(sess, os.path.join(config.CPT_PATH, 'chatbot'), global_step=model.global_step)
162
if iteration % (10 * skip_step) == 0:
163
# Run evals on development set and print their loss
164
_eval_test_set(sess, model, test_buckets)
165
start = time.time()
166
sys.stdout.flush()
167
168
def _get_user_input():
169
""" Get user's input, which will be transformed into encoder input later """
170
print("> ", end="")
171
sys.stdout.flush()
172
return sys.stdin.readline()
173
174
def _find_right_bucket(length):
175
""" Find the proper bucket for an encoder input based on its length """
176
return min([b for b in range(len(config.BUCKETS))
177
if config.BUCKETS[b][0] >= length])
178
179
def _construct_response(output_logits, inv_dec_vocab):
180
""" Construct a response to the user's encoder input.
181
@output_logits: the outputs from sequence to sequence wrapper.
182
output_logits is decoder_size np array, each of dim 1 x DEC_VOCAB
183
184
This is a greedy decoder - outputs are just argmaxes of output_logits.
185
"""
186
print(output_logits[0])
187
outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]
188
# If there is an EOS symbol in outputs, cut them at that point.
189
if config.EOS_ID in outputs:
190
outputs = outputs[:outputs.index(config.EOS_ID)]
191
# Print out sentence corresponding to outputs.
192
return " ".join([tf.compat.as_str(inv_dec_vocab[output]) for output in outputs])
193
194
def chat():
195
""" in test mode, we don't to create the backward path
196
"""
197
_, enc_vocab = data.load_vocab(os.path.join(config.PROCESSED_PATH, 'vocab.enc'))
198
inv_dec_vocab, _ = data.load_vocab(os.path.join(config.PROCESSED_PATH, 'vocab.dec'))
199
200
model = ChatBotModel(True, batch_size=1)
201
model.build_graph()
202
203
saver = tf.train.Saver()
204
205
with tf.Session() as sess:
206
sess.run(tf.global_variables_initializer())
207
_check_restore_parameters(sess, saver)
208
output_file = open(os.path.join(config.PROCESSED_PATH, config.OUTPUT_FILE), 'a+')
209
# Decode from standard input.
210
max_length = config.BUCKETS[-1][0]
211
print('Welcome to TensorBro. Say something. Enter to exit. Max length is', max_length)
212
while True:
213
line = _get_user_input()
214
if len(line) > 0 and line[-1] == '\n':
215
line = line[:-1]
216
if line == '':
217
break
218
output_file.write('HUMAN ++++ ' + line + '\n')
219
# Get token-ids for the input sentence.
220
token_ids = data.sentence2id(enc_vocab, str(line))
221
if (len(token_ids) > max_length):
222
print('Max length I can handle is:', max_length)
223
line = _get_user_input()
224
continue
225
# Which bucket does it belong to?
226
bucket_id = _find_right_bucket(len(token_ids))
227
# Get a 1-element batch to feed the sentence to the model.
228
encoder_inputs, decoder_inputs, decoder_masks = data.get_batch([(token_ids, [])],
229
bucket_id,
230
batch_size=1)
231
# Get output logits for the sentence.
232
_, _, output_logits = run_step(sess, model, encoder_inputs, decoder_inputs,
233
decoder_masks, bucket_id, True)
234
response = _construct_response(output_logits, inv_dec_vocab)
235
print(response)
236
output_file.write('BOT ++++ ' + response + '\n')
237
output_file.write('=============================================\n')
238
output_file.close()
239
240
def main():
241
parser = argparse.ArgumentParser()
242
parser.add_argument('--mode', choices={'train', 'chat'},
243
default='train', help="mode. if not specified, it's in the train mode")
244
args = parser.parse_args()
245
246
if not os.path.isdir(config.PROCESSED_PATH):
247
data.prepare_raw_data()
248
data.process_data()
249
print('Data ready!')
250
# create checkpoints folder if there isn't one already
251
data.make_dir(config.CPT_PATH)
252
253
if args.mode == 'train':
254
train()
255
elif args.mode == 'chat':
256
chat()
257
258
if __name__ == '__main__':
259
main()
260
261