Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/graph/node2vec_movielens.py
3507 views
1
"""
2
Title: Graph representation learning with node2vec
3
Author: [Khalid Salama](https://www.linkedin.com/in/khalid-salama-24403144/)
4
Date created: 2021/05/15
5
Last modified: 2021/05/15
6
Description: Implementing the node2vec model to generate embeddings for movies from the MovieLens dataset.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
Learning useful representations from objects structured as graphs is useful for
14
a variety of machine learning (ML) applications—such as social and communication networks analysis,
15
biomedicine studies, and recommendation systems.
16
[Graph representation Learning](https://www.cs.mcgill.ca/~wlh/grl_book/) aims to
17
learn embeddings for the graph nodes, which can be used for a variety of ML tasks
18
such as node label prediction (e.g. categorizing an article based on its citations)
19
and link prediction (e.g. recommending an interest group to a user in a social network).
20
21
[node2vec](https://arxiv.org/abs/1607.00653) is a simple, yet scalable and effective
22
technique for learning low-dimensional embeddings for nodes in a graph by optimizing
23
a neighborhood-preserving objective. The aim is to learn similar embeddings for
24
neighboring nodes, with respect to the graph structure.
25
26
Given your data items structured as a graph (where the items are represented as
27
nodes and the relationship between items are represented as edges),
28
node2vec works as follows:
29
30
1. Generate item sequences using (biased) random walk.
31
2. Create positive and negative training examples from these sequences.
32
3. Train a [word2vec](https://www.tensorflow.org/tutorials/text/word2vec) model
33
(skip-gram) to learn embeddings for the items.
34
35
In this example, we demonstrate the node2vec technique on the
36
[small version of the Movielens dataset](https://files.grouplens.org/datasets/movielens/ml-latest-small-README.html)
37
to learn movie embeddings. Such a dataset can be represented as a graph by treating
38
the movies as nodes, and creating edges between movies that have similar ratings
39
by the users. The learnt movie embeddings can be used for tasks such as movie recommendation,
40
or movie genres prediction.
41
42
This example requires `networkx` package, which can be installed using the following command:
43
44
```shell
45
pip install networkx
46
```
47
"""
48
49
"""
50
## Setup
51
"""
52
53
import os
54
from collections import defaultdict
55
import math
56
import networkx as nx
57
import random
58
from tqdm import tqdm
59
from zipfile import ZipFile
60
from urllib.request import urlretrieve
61
import numpy as np
62
import pandas as pd
63
import tensorflow as tf
64
from tensorflow import keras
65
from tensorflow.keras import layers
66
import matplotlib.pyplot as plt
67
68
"""
69
## Download the MovieLens dataset and prepare the data
70
71
The small version of the MovieLens dataset includes around 100k ratings
72
from 610 users on 9,742 movies.
73
74
First, let's download the dataset. The downloaded folder will contain
75
three data files: `users.csv`, `movies.csv`, and `ratings.csv`. In this example,
76
we will only need the `movies.dat`, and `ratings.dat` data files.
77
"""
78
79
urlretrieve(
80
"http://files.grouplens.org/datasets/movielens/ml-latest-small.zip", "movielens.zip"
81
)
82
ZipFile("movielens.zip", "r").extractall()
83
84
"""
85
Then, we load the data into a Pandas DataFrame and perform some basic preprocessing.
86
"""
87
88
# Load movies to a DataFrame.
89
movies = pd.read_csv("ml-latest-small/movies.csv")
90
# Create a `movieId` string.
91
movies["movieId"] = movies["movieId"].apply(lambda x: f"movie_{x}")
92
93
# Load ratings to a DataFrame.
94
ratings = pd.read_csv("ml-latest-small/ratings.csv")
95
# Convert the `ratings` to floating point
96
ratings["rating"] = ratings["rating"].apply(lambda x: float(x))
97
# Create the `movie_id` string.
98
ratings["movieId"] = ratings["movieId"].apply(lambda x: f"movie_{x}")
99
100
print("Movies data shape:", movies.shape)
101
print("Ratings data shape:", ratings.shape)
102
103
"""
104
Let's inspect a sample instance of the `ratings` DataFrame.
105
"""
106
107
ratings.head()
108
109
"""
110
Next, let's check a sample instance of the `movies` DataFrame.
111
"""
112
113
movies.head()
114
115
"""
116
Implement two utility functions for the `movies` DataFrame.
117
"""
118
119
120
def get_movie_title_by_id(movieId):
121
return list(movies[movies.movieId == movieId].title)[0]
122
123
124
def get_movie_id_by_title(title):
125
return list(movies[movies.title == title].movieId)[0]
126
127
128
"""
129
## Construct the Movies graph
130
131
We create an edge between two movie nodes in the graph if both movies are rated
132
by the same user >= `min_rating`. The weight of the edge will be based on the
133
[pointwise mutual information](https://en.wikipedia.org/wiki/Pointwise_mutual_information)
134
between the two movies, which is computed as: `log(xy) - log(x) - log(y) + log(D)`, where:
135
136
* `xy` is how many users rated both movie `x` and movie `y` with >= `min_rating`.
137
* `x` is how many users rated movie `x` >= `min_rating`.
138
* `y` is how many users rated movie `y` >= `min_rating`.
139
* `D` total number of movie ratings >= `min_rating`.
140
"""
141
142
"""
143
### Step 1: create the weighted edges between movies.
144
"""
145
146
min_rating = 5
147
pair_frequency = defaultdict(int)
148
item_frequency = defaultdict(int)
149
150
# Filter instances where rating is greater than or equal to min_rating.
151
rated_movies = ratings[ratings.rating >= min_rating]
152
# Group instances by user.
153
movies_grouped_by_users = list(rated_movies.groupby("userId"))
154
for group in tqdm(
155
movies_grouped_by_users,
156
position=0,
157
leave=True,
158
desc="Compute movie rating frequencies",
159
):
160
# Get a list of movies rated by the user.
161
current_movies = list(group[1]["movieId"])
162
163
for i in range(len(current_movies)):
164
item_frequency[current_movies[i]] += 1
165
for j in range(i + 1, len(current_movies)):
166
x = min(current_movies[i], current_movies[j])
167
y = max(current_movies[i], current_movies[j])
168
pair_frequency[(x, y)] += 1
169
170
"""
171
### Step 2: create the graph with the nodes and the edges
172
173
To reduce the number of edges between nodes, we only add an edge between movies
174
if the weight of the edge is greater than `min_weight`.
175
"""
176
177
min_weight = 10
178
D = math.log(sum(item_frequency.values()))
179
180
# Create the movies undirected graph.
181
movies_graph = nx.Graph()
182
# Add weighted edges between movies.
183
# This automatically adds the movie nodes to the graph.
184
for pair in tqdm(
185
pair_frequency, position=0, leave=True, desc="Creating the movie graph"
186
):
187
x, y = pair
188
xy_frequency = pair_frequency[pair]
189
x_frequency = item_frequency[x]
190
y_frequency = item_frequency[y]
191
pmi = math.log(xy_frequency) - math.log(x_frequency) - math.log(y_frequency) + D
192
weight = pmi * xy_frequency
193
# Only include edges with weight >= min_weight.
194
if weight >= min_weight:
195
movies_graph.add_edge(x, y, weight=weight)
196
197
"""
198
Let's display the total number of nodes and edges in the graph.
199
Note that the number of nodes is less than the total number of movies,
200
since only the movies that have edges to other movies are added.
201
"""
202
203
print("Total number of graph nodes:", movies_graph.number_of_nodes())
204
print("Total number of graph edges:", movies_graph.number_of_edges())
205
206
"""
207
Let's display the average node degree (number of neighbours) in the graph.
208
"""
209
210
degrees = []
211
for node in movies_graph.nodes:
212
degrees.append(movies_graph.degree[node])
213
214
print("Average node degree:", round(sum(degrees) / len(degrees), 2))
215
216
"""
217
### Step 3: Create vocabulary and a mapping from tokens to integer indices
218
219
The vocabulary is the nodes (movie IDs) in the graph.
220
"""
221
222
vocabulary = ["NA"] + list(movies_graph.nodes)
223
vocabulary_lookup = {token: idx for idx, token in enumerate(vocabulary)}
224
225
"""
226
## Implement the biased random walk
227
228
A random walk starts from a given node, and randomly picks a neighbour node to move to.
229
If the edges are weighted, the neighbour is selected *probabilistically* with
230
respect to weights of the edges between the current node and its neighbours.
231
This procedure is repeated for `num_steps` to generate a sequence of *related* nodes.
232
233
The [*biased* random walk](https://en.wikipedia.org/wiki/Biased_random_walk_on_a_graph) balances between **breadth-first sampling**
234
(where only local neighbours are visited) and **depth-first sampling**
235
(where distant neighbours are visited) by introducing the following two parameters:
236
237
1. **Return parameter** (`p`): Controls the likelihood of immediately revisiting
238
a node in the walk. Setting it to a high value encourages moderate exploration,
239
while setting it to a low value would keep the walk local.
240
2. **In-out parameter** (`q`): Allows the search to differentiate
241
between *inward* and *outward* nodes. Setting it to a high value biases the
242
random walk towards local nodes, while setting it to a low value biases the walk
243
to visit nodes which are further away.
244
245
"""
246
247
248
def next_step(graph, previous, current, p, q):
249
neighbors = list(graph.neighbors(current))
250
251
weights = []
252
# Adjust the weights of the edges to the neighbors with respect to p and q.
253
for neighbor in neighbors:
254
if neighbor == previous:
255
# Control the probability to return to the previous node.
256
weights.append(graph[current][neighbor]["weight"] / p)
257
elif graph.has_edge(neighbor, previous):
258
# The probability of visiting a local node.
259
weights.append(graph[current][neighbor]["weight"])
260
else:
261
# Control the probability to move forward.
262
weights.append(graph[current][neighbor]["weight"] / q)
263
264
# Compute the probabilities of visiting each neighbor.
265
weight_sum = sum(weights)
266
probabilities = [weight / weight_sum for weight in weights]
267
# Probabilistically select a neighbor to visit.
268
next = np.random.choice(neighbors, size=1, p=probabilities)[0]
269
return next
270
271
272
def random_walk(graph, num_walks, num_steps, p, q):
273
walks = []
274
nodes = list(graph.nodes())
275
# Perform multiple iterations of the random walk.
276
for walk_iteration in range(num_walks):
277
random.shuffle(nodes)
278
279
for node in tqdm(
280
nodes,
281
position=0,
282
leave=True,
283
desc=f"Random walks iteration {walk_iteration + 1} of {num_walks}",
284
):
285
# Start the walk with a random node from the graph.
286
walk = [node]
287
# Randomly walk for num_steps.
288
while len(walk) < num_steps:
289
current = walk[-1]
290
previous = walk[-2] if len(walk) > 1 else None
291
# Compute the next node to visit.
292
next = next_step(graph, previous, current, p, q)
293
walk.append(next)
294
# Replace node ids (movie ids) in the walk with token ids.
295
walk = [vocabulary_lookup[token] for token in walk]
296
# Add the walk to the generated sequence.
297
walks.append(walk)
298
299
return walks
300
301
302
"""
303
## Generate training data using the biased random walk
304
305
You can explore different configurations of `p` and `q` to different results of
306
related movies.
307
"""
308
# Random walk return parameter.
309
p = 1
310
# Random walk in-out parameter.
311
q = 1
312
# Number of iterations of random walks.
313
num_walks = 5
314
# Number of steps of each random walk.
315
num_steps = 10
316
walks = random_walk(movies_graph, num_walks, num_steps, p, q)
317
318
print("Number of walks generated:", len(walks))
319
320
"""
321
## Generate positive and negative examples
322
323
To train a skip-gram model, we use the generated walks to create positive and
324
negative training examples. Each example includes the following features:
325
326
1. `target`: A movie in a walk sequence.
327
2. `context`: Another movie in a walk sequence.
328
3. `weight`: How many times these two movies occurred in walk sequences.
329
4. `label`: The label is 1 if these two movies are samples from the walk sequences,
330
otherwise (i.e., if randomly sampled) the label is 0.
331
"""
332
333
"""
334
### Generate examples
335
"""
336
337
338
def generate_examples(sequences, window_size, num_negative_samples, vocabulary_size):
339
example_weights = defaultdict(int)
340
# Iterate over all sequences (walks).
341
for sequence in tqdm(
342
sequences,
343
position=0,
344
leave=True,
345
desc=f"Generating positive and negative examples",
346
):
347
# Generate positive and negative skip-gram pairs for a sequence (walk).
348
pairs, labels = keras.preprocessing.sequence.skipgrams(
349
sequence,
350
vocabulary_size=vocabulary_size,
351
window_size=window_size,
352
negative_samples=num_negative_samples,
353
)
354
for idx in range(len(pairs)):
355
pair = pairs[idx]
356
label = labels[idx]
357
target, context = min(pair[0], pair[1]), max(pair[0], pair[1])
358
if target == context:
359
continue
360
entry = (target, context, label)
361
example_weights[entry] += 1
362
363
targets, contexts, labels, weights = [], [], [], []
364
for entry in example_weights:
365
weight = example_weights[entry]
366
target, context, label = entry
367
targets.append(target)
368
contexts.append(context)
369
labels.append(label)
370
weights.append(weight)
371
372
return np.array(targets), np.array(contexts), np.array(labels), np.array(weights)
373
374
375
num_negative_samples = 4
376
targets, contexts, labels, weights = generate_examples(
377
sequences=walks,
378
window_size=num_steps,
379
num_negative_samples=num_negative_samples,
380
vocabulary_size=len(vocabulary),
381
)
382
383
"""
384
Let's display the shapes of the outputs
385
"""
386
387
print(f"Targets shape: {targets.shape}")
388
print(f"Contexts shape: {contexts.shape}")
389
print(f"Labels shape: {labels.shape}")
390
print(f"Weights shape: {weights.shape}")
391
392
"""
393
### Convert the data into `tf.data.Dataset` objects
394
"""
395
396
batch_size = 1024
397
398
399
def create_dataset(targets, contexts, labels, weights, batch_size):
400
inputs = {
401
"target": targets,
402
"context": contexts,
403
}
404
dataset = tf.data.Dataset.from_tensor_slices((inputs, labels, weights))
405
dataset = dataset.shuffle(buffer_size=batch_size * 2)
406
dataset = dataset.batch(batch_size, drop_remainder=True)
407
dataset = dataset.prefetch(tf.data.AUTOTUNE)
408
return dataset
409
410
411
dataset = create_dataset(
412
targets=targets,
413
contexts=contexts,
414
labels=labels,
415
weights=weights,
416
batch_size=batch_size,
417
)
418
419
"""
420
## Train the skip-gram model
421
422
Our skip-gram is a simple binary classification model that works as follows:
423
424
1. An embedding is looked up for the `target` movie.
425
2. An embedding is looked up for the `context` movie.
426
3. The dot product is computed between these two embeddings.
427
4. The result (after a sigmoid activation) is compared to the label.
428
5. A binary crossentropy loss is used.
429
"""
430
431
learning_rate = 0.001
432
embedding_dim = 50
433
num_epochs = 10
434
435
"""
436
### Implement the model
437
"""
438
439
440
def create_model(vocabulary_size, embedding_dim):
441
inputs = {
442
"target": layers.Input(name="target", shape=(), dtype="int32"),
443
"context": layers.Input(name="context", shape=(), dtype="int32"),
444
}
445
# Initialize item embeddings.
446
embed_item = layers.Embedding(
447
input_dim=vocabulary_size,
448
output_dim=embedding_dim,
449
embeddings_initializer="he_normal",
450
embeddings_regularizer=keras.regularizers.l2(1e-6),
451
name="item_embeddings",
452
)
453
# Lookup embeddings for target.
454
target_embeddings = embed_item(inputs["target"])
455
# Lookup embeddings for context.
456
context_embeddings = embed_item(inputs["context"])
457
# Compute dot similarity between target and context embeddings.
458
logits = layers.Dot(axes=1, normalize=False, name="dot_similarity")(
459
[target_embeddings, context_embeddings]
460
)
461
# Create the model.
462
model = keras.Model(inputs=inputs, outputs=logits)
463
return model
464
465
466
"""
467
### Train the model
468
"""
469
470
"""
471
We instantiate the model and compile it.
472
"""
473
474
model = create_model(len(vocabulary), embedding_dim)
475
model.compile(
476
optimizer=keras.optimizers.Adam(learning_rate),
477
loss=keras.losses.BinaryCrossentropy(from_logits=True),
478
)
479
480
"""
481
Let's plot the model.
482
"""
483
484
keras.utils.plot_model(
485
model,
486
show_shapes=True,
487
show_dtype=True,
488
show_layer_names=True,
489
)
490
491
"""
492
Now we train the model on the `dataset`.
493
"""
494
495
history = model.fit(dataset, epochs=num_epochs)
496
497
"""
498
Finally we plot the learning history.
499
"""
500
501
plt.plot(history.history["loss"])
502
plt.ylabel("loss")
503
plt.xlabel("epoch")
504
plt.show()
505
506
"""
507
## Analyze the learnt embeddings.
508
"""
509
510
movie_embeddings = model.get_layer("item_embeddings").get_weights()[0]
511
print("Embeddings shape:", movie_embeddings.shape)
512
513
"""
514
### Find related movies
515
516
Define a list with some movies called `query_movies`.
517
"""
518
519
query_movies = [
520
"Matrix, The (1999)",
521
"Star Wars: Episode IV - A New Hope (1977)",
522
"Lion King, The (1994)",
523
"Terminator 2: Judgment Day (1991)",
524
"Godfather, The (1972)",
525
]
526
527
"""
528
Get the embeddings of the movies in `query_movies`.
529
"""
530
531
query_embeddings = []
532
533
for movie_title in query_movies:
534
movieId = get_movie_id_by_title(movie_title)
535
token_id = vocabulary_lookup[movieId]
536
movie_embedding = movie_embeddings[token_id]
537
query_embeddings.append(movie_embedding)
538
539
query_embeddings = np.array(query_embeddings)
540
541
"""
542
Compute the [consine similarity](https://en.wikipedia.org/wiki/Cosine_similarity) between the embeddings of `query_movies`
543
and all the other movies, then pick the top k for each.
544
"""
545
546
similarities = tf.linalg.matmul(
547
tf.math.l2_normalize(query_embeddings),
548
tf.math.l2_normalize(movie_embeddings),
549
transpose_b=True,
550
)
551
552
_, indices = tf.math.top_k(similarities, k=5)
553
indices = indices.numpy().tolist()
554
555
"""
556
Display the top related movies in `query_movies`.
557
"""
558
559
for idx, title in enumerate(query_movies):
560
print(title)
561
print("".rjust(len(title), "-"))
562
similar_tokens = indices[idx]
563
for token in similar_tokens:
564
similar_movieId = vocabulary[token]
565
similar_title = get_movie_title_by_id(similar_movieId)
566
print(f"- {similar_title}")
567
print()
568
569
"""
570
### Visualize the embeddings using the Embedding Projector
571
"""
572
573
import io
574
575
out_v = io.open("embeddings.tsv", "w", encoding="utf-8")
576
out_m = io.open("metadata.tsv", "w", encoding="utf-8")
577
578
for idx, movie_id in enumerate(vocabulary[1:]):
579
movie_title = list(movies[movies.movieId == movie_id].title)[0]
580
vector = movie_embeddings[idx]
581
out_v.write("\t".join([str(x) for x in vector]) + "\n")
582
out_m.write(movie_title + "\n")
583
584
out_v.close()
585
out_m.close()
586
587
"""
588
Download the `embeddings.tsv` and `metadata.tsv` to analyze the obtained embeddings
589
in the [Embedding Projector](https://projector.tensorflow.org/).
590
"""
591
592
"""
593
594
**Example available on HuggingFace**
595
596
| Trained Model | Demo |
597
| :--: | :--: |
598
| [![Generic badge](https://img.shields.io/badge/%F0%9F%A4%97%20Model%3A%20-Node2Vec%20Movielens-black.svg)](https://huggingface.co/keras-io/Node2Vec_MovieLens) | [![Generic badge](https://img.shields.io/badge/%F0%9F%A4%97%20Spaces%3A-Node2Vec%20Movielens-black.svg)](https://huggingface.co/spaces/keras-io/Node2Vec_MovieLens) |
599
"""
600
601