Path: blob/master/examples/keras_rs/basic_ranking.py
3507 views
"""1Title: Recommending movies: ranking2Author: [Fabien Hertschuh](https://github.com/hertschuh/), [Abheesht Sharma](https://github.com/abheesht17/)3Date created: 2025/04/284Last modified: 2025/04/285Description: Rank movies using a two tower model.6Accelerator: GPU7"""89"""10## Introduction1112Recommender systems are often composed of two stages:13141. The retrieval stage is responsible for selecting an initial set of hundreds15of candidates from all possible candidates. The main objective of this model16is to efficiently weed out all candidates that the user is not interested in.17Because the retrieval model may be dealing with millions of candidates, it18has to be computationally efficient.192. The ranking stage takes the outputs of the retrieval model and fine-tunes20them to select the best possible handful of recommendations. Its task is to21narrow down the set of items the user may be interested in to a shortlist of22likely candidates.2324In this tutorial, we're going to focus on the second stage, ranking. If you are25interested in the retrieval stage, have a look at our26[retrieval](/keras_rs/examples/basic_retrieval/)27tutorial.2829In this tutorial, we're going to:30311. Get our data and split it into a training and test set.322. Implement a ranking model.333. Fit and evaluate it.344. Test running predictions with the model.3536Let's begin by choosing JAX as the backend we want to run on, and import all37the necessary libraries.38"""3940import os4142os.environ["KERAS_BACKEND"] = "jax" # `"tensorflow"`/`"torch"`4344import keras45import tensorflow as tf # Needed for the dataset46import tensorflow_datasets as tfds4748"""49## Preparing the dataset5051We're going to use the same data as the52[retrieval](/keras_rs/examples/basic_retrieval/)53tutorial. The ratings are the objectives we are trying to predict.54"""5556# Ratings data.57ratings = tfds.load("movielens/100k-ratings", split="train")58# Features of all the available movies.59movies = tfds.load("movielens/100k-movies", split="train")6061"""62In the Movielens dataset, user IDs are integers (represented as strings)63starting at 1 and with no gap. Normally, you would need to create a lookup table64to map user IDs to integers from 0 to N-1. But as a simplication, we'll use the65user id directly as an index in our model, in particular to lookup the user66embedding from the user embedding table. So we need to know the number of users.67"""6869users_count = (70ratings.map(lambda x: tf.strings.to_number(x["user_id"], out_type=tf.int32))71.reduce(tf.constant(0, tf.int32), tf.maximum)72.numpy()73)7475"""76In the Movielens dataset, movie IDs are integers (represented as strings)77starting at 1 and with no gap. Normally, you would need to create a lookup table78to map movie IDs to integers from 0 to N-1. But as a simplication, we'll use the79movie id directly as an index in our model, in particular to lookup the movie80embedding from the movie embedding table. So we need to know the number of81movies.82"""8384movies_count = movies.cardinality().numpy()8586"""87The inputs to the model are the user IDs and movie IDs and the labels are the88ratings.89"""909192def preprocess_rating(x):93return (94# Inputs are user IDs and movie IDs95{96"user_id": tf.strings.to_number(x["user_id"], out_type=tf.int32),97"movie_id": tf.strings.to_number(x["movie_id"], out_type=tf.int32),98},99# Labels are ratings between 0 and 1.100(x["user_rating"] - 1.0) / 4.0,101)102103104"""105We'll split the data by putting 80% of the ratings in the train set, and 20% in106the test set.107"""108109shuffled_ratings = ratings.map(preprocess_rating).shuffle(110100_000, seed=42, reshuffle_each_iteration=False111)112train_ratings = shuffled_ratings.take(80_000).batch(1000).cache()113test_ratings = shuffled_ratings.skip(80_000).take(20_000).batch(1000).cache()114115"""116## Implementing the Model117118### Architecture119120Ranking models do not face the same efficiency constraints as retrieval models121do, and so we have a little bit more freedom in our choice of architectures.122123A model composed of multiple stacked dense layers is a relatively common124architecture for ranking tasks. We can implement it as follows:125"""126127128class RankingModel(keras.Model):129"""Create the ranking model with the provided parameters.130131Args:132num_users: Number of entries in the user embedding table.133num_candidates: Number of entries in the candidate embedding table.134embedding_dimension: Output dimension for user and movie embedding tables.135"""136137def __init__(138self,139num_users,140num_candidates,141embedding_dimension=32,142**kwargs,143):144super().__init__(**kwargs)145# Embedding table for users.146self.user_embedding = keras.layers.Embedding(num_users, embedding_dimension)147# Embedding table for candidates.148self.candidate_embedding = keras.layers.Embedding(149num_candidates, embedding_dimension150)151# Predictions.152self.ratings = keras.Sequential(153[154# Learn multiple dense layers.155keras.layers.Dense(256, activation="relu"),156keras.layers.Dense(64, activation="relu"),157# Make rating predictions in the final layer.158keras.layers.Dense(1),159]160)161162def call(self, inputs):163user_id, movie_id = inputs["user_id"], inputs["movie_id"]164user_embeddings = self.user_embedding(user_id)165candidate_embeddings = self.candidate_embedding(movie_id)166return self.ratings(167keras.ops.concatenate([user_embeddings, candidate_embeddings], axis=1)168)169170171"""172Let's first instantiate the model. Note that we add `+ 1` to the number of users173and movies to account for the fact that id zero is not used for either (IDs174start at 1), but still takes a row in the embedding tables.175"""176177model = RankingModel(users_count + 1, movies_count + 1)178179"""180### Loss and metrics181182The next component is the loss used to train our model. Keras has several losses183to make this easy. In this instance, we'll make use of the `MeanSquaredError`184loss in order to predict the ratings. We'll also look at the185`RootMeanSquaredError` metric.186"""187188model.compile(189loss=keras.losses.MeanSquaredError(),190metrics=[keras.metrics.RootMeanSquaredError()],191optimizer=keras.optimizers.Adagrad(learning_rate=0.1),192)193194"""195## Fitting and evaluating196197After defining the model, we can use the standard Keras `model.fit()` to train198the model.199"""200201model.fit(train_ratings, epochs=5)202203"""204As the model trains, the loss is falling and the RMSE metric is improving.205206Finally, we can evaluate our model on the test set. The lower the RMSE metric,207the more accurate our model is at predicting ratings.208"""209210model.evaluate(test_ratings, return_dict=True)211212"""213## Testing the ranking model214215So far, we have only handled movies by id. Now is the time to create a mapping216keyed by movie IDs to be able to surface the titles.217"""218219movie_id_to_movie_title = {220int(x["movie_id"]): x["movie_title"] for x in movies.as_numpy_iterator()221}222movie_id_to_movie_title[0] = "" # Because id 0 is not in the dataset.223224"""225Now we can test the ranking model by computing predictions for a set of movies226and then rank these movies based on the predictions:227"""228229user_id = 42230movie_ids = [204, 141, 131]231predictions = model.predict(232{233"user_id": keras.ops.array([user_id] * len(movie_ids)),234"movie_id": keras.ops.array(movie_ids),235}236)237predictions = keras.ops.convert_to_numpy(keras.ops.squeeze(predictions, axis=1))238239for movie_id, prediction in zip(movie_ids, predictions):240print(f"{movie_id_to_movie_title[movie_id]}: {5.0 * prediction:,.2f}")241242243