📚 The CoCalc Library - books, templates and other resources
License: OTHER
from collections import Counter1import random2import os3import sys4sys.path.append('..')5import zipfile67import numpy as np8from six.moves import urllib9import tensorflow as tf1011import utils1213def read_data(file_path):14""" Read data into a list of tokens15There should be 17,005,207 tokens16"""17with zipfile.ZipFile(file_path) as f:18words = tf.compat.as_str(f.read(f.namelist()[0])).split()19return words2021def build_vocab(words, vocab_size, visual_fld):22""" Build vocabulary of VOCAB_SIZE most frequent words and write it to23visualization/vocab.tsv24"""25utils.safe_mkdir(visual_fld)26file = open(os.path.join(visual_fld, 'vocab.tsv'), 'w')2728dictionary = dict()29count = [('UNK', -1)]30index = 031count.extend(Counter(words).most_common(vocab_size - 1))3233for word, _ in count:34dictionary[word] = index35index += 136file.write(word + '\n')3738index_dictionary = dict(zip(dictionary.values(), dictionary.keys()))39file.close()40return dictionary, index_dictionary4142def convert_words_to_index(words, dictionary):43""" Replace each word in the dataset with its index in the dictionary """44return [dictionary[word] if word in dictionary else 0 for word in words]4546def generate_sample(index_words, context_window_size):47""" Form training pairs according to the skip-gram model. """48for index, center in enumerate(index_words):49context = random.randint(1, context_window_size)50# get a random target before the center word51for target in index_words[max(0, index - context): index]:52yield center, target53# get a random target after the center wrod54for target in index_words[index + 1: index + context + 1]:55yield center, target5657def most_common_words(visual_fld, num_visualize):58""" create a list of num_visualize most frequent words to visualize on TensorBoard.59saved to visualization/vocab_[num_visualize].tsv60"""61words = open(os.path.join(visual_fld, 'vocab.tsv'), 'r').readlines()[:num_visualize]62words = [word for word in words]63file = open(os.path.join(visual_fld, 'vocab_' + str(num_visualize) + '.tsv'), 'w')64for word in words:65file.write(word)66file.close()6768def batch_gen(download_url, expected_byte, vocab_size, batch_size,69skip_window, visual_fld):70local_dest = 'data/text8.zip'71utils.download_one_file(download_url, local_dest, expected_byte)72words = read_data(local_dest)73dictionary, _ = build_vocab(words, vocab_size, visual_fld)74index_words = convert_words_to_index(words, dictionary)75del words # to save memory76single_gen = generate_sample(index_words, skip_window)7778while True:79center_batch = np.zeros(batch_size, dtype=np.int32)80target_batch = np.zeros([batch_size, 1])81for index in range(batch_size):82center_batch[index], target_batch[index] = next(single_gen)83yield center_batch, target_batch848586