Path: blob/master/examples/keras_rs/ipynb/sequential_retrieval.ipynb
3508 views
Sequential retrieval [GRU4Rec]
Author: Abheesht Sharma, Fabien Hertschuh
Date created: 2025/04/28
Last modified: 2025/04/28
Description: Recommend movies using a GRU-based sequential retrieval model.
Introduction
In this example, we are going to build a sequential retrieval model. Sequential recommendation is a popular model that looks at a sequence of items that users have interacted with previously and then predicts the next item. Here, the order of the items within each sequence matters. So, we are going to use a recurrent neural network to model the sequential relationship. For more details, please refer to the GRU4Rec paper.
Let's begin by choosing JAX as the backend we want to run on, and import all the necessary libraries.
Let's also define all important variables/hyperparameters below.
Dataset
Next, we need to prepare our dataset. Like we did in the basic retrieval example, we are going to use the MovieLens dataset.
The dataset preparation step is fairly involved. The original ratings dataset contains (user, movie ID, rating, timestamp)
tuples (among other columns, which are not important for this example). Since we are dealing with sequential retrieval, we need to create movie sequences for every user, where the sequences are ordered by timestamp.
Let's start by downloading and reading the dataset.
Now that we have read the dataset, let's create sequences of movies for every user. Here is the function for doing just that.
We need to do some filtering and processing before we proceed with training the model:
Form sequences of all lengths up to
min(user_sequence_length, MAX_CONTEXT_LENGTH)
. So, every user will have multiple sequences corresponding to it.Get labels, i.e., Given a sequence of length
n
, the firstn-1
tokens will be fed to the model as input, and the label will be the last token.Remove all user sequences with less than
MIN_SEQUENCE_LENGTH
movies.Pad all sequences to
MAX_CONTEXT_LENGTH
.
An important point to note is how we form the train-test splits. We do not form the entire dataset of sequences and then split it into train and test. Instead, for every user, we take the last sequence to be part of the test set, and all other sequences to be part of the train set. This is to prevent data leakage.
Let's split the dataset into train and test sets. Also, we need to change the format of the dataset dictionary so as to enable conversion to a tf.data.Dataset
object.
We need to batch our datasets. We also user cache()
and prefetch()
for better performance.
Let's print out one batch.
Model and Training
In the basic retrieval example, we used one query tower for the user, and the candidate tower for the candidate movie. We are going to use a two-tower architecture here as well. However, we use the query tower with a Gated Recurrent Unit (GRU) layer to encode the sequence of historical movies, and keep the same candidate tower for the candidate movie.
Note: Take a look at how the labels are defined. The label tensor (of shape (batch_size, batch_size)
) contains one-hot vectors. The idea is: for every sample, consider movie IDs corresponding to other samples in the batch as negatives.
Let's instantiate, compile and train our model.
Making predictions
Now that we have a model, we would like to be able to make predictions.
So far, we have only handled movies by id. Now is the time to create a mapping keyed by movie IDs to be able to surface the titles.
We then simply use the Keras model.predict()
method. Under the hood, it calls the BruteForceRetrieval
layer to perform the actual retrieval.
Note that this model can retrieve movies already watched by the user. We could easily add logic to remove them if that is desirable.