📚 The CoCalc Library - books, templates and other resources
License: OTHER
""" A neural chatbot using sequence to sequence model with1attentional decoder.23This is based on Google Translate Tensorflow model4https://github.com/tensorflow/models/blob/master/tutorials/rnn/translate/56Sequence to sequence model by Cho et al.(2014)78Created by Chip Huyen ([email protected])9CS20: "TensorFlow for Deep Learning Research"10cs20.stanford.edu1112This file contains the code to run the model.1314See README.md for instruction on how to run the starter code.15"""16import argparse17import os18os.environ['TF_CPP_MIN_LOG_LEVEL']='2'19import random20import sys21import time2223import numpy as np24import tensorflow as tf2526from model import ChatBotModel27import config28import data2930def _get_random_bucket(train_buckets_scale):31""" Get a random bucket from which to choose a training sample """32rand = random.random()33return min([i for i in range(len(train_buckets_scale))34if train_buckets_scale[i] > rand])3536def _assert_lengths(encoder_size, decoder_size, encoder_inputs, decoder_inputs, decoder_masks):37""" Assert that the encoder inputs, decoder inputs, and decoder masks are38of the expected lengths """39if len(encoder_inputs) != encoder_size:40raise ValueError("Encoder length must be equal to the one in bucket,"41" %d != %d." % (len(encoder_inputs), encoder_size))42if len(decoder_inputs) != decoder_size:43raise ValueError("Decoder length must be equal to the one in bucket,"44" %d != %d." % (len(decoder_inputs), decoder_size))45if len(decoder_masks) != decoder_size:46raise ValueError("Weights length must be equal to the one in bucket,"47" %d != %d." % (len(decoder_masks), decoder_size))4849def run_step(sess, model, encoder_inputs, decoder_inputs, decoder_masks, bucket_id, forward_only):50""" Run one step in training.51@forward_only: boolean value to decide whether a backward path should be created52forward_only is set to True when you just want to evaluate on the test set,53or when you want to the bot to be in chat mode. """54encoder_size, decoder_size = config.BUCKETS[bucket_id]55_assert_lengths(encoder_size, decoder_size, encoder_inputs, decoder_inputs, decoder_masks)5657# input feed: encoder inputs, decoder inputs, target_weights, as provided.58input_feed = {}59for step in range(encoder_size):60input_feed[model.encoder_inputs[step].name] = encoder_inputs[step]61for step in range(decoder_size):62input_feed[model.decoder_inputs[step].name] = decoder_inputs[step]63input_feed[model.decoder_masks[step].name] = decoder_masks[step]6465last_target = model.decoder_inputs[decoder_size].name66input_feed[last_target] = np.zeros([model.batch_size], dtype=np.int32)6768# output feed: depends on whether we do a backward step or not.69if not forward_only:70output_feed = [model.train_ops[bucket_id], # update op that does SGD.71model.gradient_norms[bucket_id], # gradient norm.72model.losses[bucket_id]] # loss for this batch.73else:74output_feed = [model.losses[bucket_id]] # loss for this batch.75for step in range(decoder_size): # output logits.76output_feed.append(model.outputs[bucket_id][step])7778outputs = sess.run(output_feed, input_feed)79if not forward_only:80return outputs[1], outputs[2], None # Gradient norm, loss, no outputs.81else:82return None, outputs[0], outputs[1:] # No gradient norm, loss, outputs.8384def _get_buckets():85""" Load the dataset into buckets based on their lengths.86train_buckets_scale is the inverval that'll help us87choose a random bucket later on.88"""89test_buckets = data.load_data('test_ids.enc', 'test_ids.dec')90data_buckets = data.load_data('train_ids.enc', 'train_ids.dec')91train_bucket_sizes = [len(data_buckets[b]) for b in range(len(config.BUCKETS))]92print("Number of samples in each bucket:\n", train_bucket_sizes)93train_total_size = sum(train_bucket_sizes)94# list of increasing numbers from 0 to 1 that we'll use to select a bucket.95train_buckets_scale = [sum(train_bucket_sizes[:i + 1]) / train_total_size96for i in range(len(train_bucket_sizes))]97print("Bucket scale:\n", train_buckets_scale)98return test_buckets, data_buckets, train_buckets_scale99100def _get_skip_step(iteration):101""" How many steps should the model train before it saves all the weights. """102if iteration < 100:103return 30104return 100105106def _check_restore_parameters(sess, saver):107""" Restore the previously trained parameters if there are any. """108ckpt = tf.train.get_checkpoint_state(os.path.dirname(config.CPT_PATH + '/checkpoint'))109if ckpt and ckpt.model_checkpoint_path:110print("Loading parameters for the Chatbot")111saver.restore(sess, ckpt.model_checkpoint_path)112else:113print("Initializing fresh parameters for the Chatbot")114115def _eval_test_set(sess, model, test_buckets):116""" Evaluate on the test set. """117for bucket_id in range(len(config.BUCKETS)):118if len(test_buckets[bucket_id]) == 0:119print(" Test: empty bucket %d" % (bucket_id))120continue121start = time.time()122encoder_inputs, decoder_inputs, decoder_masks = data.get_batch(test_buckets[bucket_id],123bucket_id,124batch_size=config.BATCH_SIZE)125_, step_loss, _ = run_step(sess, model, encoder_inputs, decoder_inputs,126decoder_masks, bucket_id, True)127print('Test bucket {}: loss {}, time {}'.format(bucket_id, step_loss, time.time() - start))128129def train():130""" Train the bot """131test_buckets, data_buckets, train_buckets_scale = _get_buckets()132# in train mode, we need to create the backward path, so forwrad_only is False133model = ChatBotModel(False, config.BATCH_SIZE)134model.build_graph()135136saver = tf.train.Saver()137138with tf.Session() as sess:139print('Running session')140sess.run(tf.global_variables_initializer())141_check_restore_parameters(sess, saver)142143iteration = model.global_step.eval()144total_loss = 0145while True:146skip_step = _get_skip_step(iteration)147bucket_id = _get_random_bucket(train_buckets_scale)148encoder_inputs, decoder_inputs, decoder_masks = data.get_batch(data_buckets[bucket_id],149bucket_id,150batch_size=config.BATCH_SIZE)151start = time.time()152_, step_loss, _ = run_step(sess, model, encoder_inputs, decoder_inputs, decoder_masks, bucket_id, False)153total_loss += step_loss154iteration += 1155156if iteration % skip_step == 0:157print('Iter {}: loss {}, time {}'.format(iteration, total_loss/skip_step, time.time() - start))158start = time.time()159total_loss = 0160saver.save(sess, os.path.join(config.CPT_PATH, 'chatbot'), global_step=model.global_step)161if iteration % (10 * skip_step) == 0:162# Run evals on development set and print their loss163_eval_test_set(sess, model, test_buckets)164start = time.time()165sys.stdout.flush()166167def _get_user_input():168""" Get user's input, which will be transformed into encoder input later """169print("> ", end="")170sys.stdout.flush()171return sys.stdin.readline()172173def _find_right_bucket(length):174""" Find the proper bucket for an encoder input based on its length """175return min([b for b in range(len(config.BUCKETS))176if config.BUCKETS[b][0] >= length])177178def _construct_response(output_logits, inv_dec_vocab):179""" Construct a response to the user's encoder input.180@output_logits: the outputs from sequence to sequence wrapper.181output_logits is decoder_size np array, each of dim 1 x DEC_VOCAB182183This is a greedy decoder - outputs are just argmaxes of output_logits.184"""185print(output_logits[0])186outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]187# If there is an EOS symbol in outputs, cut them at that point.188if config.EOS_ID in outputs:189outputs = outputs[:outputs.index(config.EOS_ID)]190# Print out sentence corresponding to outputs.191return " ".join([tf.compat.as_str(inv_dec_vocab[output]) for output in outputs])192193def chat():194""" in test mode, we don't to create the backward path195"""196_, enc_vocab = data.load_vocab(os.path.join(config.PROCESSED_PATH, 'vocab.enc'))197inv_dec_vocab, _ = data.load_vocab(os.path.join(config.PROCESSED_PATH, 'vocab.dec'))198199model = ChatBotModel(True, batch_size=1)200model.build_graph()201202saver = tf.train.Saver()203204with tf.Session() as sess:205sess.run(tf.global_variables_initializer())206_check_restore_parameters(sess, saver)207output_file = open(os.path.join(config.PROCESSED_PATH, config.OUTPUT_FILE), 'a+')208# Decode from standard input.209max_length = config.BUCKETS[-1][0]210print('Welcome to TensorBro. Say something. Enter to exit. Max length is', max_length)211while True:212line = _get_user_input()213if len(line) > 0 and line[-1] == '\n':214line = line[:-1]215if line == '':216break217output_file.write('HUMAN ++++ ' + line + '\n')218# Get token-ids for the input sentence.219token_ids = data.sentence2id(enc_vocab, str(line))220if (len(token_ids) > max_length):221print('Max length I can handle is:', max_length)222line = _get_user_input()223continue224# Which bucket does it belong to?225bucket_id = _find_right_bucket(len(token_ids))226# Get a 1-element batch to feed the sentence to the model.227encoder_inputs, decoder_inputs, decoder_masks = data.get_batch([(token_ids, [])],228bucket_id,229batch_size=1)230# Get output logits for the sentence.231_, _, output_logits = run_step(sess, model, encoder_inputs, decoder_inputs,232decoder_masks, bucket_id, True)233response = _construct_response(output_logits, inv_dec_vocab)234print(response)235output_file.write('BOT ++++ ' + response + '\n')236output_file.write('=============================================\n')237output_file.close()238239def main():240parser = argparse.ArgumentParser()241parser.add_argument('--mode', choices={'train', 'chat'},242default='train', help="mode. if not specified, it's in the train mode")243args = parser.parse_args()244245if not os.path.isdir(config.PROCESSED_PATH):246data.prepare_raw_data()247data.process_data()248print('Data ready!')249# create checkpoints folder if there isn't one already250data.make_dir(config.CPT_PATH)251252if args.mode == 'train':253train()254elif args.mode == 'chat':255chat()256257if __name__ == '__main__':258main()259260261