📚 The CoCalc Library - books, templates and other resources
License: OTHER
""" A neural chatbot using sequence to sequence model with1attentional decoder.23This is based on Google Translate Tensorflow model4https://github.com/tensorflow/models/blob/master/tutorials/rnn/translate/56Sequence to sequence model by Cho et al.(2014)78Created by Chip Huyen ([email protected])9CS20: "TensorFlow for Deep Learning Research"10cs20.stanford.edu1112This file contains the code to do the pre-processing for the13Cornell Movie-Dialogs Corpus.1415See readme.md for instruction on how to run the starter code.16"""17import os18import random19import re2021import numpy as np2223import config2425def get_lines():26id2line = {}27file_path = os.path.join(config.DATA_PATH, config.LINE_FILE)28print(config.LINE_FILE)29with open(file_path, 'r', errors='ignore') as f:30# lines = f.readlines()31# for line in lines:32i = 033try:34for line in f:35parts = line.split(' +++$+++ ')36if len(parts) == 5:37if parts[4][-1] == '\n':38parts[4] = parts[4][:-1]39id2line[parts[0]] = parts[4]40i += 141except UnicodeDecodeError:42print(i, line)43return id2line4445def get_convos():46""" Get conversations from the raw data """47file_path = os.path.join(config.DATA_PATH, config.CONVO_FILE)48convos = []49with open(file_path, 'r') as f:50for line in f.readlines():51parts = line.split(' +++$+++ ')52if len(parts) == 4:53convo = []54for line in parts[3][1:-2].split(', '):55convo.append(line[1:-1])56convos.append(convo)5758return convos5960def question_answers(id2line, convos):61""" Divide the dataset into two sets: questions and answers. """62questions, answers = [], []63for convo in convos:64for index, line in enumerate(convo[:-1]):65questions.append(id2line[convo[index]])66answers.append(id2line[convo[index + 1]])67assert len(questions) == len(answers)68return questions, answers6970def prepare_dataset(questions, answers):71# create path to store all the train & test encoder & decoder72make_dir(config.PROCESSED_PATH)7374# random convos to create the test set75test_ids = random.sample([i for i in range(len(questions))],config.TESTSET_SIZE)7677filenames = ['train.enc', 'train.dec', 'test.enc', 'test.dec']78files = []79for filename in filenames:80files.append(open(os.path.join(config.PROCESSED_PATH, filename),'w'))8182for i in range(len(questions)):83if i in test_ids:84files[2].write(questions[i] + '\n')85files[3].write(answers[i] + '\n')86else:87files[0].write(questions[i] + '\n')88files[1].write(answers[i] + '\n')8990for file in files:91file.close()9293def make_dir(path):94""" Create a directory if there isn't one already. """95try:96os.mkdir(path)97except OSError:98pass99100def basic_tokenizer(line, normalize_digits=True):101""" A basic tokenizer to tokenize text into tokens.102Feel free to change this to suit your need. """103line = re.sub('<u>', '', line)104line = re.sub('</u>', '', line)105line = re.sub('\[', '', line)106line = re.sub('\]', '', line)107words = []108_WORD_SPLIT = re.compile("([.,!?\"'-<>:;)(])")109_DIGIT_RE = re.compile(r"\d")110for fragment in line.strip().lower().split():111for token in re.split(_WORD_SPLIT, fragment):112if not token:113continue114if normalize_digits:115token = re.sub(_DIGIT_RE, '#', token)116words.append(token)117return words118119def build_vocab(filename, normalize_digits=True):120in_path = os.path.join(config.PROCESSED_PATH, filename)121out_path = os.path.join(config.PROCESSED_PATH, 'vocab.{}'.format(filename[-3:]))122123vocab = {}124with open(in_path, 'r') as f:125for line in f.readlines():126for token in basic_tokenizer(line):127if not token in vocab:128vocab[token] = 0129vocab[token] += 1130131sorted_vocab = sorted(vocab, key=vocab.get, reverse=True)132with open(out_path, 'w') as f:133f.write('<pad>' + '\n')134f.write('<unk>' + '\n')135f.write('<s>' + '\n')136f.write('<\s>' + '\n')137index = 4138for word in sorted_vocab:139if vocab[word] < config.THRESHOLD:140break141f.write(word + '\n')142index += 1143with open('config.py', 'a') as cf:144if filename[-3:] == 'enc':145cf.write('ENC_VOCAB = ' + str(index) + '\n')146else:147cf.write('DEC_VOCAB = ' + str(index) + '\n')148149def load_vocab(vocab_path):150with open(vocab_path, 'r') as f:151words = f.read().splitlines()152return words, {words[i]: i for i in range(len(words))}153154def sentence2id(vocab, line):155return [vocab.get(token, vocab['<unk>']) for token in basic_tokenizer(line)]156157def token2id(data, mode):158""" Convert all the tokens in the data into their corresponding159index in the vocabulary. """160vocab_path = 'vocab.' + mode161in_path = data + '.' + mode162out_path = data + '_ids.' + mode163164_, vocab = load_vocab(os.path.join(config.PROCESSED_PATH, vocab_path))165in_file = open(os.path.join(config.PROCESSED_PATH, in_path), 'r')166out_file = open(os.path.join(config.PROCESSED_PATH, out_path), 'w')167168lines = in_file.read().splitlines()169for line in lines:170if mode == 'dec': # we only care about '<s>' and </s> in encoder171ids = [vocab['<s>']]172else:173ids = []174ids.extend(sentence2id(vocab, line))175# ids.extend([vocab.get(token, vocab['<unk>']) for token in basic_tokenizer(line)])176if mode == 'dec':177ids.append(vocab['<\s>'])178out_file.write(' '.join(str(id_) for id_ in ids) + '\n')179180def prepare_raw_data():181print('Preparing raw data into train set and test set ...')182id2line = get_lines()183convos = get_convos()184questions, answers = question_answers(id2line, convos)185prepare_dataset(questions, answers)186187def process_data():188print('Preparing data to be model-ready ...')189build_vocab('train.enc')190build_vocab('train.dec')191token2id('train', 'enc')192token2id('train', 'dec')193token2id('test', 'enc')194token2id('test', 'dec')195196def load_data(enc_filename, dec_filename, max_training_size=None):197encode_file = open(os.path.join(config.PROCESSED_PATH, enc_filename), 'r')198decode_file = open(os.path.join(config.PROCESSED_PATH, dec_filename), 'r')199encode, decode = encode_file.readline(), decode_file.readline()200data_buckets = [[] for _ in config.BUCKETS]201i = 0202while encode and decode:203if (i + 1) % 10000 == 0:204print("Bucketing conversation number", i)205encode_ids = [int(id_) for id_ in encode.split()]206decode_ids = [int(id_) for id_ in decode.split()]207for bucket_id, (encode_max_size, decode_max_size) in enumerate(config.BUCKETS):208if len(encode_ids) <= encode_max_size and len(decode_ids) <= decode_max_size:209data_buckets[bucket_id].append([encode_ids, decode_ids])210break211encode, decode = encode_file.readline(), decode_file.readline()212i += 1213return data_buckets214215def _pad_input(input_, size):216return input_ + [config.PAD_ID] * (size - len(input_))217218def _reshape_batch(inputs, size, batch_size):219""" Create batch-major inputs. Batch inputs are just re-indexed inputs220"""221batch_inputs = []222for length_id in range(size):223batch_inputs.append(np.array([inputs[batch_id][length_id]224for batch_id in range(batch_size)], dtype=np.int32))225return batch_inputs226227228def get_batch(data_bucket, bucket_id, batch_size=1):229""" Return one batch to feed into the model """230# only pad to the max length of the bucket231encoder_size, decoder_size = config.BUCKETS[bucket_id]232encoder_inputs, decoder_inputs = [], []233234for _ in range(batch_size):235encoder_input, decoder_input = random.choice(data_bucket)236# pad both encoder and decoder, reverse the encoder237encoder_inputs.append(list(reversed(_pad_input(encoder_input, encoder_size))))238decoder_inputs.append(_pad_input(decoder_input, decoder_size))239240# now we create batch-major vectors from the data selected above.241batch_encoder_inputs = _reshape_batch(encoder_inputs, encoder_size, batch_size)242batch_decoder_inputs = _reshape_batch(decoder_inputs, decoder_size, batch_size)243244# create decoder_masks to be 0 for decoders that are padding.245batch_masks = []246for length_id in range(decoder_size):247batch_mask = np.ones(batch_size, dtype=np.float32)248for batch_id in range(batch_size):249# we set mask to 0 if the corresponding target is a PAD symbol.250# the corresponding decoder is decoder_input shifted by 1 forward.251if length_id < decoder_size - 1:252target = decoder_inputs[batch_id][length_id + 1]253if length_id == decoder_size - 1 or target == config.PAD_ID:254batch_mask[batch_id] = 0.0255batch_masks.append(batch_mask)256return batch_encoder_inputs, batch_decoder_inputs, batch_masks257258if __name__ == '__main__':259prepare_raw_data()260process_data()261262