📚 The CoCalc Library - books, templates and other resources
License: OTHER
""" word2vec skip-gram model with NCE loss and1code to visualize the embeddings on TensorBoard2CS 20: "TensorFlow for Deep Learning Research"3cs20.stanford.edu4Chip Huyen ([email protected])5Lecture 046"""78import os9os.environ['TF_CPP_MIN_LOG_LEVEL']='2'1011import numpy as np12from tensorflow.contrib.tensorboard.plugins import projector13import tensorflow as tf1415import utils16import word2vec_utils1718# Model hyperparameters19VOCAB_SIZE = 5000020BATCH_SIZE = 12821EMBED_SIZE = 128 # dimension of the word embedding vectors22SKIP_WINDOW = 1 # the context window23NUM_SAMPLED = 64 # number of negative examples to sample24LEARNING_RATE = 1.025NUM_TRAIN_STEPS = 10000026VISUAL_FLD = 'visualization'27SKIP_STEP = 50002829# Parameters for downloading data30DOWNLOAD_URL = 'http://mattmahoney.net/dc/text8.zip'31EXPECTED_BYTES = 3134401632NUM_VISUALIZE = 3000 # number of tokens to visualize3334class SkipGramModel:35""" Build the graph for word2vec model """36def __init__(self, dataset, vocab_size, embed_size, batch_size, num_sampled, learning_rate):37self.vocab_size = vocab_size38self.embed_size = embed_size39self.batch_size = batch_size40self.num_sampled = num_sampled41self.lr = learning_rate42self.global_step = tf.get_variable('global_step', initializer=tf.constant(0), trainable=False)43self.skip_step = SKIP_STEP44self.dataset = dataset4546def _import_data(self):47""" Step 1: import data48"""49with tf.name_scope('data'):50self.iterator = self.dataset.make_initializable_iterator()51self.center_words, self.target_words = self.iterator.get_next()5253def _create_embedding(self):54""" Step 2 + 3: define weights and embedding lookup.55In word2vec, it's actually the weights that we care about56"""57with tf.name_scope('embed'):58self.embed_matrix = tf.get_variable('embed_matrix',59shape=[self.vocab_size, self.embed_size],60initializer=tf.random_uniform_initializer())61self.embed = tf.nn.embedding_lookup(self.embed_matrix, self.center_words, name='embedding')6263def _create_loss(self):64""" Step 4: define the loss function """65with tf.name_scope('loss'):66# construct variables for NCE loss67nce_weight = tf.get_variable('nce_weight',68shape=[self.vocab_size, self.embed_size],69initializer=tf.truncated_normal_initializer(stddev=1.0 / (self.embed_size ** 0.5)))70nce_bias = tf.get_variable('nce_bias', initializer=tf.zeros([VOCAB_SIZE]))7172# define loss function to be NCE loss function73self.loss = tf.reduce_mean(tf.nn.nce_loss(weights=nce_weight,74biases=nce_bias,75labels=self.target_words,76inputs=self.embed,77num_sampled=self.num_sampled,78num_classes=self.vocab_size), name='loss')79def _create_optimizer(self):80""" Step 5: define optimizer """81self.optimizer = tf.train.GradientDescentOptimizer(self.lr).minimize(self.loss,82global_step=self.global_step)8384def _create_summaries(self):85with tf.name_scope('summaries'):86tf.summary.scalar('loss', self.loss)87tf.summary.histogram('histogram loss', self.loss)88# because you have several summaries, we should merge them all89# into one op to make it easier to manage90self.summary_op = tf.summary.merge_all()9192def build_graph(self):93""" Build the graph for our model """94self._import_data()95self._create_embedding()96self._create_loss()97self._create_optimizer()98self._create_summaries()99100def train(self, num_train_steps):101saver = tf.train.Saver() # defaults to saving all variables - in this case embed_matrix, nce_weight, nce_bias102103initial_step = 0104utils.safe_mkdir('checkpoints')105with tf.Session() as sess:106sess.run(self.iterator.initializer)107sess.run(tf.global_variables_initializer())108ckpt = tf.train.get_checkpoint_state(os.path.dirname('checkpoints/checkpoint'))109110# if that checkpoint exists, restore from checkpoint111if ckpt and ckpt.model_checkpoint_path:112saver.restore(sess, ckpt.model_checkpoint_path)113114total_loss = 0.0 # we use this to calculate late average loss in the last SKIP_STEP steps115writer = tf.summary.FileWriter('graphs/word2vec/lr' + str(self.lr), sess.graph)116initial_step = self.global_step.eval()117118for index in range(initial_step, initial_step + num_train_steps):119try:120loss_batch, _, summary = sess.run([self.loss, self.optimizer, self.summary_op])121writer.add_summary(summary, global_step=index)122total_loss += loss_batch123if (index + 1) % self.skip_step == 0:124print('Average loss at step {}: {:5.1f}'.format(index, total_loss / self.skip_step))125total_loss = 0.0126saver.save(sess, 'checkpoints/skip-gram', index)127except tf.errors.OutOfRangeError:128sess.run(self.iterator.initializer)129writer.close()130131def visualize(self, visual_fld, num_visualize):132""" run "'tensorboard --logdir='visualization'" to see the embeddings """133134# create the list of num_variable most common words to visualize135word2vec_utils.most_common_words(visual_fld, num_visualize)136137saver = tf.train.Saver()138with tf.Session() as sess:139sess.run(tf.global_variables_initializer())140ckpt = tf.train.get_checkpoint_state(os.path.dirname('checkpoints/checkpoint'))141142# if that checkpoint exists, restore from checkpoint143if ckpt and ckpt.model_checkpoint_path:144saver.restore(sess, ckpt.model_checkpoint_path)145146final_embed_matrix = sess.run(self.embed_matrix)147148# you have to store embeddings in a new variable149embedding_var = tf.Variable(final_embed_matrix[:num_visualize], name='embedding')150sess.run(embedding_var.initializer)151152config = projector.ProjectorConfig()153summary_writer = tf.summary.FileWriter(visual_fld)154155# add embedding to the config file156embedding = config.embeddings.add()157embedding.tensor_name = embedding_var.name158159# link this tensor to its metadata file, in this case the first NUM_VISUALIZE words of vocab160embedding.metadata_path = 'vocab_' + str(num_visualize) + '.tsv'161162# saves a configuration file that TensorBoard will read during startup.163projector.visualize_embeddings(summary_writer, config)164saver_embed = tf.train.Saver([embedding_var])165saver_embed.save(sess, os.path.join(visual_fld, 'model.ckpt'), 1)166167def gen():168yield from word2vec_utils.batch_gen(DOWNLOAD_URL, EXPECTED_BYTES, VOCAB_SIZE,169BATCH_SIZE, SKIP_WINDOW, VISUAL_FLD)170171def main():172dataset = tf.data.Dataset.from_generator(gen,173(tf.int32, tf.int32),174(tf.TensorShape([BATCH_SIZE]), tf.TensorShape([BATCH_SIZE, 1])))175model = SkipGramModel(dataset, VOCAB_SIZE, EMBED_SIZE, BATCH_SIZE, NUM_SAMPLED, LEARNING_RATE)176model.build_graph()177model.train(NUM_TRAIN_STEPS)178model.visualize(VISUAL_FLD, NUM_VISUALIZE)179180if __name__ == '__main__':181main()182183