Path: blob/master/examples/graph/node2vec_movielens.py
3507 views
"""1Title: Graph representation learning with node2vec2Author: [Khalid Salama](https://www.linkedin.com/in/khalid-salama-24403144/)3Date created: 2021/05/154Last modified: 2021/05/155Description: Implementing the node2vec model to generate embeddings for movies from the MovieLens dataset.6Accelerator: GPU7"""89"""10## Introduction1112Learning useful representations from objects structured as graphs is useful for13a variety of machine learning (ML) applications—such as social and communication networks analysis,14biomedicine studies, and recommendation systems.15[Graph representation Learning](https://www.cs.mcgill.ca/~wlh/grl_book/) aims to16learn embeddings for the graph nodes, which can be used for a variety of ML tasks17such as node label prediction (e.g. categorizing an article based on its citations)18and link prediction (e.g. recommending an interest group to a user in a social network).1920[node2vec](https://arxiv.org/abs/1607.00653) is a simple, yet scalable and effective21technique for learning low-dimensional embeddings for nodes in a graph by optimizing22a neighborhood-preserving objective. The aim is to learn similar embeddings for23neighboring nodes, with respect to the graph structure.2425Given your data items structured as a graph (where the items are represented as26nodes and the relationship between items are represented as edges),27node2vec works as follows:28291. Generate item sequences using (biased) random walk.302. Create positive and negative training examples from these sequences.313. Train a [word2vec](https://www.tensorflow.org/tutorials/text/word2vec) model32(skip-gram) to learn embeddings for the items.3334In this example, we demonstrate the node2vec technique on the35[small version of the Movielens dataset](https://files.grouplens.org/datasets/movielens/ml-latest-small-README.html)36to learn movie embeddings. Such a dataset can be represented as a graph by treating37the movies as nodes, and creating edges between movies that have similar ratings38by the users. The learnt movie embeddings can be used for tasks such as movie recommendation,39or movie genres prediction.4041This example requires `networkx` package, which can be installed using the following command:4243```shell44pip install networkx45```46"""4748"""49## Setup50"""5152import os53from collections import defaultdict54import math55import networkx as nx56import random57from tqdm import tqdm58from zipfile import ZipFile59from urllib.request import urlretrieve60import numpy as np61import pandas as pd62import tensorflow as tf63from tensorflow import keras64from tensorflow.keras import layers65import matplotlib.pyplot as plt6667"""68## Download the MovieLens dataset and prepare the data6970The small version of the MovieLens dataset includes around 100k ratings71from 610 users on 9,742 movies.7273First, let's download the dataset. The downloaded folder will contain74three data files: `users.csv`, `movies.csv`, and `ratings.csv`. In this example,75we will only need the `movies.dat`, and `ratings.dat` data files.76"""7778urlretrieve(79"http://files.grouplens.org/datasets/movielens/ml-latest-small.zip", "movielens.zip"80)81ZipFile("movielens.zip", "r").extractall()8283"""84Then, we load the data into a Pandas DataFrame and perform some basic preprocessing.85"""8687# Load movies to a DataFrame.88movies = pd.read_csv("ml-latest-small/movies.csv")89# Create a `movieId` string.90movies["movieId"] = movies["movieId"].apply(lambda x: f"movie_{x}")9192# Load ratings to a DataFrame.93ratings = pd.read_csv("ml-latest-small/ratings.csv")94# Convert the `ratings` to floating point95ratings["rating"] = ratings["rating"].apply(lambda x: float(x))96# Create the `movie_id` string.97ratings["movieId"] = ratings["movieId"].apply(lambda x: f"movie_{x}")9899print("Movies data shape:", movies.shape)100print("Ratings data shape:", ratings.shape)101102"""103Let's inspect a sample instance of the `ratings` DataFrame.104"""105106ratings.head()107108"""109Next, let's check a sample instance of the `movies` DataFrame.110"""111112movies.head()113114"""115Implement two utility functions for the `movies` DataFrame.116"""117118119def get_movie_title_by_id(movieId):120return list(movies[movies.movieId == movieId].title)[0]121122123def get_movie_id_by_title(title):124return list(movies[movies.title == title].movieId)[0]125126127"""128## Construct the Movies graph129130We create an edge between two movie nodes in the graph if both movies are rated131by the same user >= `min_rating`. The weight of the edge will be based on the132[pointwise mutual information](https://en.wikipedia.org/wiki/Pointwise_mutual_information)133between the two movies, which is computed as: `log(xy) - log(x) - log(y) + log(D)`, where:134135* `xy` is how many users rated both movie `x` and movie `y` with >= `min_rating`.136* `x` is how many users rated movie `x` >= `min_rating`.137* `y` is how many users rated movie `y` >= `min_rating`.138* `D` total number of movie ratings >= `min_rating`.139"""140141"""142### Step 1: create the weighted edges between movies.143"""144145min_rating = 5146pair_frequency = defaultdict(int)147item_frequency = defaultdict(int)148149# Filter instances where rating is greater than or equal to min_rating.150rated_movies = ratings[ratings.rating >= min_rating]151# Group instances by user.152movies_grouped_by_users = list(rated_movies.groupby("userId"))153for group in tqdm(154movies_grouped_by_users,155position=0,156leave=True,157desc="Compute movie rating frequencies",158):159# Get a list of movies rated by the user.160current_movies = list(group[1]["movieId"])161162for i in range(len(current_movies)):163item_frequency[current_movies[i]] += 1164for j in range(i + 1, len(current_movies)):165x = min(current_movies[i], current_movies[j])166y = max(current_movies[i], current_movies[j])167pair_frequency[(x, y)] += 1168169"""170### Step 2: create the graph with the nodes and the edges171172To reduce the number of edges between nodes, we only add an edge between movies173if the weight of the edge is greater than `min_weight`.174"""175176min_weight = 10177D = math.log(sum(item_frequency.values()))178179# Create the movies undirected graph.180movies_graph = nx.Graph()181# Add weighted edges between movies.182# This automatically adds the movie nodes to the graph.183for pair in tqdm(184pair_frequency, position=0, leave=True, desc="Creating the movie graph"185):186x, y = pair187xy_frequency = pair_frequency[pair]188x_frequency = item_frequency[x]189y_frequency = item_frequency[y]190pmi = math.log(xy_frequency) - math.log(x_frequency) - math.log(y_frequency) + D191weight = pmi * xy_frequency192# Only include edges with weight >= min_weight.193if weight >= min_weight:194movies_graph.add_edge(x, y, weight=weight)195196"""197Let's display the total number of nodes and edges in the graph.198Note that the number of nodes is less than the total number of movies,199since only the movies that have edges to other movies are added.200"""201202print("Total number of graph nodes:", movies_graph.number_of_nodes())203print("Total number of graph edges:", movies_graph.number_of_edges())204205"""206Let's display the average node degree (number of neighbours) in the graph.207"""208209degrees = []210for node in movies_graph.nodes:211degrees.append(movies_graph.degree[node])212213print("Average node degree:", round(sum(degrees) / len(degrees), 2))214215"""216### Step 3: Create vocabulary and a mapping from tokens to integer indices217218The vocabulary is the nodes (movie IDs) in the graph.219"""220221vocabulary = ["NA"] + list(movies_graph.nodes)222vocabulary_lookup = {token: idx for idx, token in enumerate(vocabulary)}223224"""225## Implement the biased random walk226227A random walk starts from a given node, and randomly picks a neighbour node to move to.228If the edges are weighted, the neighbour is selected *probabilistically* with229respect to weights of the edges between the current node and its neighbours.230This procedure is repeated for `num_steps` to generate a sequence of *related* nodes.231232The [*biased* random walk](https://en.wikipedia.org/wiki/Biased_random_walk_on_a_graph) balances between **breadth-first sampling**233(where only local neighbours are visited) and **depth-first sampling**234(where distant neighbours are visited) by introducing the following two parameters:2352361. **Return parameter** (`p`): Controls the likelihood of immediately revisiting237a node in the walk. Setting it to a high value encourages moderate exploration,238while setting it to a low value would keep the walk local.2392. **In-out parameter** (`q`): Allows the search to differentiate240between *inward* and *outward* nodes. Setting it to a high value biases the241random walk towards local nodes, while setting it to a low value biases the walk242to visit nodes which are further away.243244"""245246247def next_step(graph, previous, current, p, q):248neighbors = list(graph.neighbors(current))249250weights = []251# Adjust the weights of the edges to the neighbors with respect to p and q.252for neighbor in neighbors:253if neighbor == previous:254# Control the probability to return to the previous node.255weights.append(graph[current][neighbor]["weight"] / p)256elif graph.has_edge(neighbor, previous):257# The probability of visiting a local node.258weights.append(graph[current][neighbor]["weight"])259else:260# Control the probability to move forward.261weights.append(graph[current][neighbor]["weight"] / q)262263# Compute the probabilities of visiting each neighbor.264weight_sum = sum(weights)265probabilities = [weight / weight_sum for weight in weights]266# Probabilistically select a neighbor to visit.267next = np.random.choice(neighbors, size=1, p=probabilities)[0]268return next269270271def random_walk(graph, num_walks, num_steps, p, q):272walks = []273nodes = list(graph.nodes())274# Perform multiple iterations of the random walk.275for walk_iteration in range(num_walks):276random.shuffle(nodes)277278for node in tqdm(279nodes,280position=0,281leave=True,282desc=f"Random walks iteration {walk_iteration + 1} of {num_walks}",283):284# Start the walk with a random node from the graph.285walk = [node]286# Randomly walk for num_steps.287while len(walk) < num_steps:288current = walk[-1]289previous = walk[-2] if len(walk) > 1 else None290# Compute the next node to visit.291next = next_step(graph, previous, current, p, q)292walk.append(next)293# Replace node ids (movie ids) in the walk with token ids.294walk = [vocabulary_lookup[token] for token in walk]295# Add the walk to the generated sequence.296walks.append(walk)297298return walks299300301"""302## Generate training data using the biased random walk303304You can explore different configurations of `p` and `q` to different results of305related movies.306"""307# Random walk return parameter.308p = 1309# Random walk in-out parameter.310q = 1311# Number of iterations of random walks.312num_walks = 5313# Number of steps of each random walk.314num_steps = 10315walks = random_walk(movies_graph, num_walks, num_steps, p, q)316317print("Number of walks generated:", len(walks))318319"""320## Generate positive and negative examples321322To train a skip-gram model, we use the generated walks to create positive and323negative training examples. Each example includes the following features:3243251. `target`: A movie in a walk sequence.3262. `context`: Another movie in a walk sequence.3273. `weight`: How many times these two movies occurred in walk sequences.3284. `label`: The label is 1 if these two movies are samples from the walk sequences,329otherwise (i.e., if randomly sampled) the label is 0.330"""331332"""333### Generate examples334"""335336337def generate_examples(sequences, window_size, num_negative_samples, vocabulary_size):338example_weights = defaultdict(int)339# Iterate over all sequences (walks).340for sequence in tqdm(341sequences,342position=0,343leave=True,344desc=f"Generating positive and negative examples",345):346# Generate positive and negative skip-gram pairs for a sequence (walk).347pairs, labels = keras.preprocessing.sequence.skipgrams(348sequence,349vocabulary_size=vocabulary_size,350window_size=window_size,351negative_samples=num_negative_samples,352)353for idx in range(len(pairs)):354pair = pairs[idx]355label = labels[idx]356target, context = min(pair[0], pair[1]), max(pair[0], pair[1])357if target == context:358continue359entry = (target, context, label)360example_weights[entry] += 1361362targets, contexts, labels, weights = [], [], [], []363for entry in example_weights:364weight = example_weights[entry]365target, context, label = entry366targets.append(target)367contexts.append(context)368labels.append(label)369weights.append(weight)370371return np.array(targets), np.array(contexts), np.array(labels), np.array(weights)372373374num_negative_samples = 4375targets, contexts, labels, weights = generate_examples(376sequences=walks,377window_size=num_steps,378num_negative_samples=num_negative_samples,379vocabulary_size=len(vocabulary),380)381382"""383Let's display the shapes of the outputs384"""385386print(f"Targets shape: {targets.shape}")387print(f"Contexts shape: {contexts.shape}")388print(f"Labels shape: {labels.shape}")389print(f"Weights shape: {weights.shape}")390391"""392### Convert the data into `tf.data.Dataset` objects393"""394395batch_size = 1024396397398def create_dataset(targets, contexts, labels, weights, batch_size):399inputs = {400"target": targets,401"context": contexts,402}403dataset = tf.data.Dataset.from_tensor_slices((inputs, labels, weights))404dataset = dataset.shuffle(buffer_size=batch_size * 2)405dataset = dataset.batch(batch_size, drop_remainder=True)406dataset = dataset.prefetch(tf.data.AUTOTUNE)407return dataset408409410dataset = create_dataset(411targets=targets,412contexts=contexts,413labels=labels,414weights=weights,415batch_size=batch_size,416)417418"""419## Train the skip-gram model420421Our skip-gram is a simple binary classification model that works as follows:4224231. An embedding is looked up for the `target` movie.4242. An embedding is looked up for the `context` movie.4253. The dot product is computed between these two embeddings.4264. The result (after a sigmoid activation) is compared to the label.4275. A binary crossentropy loss is used.428"""429430learning_rate = 0.001431embedding_dim = 50432num_epochs = 10433434"""435### Implement the model436"""437438439def create_model(vocabulary_size, embedding_dim):440inputs = {441"target": layers.Input(name="target", shape=(), dtype="int32"),442"context": layers.Input(name="context", shape=(), dtype="int32"),443}444# Initialize item embeddings.445embed_item = layers.Embedding(446input_dim=vocabulary_size,447output_dim=embedding_dim,448embeddings_initializer="he_normal",449embeddings_regularizer=keras.regularizers.l2(1e-6),450name="item_embeddings",451)452# Lookup embeddings for target.453target_embeddings = embed_item(inputs["target"])454# Lookup embeddings for context.455context_embeddings = embed_item(inputs["context"])456# Compute dot similarity between target and context embeddings.457logits = layers.Dot(axes=1, normalize=False, name="dot_similarity")(458[target_embeddings, context_embeddings]459)460# Create the model.461model = keras.Model(inputs=inputs, outputs=logits)462return model463464465"""466### Train the model467"""468469"""470We instantiate the model and compile it.471"""472473model = create_model(len(vocabulary), embedding_dim)474model.compile(475optimizer=keras.optimizers.Adam(learning_rate),476loss=keras.losses.BinaryCrossentropy(from_logits=True),477)478479"""480Let's plot the model.481"""482483keras.utils.plot_model(484model,485show_shapes=True,486show_dtype=True,487show_layer_names=True,488)489490"""491Now we train the model on the `dataset`.492"""493494history = model.fit(dataset, epochs=num_epochs)495496"""497Finally we plot the learning history.498"""499500plt.plot(history.history["loss"])501plt.ylabel("loss")502plt.xlabel("epoch")503plt.show()504505"""506## Analyze the learnt embeddings.507"""508509movie_embeddings = model.get_layer("item_embeddings").get_weights()[0]510print("Embeddings shape:", movie_embeddings.shape)511512"""513### Find related movies514515Define a list with some movies called `query_movies`.516"""517518query_movies = [519"Matrix, The (1999)",520"Star Wars: Episode IV - A New Hope (1977)",521"Lion King, The (1994)",522"Terminator 2: Judgment Day (1991)",523"Godfather, The (1972)",524]525526"""527Get the embeddings of the movies in `query_movies`.528"""529530query_embeddings = []531532for movie_title in query_movies:533movieId = get_movie_id_by_title(movie_title)534token_id = vocabulary_lookup[movieId]535movie_embedding = movie_embeddings[token_id]536query_embeddings.append(movie_embedding)537538query_embeddings = np.array(query_embeddings)539540"""541Compute the [consine similarity](https://en.wikipedia.org/wiki/Cosine_similarity) between the embeddings of `query_movies`542and all the other movies, then pick the top k for each.543"""544545similarities = tf.linalg.matmul(546tf.math.l2_normalize(query_embeddings),547tf.math.l2_normalize(movie_embeddings),548transpose_b=True,549)550551_, indices = tf.math.top_k(similarities, k=5)552indices = indices.numpy().tolist()553554"""555Display the top related movies in `query_movies`.556"""557558for idx, title in enumerate(query_movies):559print(title)560print("".rjust(len(title), "-"))561similar_tokens = indices[idx]562for token in similar_tokens:563similar_movieId = vocabulary[token]564similar_title = get_movie_title_by_id(similar_movieId)565print(f"- {similar_title}")566print()567568"""569### Visualize the embeddings using the Embedding Projector570"""571572import io573574out_v = io.open("embeddings.tsv", "w", encoding="utf-8")575out_m = io.open("metadata.tsv", "w", encoding="utf-8")576577for idx, movie_id in enumerate(vocabulary[1:]):578movie_title = list(movies[movies.movieId == movie_id].title)[0]579vector = movie_embeddings[idx]580out_v.write("\t".join([str(x) for x in vector]) + "\n")581out_m.write(movie_title + "\n")582583out_v.close()584out_m.close()585586"""587Download the `embeddings.tsv` and `metadata.tsv` to analyze the obtained embeddings588in the [Embedding Projector](https://projector.tensorflow.org/).589"""590591"""592593**Example available on HuggingFace**594595| Trained Model | Demo |596| :--: | :--: |597| [](https://huggingface.co/keras-io/Node2Vec_MovieLens) | [](https://huggingface.co/spaces/keras-io/Node2Vec_MovieLens) |598"""599600601