📚 The CoCalc Library - books, templates and other resources
License: OTHER
""" A neural chatbot using sequence to sequence model with1attentional decoder.23This is based on Google Translate Tensorflow model4https://github.com/tensorflow/models/blob/master/tutorials/rnn/translate/56Sequence to sequence model by Cho et al.(2014)78Created by Chip Huyen as the starter code for assignment 3,9class CS 20SI: "TensorFlow for Deep Learning Research"10cs20si.stanford.edu1112This file contains the code to do the pre-processing for the13Cornell Movie-Dialogs Corpus.1415See readme.md for instruction on how to run the starter code.16"""17from __future__ import print_function1819import os20import random21import re2223import numpy as np2425import config2627def get_lines():28id2line = {}29file_path = os.path.join(config.DATA_PATH, config.LINE_FILE)30with open(file_path, 'rb') as f:31lines = f.readlines()32for line in lines:33parts = line.split(' +++$+++ ')34if len(parts) == 5:35if parts[4][-1] == '\n':36parts[4] = parts[4][:-1]37id2line[parts[0]] = parts[4]38return id2line3940def get_convos():41""" Get conversations from the raw data """42file_path = os.path.join(config.DATA_PATH, config.CONVO_FILE)43convos = []44with open(file_path, 'rb') as f:45for line in f.readlines():46parts = line.split(' +++$+++ ')47if len(parts) == 4:48convo = []49for line in parts[3][1:-2].split(', '):50convo.append(line[1:-1])51convos.append(convo)5253return convos5455def question_answers(id2line, convos):56""" Divide the dataset into two sets: questions and answers. """57questions, answers = [], []58for convo in convos:59for index, line in enumerate(convo[:-1]):60questions.append(id2line[convo[index]])61answers.append(id2line[convo[index + 1]])62assert len(questions) == len(answers)63return questions, answers6465def prepare_dataset(questions, answers):66# create path to store all the train & test encoder & decoder67make_dir(config.PROCESSED_PATH)6869# random convos to create the test set70test_ids = random.sample([i for i in range(len(questions))],config.TESTSET_SIZE)7172filenames = ['train.enc', 'train.dec', 'test.enc', 'test.dec']73files = []74for filename in filenames:75files.append(open(os.path.join(config.PROCESSED_PATH, filename),'wb'))7677for i in range(len(questions)):78if i in test_ids:79files[2].write(questions[i] + '\n')80files[3].write(answers[i] + '\n')81else:82files[0].write(questions[i] + '\n')83files[1].write(answers[i] + '\n')8485for file in files:86file.close()8788def make_dir(path):89""" Create a directory if there isn't one already. """90try:91os.mkdir(path)92except OSError:93pass9495def basic_tokenizer(line, normalize_digits=True):96""" A basic tokenizer to tokenize text into tokens.97Feel free to change this to suit your need. """98line = re.sub('<u>', '', line)99line = re.sub('</u>', '', line)100line = re.sub('\[', '', line)101line = re.sub('\]', '', line)102words = []103_WORD_SPLIT = re.compile(b"([.,!?\"'-<>:;)(])")104_DIGIT_RE = re.compile(r"\d")105for fragment in line.strip().lower().split():106for token in re.split(_WORD_SPLIT, fragment):107if not token:108continue109if normalize_digits:110token = re.sub(_DIGIT_RE, b'#', token)111words.append(token)112return words113114def build_vocab(filename, normalize_digits=True):115in_path = os.path.join(config.PROCESSED_PATH, filename)116out_path = os.path.join(config.PROCESSED_PATH, 'vocab.{}'.format(filename[-3:]))117118vocab = {}119with open(in_path, 'rb') as f:120for line in f.readlines():121for token in basic_tokenizer(line):122if not token in vocab:123vocab[token] = 0124vocab[token] += 1125126sorted_vocab = sorted(vocab, key=vocab.get, reverse=True)127with open(out_path, 'wb') as f:128f.write('<pad>' + '\n')129f.write('<unk>' + '\n')130f.write('<s>' + '\n')131f.write('<\s>' + '\n')132index = 4133for word in sorted_vocab:134if vocab[word] < config.THRESHOLD:135with open('config.py', 'ab') as cf:136if filename[-3:] == 'enc':137cf.write('ENC_VOCAB = ' + str(index) + '\n')138else:139cf.write('DEC_VOCAB = ' + str(index) + '\n')140break141f.write(word + '\n')142index += 1143144def load_vocab(vocab_path):145with open(vocab_path, 'rb') as f:146words = f.read().splitlines()147return words, {words[i]: i for i in range(len(words))}148149def sentence2id(vocab, line):150return [vocab.get(token, vocab['<unk>']) for token in basic_tokenizer(line)]151152def token2id(data, mode):153""" Convert all the tokens in the data into their corresponding154index in the vocabulary. """155vocab_path = 'vocab.' + mode156in_path = data + '.' + mode157out_path = data + '_ids.' + mode158159_, vocab = load_vocab(os.path.join(config.PROCESSED_PATH, vocab_path))160in_file = open(os.path.join(config.PROCESSED_PATH, in_path), 'rb')161out_file = open(os.path.join(config.PROCESSED_PATH, out_path), 'wb')162163lines = in_file.read().splitlines()164for line in lines:165if mode == 'dec': # we only care about '<s>' and </s> in encoder166ids = [vocab['<s>']]167else:168ids = []169ids.extend(sentence2id(vocab, line))170# ids.extend([vocab.get(token, vocab['<unk>']) for token in basic_tokenizer(line)])171if mode == 'dec':172ids.append(vocab['<\s>'])173out_file.write(' '.join(str(id_) for id_ in ids) + '\n')174175def prepare_raw_data():176print('Preparing raw data into train set and test set ...')177id2line = get_lines()178convos = get_convos()179questions, answers = question_answers(id2line, convos)180prepare_dataset(questions, answers)181182def process_data():183print('Preparing data to be model-ready ...')184build_vocab('train.enc')185build_vocab('train.dec')186token2id('train', 'enc')187token2id('train', 'dec')188token2id('test', 'enc')189token2id('test', 'dec')190191def load_data(enc_filename, dec_filename, max_training_size=None):192encode_file = open(os.path.join(config.PROCESSED_PATH, enc_filename), 'rb')193decode_file = open(os.path.join(config.PROCESSED_PATH, dec_filename), 'rb')194encode, decode = encode_file.readline(), decode_file.readline()195data_buckets = [[] for _ in config.BUCKETS]196i = 0197while encode and decode:198if (i + 1) % 10000 == 0:199print("Bucketing conversation number", i)200encode_ids = [int(id_) for id_ in encode.split()]201decode_ids = [int(id_) for id_ in decode.split()]202for bucket_id, (encode_max_size, decode_max_size) in enumerate(config.BUCKETS):203if len(encode_ids) <= encode_max_size and len(decode_ids) <= decode_max_size:204data_buckets[bucket_id].append([encode_ids, decode_ids])205break206encode, decode = encode_file.readline(), decode_file.readline()207i += 1208return data_buckets209210def _pad_input(input_, size):211return input_ + [config.PAD_ID] * (size - len(input_))212213def _reshape_batch(inputs, size, batch_size):214""" Create batch-major inputs. Batch inputs are just re-indexed inputs215"""216batch_inputs = []217for length_id in range(size):218batch_inputs.append(np.array([inputs[batch_id][length_id]219for batch_id in range(batch_size)], dtype=np.int32))220return batch_inputs221222223def get_batch(data_bucket, bucket_id, batch_size=1):224""" Return one batch to feed into the model """225# only pad to the max length of the bucket226encoder_size, decoder_size = config.BUCKETS[bucket_id]227encoder_inputs, decoder_inputs = [], []228229for _ in range(batch_size):230encoder_input, decoder_input = random.choice(data_bucket)231# pad both encoder and decoder, reverse the encoder232encoder_inputs.append(list(reversed(_pad_input(encoder_input, encoder_size))))233decoder_inputs.append(_pad_input(decoder_input, decoder_size))234235# now we create batch-major vectors from the data selected above.236batch_encoder_inputs = _reshape_batch(encoder_inputs, encoder_size, batch_size)237batch_decoder_inputs = _reshape_batch(decoder_inputs, decoder_size, batch_size)238239# create decoder_masks to be 0 for decoders that are padding.240batch_masks = []241for length_id in range(decoder_size):242batch_mask = np.ones(batch_size, dtype=np.float32)243for batch_id in range(batch_size):244# we set mask to 0 if the corresponding target is a PAD symbol.245# the corresponding decoder is decoder_input shifted by 1 forward.246if length_id < decoder_size - 1:247target = decoder_inputs[batch_id][length_id + 1]248if length_id == decoder_size - 1 or target == config.PAD_ID:249batch_mask[batch_id] = 0.0250batch_masks.append(batch_mask)251return batch_encoder_inputs, batch_decoder_inputs, batch_masks252253if __name__ == '__main__':254prepare_raw_data()255process_data()256257