Path: blob/master/examples/keras_rs/ipynb/sas_rec.ipynb
3508 views
Sequential retrieval using SASRec
Author: Abheesht Sharma, Fabien Hertschuh
Date created: 2025/04/28
Last modified: 2025/04/28
Description: Recommend movies using a Transformer-based retrieval model (SASRec).
Introduction
Sequential recommendation is a popular model that looks at a sequence of items that users have interacted with previously and then predicts the next item. Here, the order of the items within each sequence matters. Previously, in the Recommending movies: retrieval using a sequential model example, we built a GRU-based sequential retrieval model. In this example, we will build a popular Transformer decoder-based model named Self-Attentive Sequential Recommendation (SASRec) for the same sequential recommendation task.
Let's begin by importing all the necessary libraries.
Let's also define all important variables/hyperparameters below.
Dataset
Next, we need to prepare our dataset. Like we did in the sequential retrieval example, we are going to use the MovieLens dataset.
The dataset preparation step is fairly involved. The original ratings dataset contains (user, movie ID, rating, timestamp)
tuples (among other columns, which are not important for this example). Since we are dealing with sequential retrieval, we need to create movie sequences for every user, where the sequences are ordered by timestamp.
Let's start by downloading and reading the dataset.
Now that we have read the dataset, let's create sequences of movies for every user. Here is the function for doing just that.
So far, we have essentially replicated what we did in the sequential retrieval example. We have a sequence of movies for every user.
SASRec is trained contrastively, which means the model learns to distinguish between sequences of movies a user has actually interacted with (positive examples) and sequences they have not interacted with (negative examples).
The following function, format_data
, prepares the data in this specific format. For each user's movie sequence, it generates a corresponding "negative sequence". This negative sequence consists of randomly selected movies that the user has not interacted with, but are of the same length as the original sequence.
Now that we have the original movie interaction sequences for each user (from format_data
, stored in examples["sequence"]
) and their corresponding random negative sequences (in examples["negative_sequence"]
), the next step is to prepare this data for input to the model. The primary goals of this preprocessing are:
Creating Input Features and Target Labels: For sequential recommendation, the model learns to predict the next item in a sequence given the preceding items. This is achieved by:
taking the original
example["sequence"]
and creating the model's input features (item_ids
) from all items except the last one (example["sequence"][..., :-1]
);creating the target "positive sequence" (what the model tries to predict as the actual next items) by taking the original
example["sequence"]
and shifting it, using all items except the first one (example["sequence"][..., 1:]
);shifting
example["negative_sequence"]
(fromformat_data
) is to create the target "negative sequence" for the contrastive loss (example["negative_sequence"][..., 1:]
).
Handling Variable Length Sequences: Neural networks typically require fixed-size inputs. Therefore, both the input feature sequences and the target sequences are padded (with a special
PAD_ITEM_ID
) or truncated to a predefinedMAX_CONTEXT_LENGTH
. Apadding_mask
is also generated from the input features to ensure the model ignores these padded tokens during attention calculations, i.e, these tokens will be masked.Differentiating Training and Validation/Testing:
During training:
Input features (
item_ids
) and context for negative sequences are prepared as described above (all but the last item of the original sequences).Target positive and negative sequences are the shifted versions of the original sequences.
sample_weight
is created based on the input features to ensure that loss is calculated only on actual items, not on padding tokens in the targets.
During validation/testing:
Input features are prepared similarly.
The model's performance is typically evaluated on its ability to predict the actual last item of the original sequence. Thus,
sample_weight
is configured to focus the loss calculation only on this final prediction in the target sequences.
Note: SASRec does the same thing we've done above, except that they take the item_ids[:-2]
for the validation set and item_ids[:-1]
for the test set. We skip that here for brevity.
We can see a batch for each.
Model
To encode the input sequence, we use a Transformer decoder-based model. This part of the model is very similar to the GPT-2 architecture. Refer to the GPT text generation from scratch with KerasHub guide for more details on this part.
One part to note is that when we are "predicting", i.e., training
is False
, we get the embedding corresponding to the last movie in the sequence. This makes sense, because at inference time, we want to predict the movie the user will likely watch after watching the last movie.
Also, it's worth discussing the compute_loss
method. We embed the positive and negative sequences using the input embedding matrix. We compute the similarity of (positive sequence, input sequence) and (negative sequence, input sequence) pair embeddings by computing the dot product. The goal now is to maximize the similarity of the former and minimize the similarity of the latter. Let's see this mathematically. Binary Cross Entropy is written as follows:
Here, we assign the positive pairs a label of 1 and the negative pairs a label of 0. So, for a positive pair, the loss reduces to:
Minimising the loss means we want to maximize the log term, which in turn, implies maximising positive_logits
. Similarly, we want to minimize negative_logits
.
Let's instantiate our model and do some sanity checks.
Now, let's compile and train our model.
Making predictions
Now that we have a model, we would like to be able to make predictions.
So far, we have only handled movies by id. Now is the time to create a mapping keyed by movie IDs to be able to surface the titles.
We then simply use the Keras model.predict()
method. Under the hood, it calls the BruteForceRetrieval
layer to perform the actual retrieval.
Note that this model can retrieve movies already watched by the user. We could easily add logic to remove them if that is desirable.
And that's all!