📚 The CoCalc Library - books, templates and other resources
License: OTHER
from gensim.models import word2vec1from os.path import join, exists, split2import os3import numpy as np45def train_word2vec(sentence_matrix, vocabulary_inv,6num_features=300, min_word_count=1, context=10):7"""8Trains, saves, loads Word2Vec model9Returns initial weights for embedding layer.1011inputs:12sentence_matrix # int matrix: num_sentences x max_sentence_len13vocabulary_inv # dict {str:int}14num_features # Word vector dimensionality15min_word_count # Minimum word count16context # Context window size17"""18model_dir = 'word2vec_models'19model_name = "{:d}features_{:d}minwords_{:d}context".format(num_features, min_word_count, context)20model_name = join(model_dir, model_name)21if exists(model_name):22embedding_model = word2vec.Word2Vec.load(model_name)23print('Loading existing Word2Vec model \'%s\'' % split(model_name)[-1])24else:25# Set values for various parameters26num_workers = 2 # Number of threads to run in parallel27downsampling = 1e-3 # Downsample setting for frequent words2829# Initialize and train the model30print("Training Word2Vec model...")31sentences = [[vocabulary_inv[w] for w in s] for s in sentence_matrix]32embedding_model = word2vec.Word2Vec(sentences, workers=num_workers, \33size=num_features, min_count = min_word_count, \34window = context, sample = downsampling)3536# If we don't plan to train the model any further, calling37# init_sims will make the model much more memory-efficient.38embedding_model.init_sims(replace=True)3940# Saving the model for later use. You can load it later using Word2Vec.load()41if not exists(model_dir):42os.mkdir(model_dir)43print('Saving Word2Vec model \'%s\'' % split(model_name)[-1])44embedding_model.save(model_name)4546# add unknown words47embedding_weights = [np.array([embedding_model[w] if w in embedding_model\48else np.random.uniform(-0.25,0.25,embedding_model.vector_size)\49for w in vocabulary_inv])]50return embedding_weights5152if __name__=='__main__':53import data_helpers54print("Loading data...")55x, _, _, vocabulary_inv = data_helpers.load_data()56w = train_word2vec(x, vocabulary_inv)575859