Path: blob/master/examples/keras_rs/sequential_retrieval.py
3507 views
"""1Title: Sequential retrieval [GRU4Rec]2Author: [Abheesht Sharma](https://github.com/abheesht17/), [Fabien Hertschuh](https://github.com/hertschuh/)3Date created: 2025/04/284Last modified: 2025/04/285Description: Recommend movies using a GRU-based sequential retrieval model.6Accelerator: GPU7"""89"""10## Introduction1112In this example, we are going to build a sequential retrieval model. Sequential13recommendation is a popular model that looks at a sequence of items that users14have interacted with previously and then predicts the next item. Here, the order15of the items within each sequence matters. So, we are going to use a recurrent16neural network to model the sequential relationship. For more details,17please refer to the [GRU4Rec](https://arxiv.org/abs/1511.06939) paper.1819Let's begin by choosing JAX as the backend we want to run on, and import all20the necessary libraries.21"""2223"""shell24pip install -q keras-rs25"""2627import os2829os.environ["KERAS_BACKEND"] = "jax" # `"tensorflow"`/`"torch"`3031import collections32import os33import random3435import keras36import pandas as pd37import tensorflow as tf # Needed only for the dataset3839import keras_rs4041"""42Let's also define all important variables/hyperparameters below.43"""4445DATA_DIR = "./raw/data/"4647# MovieLens-specific variables48MOVIELENS_1M_URL = "https://files.grouplens.org/datasets/movielens/ml-1m.zip"49MOVIELENS_ZIP_HASH = "a6898adb50b9ca05aa231689da44c217cb524e7ebd39d264c56e2832f2c54e20"5051RATINGS_FILE_NAME = "ratings.dat"52MOVIES_FILE_NAME = "movies.dat"5354# Data processing args55MAX_CONTEXT_LENGTH = 1056MIN_SEQUENCE_LENGTH = 35758RATINGS_DATA_COLUMNS = ["UserID", "MovieID", "Rating", "Timestamp"]59MOVIES_DATA_COLUMNS = ["MovieID", "Title", "Genres"]60MIN_RATING = 26162# Training/model args63BATCH_SIZE = 409664TEST_BATCH_SIZE = 204865EMBEDDING_DIM = 3266NUM_EPOCHS = 567LEARNING_RATE = 0.0056869"""70## Dataset7172Next, we need to prepare our dataset. Like we did in the73[basic retrieval](/keras_rs/examples/basic_retrieval/)74example, we are going to use the MovieLens dataset.7576The dataset preparation step is fairly involved. The original ratings dataset77contains `(user, movie ID, rating, timestamp)` tuples (among other columns,78which are not important for this example). Since we are dealing with sequential79retrieval, we need to create movie sequences for every user, where the sequences80are ordered by timestamp.8182Let's start by downloading and reading the dataset.83"""8485# Download the MovieLens dataset.86if not os.path.exists(DATA_DIR):87os.makedirs(DATA_DIR)8889path_to_zip = keras.utils.get_file(90fname="ml-1m.zip",91origin=MOVIELENS_1M_URL,92file_hash=MOVIELENS_ZIP_HASH,93hash_algorithm="sha256",94extract=True,95cache_dir=DATA_DIR,96)97movielens_extracted_dir = os.path.join(98os.path.dirname(path_to_zip),99"ml-1m_extracted",100"ml-1m",101)102103104# Read the dataset.105def read_data(data_directory, min_rating=None):106"""Read movielens ratings.dat and movies.dat file107into dataframe.108"""109110ratings_df = pd.read_csv(111os.path.join(data_directory, RATINGS_FILE_NAME),112sep="::",113names=RATINGS_DATA_COLUMNS,114encoding="unicode_escape",115engine="python",116)117ratings_df["Timestamp"] = ratings_df["Timestamp"].apply(int)118119# Remove movies with `rating < min_rating`.120if min_rating is not None:121ratings_df = ratings_df[ratings_df["Rating"] >= min_rating]122123movies_df = pd.read_csv(124os.path.join(data_directory, MOVIES_FILE_NAME),125sep="::",126names=MOVIES_DATA_COLUMNS,127encoding="unicode_escape",128engine="python",129)130return ratings_df, movies_df131132133ratings_df, movies_df = read_data(134data_directory=movielens_extracted_dir, min_rating=MIN_RATING135)136137# Need to know #movies so as to define embedding layers.138movies_count = movies_df["MovieID"].max()139140"""141Now that we have read the dataset, let's create sequences of movies142for every user. Here is the function for doing just that.143"""144145146def get_movie_sequence_per_user(ratings_df):147"""Get movieID sequences for every user."""148sequences = collections.defaultdict(list)149150for user_id, movie_id, rating, timestamp in ratings_df.values:151sequences[user_id].append(152{153"movie_id": movie_id,154"timestamp": timestamp,155"rating": rating,156}157)158159# Sort movie sequences by timestamp for every user.160for user_id, context in sequences.items():161context.sort(key=lambda x: x["timestamp"])162sequences[user_id] = context163164return sequences165166167"""168We need to do some filtering and processing before we proceed169with training the model:1701711. Form sequences of all lengths up to172`min(user_sequence_length, MAX_CONTEXT_LENGTH)`. So, every user173will have multiple sequences corresponding to it.1742. Get labels, i.e., Given a sequence of length `n`, the first175`n-1` tokens will be fed to the model as input, and the label176will be the last token.1773. Remove all user sequences with less than `MIN_SEQUENCE_LENGTH`178movies.1794. Pad all sequences to `MAX_CONTEXT_LENGTH`.180181An important point to note is how we form the train-test splits. We do not182form the entire dataset of sequences and then split it into train and test.183Instead, for every user, we take the last sequence to be part of the test set,184and all other sequences to be part of the train set. This is to prevent data185leakage.186"""187188189def generate_examples_from_user_sequences(sequences):190"""Generates sequences for all users, with padding, truncation, etc."""191192def generate_examples_from_user_sequence(sequence):193"""Generates examples for a single user sequence."""194195train_examples = []196test_examples = []197for label_idx in range(1, len(sequence)):198start_idx = max(0, label_idx - MAX_CONTEXT_LENGTH)199context = sequence[start_idx:label_idx]200201# Padding202while len(context) < MAX_CONTEXT_LENGTH:203context.append(204{205"movie_id": 0,206"timestamp": 0,207"rating": 0.0,208}209)210211label_movie_id = int(sequence[label_idx]["movie_id"])212context_movie_id = [int(movie["movie_id"]) for movie in context]213214example = {215"context_movie_id": context_movie_id,216"label_movie_id": label_movie_id,217}218219if label_idx == len(sequence) - 1:220test_examples.append(example)221else:222train_examples.append(example)223224return train_examples, test_examples225226all_train_examples = []227all_test_examples = []228for sequence in sequences.values():229if len(sequence) < MIN_SEQUENCE_LENGTH:230continue231232user_train_examples, user_test_example = generate_examples_from_user_sequence(233sequence234)235236all_train_examples.extend(user_train_examples)237all_test_examples.extend(user_test_example)238239return all_train_examples, all_test_examples240241242"""243Let's split the dataset into train and test sets. Also, we need to244change the format of the dataset dictionary so as to enable conversion245to a `tf.data.Dataset` object.246"""247sequences = get_movie_sequence_per_user(ratings_df)248train_examples, test_examples = generate_examples_from_user_sequences(sequences)249250251def list_of_dicts_to_dict_of_lists(list_of_dicts):252"""Convert list of dictionaries to dictionary of lists for253`tf.data` conversion.254"""255dict_of_lists = collections.defaultdict(list)256for dictionary in list_of_dicts:257for key, value in dictionary.items():258dict_of_lists[key].append(value)259return dict_of_lists260261262train_examples = list_of_dicts_to_dict_of_lists(train_examples)263test_examples = list_of_dicts_to_dict_of_lists(test_examples)264265train_ds = tf.data.Dataset.from_tensor_slices(train_examples).map(266lambda x: (x["context_movie_id"], x["label_movie_id"])267)268test_ds = tf.data.Dataset.from_tensor_slices(test_examples).map(269lambda x: (x["context_movie_id"], x["label_movie_id"])270)271272"""273We need to batch our datasets. We also user `cache()` and `prefetch()`274for better performance.275"""276train_ds = train_ds.batch(BATCH_SIZE).cache().prefetch(tf.data.AUTOTUNE)277test_ds = test_ds.batch(TEST_BATCH_SIZE).cache().prefetch(tf.data.AUTOTUNE)278279"""280Let's print out one batch.281"""282283for sample in train_ds.take(1):284print(sample)285286"""287## Model and Training288289In the basic retrieval example, we used one query tower for the290user, and the candidate tower for the candidate movie. We are291going to use a two-tower architecture here as well. However,292we use the query tower with a Gated Recurrent Unit (GRU) layer293to encode the sequence of historical movies, and keep the same294candidate tower for the candidate movie.295296Note: Take a look at how the labels are defined. The label tensor297(of shape `(batch_size, batch_size)`) contains one-hot vectors. The idea298is: for every sample, consider movie IDs corresponding to other samples in299the batch as negatives.300"""301302303class SequentialRetrievalModel(keras.Model):304"""Create the sequential retrieval model.305306Args:307movies_count: Total number of unique movies in the dataset.308embedding_dimension: Output dimension for movie embedding tables.309"""310311def __init__(312self,313movies_count,314embedding_dimension=128,315**kwargs,316):317super().__init__(**kwargs)318# Our query tower, simply an embedding table followed by319# a GRU unit. This encodes sequence of historical movies.320self.query_model = keras.Sequential(321[322keras.layers.Embedding(movies_count + 1, embedding_dimension),323keras.layers.GRU(embedding_dimension),324]325)326327# Our candidate tower, simply an embedding table.328self.candidate_model = keras.layers.Embedding(329movies_count + 1, embedding_dimension330)331332# The layer that performs the retrieval.333self.retrieval = keras_rs.layers.BruteForceRetrieval(k=10, return_scores=False)334self.loss_fn = keras.losses.CategoricalCrossentropy(335from_logits=True,336)337338def build(self, input_shape):339self.query_model.build(input_shape)340self.candidate_model.build(input_shape)341342# In this case, the candidates are directly the movie embeddings.343# We take a shortcut and directly reuse the variable.344self.retrieval.candidate_embeddings = self.candidate_model.embeddings345self.retrieval.build(input_shape)346super().build(input_shape)347348def call(self, inputs, training=False):349query_embeddings = self.query_model(inputs)350result = {351"query_embeddings": query_embeddings,352}353354if not training:355# Skip the retrieval of top movies during training as the356# predictions are not used.357result["predictions"] = self.retrieval(query_embeddings)358return result359360def compute_loss(self, x, y, y_pred, sample_weight, training=True):361candidate_id = y362query_embeddings = y_pred["query_embeddings"]363candidate_embeddings = self.candidate_model(candidate_id)364365num_queries = keras.ops.shape(query_embeddings)[0]366num_candidates = keras.ops.shape(candidate_embeddings)[0]367368# One-hot vectors for labels.369labels = keras.ops.eye(num_queries, num_candidates)370371# Compute the affinity score by multiplying the two embeddings.372scores = keras.ops.matmul(373query_embeddings, keras.ops.transpose(candidate_embeddings)374)375376return self.loss_fn(labels, scores, sample_weight)377378379"""380Let's instantiate, compile and train our model.381"""382383model = SequentialRetrievalModel(384movies_count=movies_count, embedding_dimension=EMBEDDING_DIM385)386387# Compile.388model.compile(optimizer=keras.optimizers.AdamW(learning_rate=LEARNING_RATE))389390# Train.391model.fit(392train_ds,393validation_data=test_ds,394epochs=NUM_EPOCHS,395)396397"""398## Making predictions399400Now that we have a model, we would like to be able to make predictions.401402So far, we have only handled movies by id. Now is the time to create a mapping403keyed by movie IDs to be able to surface the titles.404"""405406movie_id_to_movie_title = dict(zip(movies_df["MovieID"], movies_df["Title"]))407movie_id_to_movie_title[0] = "" # Because id 0 is not in the dataset.408409"""410We then simply use the Keras `model.predict()` method. Under the hood, it calls411the `BruteForceRetrieval` layer to perform the actual retrieval.412413Note that this model can retrieve movies already watched by the user. We could414easily add logic to remove them if that is desirable.415"""416417print("\n==> Movies the user has watched:")418movie_sequence = test_ds.unbatch().take(1)419for element in movie_sequence:420for movie_id in element[0][:-1]:421print(movie_id_to_movie_title[movie_id.numpy()], end=", ")422print(movie_id_to_movie_title[element[0][-1].numpy()])423424predictions = model.predict(movie_sequence.batch(1))425predictions = keras.ops.convert_to_numpy(predictions["predictions"])426427print("\n==> Recommended movies for the above sequence:")428for movie_id in predictions[0]:429print(movie_id_to_movie_title[movie_id])430431432