Path: blob/master/examples/keras_rs/basic_retrieval.py
3507 views
"""1Title: Recommending movies: retrieval2Author: [Fabien Hertschuh](https://github.com/hertschuh/), [Abheesht Sharma](https://github.com/abheesht17/)3Date created: 2025/04/284Last modified: 2025/04/285Description: Retrieve movies using a two tower model.6Accelerator: GPU7"""89"""10## Introduction1112Recommender systems are often composed of two stages:13141. The retrieval stage is responsible for selecting an initial set of hundreds15of candidates from all possible candidates. The main objective of this model16is to efficiently weed out all candidates that the user is not interested in.17Because the retrieval model may be dealing with millions of candidates, it18has to be computationally efficient.192. The ranking stage takes the outputs of the retrieval model and fine-tunes20them to select the best possible handful of recommendations. Its task is to21narrow down the set of items the user may be interested in to a shortlist of22likely candidates.2324In this tutorial, we're going to focus on the first stage, retrieval. If you are25interested in the ranking stage, have a look at our26[ranking](/keras_rs/examples/basic_ranking/) tutorial.2728Retrieval models are often composed of two sub-models:29301. A query tower computing the query representation (normally a31fixed-dimensionality embedding vector) using query features.322. A candidate tower computing the candidate representation (an equally-sized33vector) using the candidate features. The outputs of the two models are then34multiplied together to give a query-candidate affinity score, with higher35scores expressing a better match between the candidate and the query.3637In this tutorial, we're going to build and train such a two-tower model using38the Movielens dataset.3940We're going to:41421. Get our data and split it into a training and test set.432. Implement a retrieval model.443. Fit and evaluate it.454. Test running predictions with the model.4647### The dataset4849The Movielens dataset is a classic dataset from the50[GroupLens](https://grouplens.org/datasets/movielens/) research group at the51University of Minnesota. It contains a set of ratings given to movies by a set52of users, and is a standard for recommender systems research.5354The data can be treated in two ways:55561. It can be interpreted as expressesing which movies the users watched (and57rated), and which they did not. This is a form of implicit feedback, where58users' watches tell us which things they prefer to see and which they'd59rather not see.602. It can also be seen as expressesing how much the users liked the movies they61did watch. This is a form of explicit feedback: given that a user watched a62movie, we can tell how much they liked by looking at the rating they have63given.6465In this tutorial, we are focusing on a retrieval system: a model that predicts a66set of movies from the catalogue that the user is likely to watch. For this, the67model will try to predict the rating users would give to all the movies in the68catalogue. We will therefore use the explicit rating data.6970Let's begin by choosing JAX as the backend we want to run on, and import all71the necessary libraries.72"""7374"""shell75pip install -q keras-rs76"""7778import os7980os.environ["KERAS_BACKEND"] = "jax" # `"tensorflow"`/`"torch"`8182import keras83import tensorflow as tf # Needed for the dataset84import tensorflow_datasets as tfds8586import keras_rs8788"""89## Preparing the dataset9091Let's first have a look at the data.9293We use the MovieLens dataset from94[Tensorflow Datasets](https://www.tensorflow.org/datasets). Loading95`movielens/100k_ratings` yields a `tf.data.Dataset` object containing the96ratings alongside user and movie data. Loading `movielens/100k_movies` yields a97`tf.data.Dataset` object containing only the movies data.9899Note that since the MovieLens dataset does not have predefined splits, all data100are under `train` split.101"""102103# Ratings data with user and movie data.104ratings = tfds.load("movielens/100k-ratings", split="train")105# Features of all the available movies.106movies = tfds.load("movielens/100k-movies", split="train")107108"""109The ratings dataset returns a dictionary of movie id, user id, the assigned110rating, timestamp, movie information, and user information:111"""112113for data in ratings.take(1).as_numpy_iterator():114print(str(data).replace(", '", ",\n '"))115116"""117In the Movielens dataset, user IDs are integers (represented as strings)118starting at 1 and with no gap. Normally, you would need to create a lookup table119to map user IDs to integers from 0 to N-1. But as a simplication, we'll use the120user id directly as an index in our model, in particular to lookup the user121embedding from the user embedding table. So we need do know the number of users.122"""123124users_count = (125ratings.map(lambda x: tf.strings.to_number(x["user_id"], out_type=tf.int32))126.reduce(tf.constant(0, tf.int32), tf.maximum)127.numpy()128)129130"""131The movies dataset contains the movie id, movie title, and the genres it belongs132to. Note that the genres are encoded with integer labels.133"""134135for data in movies.take(1).as_numpy_iterator():136print(str(data).replace(", '", ",\n '"))137138"""139In the Movielens dataset, movie IDs are integers (represented as strings)140starting at 1 and with no gap. Normally, you would need to create a lookup table141to map movie IDs to integers from 0 to N-1. But as a simplication, we'll use the142movie id directly as an index in our model, in particular to lookup the movie143embedding from the movie embedding table. So we need do know the number of144movies.145"""146147movies_count = movies.cardinality().numpy()148149"""150In this example, we're going to focus on the ratings data. Other tutorials151explore how to use the movie information data as well as the user information to152improve the model quality.153154We keep only the `user_id`, `movie_id` and `rating` fields in the dataset. Our155input is the `user_id`. The labels are the `movie_id` alongside the `rating` for156the given movie and user.157158The `rating` is a number between 1 and 5, we adapt it to be between 0 and 1.159"""160161162def preprocess_rating(x):163return (164# Input is the user IDs165tf.strings.to_number(x["user_id"], out_type=tf.int32),166# Labels are movie IDs + ratings between 0 and 1.167{168"movie_id": tf.strings.to_number(x["movie_id"], out_type=tf.int32),169"rating": (x["user_rating"] - 1.0) / 4.0,170},171)172173174"""175To fit and evaluate the model, we need to split it into a training and176evaluation set. In a real recommender system, this would most likely be done by177time: the data up to time *T* would be used to predict interactions after *T*.178179In this simple example, however, let's use a random split, putting 80% of the180ratings in the train set, and 20% in the test set.181"""182183shuffled_ratings = ratings.map(preprocess_rating).shuffle(184100_000, seed=42, reshuffle_each_iteration=False185)186train_ratings = shuffled_ratings.take(80_000).batch(1000).cache()187test_ratings = shuffled_ratings.skip(80_000).take(20_000).batch(1000).cache()188189"""190## Implementing the Model191192Choosing the architecture of our model is a key part of modelling.193194We are building a two-tower retrieval model, therefore we need to combine a195query tower for users and a candidate tower for movies.196197The first step is to decide on the dimensionality of the query and candidate198representations. This is the `embedding_dimension` argument in our model199constructor. We'll test with a value of `32`. Higher values will correspond to200models that may be more accurate, but will also be slower to fit and more prone201to overfitting.202203### Query and Candidate Towers204205The second step is to define the model itself. In this simple example, the query206tower and candidate tower are simply embeddings with nothing else. We'll use207Keras' `Embedding` layer.208209We can easily extend the towers to make them arbitrarily complex using standard210Keras components, as long as we return an `embedding_dimension`-wide output at211the end.212213### Retrieval214215The retrieval itself will be performed by `BruteForceRetrieval` layer from Keras216Recommenders. This layer computes the affinity scores for the given users and217all the candidate movies, then returns the top K in order.218219Note that during training, we don't actually need to perform any retrieval since220the only affinity scores we need are the ones for the users and movies in the221batch. As an optimization, we skip the retrieval entirely in the `call` method.222223### Loss224225The next component is the loss used to train our model. In this case, we use a226mean square error loss to measure the difference between the predicted movie227ratings and the actual ratins from users.228229Note that we override `compute_loss` from the `keras.Model` class. This allows230us to compute the query-candidate affinity score, which is obtained by231multiplying the outputs of the two towers together. That affinity score can then232be passed to the loss function.233"""234235236class RetrievalModel(keras.Model):237"""Create the retrieval model with the provided parameters.238239Args:240num_users: Number of entries in the user embedding table.241num_candidates: Number of entries in the candidate embedding table.242embedding_dimension: Output dimension for user and movie embedding tables.243"""244245def __init__(246self,247num_users,248num_candidates,249embedding_dimension=32,250**kwargs,251):252super().__init__(**kwargs)253# Our query tower, simply an embedding table.254self.user_embedding = keras.layers.Embedding(num_users, embedding_dimension)255# Our candidate tower, simply an embedding table.256self.candidate_embedding = keras.layers.Embedding(257num_candidates, embedding_dimension258)259# The layer that performs the retrieval.260self.retrieval = keras_rs.layers.BruteForceRetrieval(k=10, return_scores=False)261self.loss_fn = keras.losses.MeanSquaredError()262263def build(self, input_shape):264self.user_embedding.build(input_shape)265self.candidate_embedding.build(input_shape)266# In this case, the candidates are directly the movie embeddings.267# We take a shortcut and directly reuse the variable.268self.retrieval.candidate_embeddings = self.candidate_embedding.embeddings269self.retrieval.build(input_shape)270super().build(input_shape)271272def call(self, inputs, training=False):273user_embeddings = self.user_embedding(inputs)274result = {275"user_embeddings": user_embeddings,276}277if not training:278# Skip the retrieval of top movies during training as the279# predictions are not used.280result["predictions"] = self.retrieval(user_embeddings)281return result282283def compute_loss(self, x, y, y_pred, sample_weight, training=True):284candidate_id, rating = y["movie_id"], y["rating"]285user_embeddings = y_pred["user_embeddings"]286candidate_embeddings = self.candidate_embedding(candidate_id)287288labels = keras.ops.expand_dims(rating, -1)289# Compute the affinity score by multiplying the two embeddings.290scores = keras.ops.sum(291keras.ops.multiply(user_embeddings, candidate_embeddings),292axis=1,293keepdims=True,294)295return self.loss_fn(labels, scores, sample_weight)296297298"""299## Fitting and evaluating300301After defining the model, we can use the standard Keras `model.fit()` to train302and evaluate the model.303304Let's first instantiate the model. Note that we add `+ 1` to the number of users305and movies to account for the fact that id zero is not used for either (IDs306start at 1), but still takes a row in the embedding tables.307"""308309model = RetrievalModel(users_count + 1, movies_count + 1)310model.compile(optimizer=keras.optimizers.Adagrad(learning_rate=0.1))311312"""313Then train the model. Evaluation takes a bit of time, so we only evaluate the314model every 5 epochs.315"""316317history = model.fit(318train_ratings, validation_data=test_ratings, validation_freq=5, epochs=50319)320321"""322## Making predictions323324Now that we have a model, we would like to be able to make predictions.325326So far, we have only handled movies by id. Now is the time to create a mapping327keyed by movie IDs to be able to surface the titles.328"""329330movie_id_to_movie_title = {331int(x["movie_id"]): x["movie_title"] for x in movies.as_numpy_iterator()332}333movie_id_to_movie_title[0] = "" # Because id 0 is not in the dataset.334335"""336We then simply use the Keras `model.predict()` method. Under the hood, it calls337the `BruteForceRetrieval` layer to perform the actual retrieval.338339Note that this model can retrieve movies already watched by the user. We could340easily add logic to remove them if that is desirable.341"""342343user_id = 42344predictions = model.predict(keras.ops.convert_to_tensor([user_id]))345predictions = keras.ops.convert_to_numpy(predictions["predictions"])346347print(f"Recommended movies for user {user_id}:")348for movie_id in predictions[0]:349print(movie_id_to_movie_title[movie_id])350351"""352## Item-to-item recommendation353354In this model, we created a user-movie model. However, for some applications355(for example, product detail pages) it's common to perform item-to-item (for356example, movie-to-movie or product-to-product) recommendations.357358Training models like this would follow the same pattern as shown in this359tutorial, but with different training data. Here, we had a user and a movie360tower, and used (user, movie) pairs to train them. In an item-to-item model, we361would have two item towers (for the query and candidate item), and train the362model using (query item, candidate item) pairs. These could be constructed from363clicks on product detail pages.364"""365366367