Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Download

📚 The CoCalc Library - books, templates and other resources

132928 views
License: OTHER
1
""" word2vec with NCE loss and code to visualize the embeddings on TensorBoard
2
Author: Chip Huyen
3
Prepared for the class CS 20SI: "TensorFlow for Deep Learning Research"
4
cs20si.stanford.edu
5
"""
6
7
from __future__ import absolute_import
8
from __future__ import division
9
from __future__ import print_function
10
11
import os
12
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
13
14
import numpy as np
15
from tensorflow.contrib.tensorboard.plugins import projector
16
import tensorflow as tf
17
18
from process_data import process_data
19
import utils
20
21
VOCAB_SIZE = 50000
22
BATCH_SIZE = 128
23
EMBED_SIZE = 128 # dimension of the word embedding vectors
24
SKIP_WINDOW = 1 # the context window
25
NUM_SAMPLED = 64 # Number of negative examples to sample.
26
LEARNING_RATE = 1.0
27
NUM_TRAIN_STEPS = 100000
28
WEIGHTS_FLD = 'processed/'
29
SKIP_STEP = 2000
30
31
class SkipGramModel:
32
""" Build the graph for word2vec model """
33
def __init__(self, vocab_size, embed_size, batch_size, num_sampled, learning_rate):
34
self.vocab_size = vocab_size
35
self.embed_size = embed_size
36
self.batch_size = batch_size
37
self.num_sampled = num_sampled
38
self.lr = learning_rate
39
self.global_step = tf.Variable(0, dtype=tf.int32, trainable=False, name='global_step')
40
41
def _create_placeholders(self):
42
""" Step 1: define the placeholders for input and output """
43
with tf.name_scope("data"):
44
self.center_words = tf.placeholder(tf.int32, shape=[self.batch_size], name='center_words')
45
self.target_words = tf.placeholder(tf.int32, shape=[self.batch_size, 1], name='target_words')
46
47
def _create_embedding(self):
48
""" Step 2: define weights. In word2vec, it's actually the weights that we care about """
49
# Assemble this part of the graph on the CPU. You can change it to GPU if you have GPU
50
with tf.device('/cpu:0'):
51
with tf.name_scope("embed"):
52
self.embed_matrix = tf.Variable(tf.random_uniform([self.vocab_size,
53
self.embed_size], -1.0, 1.0),
54
name='embed_matrix')
55
56
def _create_loss(self):
57
""" Step 3 + 4: define the model + the loss function """
58
with tf.device('/cpu:0'):
59
with tf.name_scope("loss"):
60
# Step 3: define the inference
61
embed = tf.nn.embedding_lookup(self.embed_matrix, self.center_words, name='embed')
62
63
# Step 4: define loss function
64
# construct variables for NCE loss
65
nce_weight = tf.Variable(tf.truncated_normal([self.vocab_size, self.embed_size],
66
stddev=1.0 / (self.embed_size ** 0.5)),
67
name='nce_weight')
68
nce_bias = tf.Variable(tf.zeros([VOCAB_SIZE]), name='nce_bias')
69
70
# define loss function to be NCE loss function
71
self.loss = tf.reduce_mean(tf.nn.nce_loss(weights=nce_weight,
72
biases=nce_bias,
73
labels=self.target_words,
74
inputs=embed,
75
num_sampled=self.num_sampled,
76
num_classes=self.vocab_size), name='loss')
77
def _create_optimizer(self):
78
""" Step 5: define optimizer """
79
with tf.device('/cpu:0'):
80
self.optimizer = tf.train.GradientDescentOptimizer(self.lr).minimize(self.loss,
81
global_step=self.global_step)
82
83
def _create_summaries(self):
84
with tf.name_scope("summaries"):
85
tf.summary.scalar("loss", self.loss)
86
tf.summary.histogram("histogram loss", self.loss)
87
# because you have several summaries, we should merge them all
88
# into one op to make it easier to manage
89
self.summary_op = tf.summary.merge_all()
90
91
def build_graph(self):
92
""" Build the graph for our model """
93
self._create_placeholders()
94
self._create_embedding()
95
self._create_loss()
96
self._create_optimizer()
97
self._create_summaries()
98
99
def train_model(model, batch_gen, num_train_steps, weights_fld):
100
saver = tf.train.Saver() # defaults to saving all variables - in this case embed_matrix, nce_weight, nce_bias
101
102
initial_step = 0
103
utils.make_dir('checkpoints')
104
with tf.Session() as sess:
105
sess.run(tf.global_variables_initializer())
106
ckpt = tf.train.get_checkpoint_state(os.path.dirname('checkpoints/checkpoint'))
107
# if that checkpoint exists, restore from checkpoint
108
if ckpt and ckpt.model_checkpoint_path:
109
saver.restore(sess, ckpt.model_checkpoint_path)
110
111
total_loss = 0.0 # we use this to calculate late average loss in the last SKIP_STEP steps
112
writer = tf.summary.FileWriter('improved_graph/lr' + str(LEARNING_RATE), sess.graph)
113
initial_step = model.global_step.eval()
114
for index in range(initial_step, initial_step + num_train_steps):
115
centers, targets = next(batch_gen)
116
feed_dict={model.center_words: centers, model.target_words: targets}
117
loss_batch, _, summary = sess.run([model.loss, model.optimizer, model.summary_op],
118
feed_dict=feed_dict)
119
writer.add_summary(summary, global_step=index)
120
total_loss += loss_batch
121
if (index + 1) % SKIP_STEP == 0:
122
print('Average loss at step {}: {:5.1f}'.format(index, total_loss / SKIP_STEP))
123
total_loss = 0.0
124
saver.save(sess, 'checkpoints/skip-gram', index)
125
126
####################
127
# code to visualize the embeddings. uncomment the below to visualize embeddings
128
# run "'tensorboard --logdir='processed'" to see the embeddings
129
# final_embed_matrix = sess.run(model.embed_matrix)
130
131
# # it has to variable. constants don't work here. you can't reuse model.embed_matrix
132
# embedding_var = tf.Variable(final_embed_matrix[:1000], name='embedding')
133
# sess.run(embedding_var.initializer)
134
135
# config = projector.ProjectorConfig()
136
# summary_writer = tf.summary.FileWriter('processed')
137
138
# # add embedding to the config file
139
# embedding = config.embeddings.add()
140
# embedding.tensor_name = embedding_var.name
141
142
# # link this tensor to its metadata file, in this case the first 500 words of vocab
143
# embedding.metadata_path = 'processed/vocab_1000.tsv'
144
145
# # saves a configuration file that TensorBoard will read during startup.
146
# projector.visualize_embeddings(summary_writer, config)
147
# saver_embed = tf.train.Saver([embedding_var])
148
# saver_embed.save(sess, 'processed/model3.ckpt', 1)
149
150
def main():
151
model = SkipGramModel(VOCAB_SIZE, EMBED_SIZE, BATCH_SIZE, NUM_SAMPLED, LEARNING_RATE)
152
model.build_graph()
153
batch_gen = process_data(VOCAB_SIZE, BATCH_SIZE, SKIP_WINDOW)
154
train_model(model, batch_gen, NUM_TRAIN_STEPS, WEIGHTS_FLD)
155
156
if __name__ == '__main__':
157
main()
158