Path: blob/master/examples/keras_rs/listwise_ranking.py
3507 views
"""1Title: List-wise ranking2Author: [Abheesht Sharma](https://github.com/abheesht17/), [Fabien Hertschuh](https://github.com/hertschuh/)3Date created: 2025/04/284Last modified: 2025/04/285Description: Rank movies using pairwise losses instead of pointwise losses.6Accelerator: GPU7"""89"""10## Introduction1112In our13[basic ranking tutorial](/keras_rs/examples/basic_ranking/), we explored a model14that learned to predict ratings for specific user-movie combinations. This model15took (user, movie) pairs as input and was trained using mean-squared error to16precisely predict the rating a user might give to a movie.1718However, solely optimizing a model's accuracy in predicting individual movie19scores isn't always the most effective strategy for developing ranking systems.20For ranking models, pinpoint accuracy in predicting scores is less critical than21the model's capability to generate an ordered list of items that aligns with a22user's preferences. In essence, the relative order of items matters more than23the exact predicted values.2425Instead of focusing on the model's predictions for individual query-item pairs26(a pointwise approach), we can optimize the model based on its ability to27correctly order items. One common method for this is pairwise ranking. In this28approach, the model learns by comparing pairs of items (e.g., item A and item B)29and determining which one should be ranked higher for a given user or query. The30goal is to minimize the number of incorrectly ordered pairs.3132Let's begin by importing all the necessary libraries.33"""3435"""shell36pip install -q keras-rs37"""3839import os4041os.environ["KERAS_BACKEND"] = "jax" # `"tensorflow"`/`"torch"`4243import collections4445import keras46import numpy as np47import tensorflow as tf # Needed only for the dataset48import tensorflow_datasets as tfds49from keras import ops5051import keras_rs5253"""54Let's define some hyperparameters here.55"""5657# Data args58TRAIN_NUM_LIST_PER_USER = 5059TEST_NUM_LIST_PER_USER = 160NUM_EXAMPLES_PER_LIST = 56162# Model args63EMBEDDING_DIM = 326465# Train args66BATCH_SIZE = 102467EPOCHS = 568LEARNING_RATE = 0.16970"""71## Preparing the dataset7273We use the MovieLens dataset. The data loading and processing steps are similar74to previous tutorials, so, we will only discuss the differences here.75"""7677# Ratings data.78ratings = tfds.load("movielens/100k-ratings", split="train")79# Features of all the available movies.80movies = tfds.load("movielens/100k-movies", split="train")8182users_count = (83ratings.map(lambda x: tf.strings.to_number(x["user_id"], out_type=tf.int32))84.reduce(tf.constant(0, tf.int32), tf.maximum)85.numpy()86)87movies_count = movies.cardinality().numpy()888990def preprocess_rating(x):91return {92"user_id": tf.strings.to_number(x["user_id"], out_type=tf.int32),93"movie_id": tf.strings.to_number(x["movie_id"], out_type=tf.int32),94# Normalise ratings between 0 and 1.95"user_rating": (x["user_rating"] - 1.0) / 4.0,96}979899shuffled_ratings = ratings.map(preprocess_rating).shuffle(100100_000, seed=42, reshuffle_each_iteration=False101)102train_ratings = shuffled_ratings.take(70_000)103val_ratings = shuffled_ratings.skip(70_000).take(15_000)104test_ratings = shuffled_ratings.skip(85_000).take(15_000)105106"""107So far, we've replicated what we have in the basic ranking tutorial.108109However, this existing dataset is not directly applicable to list-wise110optimization. List-wise optimization requires, for each user, a list of movies111they have rated, allowing the model to learn from the relative orderings within112that list. The MovieLens 100K dataset, in its original form, provides individual113rating instances (one user, one movie, one rating per example), rather than114these aggregated user-specific lists.115116To enable listwise optimization, we need to restructure the dataset. This117involves transforming it so that each data point or example represents a single118user ID accompanied by a list of movies that user has rated. Within these lists,119some movies will naturally be ranked higher by the user (as evidenced by their120ratings) than others. The primary objective for our model will then be to learn121to predict item orderings that correspond to these observed user preferences.122123Let's start by getting the entire list of movies and corresponding ratings for124every user. We remove `user_ids` corresponding to users who have rated less than125`NUM_EXAMPLES_PER_LIST` number of movies.126"""127128129def get_movie_sequence_per_user(ratings, min_examples_per_list):130"""Gets movieID sequences and ratings for every user."""131sequences = collections.defaultdict(list)132133for sample in ratings:134user_id = sample["user_id"]135movie_id = sample["movie_id"]136user_rating = sample["user_rating"]137138sequences[int(user_id.numpy())].append(139{140"movie_id": int(movie_id.numpy()),141"user_rating": float(user_rating.numpy()),142}143)144145# Remove lists with < `min_examples_per_list` number of elements.146sequences = {147user_id: sequence148for user_id, sequence in sequences.items()149if len(sequence) >= min_examples_per_list150}151152return sequences153154155"""156We now sample 50 lists for each user for the training data. For each list, we157randomly sample 5 movies from the movies the user rated.158"""159160161def sample_sublist_from_list(162lst,163num_examples_per_list,164):165"""Random selects `num_examples_per_list` number of elements from list."""166167indices = np.random.choice(168range(len(lst)),169size=num_examples_per_list,170replace=False,171)172173samples = [lst[i] for i in indices]174return samples175176177def get_examples(178sequences,179num_list_per_user,180num_examples_per_list,181):182inputs = {183"user_id": [],184"movie_id": [],185}186labels = []187for user_id, user_list in sequences.items():188for _ in range(num_list_per_user):189sampled_list = sample_sublist_from_list(190user_list,191num_examples_per_list,192)193194inputs["user_id"].append(user_id)195inputs["movie_id"].append(196tf.convert_to_tensor([f["movie_id"] for f in sampled_list])197)198labels.append(199tf.convert_to_tensor([f["user_rating"] for f in sampled_list])200)201202return (203{"user_id": inputs["user_id"], "movie_id": inputs["movie_id"]},204labels,205)206207208train_sequences = get_movie_sequence_per_user(209ratings=train_ratings, min_examples_per_list=NUM_EXAMPLES_PER_LIST210)211train_examples = get_examples(212train_sequences,213num_list_per_user=TRAIN_NUM_LIST_PER_USER,214num_examples_per_list=NUM_EXAMPLES_PER_LIST,215)216train_ds = tf.data.Dataset.from_tensor_slices(train_examples)217218val_sequences = get_movie_sequence_per_user(219ratings=val_ratings, min_examples_per_list=5220)221val_examples = get_examples(222val_sequences,223num_list_per_user=TEST_NUM_LIST_PER_USER,224num_examples_per_list=NUM_EXAMPLES_PER_LIST,225)226val_ds = tf.data.Dataset.from_tensor_slices(val_examples)227228test_sequences = get_movie_sequence_per_user(229ratings=test_ratings, min_examples_per_list=5230)231test_examples = get_examples(232test_sequences,233num_list_per_user=TEST_NUM_LIST_PER_USER,234num_examples_per_list=NUM_EXAMPLES_PER_LIST,235)236test_ds = tf.data.Dataset.from_tensor_slices(test_examples)237238"""239Batch up the dataset, and cache it.240"""241242train_ds = train_ds.batch(BATCH_SIZE).cache()243val_ds = val_ds.batch(BATCH_SIZE).cache()244test_ds = test_ds.batch(BATCH_SIZE).cache()245246"""247## Building the model248249We build a typical two-tower ranking model, similar to the250[basic ranking tutorial](/keras_rs/examples/basic_ranking/).251We have separate embedding layers for user ID and movie IDs. After obtaining252these embeddings, we concatenate them and pass them through a network of dense253layers.254255The only point of difference is that for movie IDs, we take a list of IDs256rather than just one movie ID. So, when we concatenate user ID embedding and257movie IDs' embeddings, we "repeat" the user ID 'NUM_EXAMPLES_PER_LIST' times so258as to get the same shape as the movie IDs' embeddings.259"""260261262class RankingModel(keras.Model):263"""Create the ranking model with the provided parameters.264265Args:266num_users: Number of entries in the user embedding table.267num_candidates: Number of entries in the candidate embedding table.268embedding_dimension: Output dimension for user and movie embedding tables.269"""270271def __init__(272self,273num_users,274num_candidates,275embedding_dimension=32,276**kwargs,277):278super().__init__(**kwargs)279# Embedding table for users.280self.user_embedding = keras.layers.Embedding(num_users, embedding_dimension)281# Embedding table for candidates.282self.candidate_embedding = keras.layers.Embedding(283num_candidates, embedding_dimension284)285# Predictions.286self.ratings = keras.Sequential(287[288# Learn multiple dense layers.289keras.layers.Dense(256, activation="relu"),290keras.layers.Dense(64, activation="relu"),291# Make rating predictions in the final layer.292keras.layers.Dense(1),293]294)295296def build(self, input_shape):297self.user_embedding.build(input_shape["user_id"])298self.candidate_embedding.build(input_shape["movie_id"])299300output_shape = self.candidate_embedding.compute_output_shape(301input_shape["movie_id"]302)303304self.ratings.build(list(output_shape[:-1]) + [2 * output_shape[-1]])305306def call(self, inputs):307user_id, movie_id = inputs["user_id"], inputs["movie_id"]308user_embeddings = self.user_embedding(user_id)309candidate_embeddings = self.candidate_embedding(movie_id)310311list_length = ops.shape(movie_id)[-1]312user_embeddings_repeated = ops.repeat(313ops.expand_dims(user_embeddings, axis=1),314repeats=list_length,315axis=1,316)317concatenated_embeddings = ops.concatenate(318[user_embeddings_repeated, candidate_embeddings], axis=-1319)320321scores = self.ratings(concatenated_embeddings)322scores = ops.squeeze(scores, axis=-1)323324return scores325326def compute_output_shape(self, input_shape):327return (input_shape[0], input_shape[1])328329330"""331Let's instantiate, compile and train our model. We will train two models:332one with vanilla mean-squared error, and the other with pairwise hinge loss.333For the latter, we will use `keras_rs.losses.PairwiseHingeLoss`.334335Pairwise losses compare pairs of items within each list, penalizing cases where336an item with a higher true label has a lower predicted score than an item with a337lower true label. This is why they are more suited for ranking tasks than338pointwise losses.339340To quantify these results, we compute nDCG. nDCG is a measure of ranking quality341that evaluates how well a system orders items based on relevance, giving more342importance to highly relevant items appearing at the top of the list and343normalizing the score against an ideal ranking.344To compute it, we just need to pass `keras_rs.metrics.NDCG()` as a metric to345`model.compile`.346"""347348model_mse = RankingModel(349num_users=users_count + 1,350num_candidates=movies_count + 1,351embedding_dimension=EMBEDDING_DIM,352)353model_mse.compile(354loss=keras.losses.MeanSquaredError(),355metrics=[keras_rs.metrics.NDCG(k=NUM_EXAMPLES_PER_LIST, name="ndcg")],356optimizer=keras.optimizers.Adagrad(learning_rate=LEARNING_RATE),357)358model_mse.fit(train_ds, validation_data=val_ds, epochs=EPOCHS)359360"""361And now, the model with pairwise hinge loss.362"""363364model_hinge = RankingModel(365num_users=users_count + 1,366num_candidates=movies_count + 1,367embedding_dimension=EMBEDDING_DIM,368)369model_hinge.compile(370loss=keras_rs.losses.PairwiseHingeLoss(),371metrics=[keras_rs.metrics.NDCG(k=NUM_EXAMPLES_PER_LIST, name="ndcg")],372optimizer=keras.optimizers.Adagrad(learning_rate=LEARNING_RATE),373)374model_hinge.fit(train_ds, validation_data=val_ds, epochs=EPOCHS)375376"""377## Evaluation378379Comparing the validation nDCG values, it is clear that the model trained with380the pairwise hinge loss outperforms the other one. Let's make this observation381more concrete by comparing results on the test set.382"""383384ndcg_mse = model_mse.evaluate(test_ds, return_dict=True)["ndcg"]385ndcg_hinge = model_hinge.evaluate(test_ds, return_dict=True)["ndcg"]386print(ndcg_mse, ndcg_hinge)387388"""389## Prediction390391Now, let's rank some lists!392393Let's create a mapping from movie ID to title so that we can surface the titles394for the ranked list.395"""396397movie_id_to_movie_title = {398int(x["movie_id"]): x["movie_title"] for x in movies.as_numpy_iterator()399}400movie_id_to_movie_title[0] = "" # Because id 0 is not in the dataset.401402user_id = 42403movie_ids = [409, 237, 131, 941, 543]404predictions = model_hinge.predict(405{406"user_id": keras.ops.array([user_id]),407"movie_id": keras.ops.array([movie_ids]),408}409)410predictions = keras.ops.convert_to_numpy(keras.ops.squeeze(predictions, axis=0))411sorted_indices = np.argsort(predictions)412sorted_movies = [movie_ids[i] for i in sorted_indices]413414for i, movie_id in enumerate(sorted_movies):415print(f"{i + 1}. ", movie_id_to_movie_title[movie_id])416417"""418And we're all done!419"""420421422