Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/keras_rs/multi_task.py
3507 views
1
"""
2
Title: Multi-task recommenders: retrieval + ranking
3
Author: [Abheesht Sharma](https://github.com/abheesht17/), [Fabien Hertschuh](https://github.com/hertschuh/)
4
Date created: 2025/04/28
5
Last modified: 2025/04/28
6
Description: Using one model for both retrieval and ranking.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
In the
14
[basic retrieval](/keras_rs/examples/basic_retrieval/)
15
and
16
[basic ranking](/keras_rs/examples/basic_ranking/)
17
tutorials, we created separate models for retrieval and ranking tasks,
18
respectively. However, in many cases, building a single, joint model for
19
multiple tasks can lead to better performance than creating distinct models for
20
each task. This is especially true when dealing with data that is unevenly
21
distributed — such as abundant data (e.g., clicks) versus sparse data
22
(e.g., purchases, returns, or manual reviews). In such scenarios, a joint model
23
can leverage representations learned from the abundant data to improve
24
predictions on the sparse data, a technique known as transfer learning.
25
For instance, [research](https://openreview.net/forum?id=SJxPVcSonN) shows that
26
a model trained to predict user ratings from sparse survey data can be
27
significantly enhanced by incorporating an auxiliary task using abundant click
28
log data.
29
30
In this example, we develop a multi-objective recommender system using the
31
MovieLens dataset. We incorporate both implicit feedback (e.g., movie watches)
32
and explicit feedback (e.g., ratings) to create a more robust and effective
33
recommendation model. For the former, we predict "movie watches", i.e., whether
34
a user has watched a movie, and for the latter, we predict the rating given by a
35
user to a movie.
36
37
Let's start by importing the necessary packages.
38
"""
39
40
"""shell
41
pip install -q keras-rs
42
"""
43
44
import os
45
46
os.environ["KERAS_BACKEND"] = "jax" # `"tensorflow"`/`"torch"`
47
48
import keras
49
import tensorflow as tf # Needed for the dataset
50
import tensorflow_datasets as tfds
51
52
import keras_rs
53
54
"""
55
## Prepare the dataset
56
57
We use the MovieLens dataset. The data loading and processing steps are similar
58
to previous tutorials, so we will not discuss them in details here.
59
"""
60
61
# Ratings data with user and movie data.
62
ratings = tfds.load("movielens/100k-ratings", split="train")
63
# Features of all the available movies.
64
movies = tfds.load("movielens/100k-movies", split="train")
65
66
"""
67
Get user and movie counts so that we can define embedding layers.
68
"""
69
70
users_count = (
71
ratings.map(lambda x: tf.strings.to_number(x["user_id"], out_type=tf.int32))
72
.reduce(tf.constant(0, tf.int32), tf.maximum)
73
.numpy()
74
)
75
76
movies_count = movies.cardinality().numpy()
77
78
"""
79
Our inputs are `"user_id"` and `"movie_id"`. Our label for the ranking task is
80
`"user_rating"`. `"user_rating"` is an integer between 0 to 4. We constrain it
81
to `[0, 1]`.
82
"""
83
84
85
def preprocess_rating(x):
86
return (
87
{
88
"user_id": tf.strings.to_number(x["user_id"], out_type=tf.int32),
89
"movie_id": tf.strings.to_number(x["movie_id"], out_type=tf.int32),
90
},
91
(x["user_rating"] - 1.0) / 4.0,
92
)
93
94
95
shuffled_ratings = ratings.map(preprocess_rating).shuffle(
96
100_000, seed=42, reshuffle_each_iteration=False
97
)
98
99
100
"""
101
Split the dataset into train-test sets.
102
"""
103
104
train_ratings = shuffled_ratings.take(80_000).batch(1000).cache()
105
test_ratings = shuffled_ratings.skip(80_000).take(20_000).batch(1000).cache()
106
107
"""
108
## Building the model
109
110
We build the model in a similar way to the basic retrieval and basic ranking
111
guides.
112
113
For the retrieval task (i.e., predicting whether a user watched a movie),
114
we compute the similarity of the corresponding user and movie embeddings, and
115
use cross entropy loss, where the positive pairs are labelled one, and all other
116
samples in the batch are considered "negatives". We report top-k accuracy for
117
this task.
118
119
For the ranking task (i.e., given a user-movie pair, predict rating), we
120
concatenate user and movie embeddings and pass it to a dense module. We use
121
MSE loss here, and report the Root Mean Squared Error (RMSE).
122
123
The final loss is a weighted combination of the two losses mentioned above,
124
where the weights are `"retrieval_loss_wt"` and `"ranking_loss_wt"`. These
125
weights decide which task the model will focus on.
126
"""
127
128
129
class MultiTaskModel(keras.Model):
130
def __init__(
131
self,
132
num_users,
133
num_candidates,
134
embedding_dimension=32,
135
layer_sizes=(256, 128),
136
retrieval_loss_wt=1.0,
137
ranking_loss_wt=1.0,
138
**kwargs,
139
):
140
super().__init__(**kwargs)
141
# Our query tower, simply an embedding table.
142
self.user_embedding = keras.layers.Embedding(num_users, embedding_dimension)
143
144
# Our candidate tower, simply an embedding table.
145
self.candidate_embedding = keras.layers.Embedding(
146
num_candidates, embedding_dimension
147
)
148
149
# Rating model.
150
self.rating_model = keras.Sequential(
151
[
152
keras.layers.Dense(layer_size, activation="relu")
153
for layer_size in layer_sizes
154
]
155
+ [keras.layers.Dense(1)]
156
)
157
158
# The layer that performs the retrieval.
159
self.retrieval = keras_rs.layers.BruteForceRetrieval(k=10, return_scores=False)
160
161
self.retrieval_loss_fn = keras.losses.CategoricalCrossentropy(
162
from_logits=True,
163
reduction="sum",
164
)
165
self.ranking_loss_fn = keras.losses.MeanSquaredError()
166
167
# Top-k accuracy for retrieval
168
self.top_k_metric = keras.metrics.SparseTopKCategoricalAccuracy(
169
k=10, from_sorted_ids=True
170
)
171
# RMSE for ranking
172
self.rmse_metric = keras.metrics.RootMeanSquaredError()
173
174
# Attributes.
175
self.num_users = num_users
176
self.num_candidates = num_candidates
177
self.embedding_dimension = embedding_dimension
178
self.layer_sizes = layer_sizes
179
self.retrieval_loss_wt = retrieval_loss_wt
180
self.ranking_loss_wt = ranking_loss_wt
181
182
def build(self, input_shape):
183
self.user_embedding.build(input_shape)
184
self.candidate_embedding.build(input_shape)
185
# In this case, the candidates are directly the movie embeddings.
186
# We take a shortcut and directly reuse the variable.
187
self.retrieval.candidate_embeddings = self.candidate_embedding.embeddings
188
self.retrieval.build(input_shape)
189
190
self.rating_model.build((None, 2 * self.embedding_dimension))
191
192
super().build(input_shape)
193
194
def call(self, inputs, training=False):
195
# Unpack inputs. Note that we have the if condition throughout this
196
# `call()` method so that we can do a `.predict()` for the retrieval
197
# task.
198
user_id = inputs["user_id"]
199
if "movie_id" in inputs:
200
movie_id = inputs["movie_id"]
201
202
result = {}
203
204
# Get user, movie embeddings.
205
user_embeddings = self.user_embedding(user_id)
206
result["user_embeddings"] = user_embeddings
207
208
if "movie_id" in inputs:
209
candidate_embeddings = self.candidate_embedding(movie_id)
210
result["candidate_embeddings"] = candidate_embeddings
211
212
# Pass both embeddings through the rating block of the model.
213
rating = self.rating_model(
214
keras.ops.concatenate([user_embeddings, candidate_embeddings], axis=1)
215
)
216
result["rating"] = rating
217
218
if not training:
219
# Skip the retrieval of top movies during training as the
220
# predictions are not used.
221
result["predictions"] = self.retrieval(user_embeddings)
222
223
return result
224
225
def compute_loss(self, x, y, y_pred, sample_weight, training=True):
226
user_embeddings = y_pred["user_embeddings"]
227
candidate_embeddings = y_pred["candidate_embeddings"]
228
229
# 1. Retrieval
230
231
# Compute the affinity score by multiplying the two embeddings.
232
scores = keras.ops.matmul(
233
user_embeddings,
234
keras.ops.transpose(candidate_embeddings),
235
)
236
237
# Retrieval labels: One-hot vectors
238
num_users = keras.ops.shape(user_embeddings)[0]
239
num_candidates = keras.ops.shape(candidate_embeddings)[0]
240
retrieval_labels = keras.ops.eye(num_users, num_candidates)
241
# Retrieval loss
242
retrieval_loss = self.retrieval_loss_fn(retrieval_labels, scores, sample_weight)
243
244
# 2. Ranking
245
ratings = y
246
pred_rating = y_pred["rating"]
247
248
# Ranking labels are just ratings.
249
ranking_labels = keras.ops.expand_dims(ratings, -1)
250
# Ranking loss
251
ranking_loss = self.ranking_loss_fn(ranking_labels, pred_rating, sample_weight)
252
253
# Total loss is a weighted combination of the two losses.
254
total_loss = (
255
self.retrieval_loss_wt * retrieval_loss
256
+ self.ranking_loss_wt * ranking_loss
257
)
258
259
return total_loss
260
261
def compute_metrics(self, x, y, y_pred, sample_weight=None):
262
# RMSE can be computed irrespective of whether we are
263
# training/evaluating.
264
self.rmse_metric.update_state(
265
y,
266
y_pred["rating"],
267
sample_weight=sample_weight,
268
)
269
270
if "predictions" in y_pred:
271
# We are evaluating or predicting. Update `top_k_metric`.
272
movie_ids = x["movie_id"]
273
predictions = y_pred["predictions"]
274
# For `top_k_metric`, which is a `SparseTopKCategoricalAccuracy`, we
275
# only take top rated movies, and we put a weight of 0 for the rest.
276
rating_weight = keras.ops.cast(keras.ops.greater(y, 0.9), "float32")
277
sample_weight = (
278
rating_weight
279
if sample_weight is None
280
else keras.ops.multiply(rating_weight, sample_weight)
281
)
282
self.top_k_metric.update_state(
283
movie_ids, predictions, sample_weight=sample_weight
284
)
285
286
return self.get_metrics_result()
287
else:
288
# We are training. `top_k_metric` is not updated and is zero, so
289
# don't report it.
290
result = self.get_metrics_result()
291
result.pop(self.top_k_metric.name)
292
return result
293
294
295
"""
296
## Training and evaluating
297
298
We will train three different models here. This can be done easily by passing
299
the correct loss weights:
300
301
1. Rating-specialised model
302
2. Retrieval-specialised model
303
3. Multi-task model
304
"""
305
306
# Rating-specialised model
307
model = MultiTaskModel(
308
num_users=users_count + 1,
309
num_candidates=movies_count + 1,
310
ranking_loss_wt=1.0,
311
retrieval_loss_wt=0.0,
312
)
313
model.compile(optimizer=keras.optimizers.Adagrad(0.1))
314
model.fit(train_ratings, epochs=5)
315
316
model.evaluate(test_ratings)
317
318
# Retrieval-specialised model
319
model = MultiTaskModel(
320
num_users=users_count + 1,
321
num_candidates=movies_count + 1,
322
ranking_loss_wt=0.0,
323
retrieval_loss_wt=1.0,
324
)
325
model.compile(optimizer=keras.optimizers.Adagrad(0.1))
326
model.fit(train_ratings, epochs=5)
327
328
model.evaluate(test_ratings)
329
330
# Multi-task model
331
model = MultiTaskModel(
332
num_users=users_count + 1,
333
num_candidates=movies_count + 1,
334
ranking_loss_wt=1.0,
335
retrieval_loss_wt=1.0,
336
)
337
model.compile(optimizer=keras.optimizers.Adagrad(0.1))
338
model.fit(train_ratings, epochs=5)
339
340
model.evaluate(test_ratings)
341
342
"""
343
Let's plot a table of the metrics and pen down our observations:
344
345
| Model | Top-K Accuracy (↑) | RMSE (↓) |
346
|-----------------------|--------------------|----------|
347
| rating-specialised | 0.005 | 0.26 |
348
| retrieval-specialised | 0.020 | 0.78 |
349
| multi-task | 0.022 | 0.25 |
350
351
As expected, the rating-specialised model has good RMSE, but poor top-k
352
accuracy. For the retrieval-specialised model, it's the opposite.
353
354
For the multi-task model, we notice that the model does well (or even slightly
355
better than the two specialised models) on both tasks. In general, we can expect
356
multi-task learning to bring about better results, especially when one task has
357
a data-abundant source, and the other task is trained on sparse data.
358
359
Now, let's make a prediction! We will first do a retrieval, and then for the
360
retrieved list of movies, we will predict the rating using the same model.
361
"""
362
363
movie_id_to_movie_title = {
364
int(x["movie_id"]): x["movie_title"] for x in movies.as_numpy_iterator()
365
}
366
movie_id_to_movie_title[0] = "" # Because id 0 is not in the dataset.
367
368
user_id = 5
369
retrieved_movie_ids = model.predict(
370
{
371
"user_id": keras.ops.array([user_id]),
372
}
373
)
374
retrieved_movie_ids = keras.ops.convert_to_numpy(retrieved_movie_ids["predictions"][0])
375
retrieved_movies = [movie_id_to_movie_title[x] for x in retrieved_movie_ids]
376
377
"""
378
For these retrieved movies, we can now get the corresponding ratings.
379
"""
380
381
pred_ratings = model.predict(
382
{
383
"user_id": keras.ops.array([user_id] * len(retrieved_movie_ids)),
384
"movie_id": keras.ops.array(retrieved_movie_ids),
385
}
386
)["rating"]
387
pred_ratings = keras.ops.convert_to_numpy(keras.ops.squeeze(pred_ratings, axis=1))
388
389
for movie_id, prediction in zip(retrieved_movie_ids, pred_ratings):
390
print(f"{movie_id_to_movie_title[movie_id]}: {5.0 * prediction:,.2f}")
391
392