Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/keras_rs/listwise_ranking.py
3507 views
1
"""
2
Title: List-wise ranking
3
Author: [Abheesht Sharma](https://github.com/abheesht17/), [Fabien Hertschuh](https://github.com/hertschuh/)
4
Date created: 2025/04/28
5
Last modified: 2025/04/28
6
Description: Rank movies using pairwise losses instead of pointwise losses.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
In our
14
[basic ranking tutorial](/keras_rs/examples/basic_ranking/), we explored a model
15
that learned to predict ratings for specific user-movie combinations. This model
16
took (user, movie) pairs as input and was trained using mean-squared error to
17
precisely predict the rating a user might give to a movie.
18
19
However, solely optimizing a model's accuracy in predicting individual movie
20
scores isn't always the most effective strategy for developing ranking systems.
21
For ranking models, pinpoint accuracy in predicting scores is less critical than
22
the model's capability to generate an ordered list of items that aligns with a
23
user's preferences. In essence, the relative order of items matters more than
24
the exact predicted values.
25
26
Instead of focusing on the model's predictions for individual query-item pairs
27
(a pointwise approach), we can optimize the model based on its ability to
28
correctly order items. One common method for this is pairwise ranking. In this
29
approach, the model learns by comparing pairs of items (e.g., item A and item B)
30
and determining which one should be ranked higher for a given user or query. The
31
goal is to minimize the number of incorrectly ordered pairs.
32
33
Let's begin by importing all the necessary libraries.
34
"""
35
36
"""shell
37
pip install -q keras-rs
38
"""
39
40
import os
41
42
os.environ["KERAS_BACKEND"] = "jax" # `"tensorflow"`/`"torch"`
43
44
import collections
45
46
import keras
47
import numpy as np
48
import tensorflow as tf # Needed only for the dataset
49
import tensorflow_datasets as tfds
50
from keras import ops
51
52
import keras_rs
53
54
"""
55
Let's define some hyperparameters here.
56
"""
57
58
# Data args
59
TRAIN_NUM_LIST_PER_USER = 50
60
TEST_NUM_LIST_PER_USER = 1
61
NUM_EXAMPLES_PER_LIST = 5
62
63
# Model args
64
EMBEDDING_DIM = 32
65
66
# Train args
67
BATCH_SIZE = 1024
68
EPOCHS = 5
69
LEARNING_RATE = 0.1
70
71
"""
72
## Preparing the dataset
73
74
We use the MovieLens dataset. The data loading and processing steps are similar
75
to previous tutorials, so, we will only discuss the differences here.
76
"""
77
78
# Ratings data.
79
ratings = tfds.load("movielens/100k-ratings", split="train")
80
# Features of all the available movies.
81
movies = tfds.load("movielens/100k-movies", split="train")
82
83
users_count = (
84
ratings.map(lambda x: tf.strings.to_number(x["user_id"], out_type=tf.int32))
85
.reduce(tf.constant(0, tf.int32), tf.maximum)
86
.numpy()
87
)
88
movies_count = movies.cardinality().numpy()
89
90
91
def preprocess_rating(x):
92
return {
93
"user_id": tf.strings.to_number(x["user_id"], out_type=tf.int32),
94
"movie_id": tf.strings.to_number(x["movie_id"], out_type=tf.int32),
95
# Normalise ratings between 0 and 1.
96
"user_rating": (x["user_rating"] - 1.0) / 4.0,
97
}
98
99
100
shuffled_ratings = ratings.map(preprocess_rating).shuffle(
101
100_000, seed=42, reshuffle_each_iteration=False
102
)
103
train_ratings = shuffled_ratings.take(70_000)
104
val_ratings = shuffled_ratings.skip(70_000).take(15_000)
105
test_ratings = shuffled_ratings.skip(85_000).take(15_000)
106
107
"""
108
So far, we've replicated what we have in the basic ranking tutorial.
109
110
However, this existing dataset is not directly applicable to list-wise
111
optimization. List-wise optimization requires, for each user, a list of movies
112
they have rated, allowing the model to learn from the relative orderings within
113
that list. The MovieLens 100K dataset, in its original form, provides individual
114
rating instances (one user, one movie, one rating per example), rather than
115
these aggregated user-specific lists.
116
117
To enable listwise optimization, we need to restructure the dataset. This
118
involves transforming it so that each data point or example represents a single
119
user ID accompanied by a list of movies that user has rated. Within these lists,
120
some movies will naturally be ranked higher by the user (as evidenced by their
121
ratings) than others. The primary objective for our model will then be to learn
122
to predict item orderings that correspond to these observed user preferences.
123
124
Let's start by getting the entire list of movies and corresponding ratings for
125
every user. We remove `user_ids` corresponding to users who have rated less than
126
`NUM_EXAMPLES_PER_LIST` number of movies.
127
"""
128
129
130
def get_movie_sequence_per_user(ratings, min_examples_per_list):
131
"""Gets movieID sequences and ratings for every user."""
132
sequences = collections.defaultdict(list)
133
134
for sample in ratings:
135
user_id = sample["user_id"]
136
movie_id = sample["movie_id"]
137
user_rating = sample["user_rating"]
138
139
sequences[int(user_id.numpy())].append(
140
{
141
"movie_id": int(movie_id.numpy()),
142
"user_rating": float(user_rating.numpy()),
143
}
144
)
145
146
# Remove lists with < `min_examples_per_list` number of elements.
147
sequences = {
148
user_id: sequence
149
for user_id, sequence in sequences.items()
150
if len(sequence) >= min_examples_per_list
151
}
152
153
return sequences
154
155
156
"""
157
We now sample 50 lists for each user for the training data. For each list, we
158
randomly sample 5 movies from the movies the user rated.
159
"""
160
161
162
def sample_sublist_from_list(
163
lst,
164
num_examples_per_list,
165
):
166
"""Random selects `num_examples_per_list` number of elements from list."""
167
168
indices = np.random.choice(
169
range(len(lst)),
170
size=num_examples_per_list,
171
replace=False,
172
)
173
174
samples = [lst[i] for i in indices]
175
return samples
176
177
178
def get_examples(
179
sequences,
180
num_list_per_user,
181
num_examples_per_list,
182
):
183
inputs = {
184
"user_id": [],
185
"movie_id": [],
186
}
187
labels = []
188
for user_id, user_list in sequences.items():
189
for _ in range(num_list_per_user):
190
sampled_list = sample_sublist_from_list(
191
user_list,
192
num_examples_per_list,
193
)
194
195
inputs["user_id"].append(user_id)
196
inputs["movie_id"].append(
197
tf.convert_to_tensor([f["movie_id"] for f in sampled_list])
198
)
199
labels.append(
200
tf.convert_to_tensor([f["user_rating"] for f in sampled_list])
201
)
202
203
return (
204
{"user_id": inputs["user_id"], "movie_id": inputs["movie_id"]},
205
labels,
206
)
207
208
209
train_sequences = get_movie_sequence_per_user(
210
ratings=train_ratings, min_examples_per_list=NUM_EXAMPLES_PER_LIST
211
)
212
train_examples = get_examples(
213
train_sequences,
214
num_list_per_user=TRAIN_NUM_LIST_PER_USER,
215
num_examples_per_list=NUM_EXAMPLES_PER_LIST,
216
)
217
train_ds = tf.data.Dataset.from_tensor_slices(train_examples)
218
219
val_sequences = get_movie_sequence_per_user(
220
ratings=val_ratings, min_examples_per_list=5
221
)
222
val_examples = get_examples(
223
val_sequences,
224
num_list_per_user=TEST_NUM_LIST_PER_USER,
225
num_examples_per_list=NUM_EXAMPLES_PER_LIST,
226
)
227
val_ds = tf.data.Dataset.from_tensor_slices(val_examples)
228
229
test_sequences = get_movie_sequence_per_user(
230
ratings=test_ratings, min_examples_per_list=5
231
)
232
test_examples = get_examples(
233
test_sequences,
234
num_list_per_user=TEST_NUM_LIST_PER_USER,
235
num_examples_per_list=NUM_EXAMPLES_PER_LIST,
236
)
237
test_ds = tf.data.Dataset.from_tensor_slices(test_examples)
238
239
"""
240
Batch up the dataset, and cache it.
241
"""
242
243
train_ds = train_ds.batch(BATCH_SIZE).cache()
244
val_ds = val_ds.batch(BATCH_SIZE).cache()
245
test_ds = test_ds.batch(BATCH_SIZE).cache()
246
247
"""
248
## Building the model
249
250
We build a typical two-tower ranking model, similar to the
251
[basic ranking tutorial](/keras_rs/examples/basic_ranking/).
252
We have separate embedding layers for user ID and movie IDs. After obtaining
253
these embeddings, we concatenate them and pass them through a network of dense
254
layers.
255
256
The only point of difference is that for movie IDs, we take a list of IDs
257
rather than just one movie ID. So, when we concatenate user ID embedding and
258
movie IDs' embeddings, we "repeat" the user ID 'NUM_EXAMPLES_PER_LIST' times so
259
as to get the same shape as the movie IDs' embeddings.
260
"""
261
262
263
class RankingModel(keras.Model):
264
"""Create the ranking model with the provided parameters.
265
266
Args:
267
num_users: Number of entries in the user embedding table.
268
num_candidates: Number of entries in the candidate embedding table.
269
embedding_dimension: Output dimension for user and movie embedding tables.
270
"""
271
272
def __init__(
273
self,
274
num_users,
275
num_candidates,
276
embedding_dimension=32,
277
**kwargs,
278
):
279
super().__init__(**kwargs)
280
# Embedding table for users.
281
self.user_embedding = keras.layers.Embedding(num_users, embedding_dimension)
282
# Embedding table for candidates.
283
self.candidate_embedding = keras.layers.Embedding(
284
num_candidates, embedding_dimension
285
)
286
# Predictions.
287
self.ratings = keras.Sequential(
288
[
289
# Learn multiple dense layers.
290
keras.layers.Dense(256, activation="relu"),
291
keras.layers.Dense(64, activation="relu"),
292
# Make rating predictions in the final layer.
293
keras.layers.Dense(1),
294
]
295
)
296
297
def build(self, input_shape):
298
self.user_embedding.build(input_shape["user_id"])
299
self.candidate_embedding.build(input_shape["movie_id"])
300
301
output_shape = self.candidate_embedding.compute_output_shape(
302
input_shape["movie_id"]
303
)
304
305
self.ratings.build(list(output_shape[:-1]) + [2 * output_shape[-1]])
306
307
def call(self, inputs):
308
user_id, movie_id = inputs["user_id"], inputs["movie_id"]
309
user_embeddings = self.user_embedding(user_id)
310
candidate_embeddings = self.candidate_embedding(movie_id)
311
312
list_length = ops.shape(movie_id)[-1]
313
user_embeddings_repeated = ops.repeat(
314
ops.expand_dims(user_embeddings, axis=1),
315
repeats=list_length,
316
axis=1,
317
)
318
concatenated_embeddings = ops.concatenate(
319
[user_embeddings_repeated, candidate_embeddings], axis=-1
320
)
321
322
scores = self.ratings(concatenated_embeddings)
323
scores = ops.squeeze(scores, axis=-1)
324
325
return scores
326
327
def compute_output_shape(self, input_shape):
328
return (input_shape[0], input_shape[1])
329
330
331
"""
332
Let's instantiate, compile and train our model. We will train two models:
333
one with vanilla mean-squared error, and the other with pairwise hinge loss.
334
For the latter, we will use `keras_rs.losses.PairwiseHingeLoss`.
335
336
Pairwise losses compare pairs of items within each list, penalizing cases where
337
an item with a higher true label has a lower predicted score than an item with a
338
lower true label. This is why they are more suited for ranking tasks than
339
pointwise losses.
340
341
To quantify these results, we compute nDCG. nDCG is a measure of ranking quality
342
that evaluates how well a system orders items based on relevance, giving more
343
importance to highly relevant items appearing at the top of the list and
344
normalizing the score against an ideal ranking.
345
To compute it, we just need to pass `keras_rs.metrics.NDCG()` as a metric to
346
`model.compile`.
347
"""
348
349
model_mse = RankingModel(
350
num_users=users_count + 1,
351
num_candidates=movies_count + 1,
352
embedding_dimension=EMBEDDING_DIM,
353
)
354
model_mse.compile(
355
loss=keras.losses.MeanSquaredError(),
356
metrics=[keras_rs.metrics.NDCG(k=NUM_EXAMPLES_PER_LIST, name="ndcg")],
357
optimizer=keras.optimizers.Adagrad(learning_rate=LEARNING_RATE),
358
)
359
model_mse.fit(train_ds, validation_data=val_ds, epochs=EPOCHS)
360
361
"""
362
And now, the model with pairwise hinge loss.
363
"""
364
365
model_hinge = RankingModel(
366
num_users=users_count + 1,
367
num_candidates=movies_count + 1,
368
embedding_dimension=EMBEDDING_DIM,
369
)
370
model_hinge.compile(
371
loss=keras_rs.losses.PairwiseHingeLoss(),
372
metrics=[keras_rs.metrics.NDCG(k=NUM_EXAMPLES_PER_LIST, name="ndcg")],
373
optimizer=keras.optimizers.Adagrad(learning_rate=LEARNING_RATE),
374
)
375
model_hinge.fit(train_ds, validation_data=val_ds, epochs=EPOCHS)
376
377
"""
378
## Evaluation
379
380
Comparing the validation nDCG values, it is clear that the model trained with
381
the pairwise hinge loss outperforms the other one. Let's make this observation
382
more concrete by comparing results on the test set.
383
"""
384
385
ndcg_mse = model_mse.evaluate(test_ds, return_dict=True)["ndcg"]
386
ndcg_hinge = model_hinge.evaluate(test_ds, return_dict=True)["ndcg"]
387
print(ndcg_mse, ndcg_hinge)
388
389
"""
390
## Prediction
391
392
Now, let's rank some lists!
393
394
Let's create a mapping from movie ID to title so that we can surface the titles
395
for the ranked list.
396
"""
397
398
movie_id_to_movie_title = {
399
int(x["movie_id"]): x["movie_title"] for x in movies.as_numpy_iterator()
400
}
401
movie_id_to_movie_title[0] = "" # Because id 0 is not in the dataset.
402
403
user_id = 42
404
movie_ids = [409, 237, 131, 941, 543]
405
predictions = model_hinge.predict(
406
{
407
"user_id": keras.ops.array([user_id]),
408
"movie_id": keras.ops.array([movie_ids]),
409
}
410
)
411
predictions = keras.ops.convert_to_numpy(keras.ops.squeeze(predictions, axis=0))
412
sorted_indices = np.argsort(predictions)
413
sorted_movies = [movie_ids[i] for i in sorted_indices]
414
415
for i, movie_id in enumerate(sorted_movies):
416
print(f"{i + 1}. ", movie_id_to_movie_title[movie_id])
417
418
"""
419
And we're all done!
420
"""
421
422