Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Download

📚 The CoCalc Library - books, templates and other resources

132928 views
License: OTHER
1
""" A clean, no_frills character-level generative language model.
2
Created by Danijar Hafner (danijar.com), edited by Chip Huyen
3
for the class CS 20SI: "TensorFlow for Deep Learning Research"
4
5
Based on Andrej Karpathy's blog:
6
http://karpathy.github.io/2015/05/21/rnn-effectiveness/
7
"""
8
import os
9
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
10
import sys
11
sys.path.append('..')
12
13
import time
14
15
import tensorflow as tf
16
17
import utils
18
19
DATA_PATH = 'data/arvix_abstracts.txt'
20
HIDDEN_SIZE = 200
21
BATCH_SIZE = 64
22
NUM_STEPS = 50
23
SKIP_STEP = 40
24
TEMPRATURE = 0.7
25
LR = 0.003
26
LEN_GENERATED = 300
27
28
def vocab_encode(text, vocab):
29
return [vocab.index(x) + 1 for x in text if x in vocab]
30
31
def vocab_decode(array, vocab):
32
return ''.join([vocab[x - 1] for x in array])
33
34
def read_data(filename, vocab, window=NUM_STEPS, overlap=NUM_STEPS//2):
35
for text in open(filename):
36
text = vocab_encode(text, vocab)
37
for start in range(0, len(text) - window, overlap):
38
chunk = text[start: start + window]
39
chunk += [0] * (window - len(chunk))
40
yield chunk
41
42
def read_batch(stream, batch_size=BATCH_SIZE):
43
batch = []
44
for element in stream:
45
batch.append(element)
46
if len(batch) == batch_size:
47
yield batch
48
batch = []
49
yield batch
50
51
def create_rnn(seq, hidden_size=HIDDEN_SIZE):
52
cell = tf.contrib.rnn.GRUCell(hidden_size)
53
in_state = tf.placeholder_with_default(
54
cell.zero_state(tf.shape(seq)[0], tf.float32), [None, hidden_size])
55
# this line to calculate the real length of seq
56
# all seq are padded to be of the same length which is NUM_STEPS
57
length = tf.reduce_sum(tf.reduce_max(tf.sign(seq), 2), 1)
58
output, out_state = tf.nn.dynamic_rnn(cell, seq, length, in_state)
59
return output, in_state, out_state
60
61
def create_model(seq, temp, vocab, hidden=HIDDEN_SIZE):
62
seq = tf.one_hot(seq, len(vocab))
63
output, in_state, out_state = create_rnn(seq, hidden)
64
# fully_connected is syntactic sugar for tf.matmul(w, output) + b
65
# it will create w and b for us
66
logits = tf.contrib.layers.fully_connected(output, len(vocab), None)
67
loss = tf.reduce_sum(tf.nn.softmax_cross_entropy_with_logits(logits=logits[:, :-1], labels=seq[:, 1:]))
68
# sample the next character from Maxwell-Boltzmann Distribution with temperature temp
69
# it works equally well without tf.exp
70
sample = tf.multinomial(tf.exp(logits[:, -1] / temp), 1)[:, 0]
71
return loss, sample, in_state, out_state
72
73
def training(vocab, seq, loss, optimizer, global_step, temp, sample, in_state, out_state):
74
saver = tf.train.Saver()
75
start = time.time()
76
with tf.Session() as sess:
77
writer = tf.summary.FileWriter('graphs/gist', sess.graph)
78
sess.run(tf.global_variables_initializer())
79
80
ckpt = tf.train.get_checkpoint_state(os.path.dirname('checkpoints/arvix/checkpoint'))
81
if ckpt and ckpt.model_checkpoint_path:
82
saver.restore(sess, ckpt.model_checkpoint_path)
83
84
iteration = global_step.eval()
85
for batch in read_batch(read_data(DATA_PATH, vocab)):
86
batch_loss, _ = sess.run([loss, optimizer], {seq: batch})
87
if (iteration + 1) % SKIP_STEP == 0:
88
print('Iter {}. \n Loss {}. Time {}'.format(iteration, batch_loss, time.time() - start))
89
online_inference(sess, vocab, seq, sample, temp, in_state, out_state)
90
start = time.time()
91
saver.save(sess, 'checkpoints/arvix/char-rnn', iteration)
92
iteration += 1
93
94
def online_inference(sess, vocab, seq, sample, temp, in_state, out_state, seed='T'):
95
""" Generate sequence one character at a time, based on the previous character
96
"""
97
sentence = seed
98
state = None
99
for _ in range(LEN_GENERATED):
100
batch = [vocab_encode(sentence[-1], vocab)]
101
feed = {seq: batch, temp: TEMPRATURE}
102
# for the first decoder step, the state is None
103
if state is not None:
104
feed.update({in_state: state})
105
index, state = sess.run([sample, out_state], feed)
106
sentence += vocab_decode(index, vocab)
107
print(sentence)
108
109
def main():
110
vocab = (
111
" $%'()+,-./0123456789:;=?ABCDEFGHIJKLMNOPQRSTUVWXYZ"
112
"\\^_abcdefghijklmnopqrstuvwxyz{|}")
113
seq = tf.placeholder(tf.int32, [None, None])
114
temp = tf.placeholder(tf.float32)
115
loss, sample, in_state, out_state = create_model(seq, temp, vocab)
116
global_step = tf.Variable(0, dtype=tf.int32, trainable=False, name='global_step')
117
optimizer = tf.train.AdamOptimizer(LR).minimize(loss, global_step=global_step)
118
utils.make_dir('checkpoints')
119
utils.make_dir('checkpoints/arvix')
120
training(vocab, seq, loss, optimizer, global_step, temp, sample, in_state, out_state)
121
122
if __name__ == '__main__':
123
main()
124