Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/keras_rs/basic_retrieval.py
3507 views
1
"""
2
Title: Recommending movies: retrieval
3
Author: [Fabien Hertschuh](https://github.com/hertschuh/), [Abheesht Sharma](https://github.com/abheesht17/)
4
Date created: 2025/04/28
5
Last modified: 2025/04/28
6
Description: Retrieve movies using a two tower model.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
Recommender systems are often composed of two stages:
14
15
1. The retrieval stage is responsible for selecting an initial set of hundreds
16
of candidates from all possible candidates. The main objective of this model
17
is to efficiently weed out all candidates that the user is not interested in.
18
Because the retrieval model may be dealing with millions of candidates, it
19
has to be computationally efficient.
20
2. The ranking stage takes the outputs of the retrieval model and fine-tunes
21
them to select the best possible handful of recommendations. Its task is to
22
narrow down the set of items the user may be interested in to a shortlist of
23
likely candidates.
24
25
In this tutorial, we're going to focus on the first stage, retrieval. If you are
26
interested in the ranking stage, have a look at our
27
[ranking](/keras_rs/examples/basic_ranking/) tutorial.
28
29
Retrieval models are often composed of two sub-models:
30
31
1. A query tower computing the query representation (normally a
32
fixed-dimensionality embedding vector) using query features.
33
2. A candidate tower computing the candidate representation (an equally-sized
34
vector) using the candidate features. The outputs of the two models are then
35
multiplied together to give a query-candidate affinity score, with higher
36
scores expressing a better match between the candidate and the query.
37
38
In this tutorial, we're going to build and train such a two-tower model using
39
the Movielens dataset.
40
41
We're going to:
42
43
1. Get our data and split it into a training and test set.
44
2. Implement a retrieval model.
45
3. Fit and evaluate it.
46
4. Test running predictions with the model.
47
48
### The dataset
49
50
The Movielens dataset is a classic dataset from the
51
[GroupLens](https://grouplens.org/datasets/movielens/) research group at the
52
University of Minnesota. It contains a set of ratings given to movies by a set
53
of users, and is a standard for recommender systems research.
54
55
The data can be treated in two ways:
56
57
1. It can be interpreted as expressesing which movies the users watched (and
58
rated), and which they did not. This is a form of implicit feedback, where
59
users' watches tell us which things they prefer to see and which they'd
60
rather not see.
61
2. It can also be seen as expressesing how much the users liked the movies they
62
did watch. This is a form of explicit feedback: given that a user watched a
63
movie, we can tell how much they liked by looking at the rating they have
64
given.
65
66
In this tutorial, we are focusing on a retrieval system: a model that predicts a
67
set of movies from the catalogue that the user is likely to watch. For this, the
68
model will try to predict the rating users would give to all the movies in the
69
catalogue. We will therefore use the explicit rating data.
70
71
Let's begin by choosing JAX as the backend we want to run on, and import all
72
the necessary libraries.
73
"""
74
75
"""shell
76
pip install -q keras-rs
77
"""
78
79
import os
80
81
os.environ["KERAS_BACKEND"] = "jax" # `"tensorflow"`/`"torch"`
82
83
import keras
84
import tensorflow as tf # Needed for the dataset
85
import tensorflow_datasets as tfds
86
87
import keras_rs
88
89
"""
90
## Preparing the dataset
91
92
Let's first have a look at the data.
93
94
We use the MovieLens dataset from
95
[Tensorflow Datasets](https://www.tensorflow.org/datasets). Loading
96
`movielens/100k_ratings` yields a `tf.data.Dataset` object containing the
97
ratings alongside user and movie data. Loading `movielens/100k_movies` yields a
98
`tf.data.Dataset` object containing only the movies data.
99
100
Note that since the MovieLens dataset does not have predefined splits, all data
101
are under `train` split.
102
"""
103
104
# Ratings data with user and movie data.
105
ratings = tfds.load("movielens/100k-ratings", split="train")
106
# Features of all the available movies.
107
movies = tfds.load("movielens/100k-movies", split="train")
108
109
"""
110
The ratings dataset returns a dictionary of movie id, user id, the assigned
111
rating, timestamp, movie information, and user information:
112
"""
113
114
for data in ratings.take(1).as_numpy_iterator():
115
print(str(data).replace(", '", ",\n '"))
116
117
"""
118
In the Movielens dataset, user IDs are integers (represented as strings)
119
starting at 1 and with no gap. Normally, you would need to create a lookup table
120
to map user IDs to integers from 0 to N-1. But as a simplication, we'll use the
121
user id directly as an index in our model, in particular to lookup the user
122
embedding from the user embedding table. So we need do know the number of users.
123
"""
124
125
users_count = (
126
ratings.map(lambda x: tf.strings.to_number(x["user_id"], out_type=tf.int32))
127
.reduce(tf.constant(0, tf.int32), tf.maximum)
128
.numpy()
129
)
130
131
"""
132
The movies dataset contains the movie id, movie title, and the genres it belongs
133
to. Note that the genres are encoded with integer labels.
134
"""
135
136
for data in movies.take(1).as_numpy_iterator():
137
print(str(data).replace(", '", ",\n '"))
138
139
"""
140
In the Movielens dataset, movie IDs are integers (represented as strings)
141
starting at 1 and with no gap. Normally, you would need to create a lookup table
142
to map movie IDs to integers from 0 to N-1. But as a simplication, we'll use the
143
movie id directly as an index in our model, in particular to lookup the movie
144
embedding from the movie embedding table. So we need do know the number of
145
movies.
146
"""
147
148
movies_count = movies.cardinality().numpy()
149
150
"""
151
In this example, we're going to focus on the ratings data. Other tutorials
152
explore how to use the movie information data as well as the user information to
153
improve the model quality.
154
155
We keep only the `user_id`, `movie_id` and `rating` fields in the dataset. Our
156
input is the `user_id`. The labels are the `movie_id` alongside the `rating` for
157
the given movie and user.
158
159
The `rating` is a number between 1 and 5, we adapt it to be between 0 and 1.
160
"""
161
162
163
def preprocess_rating(x):
164
return (
165
# Input is the user IDs
166
tf.strings.to_number(x["user_id"], out_type=tf.int32),
167
# Labels are movie IDs + ratings between 0 and 1.
168
{
169
"movie_id": tf.strings.to_number(x["movie_id"], out_type=tf.int32),
170
"rating": (x["user_rating"] - 1.0) / 4.0,
171
},
172
)
173
174
175
"""
176
To fit and evaluate the model, we need to split it into a training and
177
evaluation set. In a real recommender system, this would most likely be done by
178
time: the data up to time *T* would be used to predict interactions after *T*.
179
180
In this simple example, however, let's use a random split, putting 80% of the
181
ratings in the train set, and 20% in the test set.
182
"""
183
184
shuffled_ratings = ratings.map(preprocess_rating).shuffle(
185
100_000, seed=42, reshuffle_each_iteration=False
186
)
187
train_ratings = shuffled_ratings.take(80_000).batch(1000).cache()
188
test_ratings = shuffled_ratings.skip(80_000).take(20_000).batch(1000).cache()
189
190
"""
191
## Implementing the Model
192
193
Choosing the architecture of our model is a key part of modelling.
194
195
We are building a two-tower retrieval model, therefore we need to combine a
196
query tower for users and a candidate tower for movies.
197
198
The first step is to decide on the dimensionality of the query and candidate
199
representations. This is the `embedding_dimension` argument in our model
200
constructor. We'll test with a value of `32`. Higher values will correspond to
201
models that may be more accurate, but will also be slower to fit and more prone
202
to overfitting.
203
204
### Query and Candidate Towers
205
206
The second step is to define the model itself. In this simple example, the query
207
tower and candidate tower are simply embeddings with nothing else. We'll use
208
Keras' `Embedding` layer.
209
210
We can easily extend the towers to make them arbitrarily complex using standard
211
Keras components, as long as we return an `embedding_dimension`-wide output at
212
the end.
213
214
### Retrieval
215
216
The retrieval itself will be performed by `BruteForceRetrieval` layer from Keras
217
Recommenders. This layer computes the affinity scores for the given users and
218
all the candidate movies, then returns the top K in order.
219
220
Note that during training, we don't actually need to perform any retrieval since
221
the only affinity scores we need are the ones for the users and movies in the
222
batch. As an optimization, we skip the retrieval entirely in the `call` method.
223
224
### Loss
225
226
The next component is the loss used to train our model. In this case, we use a
227
mean square error loss to measure the difference between the predicted movie
228
ratings and the actual ratins from users.
229
230
Note that we override `compute_loss` from the `keras.Model` class. This allows
231
us to compute the query-candidate affinity score, which is obtained by
232
multiplying the outputs of the two towers together. That affinity score can then
233
be passed to the loss function.
234
"""
235
236
237
class RetrievalModel(keras.Model):
238
"""Create the retrieval model with the provided parameters.
239
240
Args:
241
num_users: Number of entries in the user embedding table.
242
num_candidates: Number of entries in the candidate embedding table.
243
embedding_dimension: Output dimension for user and movie embedding tables.
244
"""
245
246
def __init__(
247
self,
248
num_users,
249
num_candidates,
250
embedding_dimension=32,
251
**kwargs,
252
):
253
super().__init__(**kwargs)
254
# Our query tower, simply an embedding table.
255
self.user_embedding = keras.layers.Embedding(num_users, embedding_dimension)
256
# Our candidate tower, simply an embedding table.
257
self.candidate_embedding = keras.layers.Embedding(
258
num_candidates, embedding_dimension
259
)
260
# The layer that performs the retrieval.
261
self.retrieval = keras_rs.layers.BruteForceRetrieval(k=10, return_scores=False)
262
self.loss_fn = keras.losses.MeanSquaredError()
263
264
def build(self, input_shape):
265
self.user_embedding.build(input_shape)
266
self.candidate_embedding.build(input_shape)
267
# In this case, the candidates are directly the movie embeddings.
268
# We take a shortcut and directly reuse the variable.
269
self.retrieval.candidate_embeddings = self.candidate_embedding.embeddings
270
self.retrieval.build(input_shape)
271
super().build(input_shape)
272
273
def call(self, inputs, training=False):
274
user_embeddings = self.user_embedding(inputs)
275
result = {
276
"user_embeddings": user_embeddings,
277
}
278
if not training:
279
# Skip the retrieval of top movies during training as the
280
# predictions are not used.
281
result["predictions"] = self.retrieval(user_embeddings)
282
return result
283
284
def compute_loss(self, x, y, y_pred, sample_weight, training=True):
285
candidate_id, rating = y["movie_id"], y["rating"]
286
user_embeddings = y_pred["user_embeddings"]
287
candidate_embeddings = self.candidate_embedding(candidate_id)
288
289
labels = keras.ops.expand_dims(rating, -1)
290
# Compute the affinity score by multiplying the two embeddings.
291
scores = keras.ops.sum(
292
keras.ops.multiply(user_embeddings, candidate_embeddings),
293
axis=1,
294
keepdims=True,
295
)
296
return self.loss_fn(labels, scores, sample_weight)
297
298
299
"""
300
## Fitting and evaluating
301
302
After defining the model, we can use the standard Keras `model.fit()` to train
303
and evaluate the model.
304
305
Let's first instantiate the model. Note that we add `+ 1` to the number of users
306
and movies to account for the fact that id zero is not used for either (IDs
307
start at 1), but still takes a row in the embedding tables.
308
"""
309
310
model = RetrievalModel(users_count + 1, movies_count + 1)
311
model.compile(optimizer=keras.optimizers.Adagrad(learning_rate=0.1))
312
313
"""
314
Then train the model. Evaluation takes a bit of time, so we only evaluate the
315
model every 5 epochs.
316
"""
317
318
history = model.fit(
319
train_ratings, validation_data=test_ratings, validation_freq=5, epochs=50
320
)
321
322
"""
323
## Making predictions
324
325
Now that we have a model, we would like to be able to make predictions.
326
327
So far, we have only handled movies by id. Now is the time to create a mapping
328
keyed by movie IDs to be able to surface the titles.
329
"""
330
331
movie_id_to_movie_title = {
332
int(x["movie_id"]): x["movie_title"] for x in movies.as_numpy_iterator()
333
}
334
movie_id_to_movie_title[0] = "" # Because id 0 is not in the dataset.
335
336
"""
337
We then simply use the Keras `model.predict()` method. Under the hood, it calls
338
the `BruteForceRetrieval` layer to perform the actual retrieval.
339
340
Note that this model can retrieve movies already watched by the user. We could
341
easily add logic to remove them if that is desirable.
342
"""
343
344
user_id = 42
345
predictions = model.predict(keras.ops.convert_to_tensor([user_id]))
346
predictions = keras.ops.convert_to_numpy(predictions["predictions"])
347
348
print(f"Recommended movies for user {user_id}:")
349
for movie_id in predictions[0]:
350
print(movie_id_to_movie_title[movie_id])
351
352
"""
353
## Item-to-item recommendation
354
355
In this model, we created a user-movie model. However, for some applications
356
(for example, product detail pages) it's common to perform item-to-item (for
357
example, movie-to-movie or product-to-product) recommendations.
358
359
Training models like this would follow the same pattern as shown in this
360
tutorial, but with different training data. Here, we had a user and a movie
361
tower, and used (user, movie) pairs to train them. In an item-to-item model, we
362
would have two item towers (for the query and candidate item), and train the
363
model using (query item, candidate item) pairs. These could be constructed from
364
clicks on product detail pages.
365
"""
366
367