📚 The CoCalc Library - books, templates and other resources
License: OTHER
""" A neural chatbot using sequence to sequence model with1attentional decoder.23This is based on Google Translate Tensorflow model4https://github.com/tensorflow/models/blob/master/tutorials/rnn/translate/56Sequence to sequence model by Cho et al.(2014)78Created by Chip Huyen as the starter code for assignment 3,9class CS 20SI: "TensorFlow for Deep Learning Research"10cs20si.stanford.edu1112This file contains the code to run the model.1314See readme.md for instruction on how to run the starter code.15"""16from __future__ import division17from __future__ import print_function1819import argparse20import os21os.environ['TF_CPP_MIN_LOG_LEVEL']='2'22import random23import sys24import time2526import numpy as np27import tensorflow as tf2829from model import ChatBotModel30import config31import data3233def _get_random_bucket(train_buckets_scale):34""" Get a random bucket from which to choose a training sample """35rand = random.random()36return min([i for i in range(len(train_buckets_scale))37if train_buckets_scale[i] > rand])3839def _assert_lengths(encoder_size, decoder_size, encoder_inputs, decoder_inputs, decoder_masks):40""" Assert that the encoder inputs, decoder inputs, and decoder masks are41of the expected lengths """42if len(encoder_inputs) != encoder_size:43raise ValueError("Encoder length must be equal to the one in bucket,"44" %d != %d." % (len(encoder_inputs), encoder_size))45if len(decoder_inputs) != decoder_size:46raise ValueError("Decoder length must be equal to the one in bucket,"47" %d != %d." % (len(decoder_inputs), decoder_size))48if len(decoder_masks) != decoder_size:49raise ValueError("Weights length must be equal to the one in bucket,"50" %d != %d." % (len(decoder_masks), decoder_size))5152def run_step(sess, model, encoder_inputs, decoder_inputs, decoder_masks, bucket_id, forward_only):53""" Run one step in training.54@forward_only: boolean value to decide whether a backward path should be created55forward_only is set to True when you just want to evaluate on the test set,56or when you want to the bot to be in chat mode. """57encoder_size, decoder_size = config.BUCKETS[bucket_id]58_assert_lengths(encoder_size, decoder_size, encoder_inputs, decoder_inputs, decoder_masks)5960# input feed: encoder inputs, decoder inputs, target_weights, as provided.61input_feed = {}62for step in range(encoder_size):63input_feed[model.encoder_inputs[step].name] = encoder_inputs[step]64for step in range(decoder_size):65input_feed[model.decoder_inputs[step].name] = decoder_inputs[step]66input_feed[model.decoder_masks[step].name] = decoder_masks[step]6768last_target = model.decoder_inputs[decoder_size].name69input_feed[last_target] = np.zeros([model.batch_size], dtype=np.int32)7071# output feed: depends on whether we do a backward step or not.72if not forward_only:73output_feed = [model.train_ops[bucket_id], # update op that does SGD.74model.gradient_norms[bucket_id], # gradient norm.75model.losses[bucket_id]] # loss for this batch.76else:77output_feed = [model.losses[bucket_id]] # loss for this batch.78for step in range(decoder_size): # output logits.79output_feed.append(model.outputs[bucket_id][step])8081outputs = sess.run(output_feed, input_feed)82if not forward_only:83return outputs[1], outputs[2], None # Gradient norm, loss, no outputs.84else:85return None, outputs[0], outputs[1:] # No gradient norm, loss, outputs.8687def _get_buckets():88""" Load the dataset into buckets based on their lengths.89train_buckets_scale is the inverval that'll help us90choose a random bucket later on.91"""92test_buckets = data.load_data('test_ids.enc', 'test_ids.dec')93data_buckets = data.load_data('train_ids.enc', 'train_ids.dec')94train_bucket_sizes = [len(data_buckets[b]) for b in range(len(config.BUCKETS))]95print("Number of samples in each bucket:\n", train_bucket_sizes)96train_total_size = sum(train_bucket_sizes)97# list of increasing numbers from 0 to 1 that we'll use to select a bucket.98train_buckets_scale = [sum(train_bucket_sizes[:i + 1]) / train_total_size99for i in range(len(train_bucket_sizes))]100print("Bucket scale:\n", train_buckets_scale)101return test_buckets, data_buckets, train_buckets_scale102103def _get_skip_step(iteration):104""" How many steps should the model train before it saves all the weights. """105if iteration < 100:106return 30107return 100108109def _check_restore_parameters(sess, saver):110""" Restore the previously trained parameters if there are any. """111ckpt = tf.train.get_checkpoint_state(os.path.dirname(config.CPT_PATH + '/checkpoint'))112if ckpt and ckpt.model_checkpoint_path:113print("Loading parameters for the Chatbot")114saver.restore(sess, ckpt.model_checkpoint_path)115else:116print("Initializing fresh parameters for the Chatbot")117118def _eval_test_set(sess, model, test_buckets):119""" Evaluate on the test set. """120for bucket_id in range(len(config.BUCKETS)):121if len(test_buckets[bucket_id]) == 0:122print(" Test: empty bucket %d" % (bucket_id))123continue124start = time.time()125encoder_inputs, decoder_inputs, decoder_masks = data.get_batch(test_buckets[bucket_id],126bucket_id,127batch_size=config.BATCH_SIZE)128_, step_loss, _ = run_step(sess, model, encoder_inputs, decoder_inputs,129decoder_masks, bucket_id, True)130print('Test bucket {}: loss {}, time {}'.format(bucket_id, step_loss, time.time() - start))131132def train():133""" Train the bot """134test_buckets, data_buckets, train_buckets_scale = _get_buckets()135# in train mode, we need to create the backward path, so forwrad_only is False136model = ChatBotModel(False, config.BATCH_SIZE)137model.build_graph()138139saver = tf.train.Saver()140141with tf.Session() as sess:142print('Running session')143sess.run(tf.global_variables_initializer())144_check_restore_parameters(sess, saver)145146iteration = model.global_step.eval()147total_loss = 0148while True:149skip_step = _get_skip_step(iteration)150bucket_id = _get_random_bucket(train_buckets_scale)151encoder_inputs, decoder_inputs, decoder_masks = data.get_batch(data_buckets[bucket_id],152bucket_id,153batch_size=config.BATCH_SIZE)154start = time.time()155_, step_loss, _ = run_step(sess, model, encoder_inputs, decoder_inputs, decoder_masks, bucket_id, False)156total_loss += step_loss157iteration += 1158159if iteration % skip_step == 0:160print('Iter {}: loss {}, time {}'.format(iteration, total_loss/skip_step, time.time() - start))161start = time.time()162total_loss = 0163saver.save(sess, os.path.join(config.CPT_PATH, 'chatbot'), global_step=model.global_step)164if iteration % (10 * skip_step) == 0:165# Run evals on development set and print their loss166_eval_test_set(sess, model, test_buckets)167start = time.time()168sys.stdout.flush()169170def _get_user_input():171""" Get user's input, which will be transformed into encoder input later """172print("> ", end="")173sys.stdout.flush()174return sys.stdin.readline()175176def _find_right_bucket(length):177""" Find the proper bucket for an encoder input based on its length """178return min([b for b in range(len(config.BUCKETS))179if config.BUCKETS[b][0] >= length])180181def _construct_response(output_logits, inv_dec_vocab):182""" Construct a response to the user's encoder input.183@output_logits: the outputs from sequence to sequence wrapper.184output_logits is decoder_size np array, each of dim 1 x DEC_VOCAB185186This is a greedy decoder - outputs are just argmaxes of output_logits.187"""188print(output_logits[0])189outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]190# If there is an EOS symbol in outputs, cut them at that point.191if config.EOS_ID in outputs:192outputs = outputs[:outputs.index(config.EOS_ID)]193# Print out sentence corresponding to outputs.194return " ".join([tf.compat.as_str(inv_dec_vocab[output]) for output in outputs])195196def chat():197""" in test mode, we don't to create the backward path198"""199_, enc_vocab = data.load_vocab(os.path.join(config.PROCESSED_PATH, 'vocab.enc'))200inv_dec_vocab, _ = data.load_vocab(os.path.join(config.PROCESSED_PATH, 'vocab.dec'))201202model = ChatBotModel(True, batch_size=1)203model.build_graph()204205saver = tf.train.Saver()206207with tf.Session() as sess:208sess.run(tf.global_variables_initializer())209_check_restore_parameters(sess, saver)210output_file = open(os.path.join(config.PROCESSED_PATH, config.OUTPUT_FILE), 'a+')211# Decode from standard input.212max_length = config.BUCKETS[-1][0]213print('Welcome to TensorBro. Say something. Enter to exit. Max length is', max_length)214while True:215line = _get_user_input()216if len(line) > 0 and line[-1] == '\n':217line = line[:-1]218if line == '':219break220output_file.write('HUMAN ++++ ' + line + '\n')221# Get token-ids for the input sentence.222token_ids = data.sentence2id(enc_vocab, str(line))223if (len(token_ids) > max_length):224print('Max length I can handle is:', max_length)225line = _get_user_input()226continue227# Which bucket does it belong to?228bucket_id = _find_right_bucket(len(token_ids))229# Get a 1-element batch to feed the sentence to the model.230encoder_inputs, decoder_inputs, decoder_masks = data.get_batch([(token_ids, [])],231bucket_id,232batch_size=1)233# Get output logits for the sentence.234_, _, output_logits = run_step(sess, model, encoder_inputs, decoder_inputs,235decoder_masks, bucket_id, True)236response = _construct_response(output_logits, inv_dec_vocab)237print(response)238output_file.write('BOT ++++ ' + response + '\n')239output_file.write('=============================================\n')240output_file.close()241242def main():243parser = argparse.ArgumentParser()244parser.add_argument('--mode', choices={'train', 'chat'},245default='train', help="mode. if not specified, it's in the train mode")246args = parser.parse_args()247248if not os.path.isdir(config.PROCESSED_PATH):249data.prepare_raw_data()250data.process_data()251print('Data ready!')252# create checkpoints folder if there isn't one already253data.make_dir(config.CPT_PATH)254255if args.mode == 'train':256train()257elif args.mode == 'chat':258chat()259260if __name__ == '__main__':261main()262263264