Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/keras_rs/deep_recommender.py
3507 views
1
"""
2
Title: Deep Recommenders
3
Author: [Fabien Hertschuh](https://github.com/hertschuh/), [Abheesht Sharma](https://github.com/abheesht17/)
4
Date created: 2025/04/28
5
Last modified: 2025/04/28
6
Description: Building a deep retrieval model with multiple stacked layers.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
One of the great advantages of using Keras to build recommender models is the
14
freedom to build rich, flexible feature representations.
15
16
The first step in doing so is preparing the features, as raw features will
17
usually not be immediately usable in a model.
18
19
For example:
20
21
- User and item IDs may be strings (titles, usernames) or large, non-contiguous
22
integers (database IDs).
23
- Item descriptions could be raw text.
24
- Interaction timestamps could be raw Unix timestamps.
25
26
These need to be appropriately transformed in order to be useful in building
27
models:
28
29
- User and item IDs have to be translated into embedding vectors,
30
high-dimensional numerical representations that are adjusted during training
31
to help the model predict its objective better.
32
- Raw text needs to be tokenized (split into smaller parts such as individual
33
words) and translated into embeddings.
34
- Numerical features need to be normalized so that their values lie in a small
35
interval around 0.
36
37
Fortunately, the Keras
38
[`FeatureSpace`](/api/utils/feature_space/) utility makes this
39
preprocessing easy.
40
41
In this tutorial, we are going to incorporate multiple features in our models.
42
These features will come from preprocessing the MovieLens dataset.
43
44
In the
45
[basic retrieval](/keras_rs/examples/basic_retrieval/)
46
tutorial, the models consist of only an embedding layer. In this tutorial, we
47
add more dense layers to our models to increase their expressive power.
48
49
In general, deeper models are capable of learning more complex patterns than
50
shallower models. For example, our user model incorporates user IDs and user
51
features such as age, gender and occupation. A shallow model (say, a single
52
embedding layer) may only be able to learn the simplest relationships between
53
those features and movies: a given user generally prefers horror movies to
54
comedies. To capture more complex relationships, such as user preferences
55
evolving with their age, we may need a deeper model with multiple stacked dense
56
layers.
57
58
Of course, complex models also have their disadvantages. The first is
59
computational cost, as larger models require both more memory and more
60
computation to train and serve. The second is the requirement for more data. In
61
general, more training data is needed to take advantage of deeper models. With
62
more parameters, deep models might overfit or even simply memorize the training
63
examples instead of learning a function that can generalize. Finally, training
64
deeper models may be harder, and more care needs to be taken in choosing
65
settings like regularization and learning rate.
66
67
Finding a good architecture for a real-world recommender system is a complex
68
art, requiring good intuition and careful hyperparameter tuning. For example,
69
factors such as the depth and width of the model, activation function, learning
70
rate, and optimizer can radically change the performance of the model. Modelling
71
choices are further complicated by the fact that good offline evaluation metrics
72
may not correspond to good online performance, and that the choice of what to
73
optimize for is often more critical than the choice of model itself.
74
75
Nevertheless, effort put into building and fine-tuning larger models often pays
76
off. In this tutorial, we will illustrate how to build a deep retrieval model.
77
We'll do this by building progressively more complex models to see how this
78
affects model performance.
79
"""
80
81
"""shell
82
pip install -q keras-rs
83
"""
84
85
import os
86
87
os.environ["KERAS_BACKEND"] = "jax" # `"tensorflow"`/`"torch"`
88
89
import keras
90
import matplotlib.pyplot as plt
91
import tensorflow as tf # Needed for the dataset
92
import tensorflow_datasets as tfds
93
94
import keras_rs
95
96
"""
97
## The MovieLens dataset
98
99
Let's first have a look at what features we can use from the MovieLens dataset.
100
"""
101
102
# Ratings data with user and movie data.
103
ratings = tfds.load("movielens/100k-ratings", split="train")
104
# Features of all the available movies.
105
movies = tfds.load("movielens/100k-movies", split="train")
106
107
"""
108
The ratings dataset returns a dictionary of movie id, user id, the assigned
109
rating, timestamp, movie information, and user information:
110
"""
111
112
for data in ratings.take(1).as_numpy_iterator():
113
print(str(data).replace(", '", ",\n '"))
114
115
"""
116
In the Movielens dataset, user IDs are integers (represented as strings)
117
starting at 1 and with no gap. Normally, you would need to create a lookup table
118
to map user IDs to integers from 0 to N-1. But as a simplication, we'll use the
119
user id directly as an index in our model, in particular to lookup the user
120
embedding from the user embedding table. So we need do know the number of users.
121
"""
122
123
USERS_COUNT = (
124
ratings.map(lambda x: tf.strings.to_number(x["user_id"], out_type=tf.int32))
125
.reduce(tf.constant(0, tf.int32), tf.maximum)
126
.numpy()
127
)
128
129
"""
130
The movies dataset contains the movie id, movie title, and the genres it belongs
131
to. Note that the genres are encoded with integer labels.
132
"""
133
134
for data in movies.take(1).as_numpy_iterator():
135
print(str(data).replace(", '", ",\n '"))
136
137
"""
138
In the Movielens dataset, movie IDs are integers (represented as strings)
139
starting at 1 and with no gap. Normally, you would need to create a lookup table
140
to map movie IDs to integers from 0 to N-1. But as a simplication, we'll use the
141
movie id directly as an index in our model, in particular to lookup the movie
142
embedding from the movie embedding table. So we need do know the number of
143
movies.
144
"""
145
146
MOVIES_COUNT = movies.cardinality().numpy()
147
148
"""
149
## Preprocessing the dataset
150
151
### Normalizing continuous features
152
153
Continuous features may need normalization so that they fall within an
154
acceptable range for the model. We will give two examples of such normalization.
155
156
#### Discretization
157
158
A common transformation is to turn a continuous feature into a number of
159
categorical features. This makes good sense if we have reasons to suspect that a
160
feature's effect is non-continuous.
161
162
We need to decide on a number the buckets we will use for discretization. Then,
163
we will use the Keras `FeatureSpace` utility to automatically find the minimum
164
and maximum value, and divide that range by the number of buckets to perform the
165
discretization.
166
167
In this example, we will discretize the user age.
168
"""
169
170
AGE_BINS_COUNT = 10
171
user_age_feature = keras.utils.FeatureSpace.float_discretized(
172
num_bins=AGE_BINS_COUNT, output_mode="int"
173
)
174
175
"""
176
#### Rescaling
177
178
Often, we want continous features to be between 0 and 1, or between -1 and 1.
179
To achieve this, we can rescale features that have a different range.
180
181
In this example, we will standardize the rating, which is a integer between 1
182
and 5, to be a float between 0 and 1. We need to rescale it and offset it.
183
"""
184
185
user_rating_feature = keras.utils.FeatureSpace.float_rescaled(
186
scale=1.0 / 4.0, offset=-1.0 / 4.0
187
)
188
189
"""
190
### Turning categorical features into embeddings
191
192
A categorical feature is a feature that does not express a continuous quantity,
193
but rather takes on one of a set of fixed values.
194
195
Most deep learning models express these feature by turning them into
196
high-dimensional vectors. During model training, the value of that vector is
197
adjusted to help the model predict its objective better.
198
199
For example, suppose that our goal is to predict which user is going to watch
200
which movie. To do that, we represent each user and each movie by an embedding
201
vector. Initially, these embeddings will take on random values. During training,
202
we adjust them so that embeddings of users and the movies they watch end up
203
closer together.
204
205
Taking raw categorical features and turning them into embeddings is normally a
206
two-step process:
207
1. First, we need to translate the raw values into a range of contiguous
208
integers, normally by building a mapping (called a "vocabulary") that maps
209
raw values to integers.
210
2. Second, we need to take these integers and turn them into embeddings.
211
"""
212
213
"""
214
#### Defining categorical features
215
216
We will use the Keras `FeatureSpace` utility for the first step. Its `adapt`
217
method automatically discovers the vocabulary for categorical features.
218
"""
219
220
user_gender_feature = keras.utils.FeatureSpace.integer_categorical(
221
num_oov_indices=0, output_mode="int"
222
)
223
user_occupation_feature = keras.utils.FeatureSpace.integer_categorical(
224
num_oov_indices=0, output_mode="int"
225
)
226
227
"""
228
#### Using feature crosses
229
230
With crosses we can do feature interactions between multiple categorical
231
features. This can be powerful to express that the combination of features
232
represents a specific taste for movies.
233
234
Note that the combination of multiple features can result into on a super large
235
feature space, that is why the crossing_dim parameter is important to limit the
236
output dimension of the cross feature.
237
238
In this example, we will cross age and gender with the Keras `FeatureSpace`
239
utility.
240
"""
241
242
USER_GENDER_CROSS_COUNT = 20
243
user_gender_age_cross = keras.utils.FeatureSpace.cross(
244
feature_names=("user_gender", "raw_user_age"),
245
crossing_dim=USER_GENDER_CROSS_COUNT,
246
output_mode="int",
247
)
248
249
"""
250
### Processing text features
251
252
We may also want to add text features to our model. Usually, things like product
253
descriptions are free form text, and we can hope that our model can learn to use
254
the information they contain to make better recommendations, especially in a
255
cold-start or long tail scenario.
256
257
While the MovieLens dataset does not give us rich textual features, we can still
258
use movie titles. This may help us capture the fact that movies with very
259
similar titles are likely to belong to the same series.
260
261
The first transformation we need to apply to text is tokenization (splitting
262
into constituent words or word-pieces), followed by vocabulary learning,
263
followed by an embedding.
264
265
266
The
267
[`keras.layers.TextVectorization`](/api/layers/preprocessing_layers/text/text_vectorization/)
268
layer can do the first two steps for us.
269
"""
270
271
title_vectorizer = keras.layers.TextVectorization(
272
max_tokens=10_000, output_sequence_length=16, dtype="int32"
273
)
274
title_vectorizer.adapt(movies.map(lambda x: x["movie_title"]))
275
276
"""
277
Let's try it out:
278
"""
279
280
for data in movies.take(1).as_numpy_iterator():
281
print(title_vectorizer(data["movie_title"]))
282
283
"""
284
Each title is translated into a sequence of tokens, one for each piece we've
285
tokenized.
286
287
We can check the learned vocabulary to verify that the layer is using the
288
correct tokenization:
289
"""
290
291
print(title_vectorizer.get_vocabulary()[40:50])
292
293
"""
294
This looks correct, the layer is tokenizing titles into individual words. Later,
295
we will see how to embed this tokenized text. For now, we turn this vectorizer
296
into a Keras `FeatureSpace` feature.
297
"""
298
299
title_feature = keras.utils.FeatureSpace.feature(
300
preprocessor=title_vectorizer, dtype="string", output_mode="float"
301
)
302
TITLE_TOKEN_COUNT = title_vectorizer.vocabulary_size()
303
304
"""
305
### Putting the FeatureSpace features together
306
307
We're now ready to assemble the features with preprocessors in a `FeatureSpace`
308
object. We're then using `adapt` to go through the dataset and learn what needs
309
to be learned, such as the vocabulary size for categorical features or the
310
minimum and maximum values for bucketized features.
311
"""
312
313
feature_space = keras.utils.FeatureSpace(
314
features={
315
# Numerical features to discretize.
316
"raw_user_age": user_age_feature,
317
# Categorical features encoded as integers.
318
"user_gender": user_gender_feature,
319
"user_occupation_label": user_occupation_feature,
320
# Labels are ratings between 0 and 1.
321
"user_rating": user_rating_feature,
322
"movie_title": title_feature,
323
},
324
crosses=[user_gender_age_cross],
325
output_mode="dict",
326
)
327
328
feature_space.adapt(ratings)
329
GENDERS_COUNT = feature_space.preprocessors["user_gender"].vocabulary_size()
330
OCCUPATIONS_COUNT = feature_space.preprocessors[
331
"user_occupation_label"
332
].vocabulary_size()
333
334
"""
335
## Pre-building the candidate set
336
337
Our model is going to based on a `Retrieval` layer, which can provides a set of
338
best candidates among to full set of candidates. To do this, the retrieval layer
339
needs to know all the candidates and their features. In this section, we
340
assemble the full set of movies with the associated features.
341
342
### Extract raw candidate features
343
344
First, we gather all the raw features from the dataset in lists. That is the
345
titles of the movies and the genres. Note that one or more genres are
346
associated with each movie, and the number of genres varies per movie.
347
"""
348
349
movie_titles = [""] * (MOVIES_COUNT + 1)
350
movie_genres = [[]] * (MOVIES_COUNT + 1)
351
for x in movies.as_numpy_iterator():
352
movie_id = int(x["movie_id"])
353
movie_titles[movie_id] = x["movie_title"]
354
movie_genres[movie_id] = x["movie_genres"].tolist()
355
356
"""
357
### Preprocess candidate features
358
359
Genres are already in the form of category numbers starting at zero. However, we
360
do need to figure out two things:
361
- The maximum number of genres a single movie can have; this will determine the
362
dimension for this feature.
363
- The maximum value for the genre, which will give us the total number of genres
364
and determine the size of our embedding table for genres.
365
"""
366
367
MAX_GENRES_PER_MOVIE = 0
368
max_genre_id = 0
369
for one_movie_genres in movie_genres:
370
MAX_GENRES_PER_MOVIE = max(MAX_GENRES_PER_MOVIE, len(one_movie_genres))
371
if one_movie_genres:
372
max_genre_id = max(max_genre_id, max(one_movie_genres))
373
374
GENRES_COUNT = max_genre_id + 1
375
376
"""
377
Now we need to pad genres with an Out Of Vocabulary value to be able to
378
represent genres as a fixed size vector. We'll pad with zeros for simplicity, so
379
we're adding one to the genres to not conflict with genre zero, which is a valid
380
genre.
381
"""
382
383
movie_genres = [
384
[g + 1 for g in genres] + [0] * (MAX_GENRES_PER_MOVIE - len(genres))
385
for genres in movie_genres
386
]
387
388
"""
389
Then, we vectorize all the movie titles.
390
"""
391
392
movie_titles_vectors = title_vectorizer(movie_titles)
393
394
"""
395
### Convert candidate set to native tensors
396
397
We're now ready to combine these in a dataset. The last step is to make sure
398
everything is a native tensor that can be consumed by the retrieval layer.
399
As a remminder, movie id zero does not exist.
400
"""
401
402
MOVIES_DATASET = {
403
"movie_id": keras.ops.arange(0, MOVIES_COUNT + 1, dtype="int32"),
404
"movie_title_vector": movie_titles_vectors,
405
"movie_genres": keras.ops.convert_to_tensor(movie_genres, dtype="int32"),
406
}
407
408
"""
409
## Preparing the data
410
411
We can now define our preprocessing function. Most features will be handled
412
by the `FeatureSpace`. User IDs and Movie IDs need to be extracted. Movie genres
413
need to be padded. Then everything is packaged as a tuple with a dict of input
414
features and a float for the rating, which is used as a label.
415
"""
416
417
418
def preprocess_rating(x):
419
features = feature_space(
420
{
421
"raw_user_age": x["raw_user_age"],
422
"user_gender": x["user_gender"],
423
"user_occupation_label": x["user_occupation_label"],
424
"user_rating": x["user_rating"],
425
"movie_title": x["movie_title"],
426
}
427
)
428
features = {k: tf.squeeze(v, axis=0) for k, v in features.items()}
429
movie_genres = x["movie_genres"]
430
431
return (
432
{
433
# User inputs are user ID and user features
434
"user_id": int(x["user_id"]),
435
"raw_user_age": features["raw_user_age"],
436
"user_gender": features["user_gender"],
437
"user_occupation_label": features["user_occupation_label"],
438
"user_gender_X_raw_user_age": tf.squeeze(
439
features["user_gender_X_raw_user_age"], axis=-1
440
),
441
# Movie inputs are movie ID, vectorized title and genres
442
"movie_id": int(x["movie_id"]),
443
"movie_title_vector": features["movie_title"],
444
"movie_genres": tf.pad(
445
movie_genres + 1,
446
[[0, MAX_GENRES_PER_MOVIE - tf.shape(movie_genres)[0]]],
447
),
448
},
449
# Label is user rating between 0 and 1
450
features["user_rating"],
451
)
452
453
454
"""
455
We shuffle and then split the data into a training set and a testing set.
456
"""
457
458
shuffled_ratings = ratings.map(preprocess_rating).shuffle(
459
100_000, seed=42, reshuffle_each_iteration=False
460
)
461
462
train_ratings = shuffled_ratings.take(80_000).batch(1000).cache()
463
test_ratings = shuffled_ratings.skip(80_000).take(20_000).batch(1000).cache()
464
465
"""
466
## Model definition
467
468
### Query model
469
470
The query model is first tasked with converting user features to embeddings. The
471
embeddings are then concatenated into a single vector.
472
473
Defining deeper models will require us to stack more layers on top of this first
474
set of embeddings. A progressively narrower stack of layers, separated by an
475
activation function, is a common pattern:
476
477
```
478
+----------------------+
479
| 64 x 32 |
480
+----------------------+
481
| relu
482
+--------------------------+
483
| 128 x 64 |
484
+--------------------------+
485
| relu
486
+------------------------------+
487
| ... x 128 |
488
+------------------------------+
489
```
490
491
Since the expressive power of deep linear models is no greater than that of
492
shallow linear models, we use ReLU activations for all but the last hidden
493
layer. The final hidden layer does not use any activation function: using an
494
activation function would limit the output space of the final embeddings and
495
might negatively impact the performance of the model. For instance, if ReLUs are
496
used in the projection layer, all components in the output embedding would be
497
non-negative.
498
499
We're going to try this here. To make experimentation with different depths
500
easy, let's define a model whose depth (and width) is defined by a constructor
501
parameters. The `layer_sizes` parameter gives us the depth and width of the
502
model. We can vary it to experiment with shallower or deeper models.
503
"""
504
505
506
class QueryModel(keras.Model):
507
"""Model for encoding user queries."""
508
509
def __init__(self, layer_sizes, embedding_dimension=32):
510
"""Construct a model for encoding user queries.
511
512
Args:
513
layer_sizes: A list of integers where the i-th entry represents the
514
number of units the i-th layer contains.
515
embedding_dimension: Output dimension for all embedding tables.
516
"""
517
super().__init__()
518
519
# We first generate embeddings.
520
self.user_embedding = keras.layers.Embedding(
521
# +1 for user ID zero, which does not exist
522
USERS_COUNT + 1,
523
embedding_dimension,
524
)
525
self.gender_embedding = keras.layers.Embedding(
526
GENDERS_COUNT, embedding_dimension
527
)
528
self.age_embedding = keras.layers.Embedding(AGE_BINS_COUNT, embedding_dimension)
529
self.gender_x_age_embedding = keras.layers.Embedding(
530
USER_GENDER_CROSS_COUNT, embedding_dimension
531
)
532
self.occupation_embedding = keras.layers.Embedding(
533
OCCUPATIONS_COUNT, embedding_dimension
534
)
535
536
# Then construct the layers.
537
self.dense_layers = keras.Sequential()
538
539
# Use the ReLU activation for all but the last layer.
540
for layer_size in layer_sizes[:-1]:
541
self.dense_layers.add(keras.layers.Dense(layer_size, activation="relu"))
542
543
# No activation for the last layer.
544
self.dense_layers.add(keras.layers.Dense(layer_sizes[-1]))
545
546
def call(self, inputs):
547
# Take the inputs, pass each through its embedding layer, concatenate.
548
feature_embedding = keras.ops.concatenate(
549
[
550
self.user_embedding(inputs["user_id"]),
551
self.gender_embedding(inputs["user_gender"]),
552
self.age_embedding(inputs["raw_user_age"]),
553
self.gender_x_age_embedding(inputs["user_gender_X_raw_user_age"]),
554
self.occupation_embedding(inputs["user_occupation_label"]),
555
],
556
axis=1,
557
)
558
return self.dense_layers(feature_embedding)
559
560
561
"""
562
## Candidate model
563
564
We can adopt the same approach for the candidate model. Again, we start with
565
converting movie features to embeddings, concatenate them and then expand it
566
with hidden layers:
567
"""
568
569
570
class CandidateModel(keras.Model):
571
"""Model for encoding candidates (movies)."""
572
573
def __init__(self, layer_sizes, embedding_dimension=32):
574
"""Construct a model for encoding candidates (movies).
575
576
Args:
577
layer_sizes: A list of integers where the i-th entry represents the
578
number of units the i-th layer contains.
579
embedding_dimension: Output dimension for all embedding tables.
580
"""
581
super().__init__()
582
583
# We first generate embeddings.
584
self.movie_embedding = keras.layers.Embedding(
585
# +1 for movie ID zero, which does not exist
586
MOVIES_COUNT + 1,
587
embedding_dimension,
588
)
589
# Take all the title tokens for the title of the movie, embed each
590
# token, and then take the mean of all token embeddings.
591
self.movie_title_embedding = keras.Sequential(
592
[
593
keras.layers.Embedding(
594
# +1 for OOV token, which is used for padding
595
TITLE_TOKEN_COUNT + 1,
596
embedding_dimension,
597
mask_zero=True,
598
),
599
keras.layers.GlobalAveragePooling1D(),
600
]
601
)
602
# Take all the genres for the movie, embed each genre, and then take the
603
# mean of all genre embeddings.
604
self.movie_genres_embedding = keras.Sequential(
605
[
606
keras.layers.Embedding(
607
# +1 for OOV genre, which is used for padding
608
GENRES_COUNT + 1,
609
embedding_dimension,
610
mask_zero=True,
611
),
612
keras.layers.GlobalAveragePooling1D(),
613
]
614
)
615
616
# Then construct the layers.
617
self.dense_layers = keras.Sequential()
618
619
# Use the ReLU activation for all but the last layer.
620
for layer_size in layer_sizes[:-1]:
621
self.dense_layers.add(keras.layers.Dense(layer_size, activation="relu"))
622
623
# No activation for the last layer.
624
self.dense_layers.add(keras.layers.Dense(layer_sizes[-1]))
625
626
def call(self, inputs):
627
movie_id = inputs["movie_id"]
628
movie_title_vector = inputs["movie_title_vector"]
629
movie_genres = inputs["movie_genres"]
630
feature_embedding = keras.ops.concatenate(
631
[
632
self.movie_embedding(movie_id),
633
self.movie_title_embedding(movie_title_vector),
634
self.movie_genres_embedding(movie_genres),
635
],
636
axis=1,
637
)
638
return self.dense_layers(feature_embedding)
639
640
641
"""
642
## Combined model
643
644
With both QueryModel and CandidateModel defined, we can put together a combined
645
model and implement our loss and metrics logic. To make things simple, we'll
646
enforce that the model structure is the same across the query and candidate
647
models.
648
"""
649
650
651
class RetrievalModel(keras.Model):
652
"""Combined model."""
653
654
def __init__(
655
self,
656
layer_sizes=(32,),
657
embedding_dimension=32,
658
retrieval_k=100,
659
):
660
"""Construct a combined model.
661
662
Args:
663
layer_sizes: A list of integers where the i-th entry represents the
664
number of units the i-th layer contains.
665
embedding_dimension: Output dimension for all embedding tables.
666
retrieval_k: How many candidate movies to retrieve.
667
"""
668
super().__init__()
669
self.query_model = QueryModel(layer_sizes, embedding_dimension)
670
self.candidate_model = CandidateModel(layer_sizes, embedding_dimension)
671
self.retrieval = keras_rs.layers.BruteForceRetrieval(
672
k=retrieval_k, return_scores=False
673
)
674
self.update_candidates() # Provide an initial set of candidates
675
self.loss_fn = keras.losses.MeanSquaredError()
676
self.top_k_metric = keras.metrics.SparseTopKCategoricalAccuracy(
677
k=retrieval_k, from_sorted_ids=True
678
)
679
680
def update_candidates(self):
681
self.retrieval.update_candidates(
682
self.candidate_model.predict(MOVIES_DATASET, verbose=0)
683
)
684
685
def call(self, inputs, training=False):
686
query_embeddings = self.query_model(
687
{
688
"user_id": inputs["user_id"],
689
"raw_user_age": inputs["raw_user_age"],
690
"user_gender": inputs["user_gender"],
691
"user_occupation_label": inputs["user_occupation_label"],
692
"user_gender_X_raw_user_age": inputs["user_gender_X_raw_user_age"],
693
}
694
)
695
candidate_embeddings = self.candidate_model(
696
{
697
"movie_id": inputs["movie_id"],
698
"movie_title_vector": inputs["movie_title_vector"],
699
"movie_genres": inputs["movie_genres"],
700
}
701
)
702
703
result = {
704
"query_embeddings": query_embeddings,
705
"candidate_embeddings": candidate_embeddings,
706
}
707
if not training:
708
# No need to spend time extracting top predicted movies during
709
# training, they are not used.
710
result["predictions"] = self.retrieval(query_embeddings)
711
return result
712
713
def evaluate(
714
self,
715
x=None,
716
y=None,
717
batch_size=None,
718
verbose="auto",
719
sample_weight=None,
720
steps=None,
721
callbacks=None,
722
return_dict=False,
723
**kwargs,
724
):
725
"""Overridden to update the candidate set.
726
727
Before evaluating the model, we need to update our retrieval layer by
728
re-computing the values predicted by the candidate model for all the
729
candidates.
730
"""
731
self.update_candidates()
732
return super().evaluate(
733
x,
734
y,
735
batch_size=batch_size,
736
verbose=verbose,
737
sample_weight=sample_weight,
738
steps=steps,
739
callbacks=callbacks,
740
return_dict=return_dict,
741
**kwargs,
742
)
743
744
def compute_loss(self, x, y, y_pred, sample_weight, training=True):
745
query_embeddings = y_pred["query_embeddings"]
746
candidate_embeddings = y_pred["candidate_embeddings"]
747
748
labels = keras.ops.expand_dims(y, -1)
749
# Compute the affinity score by multiplying the two embeddings.
750
scores = keras.ops.sum(
751
keras.ops.multiply(query_embeddings, candidate_embeddings),
752
axis=1,
753
keepdims=True,
754
)
755
return self.loss_fn(labels, scores, sample_weight)
756
757
def compute_metrics(self, x, y, y_pred, sample_weight=None):
758
if "predictions" in y_pred:
759
# We are evaluating or predicting. Update `top_k_metric`.
760
movie_ids = x["movie_id"]
761
predictions = y_pred["predictions"]
762
# For `top_k_metric`, which is a `SparseTopKCategoricalAccuracy`, we
763
# only take top rated movies, and we put a weight of 0 for the rest.
764
rating_weight = keras.ops.cast(keras.ops.greater(y, 0.9), "float32")
765
sample_weight = (
766
rating_weight
767
if sample_weight is None
768
else keras.ops.multiply(rating_weight, sample_weight)
769
)
770
self.top_k_metric.update_state(
771
movie_ids, predictions, sample_weight=sample_weight
772
)
773
return self.get_metrics_result()
774
else:
775
# We are training. `top_k_metric` is not updated and is zero, so
776
# don't report it.
777
result = self.get_metrics_result()
778
result.pop(self.top_k_metric.name)
779
return result
780
781
782
"""
783
## Training the model
784
785
### Shallow model
786
787
We're ready to try out our first, shallow, model!
788
"""
789
790
NUM_EPOCHS = 30
791
792
one_layer_model = RetrievalModel((32,))
793
one_layer_model.compile(optimizer=keras.optimizers.Adagrad(0.05))
794
795
one_layer_history = one_layer_model.fit(
796
train_ratings,
797
validation_data=test_ratings,
798
validation_freq=5,
799
epochs=NUM_EPOCHS,
800
)
801
802
"""
803
This gives us a top-100 accuracy of around 0.30. We can use this as a reference
804
point for evaluating deeper models.
805
806
### Deeper model
807
808
What about a deeper model with two layers?
809
"""
810
811
two_layer_model = RetrievalModel((64, 32))
812
two_layer_model.compile(optimizer=keras.optimizers.Adagrad(0.05))
813
two_layer_history = two_layer_model.fit(
814
train_ratings,
815
validation_data=test_ratings,
816
validation_freq=5,
817
epochs=NUM_EPOCHS,
818
)
819
820
"""
821
While the deeper model seems to learn a bit better than the shallow model at
822
first, the difference becomes minimal towards the end of the trainign. We can
823
plot the validation accuracy curves to illustrate this:
824
"""
825
826
METRIC = "val_sparse_top_k_categorical_accuracy"
827
num_validation_runs = len(one_layer_history.history[METRIC])
828
epochs = [(x + 1) * 5 for x in range(num_validation_runs)]
829
830
plt.plot(epochs, one_layer_history.history[METRIC], label="1 layer")
831
plt.plot(epochs, two_layer_history.history[METRIC], label="2 layers")
832
plt.title("Accuracy vs epoch")
833
plt.xlabel("epoch")
834
plt.ylabel("Top-100 accuracy")
835
plt.legend()
836
plt.show()
837
838
"""
839
Deeper models are not necessarily better. The following model extends the depth
840
to three layers:
841
"""
842
843
three_layer_model = RetrievalModel((128, 64, 32))
844
three_layer_model.compile(optimizer=keras.optimizers.Adagrad(0.05))
845
three_layer_history = three_layer_model.fit(
846
train_ratings,
847
validation_data=test_ratings,
848
validation_freq=5,
849
epochs=NUM_EPOCHS,
850
)
851
852
"""
853
We don't really see an improvement over the shallow model:
854
"""
855
856
plt.plot(epochs, one_layer_history.history[METRIC], label="1 layer")
857
plt.plot(epochs, two_layer_history.history[METRIC], label="2 layers")
858
plt.plot(epochs, three_layer_history.history[METRIC], label="3 layers")
859
plt.title("Accuracy vs epoch")
860
plt.xlabel("epoch")
861
plt.ylabel("Top-100 accuracy")
862
plt.legend()
863
plt.show()
864
865
"""
866
This is a good illustration of the fact that deeper and larger models, while
867
capable of superior performance, often require very careful tuning. For example,
868
throughout this tutorial we used a single, fixed learning rate. Alternative
869
choices may give very different results and are worth exploring.
870
871
With appropriate tuning and sufficient data, the effort put into building larger
872
and deeper models is in many cases well worth it: larger models can lead to
873
substantial improvements in prediction accuracy.
874
875
## Next Steps
876
877
In this tutorial we expanded our retrieval model with dense layers and
878
activation functions. To see how to create a model that can perform not only
879
retrieval tasks but also rating tasks, take a look at the multitask tutorial.
880
"""
881
882