Path: blob/master/examples/structured_data/movielens_recommendations_transformers.py
3507 views
"""1Title: A Transformer-based recommendation system2Author: [Khalid Salama](https://www.linkedin.com/in/khalid-salama-24403144/)3Date created: 2020/12/304Last modified: 2025/01/275Description: Rating rate prediction using the Behavior Sequence Transformer (BST) model on the Movielens.6Accelerator: GPU7Made backend-agnostic by: [Humbulani Ndou](https://github.com/Humbulani1234)8"""910"""11## Introduction1213This example demonstrates the [Behavior Sequence Transformer (BST)](https://arxiv.org/abs/1905.06874)14model, by Qiwei Chen et al., using the [Movielens dataset](https://grouplens.org/datasets/movielens/).15The BST model leverages the sequential behaviour of the users in watching and rating movies,16as well as user profile and movie features, to predict the rating of the user to a target movie.1718More precisely, the BST model aims to predict the rating of a target movie by accepting19the following inputs:20211. A fixed-length *sequence* of `movie_ids` watched by a user.222. A fixed-length *sequence* of the `ratings` for the movies watched by a user.233. A *set* of user features, including `user_id`, `sex`, `occupation`, and `age_group`.244. A *set* of `genres` for each movie in the input sequence and the target movie.255. A `target_movie_id` for which to predict the rating.2627This example modifies the original BST model in the following ways:28291. We incorporate the movie features (genres) into the processing of the embedding of each30movie of the input sequence and the target movie, rather than treating them as "other features"31outside the transformer layer.322. We utilize the ratings of movies in the input sequence, along with the their positions33in the sequence, to update them before feeding them into the self-attention layer.343536Note that this example should be run with TensorFlow 2.4 or higher.37"""3839"""40## The dataset4142We use the [1M version of the Movielens dataset](https://grouplens.org/datasets/movielens/1m/).43The dataset includes around 1 million ratings from 6000 users on 4000 movies,44along with some user features, movie genres. In addition, the timestamp of each user-movie45rating is provided, which allows creating sequences of movie ratings for each user,46as expected by the BST model.47"""4849"""50## Setup51"""5253import os5455os.environ["KERAS_BACKEND"] = "jax" # or torch, or tensorflow5657import math58from zipfile import ZipFile59from urllib.request import urlretrieve60import numpy as np61import pandas as pd6263import keras64from keras import layers, ops65from keras.layers import StringLookup6667"""68## Prepare the data6970### Download and prepare the DataFrames7172First, let's download the movielens data.7374The downloaded folder will contain three data files: `users.dat`, `movies.dat`,75and `ratings.dat`.76"""7778urlretrieve("http://files.grouplens.org/datasets/movielens/ml-1m.zip", "movielens.zip")79ZipFile("movielens.zip", "r").extractall()8081"""82Then, we load the data into pandas DataFrames with their proper column names.83"""8485users = pd.read_csv(86"ml-1m/users.dat",87sep="::",88names=["user_id", "sex", "age_group", "occupation", "zip_code"],89encoding="ISO-8859-1",90engine="python",91)9293ratings = pd.read_csv(94"ml-1m/ratings.dat",95sep="::",96names=["user_id", "movie_id", "rating", "unix_timestamp"],97encoding="ISO-8859-1",98engine="python",99)100101movies = pd.read_csv(102"ml-1m/movies.dat",103sep="::",104names=["movie_id", "title", "genres"],105encoding="ISO-8859-1",106engine="python",107)108109"""110Here, we do some simple data processing to fix the data types of the columns.111"""112113users["user_id"] = users["user_id"].apply(lambda x: f"user_{x}")114users["age_group"] = users["age_group"].apply(lambda x: f"group_{x}")115users["occupation"] = users["occupation"].apply(lambda x: f"occupation_{x}")116117movies["movie_id"] = movies["movie_id"].apply(lambda x: f"movie_{x}")118119ratings["movie_id"] = ratings["movie_id"].apply(lambda x: f"movie_{x}")120ratings["user_id"] = ratings["user_id"].apply(lambda x: f"user_{x}")121ratings["rating"] = ratings["rating"].apply(lambda x: float(x))122123"""124Each movie has multiple genres. We split them into separate columns in the `movies`125DataFrame.126"""127128genres = ["Action", "Adventure", "Animation", "Children's", "Comedy", "Crime"]129genres += ["Documentary", "Drama", "Fantasy", "Film-Noir", "Horror", "Musical"]130genres += ["Mystery", "Romance", "Sci-Fi", "Thriller", "War", "Western"]131132for genre in genres:133movies[genre] = movies["genres"].apply(134lambda values: int(genre in values.split("|"))135)136137138"""139### Transform the movie ratings data into sequences140141First, let's sort the the ratings data using the `unix_timestamp`, and then group the142`movie_id` values and the `rating` values by `user_id`.143144The output DataFrame will have a record for each `user_id`, with two ordered lists145(sorted by rating datetime): the movies they have rated, and their ratings of these movies.146"""147148ratings_group = ratings.sort_values(by=["unix_timestamp"]).groupby("user_id")149150ratings_data = pd.DataFrame(151data={152"user_id": list(ratings_group.groups.keys()),153"movie_ids": list(ratings_group.movie_id.apply(list)),154"ratings": list(ratings_group.rating.apply(list)),155"timestamps": list(ratings_group.unix_timestamp.apply(list)),156}157)158159160"""161Now, let's split the `movie_ids` list into a set of sequences of a fixed length.162We do the same for the `ratings`. Set the `sequence_length` variable to change the length163of the input sequence to the model. You can also change the `step_size` to control the164number of sequences to generate for each user.165"""166167sequence_length = 4168step_size = 2169170171def create_sequences(values, window_size, step_size):172sequences = []173start_index = 0174while True:175end_index = start_index + window_size176seq = values[start_index:end_index]177if len(seq) < window_size:178seq = values[-window_size:]179if len(seq) == window_size:180sequences.append(seq)181break182sequences.append(seq)183start_index += step_size184return sequences185186187ratings_data.movie_ids = ratings_data.movie_ids.apply(188lambda ids: create_sequences(ids, sequence_length, step_size)189)190191ratings_data.ratings = ratings_data.ratings.apply(192lambda ids: create_sequences(ids, sequence_length, step_size)193)194195del ratings_data["timestamps"]196197"""198After that, we process the output to have each sequence in a separate records in199the DataFrame. In addition, we join the user features with the ratings data.200"""201202ratings_data_movies = ratings_data[["user_id", "movie_ids"]].explode(203"movie_ids", ignore_index=True204)205ratings_data_rating = ratings_data[["ratings"]].explode("ratings", ignore_index=True)206ratings_data_transformed = pd.concat([ratings_data_movies, ratings_data_rating], axis=1)207ratings_data_transformed = ratings_data_transformed.join(208users.set_index("user_id"), on="user_id"209)210ratings_data_transformed.movie_ids = ratings_data_transformed.movie_ids.apply(211lambda x: ",".join(x)212)213ratings_data_transformed.ratings = ratings_data_transformed.ratings.apply(214lambda x: ",".join([str(v) for v in x])215)216217del ratings_data_transformed["zip_code"]218219ratings_data_transformed.rename(220columns={"movie_ids": "sequence_movie_ids", "ratings": "sequence_ratings"},221inplace=True,222)223224"""225With `sequence_length` of 4 and `step_size` of 2, we end up with 498,623 sequences.226227Finally, we split the data into training and testing splits, with 85% and 15% of228the instances, respectively, and store them to CSV files.229"""230231random_selection = np.random.rand(len(ratings_data_transformed.index)) <= 0.85232train_data = ratings_data_transformed[random_selection]233test_data = ratings_data_transformed[~random_selection]234235train_data.to_csv("train_data.csv", index=False, sep="|", header=False)236test_data.to_csv("test_data.csv", index=False, sep="|", header=False)237238"""239## Define metadata240"""241242CSV_HEADER = list(ratings_data_transformed.columns)243244CATEGORICAL_FEATURES_WITH_VOCABULARY = {245"user_id": list(users.user_id.unique()),246"movie_id": list(movies.movie_id.unique()),247"sex": list(users.sex.unique()),248"age_group": list(users.age_group.unique()),249"occupation": list(users.occupation.unique()),250}251252USER_FEATURES = ["sex", "age_group", "occupation"]253254MOVIE_FEATURES = ["genres"]255256257"""258## Encode input features259260The `encode_input_features` function works as follows:2612621. Each categorical user feature is encoded using `layers.Embedding`, with embedding263dimension equals to the square root of the vocabulary size of the feature.264The embeddings of these features are concatenated to form a single input tensor.2652662. Each movie in the movie sequence and the target movie is encoded `layers.Embedding`,267where the dimension size is the square root of the number of movies.2682693. A multi-hot genres vector for each movie is concatenated with its embedding vector,270and processed using a non-linear `layers.Dense` to output a vector of the same movie271embedding dimensions.2722734. A positional embedding is added to each movie embedding in the sequence, and then274multiplied by its rating from the ratings sequence.2752765. The target movie embedding is concatenated to the sequence movie embeddings, producing277a tensor with the shape of `[batch size, sequence length, embedding size]`, as expected278by the attention layer for the transformer architecture.2792806. The method returns a tuple of two elements: `encoded_transformer_features` and281`encoded_other_features`.282"""283284# Required for tf.data.Dataset285import tensorflow as tf286287288def get_dataset_from_csv(csv_file_path, batch_size, shuffle=True):289290def process(features):291movie_ids_string = features["sequence_movie_ids"]292sequence_movie_ids = tf.strings.split(movie_ids_string, ",").to_tensor()293# The last movie id in the sequence is the target movie.294features["target_movie_id"] = sequence_movie_ids[:, -1]295features["sequence_movie_ids"] = sequence_movie_ids[:, :-1]296# Sequence ratings297ratings_string = features["sequence_ratings"]298sequence_ratings = tf.strings.to_number(299tf.strings.split(ratings_string, ","), tf.dtypes.float32300).to_tensor()301# The last rating in the sequence is the target for the model to predict.302target = sequence_ratings[:, -1]303features["sequence_ratings"] = sequence_ratings[:, :-1]304305def encoding_helper(feature_name):306307# This are target_movie_id and sequence_movie_ids and they have the same308# vocabulary as movie_id.309if feature_name not in CATEGORICAL_FEATURES_WITH_VOCABULARY:310vocabulary = CATEGORICAL_FEATURES_WITH_VOCABULARY["movie_id"]311index_lookup = StringLookup(312vocabulary=vocabulary, mask_token=None, num_oov_indices=0313)314# Convert the string input values into integer indices.315value_index = index_lookup(features[feature_name])316features[feature_name] = value_index317else:318# movie_id is not part of the features, hence not processed. It was mainly required319# for its vocabulary above.320if feature_name == "movie_id":321pass322else:323vocabulary = CATEGORICAL_FEATURES_WITH_VOCABULARY[feature_name]324index_lookup = StringLookup(325vocabulary=vocabulary, mask_token=None, num_oov_indices=0326)327# Convert the string input values into integer indices.328value_index = index_lookup(features[feature_name])329features[feature_name] = value_index330331# Encode the user features332for feature_name in CATEGORICAL_FEATURES_WITH_VOCABULARY:333encoding_helper(feature_name)334# Encoding target_movie_id and returning it as the target variable335encoding_helper("target_movie_id")336# Encoding sequence movie_ids.337encoding_helper("sequence_movie_ids")338return dict(features), target339340dataset = tf.data.experimental.make_csv_dataset(341csv_file_path,342batch_size=batch_size,343column_names=CSV_HEADER,344num_epochs=1,345header=False,346field_delim="|",347shuffle=shuffle,348).map(process)349return dataset350351352def encode_input_features(353inputs,354include_user_id,355include_user_features,356include_movie_features,357):358encoded_transformer_features = []359encoded_other_features = []360361other_feature_names = []362if include_user_id:363other_feature_names.append("user_id")364if include_user_features:365other_feature_names.extend(USER_FEATURES)366367## Encode user features368for feature_name in other_feature_names:369vocabulary = CATEGORICAL_FEATURES_WITH_VOCABULARY[feature_name]370# Compute embedding dimensions371embedding_dims = int(math.sqrt(len(vocabulary)))372# Create an embedding layer with the specified dimensions.373embedding_encoder = layers.Embedding(374input_dim=len(vocabulary),375output_dim=embedding_dims,376name=f"{feature_name}_embedding",377)378# Convert the index values to embedding representations.379encoded_other_features.append(embedding_encoder(inputs[feature_name]))380381## Create a single embedding vector for the user features382if len(encoded_other_features) > 1:383encoded_other_features = layers.concatenate(encoded_other_features)384elif len(encoded_other_features) == 1:385encoded_other_features = encoded_other_features[0]386else:387encoded_other_features = None388389## Create a movie embedding encoder390movie_vocabulary = CATEGORICAL_FEATURES_WITH_VOCABULARY["movie_id"]391movie_embedding_dims = int(math.sqrt(len(movie_vocabulary)))392# Create an embedding layer with the specified dimensions.393movie_embedding_encoder = layers.Embedding(394input_dim=len(movie_vocabulary),395output_dim=movie_embedding_dims,396name=f"movie_embedding",397)398# Create a vector lookup for movie genres.399genre_vectors = movies[genres].to_numpy()400movie_genres_lookup = layers.Embedding(401input_dim=genre_vectors.shape[0],402output_dim=genre_vectors.shape[1],403embeddings_initializer=keras.initializers.Constant(genre_vectors),404trainable=False,405name="genres_vector",406)407# Create a processing layer for genres.408movie_embedding_processor = layers.Dense(409units=movie_embedding_dims,410activation="relu",411name="process_movie_embedding_with_genres",412)413414## Define a function to encode a given movie id.415def encode_movie(movie_id):416# Convert the string input values into integer indices.417movie_embedding = movie_embedding_encoder(movie_id)418encoded_movie = movie_embedding419if include_movie_features:420movie_genres_vector = movie_genres_lookup(movie_id)421encoded_movie = movie_embedding_processor(422layers.concatenate([movie_embedding, movie_genres_vector])423)424return encoded_movie425426## Encoding target_movie_id427target_movie_id = inputs["target_movie_id"]428encoded_target_movie = encode_movie(target_movie_id)429430## Encoding sequence movie_ids.431sequence_movies_ids = inputs["sequence_movie_ids"]432encoded_sequence_movies = encode_movie(sequence_movies_ids)433# Create positional embedding.434position_embedding_encoder = layers.Embedding(435input_dim=sequence_length,436output_dim=movie_embedding_dims,437name="position_embedding",438)439positions = ops.arange(start=0, stop=sequence_length - 1, step=1)440encodded_positions = position_embedding_encoder(positions)441# Retrieve sequence ratings to incorporate them into the encoding of the movie.442sequence_ratings = inputs["sequence_ratings"]443sequence_ratings = ops.expand_dims(sequence_ratings, -1)444# Add the positional encoding to the movie encodings and multiply them by rating.445encoded_sequence_movies_with_poistion_and_rating = layers.Multiply()(446[(encoded_sequence_movies + encodded_positions), sequence_ratings]447)448449# Construct the transformer inputs.450for i in range(sequence_length - 1):451feature = encoded_sequence_movies_with_poistion_and_rating[:, i, ...]452feature = ops.expand_dims(feature, 1)453encoded_transformer_features.append(feature)454encoded_transformer_features.append(encoded_target_movie)455encoded_transformer_features = layers.concatenate(456encoded_transformer_features, axis=1457)458return encoded_transformer_features, encoded_other_features459460461"""462## Create model inputs463"""464465466def create_model_inputs():467return {468"user_id": keras.Input(name="user_id", shape=(1,), dtype="int32"),469"sequence_movie_ids": keras.Input(470name="sequence_movie_ids", shape=(sequence_length - 1,), dtype="int32"471),472"target_movie_id": keras.Input(473name="target_movie_id", shape=(1,), dtype="int32"474),475"sequence_ratings": keras.Input(476name="sequence_ratings", shape=(sequence_length - 1,), dtype="float32"477),478"sex": keras.Input(name="sex", shape=(1,), dtype="int32"),479"age_group": keras.Input(name="age_group", shape=(1,), dtype="int32"),480"occupation": keras.Input(name="occupation", shape=(1,), dtype="int32"),481}482483484"""485## Create a BST model486"""487488include_user_id = False489include_user_features = False490include_movie_features = False491492hidden_units = [256, 128]493dropout_rate = 0.1494num_heads = 3495496497def create_model():498499inputs = create_model_inputs()500transformer_features, other_features = encode_input_features(501inputs, include_user_id, include_user_features, include_movie_features502)503# Create a multi-headed attention layer.504attention_output = layers.MultiHeadAttention(505num_heads=num_heads, key_dim=transformer_features.shape[2], dropout=dropout_rate506)(transformer_features, transformer_features)507508# Transformer block.509attention_output = layers.Dropout(dropout_rate)(attention_output)510x1 = layers.Add()([transformer_features, attention_output])511x1 = layers.LayerNormalization()(x1)512x2 = layers.LeakyReLU()(x1)513x2 = layers.Dense(units=x2.shape[-1])(x2)514x2 = layers.Dropout(dropout_rate)(x2)515transformer_features = layers.Add()([x1, x2])516transformer_features = layers.LayerNormalization()(transformer_features)517features = layers.Flatten()(transformer_features)518519# Included the other_features.520if other_features is not None:521features = layers.concatenate(522[features, layers.Reshape([other_features.shape[-1]])(other_features)]523)524525# Fully-connected layers.526for num_units in hidden_units:527features = layers.Dense(num_units)(features)528features = layers.BatchNormalization()(features)529features = layers.LeakyReLU()(features)530features = layers.Dropout(dropout_rate)(features)531outputs = layers.Dense(units=1)(features)532model = keras.Model(inputs=inputs, outputs=outputs)533return model534535536model = create_model()537538"""539## Run training and evaluation experiment540"""541542# Compile the model.543model.compile(544optimizer=keras.optimizers.Adagrad(learning_rate=0.01),545loss=keras.losses.MeanSquaredError(),546metrics=[keras.metrics.MeanAbsoluteError()],547)548549# Read the training data.550551train_dataset = get_dataset_from_csv("train_data.csv", batch_size=265, shuffle=True)552553# Fit the model with the training data.554model.fit(train_dataset, epochs=2)555556# Read the test data.557test_dataset = get_dataset_from_csv("test_data.csv", batch_size=265)558559# Evaluate the model on the test data.560_, rmse = model.evaluate(test_dataset, verbose=0)561print(f"Test MAE: {round(rmse, 3)}")562563"""564You should achieve a Mean Absolute Error (MAE) at or around 0.7 on the test data.565"""566567"""568## Conclusion569570The BST model uses the Transformer layer in its architecture to capture the sequential signals underlying571users’ behavior sequences for recommendation.572573You can try training this model with different configurations, for example, by increasing574the input sequence length and training the model for a larger number of epochs. In addition,575you can try including other features like movie release year and customer576zipcode, and including cross features like sex X genre.577"""578579580