📚 The CoCalc Library - books, templates and other resources
License: OTHER
""" A clean, no_frills character-level generative language model.1Created by Danijar Hafner (danijar.com), edited by Chip Huyen2for the class CS 20SI: "TensorFlow for Deep Learning Research"34Based on Andrej Karpathy's blog:5http://karpathy.github.io/2015/05/21/rnn-effectiveness/6"""7import os8os.environ['TF_CPP_MIN_LOG_LEVEL']='2'9import sys10sys.path.append('..')1112import time1314import tensorflow as tf1516import utils1718DATA_PATH = 'data/arvix_abstracts.txt'19HIDDEN_SIZE = 20020BATCH_SIZE = 6421NUM_STEPS = 5022SKIP_STEP = 4023TEMPRATURE = 0.724LR = 0.00325LEN_GENERATED = 3002627def vocab_encode(text, vocab):28return [vocab.index(x) + 1 for x in text if x in vocab]2930def vocab_decode(array, vocab):31return ''.join([vocab[x - 1] for x in array])3233def read_data(filename, vocab, window=NUM_STEPS, overlap=NUM_STEPS//2):34for text in open(filename):35text = vocab_encode(text, vocab)36for start in range(0, len(text) - window, overlap):37chunk = text[start: start + window]38chunk += [0] * (window - len(chunk))39yield chunk4041def read_batch(stream, batch_size=BATCH_SIZE):42batch = []43for element in stream:44batch.append(element)45if len(batch) == batch_size:46yield batch47batch = []48yield batch4950def create_rnn(seq, hidden_size=HIDDEN_SIZE):51cell = tf.contrib.rnn.GRUCell(hidden_size)52in_state = tf.placeholder_with_default(53cell.zero_state(tf.shape(seq)[0], tf.float32), [None, hidden_size])54# this line to calculate the real length of seq55# all seq are padded to be of the same length which is NUM_STEPS56length = tf.reduce_sum(tf.reduce_max(tf.sign(seq), 2), 1)57output, out_state = tf.nn.dynamic_rnn(cell, seq, length, in_state)58return output, in_state, out_state5960def create_model(seq, temp, vocab, hidden=HIDDEN_SIZE):61seq = tf.one_hot(seq, len(vocab))62output, in_state, out_state = create_rnn(seq, hidden)63# fully_connected is syntactic sugar for tf.matmul(w, output) + b64# it will create w and b for us65logits = tf.contrib.layers.fully_connected(output, len(vocab), None)66loss = tf.reduce_sum(tf.nn.softmax_cross_entropy_with_logits(logits=logits[:, :-1], labels=seq[:, 1:]))67# sample the next character from Maxwell-Boltzmann Distribution with temperature temp68# it works equally well without tf.exp69sample = tf.multinomial(tf.exp(logits[:, -1] / temp), 1)[:, 0]70return loss, sample, in_state, out_state7172def training(vocab, seq, loss, optimizer, global_step, temp, sample, in_state, out_state):73saver = tf.train.Saver()74start = time.time()75with tf.Session() as sess:76writer = tf.summary.FileWriter('graphs/gist', sess.graph)77sess.run(tf.global_variables_initializer())7879ckpt = tf.train.get_checkpoint_state(os.path.dirname('checkpoints/arvix/checkpoint'))80if ckpt and ckpt.model_checkpoint_path:81saver.restore(sess, ckpt.model_checkpoint_path)8283iteration = global_step.eval()84for batch in read_batch(read_data(DATA_PATH, vocab)):85batch_loss, _ = sess.run([loss, optimizer], {seq: batch})86if (iteration + 1) % SKIP_STEP == 0:87print('Iter {}. \n Loss {}. Time {}'.format(iteration, batch_loss, time.time() - start))88online_inference(sess, vocab, seq, sample, temp, in_state, out_state)89start = time.time()90saver.save(sess, 'checkpoints/arvix/char-rnn', iteration)91iteration += 19293def online_inference(sess, vocab, seq, sample, temp, in_state, out_state, seed='T'):94""" Generate sequence one character at a time, based on the previous character95"""96sentence = seed97state = None98for _ in range(LEN_GENERATED):99batch = [vocab_encode(sentence[-1], vocab)]100feed = {seq: batch, temp: TEMPRATURE}101# for the first decoder step, the state is None102if state is not None:103feed.update({in_state: state})104index, state = sess.run([sample, out_state], feed)105sentence += vocab_decode(index, vocab)106print(sentence)107108def main():109vocab = (110" $%'()+,-./0123456789:;=?ABCDEFGHIJKLMNOPQRSTUVWXYZ"111"\\^_abcdefghijklmnopqrstuvwxyz{|}")112seq = tf.placeholder(tf.int32, [None, None])113temp = tf.placeholder(tf.float32)114loss, sample, in_state, out_state = create_model(seq, temp, vocab)115global_step = tf.Variable(0, dtype=tf.int32, trainable=False, name='global_step')116optimizer = tf.train.AdamOptimizer(LR).minimize(loss, global_step=global_step)117utils.make_dir('checkpoints')118utils.make_dir('checkpoints/arvix')119training(vocab, seq, loss, optimizer, global_step, temp, sample, in_state, out_state)120121if __name__ == '__main__':122main()123124