📚 The CoCalc Library - books, templates and other resources
License: OTHER
""" A clean, no_frills character-level generative language model.12CS 20: "TensorFlow for Deep Learning Research"3cs20.stanford.edu4Danijar Hafner ([email protected])5& Chip Huyen ([email protected])6Lecture 117"""8import os9os.environ['TF_CPP_MIN_LOG_LEVEL']='2'10import random11import sys12sys.path.append('..')13import time1415import tensorflow as tf1617import utils1819def vocab_encode(text, vocab):20return [vocab.index(x) + 1 for x in text if x in vocab]2122def vocab_decode(array, vocab):23return ''.join([vocab[x - 1] for x in array])2425def read_data(filename, vocab, window, overlap):26lines = [line.strip() for line in open(filename, 'r').readlines()]27while True:28random.shuffle(lines)2930for text in lines:31text = vocab_encode(text, vocab)32for start in range(0, len(text) - window, overlap):33chunk = text[start: start + window]34chunk += [0] * (window - len(chunk))35yield chunk3637def read_batch(stream, batch_size):38batch = []39for element in stream:40batch.append(element)41if len(batch) == batch_size:42yield batch43batch = []44yield batch4546class CharRNN(object):47def __init__(self, model):48self.model = model49self.path = 'data/' + model + '.txt'50if 'trump' in model:51self.vocab = ("$%'()+,-./0123456789:;=?ABCDEFGHIJKLMNOPQRSTUVWXYZ"52" '\"_abcdefghijklmnopqrstuvwxyz{|}@#➡📈")53else:54self.vocab = (" $%'()+,-./0123456789:;=?ABCDEFGHIJKLMNOPQRSTUVWXYZ"55"\\^_abcdefghijklmnopqrstuvwxyz{|}")5657self.seq = tf.placeholder(tf.int32, [None, None])58self.temp = tf.constant(1.5)59self.hidden_sizes = [128, 256]60self.batch_size = 6461self.lr = 0.000362self.skip_step = 163self.num_steps = 50 # for RNN unrolled64self.len_generated = 20065self.gstep = tf.Variable(0, dtype=tf.int32, trainable=False, name='global_step')6667def create_rnn(self, seq):68layers = [tf.nn.rnn_cell.GRUCell(size) for size in self.hidden_sizes]69cells = tf.nn.rnn_cell.MultiRNNCell(layers)70batch = tf.shape(seq)[0]71zero_states = cells.zero_state(batch, dtype=tf.float32)72self.in_state = tuple([tf.placeholder_with_default(state, [None, state.shape[1]])73for state in zero_states])74# this line to calculate the real length of seq75# all seq are padded to be of the same length, which is num_steps76length = tf.reduce_sum(tf.reduce_max(tf.sign(seq), 2), 1)77self.output, self.out_state = tf.nn.dynamic_rnn(cells, seq, length, self.in_state)7879def create_model(self):80seq = tf.one_hot(self.seq, len(self.vocab))81self.create_rnn(seq)82self.logits = tf.layers.dense(self.output, len(self.vocab), None)83loss = tf.nn.softmax_cross_entropy_with_logits(logits=self.logits[:, :-1],84labels=seq[:, 1:])85self.loss = tf.reduce_sum(loss)86# sample the next character from Maxwell-Boltzmann Distribution87# with temperature temp. It works equally well without tf.exp88self.sample = tf.multinomial(tf.exp(self.logits[:, -1] / self.temp), 1)[:, 0]89self.opt = tf.train.AdamOptimizer(self.lr).minimize(self.loss, global_step=self.gstep)9091def train(self):92saver = tf.train.Saver()93start = time.time()94min_loss = None95with tf.Session() as sess:96writer = tf.summary.FileWriter('graphs/gist', sess.graph)97sess.run(tf.global_variables_initializer())9899ckpt = tf.train.get_checkpoint_state(os.path.dirname('checkpoints/' + self.model + '/checkpoint'))100if ckpt and ckpt.model_checkpoint_path:101saver.restore(sess, ckpt.model_checkpoint_path)102103iteration = self.gstep.eval()104stream = read_data(self.path, self.vocab, self.num_steps, overlap=self.num_steps//2)105data = read_batch(stream, self.batch_size)106while True:107batch = next(data)108109# for batch in read_batch(read_data(DATA_PATH, vocab)):110batch_loss, _ = sess.run([self.loss, self.opt], {self.seq: batch})111if (iteration + 1) % self.skip_step == 0:112print('Iter {}. \n Loss {}. Time {}'.format(iteration, batch_loss, time.time() - start))113self.online_infer(sess)114start = time.time()115checkpoint_name = 'checkpoints/' + self.model + '/char-rnn'116if min_loss is None:117saver.save(sess, checkpoint_name, iteration)118elif batch_loss < min_loss:119saver.save(sess, checkpoint_name, iteration)120min_loss = batch_loss121iteration += 1122123def online_infer(self, sess):124""" Generate sequence one character at a time, based on the previous character125"""126for seed in ['Hillary', 'I', 'R', 'T', '@', 'N', 'M', '.', 'G', 'A', 'W']:127sentence = seed128state = None129for _ in range(self.len_generated):130batch = [vocab_encode(sentence[-1], self.vocab)]131feed = {self.seq: batch}132if state is not None: # for the first decoder step, the state is None133for i in range(len(state)):134feed.update({self.in_state[i]: state[i]})135index, state = sess.run([self.sample, self.out_state], feed)136sentence += vocab_decode(index, self.vocab)137print('\t' + sentence)138139def main():140model = 'trump_tweets'141utils.safe_mkdir('checkpoints')142utils.safe_mkdir('checkpoints/' + model)143144lm = CharRNN(model)145lm.create_model()146lm.train()147148if __name__ == '__main__':149main()150151