Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/structured_data/movielens_recommendations_transformers.py
3507 views
1
"""
2
Title: A Transformer-based recommendation system
3
Author: [Khalid Salama](https://www.linkedin.com/in/khalid-salama-24403144/)
4
Date created: 2020/12/30
5
Last modified: 2025/01/27
6
Description: Rating rate prediction using the Behavior Sequence Transformer (BST) model on the Movielens.
7
Accelerator: GPU
8
Made backend-agnostic by: [Humbulani Ndou](https://github.com/Humbulani1234)
9
"""
10
11
"""
12
## Introduction
13
14
This example demonstrates the [Behavior Sequence Transformer (BST)](https://arxiv.org/abs/1905.06874)
15
model, by Qiwei Chen et al., using the [Movielens dataset](https://grouplens.org/datasets/movielens/).
16
The BST model leverages the sequential behaviour of the users in watching and rating movies,
17
as well as user profile and movie features, to predict the rating of the user to a target movie.
18
19
More precisely, the BST model aims to predict the rating of a target movie by accepting
20
the following inputs:
21
22
1. A fixed-length *sequence* of `movie_ids` watched by a user.
23
2. A fixed-length *sequence* of the `ratings` for the movies watched by a user.
24
3. A *set* of user features, including `user_id`, `sex`, `occupation`, and `age_group`.
25
4. A *set* of `genres` for each movie in the input sequence and the target movie.
26
5. A `target_movie_id` for which to predict the rating.
27
28
This example modifies the original BST model in the following ways:
29
30
1. We incorporate the movie features (genres) into the processing of the embedding of each
31
movie of the input sequence and the target movie, rather than treating them as "other features"
32
outside the transformer layer.
33
2. We utilize the ratings of movies in the input sequence, along with the their positions
34
in the sequence, to update them before feeding them into the self-attention layer.
35
36
37
Note that this example should be run with TensorFlow 2.4 or higher.
38
"""
39
40
"""
41
## The dataset
42
43
We use the [1M version of the Movielens dataset](https://grouplens.org/datasets/movielens/1m/).
44
The dataset includes around 1 million ratings from 6000 users on 4000 movies,
45
along with some user features, movie genres. In addition, the timestamp of each user-movie
46
rating is provided, which allows creating sequences of movie ratings for each user,
47
as expected by the BST model.
48
"""
49
50
"""
51
## Setup
52
"""
53
54
import os
55
56
os.environ["KERAS_BACKEND"] = "jax" # or torch, or tensorflow
57
58
import math
59
from zipfile import ZipFile
60
from urllib.request import urlretrieve
61
import numpy as np
62
import pandas as pd
63
64
import keras
65
from keras import layers, ops
66
from keras.layers import StringLookup
67
68
"""
69
## Prepare the data
70
71
### Download and prepare the DataFrames
72
73
First, let's download the movielens data.
74
75
The downloaded folder will contain three data files: `users.dat`, `movies.dat`,
76
and `ratings.dat`.
77
"""
78
79
urlretrieve("http://files.grouplens.org/datasets/movielens/ml-1m.zip", "movielens.zip")
80
ZipFile("movielens.zip", "r").extractall()
81
82
"""
83
Then, we load the data into pandas DataFrames with their proper column names.
84
"""
85
86
users = pd.read_csv(
87
"ml-1m/users.dat",
88
sep="::",
89
names=["user_id", "sex", "age_group", "occupation", "zip_code"],
90
encoding="ISO-8859-1",
91
engine="python",
92
)
93
94
ratings = pd.read_csv(
95
"ml-1m/ratings.dat",
96
sep="::",
97
names=["user_id", "movie_id", "rating", "unix_timestamp"],
98
encoding="ISO-8859-1",
99
engine="python",
100
)
101
102
movies = pd.read_csv(
103
"ml-1m/movies.dat",
104
sep="::",
105
names=["movie_id", "title", "genres"],
106
encoding="ISO-8859-1",
107
engine="python",
108
)
109
110
"""
111
Here, we do some simple data processing to fix the data types of the columns.
112
"""
113
114
users["user_id"] = users["user_id"].apply(lambda x: f"user_{x}")
115
users["age_group"] = users["age_group"].apply(lambda x: f"group_{x}")
116
users["occupation"] = users["occupation"].apply(lambda x: f"occupation_{x}")
117
118
movies["movie_id"] = movies["movie_id"].apply(lambda x: f"movie_{x}")
119
120
ratings["movie_id"] = ratings["movie_id"].apply(lambda x: f"movie_{x}")
121
ratings["user_id"] = ratings["user_id"].apply(lambda x: f"user_{x}")
122
ratings["rating"] = ratings["rating"].apply(lambda x: float(x))
123
124
"""
125
Each movie has multiple genres. We split them into separate columns in the `movies`
126
DataFrame.
127
"""
128
129
genres = ["Action", "Adventure", "Animation", "Children's", "Comedy", "Crime"]
130
genres += ["Documentary", "Drama", "Fantasy", "Film-Noir", "Horror", "Musical"]
131
genres += ["Mystery", "Romance", "Sci-Fi", "Thriller", "War", "Western"]
132
133
for genre in genres:
134
movies[genre] = movies["genres"].apply(
135
lambda values: int(genre in values.split("|"))
136
)
137
138
139
"""
140
### Transform the movie ratings data into sequences
141
142
First, let's sort the the ratings data using the `unix_timestamp`, and then group the
143
`movie_id` values and the `rating` values by `user_id`.
144
145
The output DataFrame will have a record for each `user_id`, with two ordered lists
146
(sorted by rating datetime): the movies they have rated, and their ratings of these movies.
147
"""
148
149
ratings_group = ratings.sort_values(by=["unix_timestamp"]).groupby("user_id")
150
151
ratings_data = pd.DataFrame(
152
data={
153
"user_id": list(ratings_group.groups.keys()),
154
"movie_ids": list(ratings_group.movie_id.apply(list)),
155
"ratings": list(ratings_group.rating.apply(list)),
156
"timestamps": list(ratings_group.unix_timestamp.apply(list)),
157
}
158
)
159
160
161
"""
162
Now, let's split the `movie_ids` list into a set of sequences of a fixed length.
163
We do the same for the `ratings`. Set the `sequence_length` variable to change the length
164
of the input sequence to the model. You can also change the `step_size` to control the
165
number of sequences to generate for each user.
166
"""
167
168
sequence_length = 4
169
step_size = 2
170
171
172
def create_sequences(values, window_size, step_size):
173
sequences = []
174
start_index = 0
175
while True:
176
end_index = start_index + window_size
177
seq = values[start_index:end_index]
178
if len(seq) < window_size:
179
seq = values[-window_size:]
180
if len(seq) == window_size:
181
sequences.append(seq)
182
break
183
sequences.append(seq)
184
start_index += step_size
185
return sequences
186
187
188
ratings_data.movie_ids = ratings_data.movie_ids.apply(
189
lambda ids: create_sequences(ids, sequence_length, step_size)
190
)
191
192
ratings_data.ratings = ratings_data.ratings.apply(
193
lambda ids: create_sequences(ids, sequence_length, step_size)
194
)
195
196
del ratings_data["timestamps"]
197
198
"""
199
After that, we process the output to have each sequence in a separate records in
200
the DataFrame. In addition, we join the user features with the ratings data.
201
"""
202
203
ratings_data_movies = ratings_data[["user_id", "movie_ids"]].explode(
204
"movie_ids", ignore_index=True
205
)
206
ratings_data_rating = ratings_data[["ratings"]].explode("ratings", ignore_index=True)
207
ratings_data_transformed = pd.concat([ratings_data_movies, ratings_data_rating], axis=1)
208
ratings_data_transformed = ratings_data_transformed.join(
209
users.set_index("user_id"), on="user_id"
210
)
211
ratings_data_transformed.movie_ids = ratings_data_transformed.movie_ids.apply(
212
lambda x: ",".join(x)
213
)
214
ratings_data_transformed.ratings = ratings_data_transformed.ratings.apply(
215
lambda x: ",".join([str(v) for v in x])
216
)
217
218
del ratings_data_transformed["zip_code"]
219
220
ratings_data_transformed.rename(
221
columns={"movie_ids": "sequence_movie_ids", "ratings": "sequence_ratings"},
222
inplace=True,
223
)
224
225
"""
226
With `sequence_length` of 4 and `step_size` of 2, we end up with 498,623 sequences.
227
228
Finally, we split the data into training and testing splits, with 85% and 15% of
229
the instances, respectively, and store them to CSV files.
230
"""
231
232
random_selection = np.random.rand(len(ratings_data_transformed.index)) <= 0.85
233
train_data = ratings_data_transformed[random_selection]
234
test_data = ratings_data_transformed[~random_selection]
235
236
train_data.to_csv("train_data.csv", index=False, sep="|", header=False)
237
test_data.to_csv("test_data.csv", index=False, sep="|", header=False)
238
239
"""
240
## Define metadata
241
"""
242
243
CSV_HEADER = list(ratings_data_transformed.columns)
244
245
CATEGORICAL_FEATURES_WITH_VOCABULARY = {
246
"user_id": list(users.user_id.unique()),
247
"movie_id": list(movies.movie_id.unique()),
248
"sex": list(users.sex.unique()),
249
"age_group": list(users.age_group.unique()),
250
"occupation": list(users.occupation.unique()),
251
}
252
253
USER_FEATURES = ["sex", "age_group", "occupation"]
254
255
MOVIE_FEATURES = ["genres"]
256
257
258
"""
259
## Encode input features
260
261
The `encode_input_features` function works as follows:
262
263
1. Each categorical user feature is encoded using `layers.Embedding`, with embedding
264
dimension equals to the square root of the vocabulary size of the feature.
265
The embeddings of these features are concatenated to form a single input tensor.
266
267
2. Each movie in the movie sequence and the target movie is encoded `layers.Embedding`,
268
where the dimension size is the square root of the number of movies.
269
270
3. A multi-hot genres vector for each movie is concatenated with its embedding vector,
271
and processed using a non-linear `layers.Dense` to output a vector of the same movie
272
embedding dimensions.
273
274
4. A positional embedding is added to each movie embedding in the sequence, and then
275
multiplied by its rating from the ratings sequence.
276
277
5. The target movie embedding is concatenated to the sequence movie embeddings, producing
278
a tensor with the shape of `[batch size, sequence length, embedding size]`, as expected
279
by the attention layer for the transformer architecture.
280
281
6. The method returns a tuple of two elements: `encoded_transformer_features` and
282
`encoded_other_features`.
283
"""
284
285
# Required for tf.data.Dataset
286
import tensorflow as tf
287
288
289
def get_dataset_from_csv(csv_file_path, batch_size, shuffle=True):
290
291
def process(features):
292
movie_ids_string = features["sequence_movie_ids"]
293
sequence_movie_ids = tf.strings.split(movie_ids_string, ",").to_tensor()
294
# The last movie id in the sequence is the target movie.
295
features["target_movie_id"] = sequence_movie_ids[:, -1]
296
features["sequence_movie_ids"] = sequence_movie_ids[:, :-1]
297
# Sequence ratings
298
ratings_string = features["sequence_ratings"]
299
sequence_ratings = tf.strings.to_number(
300
tf.strings.split(ratings_string, ","), tf.dtypes.float32
301
).to_tensor()
302
# The last rating in the sequence is the target for the model to predict.
303
target = sequence_ratings[:, -1]
304
features["sequence_ratings"] = sequence_ratings[:, :-1]
305
306
def encoding_helper(feature_name):
307
308
# This are target_movie_id and sequence_movie_ids and they have the same
309
# vocabulary as movie_id.
310
if feature_name not in CATEGORICAL_FEATURES_WITH_VOCABULARY:
311
vocabulary = CATEGORICAL_FEATURES_WITH_VOCABULARY["movie_id"]
312
index_lookup = StringLookup(
313
vocabulary=vocabulary, mask_token=None, num_oov_indices=0
314
)
315
# Convert the string input values into integer indices.
316
value_index = index_lookup(features[feature_name])
317
features[feature_name] = value_index
318
else:
319
# movie_id is not part of the features, hence not processed. It was mainly required
320
# for its vocabulary above.
321
if feature_name == "movie_id":
322
pass
323
else:
324
vocabulary = CATEGORICAL_FEATURES_WITH_VOCABULARY[feature_name]
325
index_lookup = StringLookup(
326
vocabulary=vocabulary, mask_token=None, num_oov_indices=0
327
)
328
# Convert the string input values into integer indices.
329
value_index = index_lookup(features[feature_name])
330
features[feature_name] = value_index
331
332
# Encode the user features
333
for feature_name in CATEGORICAL_FEATURES_WITH_VOCABULARY:
334
encoding_helper(feature_name)
335
# Encoding target_movie_id and returning it as the target variable
336
encoding_helper("target_movie_id")
337
# Encoding sequence movie_ids.
338
encoding_helper("sequence_movie_ids")
339
return dict(features), target
340
341
dataset = tf.data.experimental.make_csv_dataset(
342
csv_file_path,
343
batch_size=batch_size,
344
column_names=CSV_HEADER,
345
num_epochs=1,
346
header=False,
347
field_delim="|",
348
shuffle=shuffle,
349
).map(process)
350
return dataset
351
352
353
def encode_input_features(
354
inputs,
355
include_user_id,
356
include_user_features,
357
include_movie_features,
358
):
359
encoded_transformer_features = []
360
encoded_other_features = []
361
362
other_feature_names = []
363
if include_user_id:
364
other_feature_names.append("user_id")
365
if include_user_features:
366
other_feature_names.extend(USER_FEATURES)
367
368
## Encode user features
369
for feature_name in other_feature_names:
370
vocabulary = CATEGORICAL_FEATURES_WITH_VOCABULARY[feature_name]
371
# Compute embedding dimensions
372
embedding_dims = int(math.sqrt(len(vocabulary)))
373
# Create an embedding layer with the specified dimensions.
374
embedding_encoder = layers.Embedding(
375
input_dim=len(vocabulary),
376
output_dim=embedding_dims,
377
name=f"{feature_name}_embedding",
378
)
379
# Convert the index values to embedding representations.
380
encoded_other_features.append(embedding_encoder(inputs[feature_name]))
381
382
## Create a single embedding vector for the user features
383
if len(encoded_other_features) > 1:
384
encoded_other_features = layers.concatenate(encoded_other_features)
385
elif len(encoded_other_features) == 1:
386
encoded_other_features = encoded_other_features[0]
387
else:
388
encoded_other_features = None
389
390
## Create a movie embedding encoder
391
movie_vocabulary = CATEGORICAL_FEATURES_WITH_VOCABULARY["movie_id"]
392
movie_embedding_dims = int(math.sqrt(len(movie_vocabulary)))
393
# Create an embedding layer with the specified dimensions.
394
movie_embedding_encoder = layers.Embedding(
395
input_dim=len(movie_vocabulary),
396
output_dim=movie_embedding_dims,
397
name=f"movie_embedding",
398
)
399
# Create a vector lookup for movie genres.
400
genre_vectors = movies[genres].to_numpy()
401
movie_genres_lookup = layers.Embedding(
402
input_dim=genre_vectors.shape[0],
403
output_dim=genre_vectors.shape[1],
404
embeddings_initializer=keras.initializers.Constant(genre_vectors),
405
trainable=False,
406
name="genres_vector",
407
)
408
# Create a processing layer for genres.
409
movie_embedding_processor = layers.Dense(
410
units=movie_embedding_dims,
411
activation="relu",
412
name="process_movie_embedding_with_genres",
413
)
414
415
## Define a function to encode a given movie id.
416
def encode_movie(movie_id):
417
# Convert the string input values into integer indices.
418
movie_embedding = movie_embedding_encoder(movie_id)
419
encoded_movie = movie_embedding
420
if include_movie_features:
421
movie_genres_vector = movie_genres_lookup(movie_id)
422
encoded_movie = movie_embedding_processor(
423
layers.concatenate([movie_embedding, movie_genres_vector])
424
)
425
return encoded_movie
426
427
## Encoding target_movie_id
428
target_movie_id = inputs["target_movie_id"]
429
encoded_target_movie = encode_movie(target_movie_id)
430
431
## Encoding sequence movie_ids.
432
sequence_movies_ids = inputs["sequence_movie_ids"]
433
encoded_sequence_movies = encode_movie(sequence_movies_ids)
434
# Create positional embedding.
435
position_embedding_encoder = layers.Embedding(
436
input_dim=sequence_length,
437
output_dim=movie_embedding_dims,
438
name="position_embedding",
439
)
440
positions = ops.arange(start=0, stop=sequence_length - 1, step=1)
441
encodded_positions = position_embedding_encoder(positions)
442
# Retrieve sequence ratings to incorporate them into the encoding of the movie.
443
sequence_ratings = inputs["sequence_ratings"]
444
sequence_ratings = ops.expand_dims(sequence_ratings, -1)
445
# Add the positional encoding to the movie encodings and multiply them by rating.
446
encoded_sequence_movies_with_poistion_and_rating = layers.Multiply()(
447
[(encoded_sequence_movies + encodded_positions), sequence_ratings]
448
)
449
450
# Construct the transformer inputs.
451
for i in range(sequence_length - 1):
452
feature = encoded_sequence_movies_with_poistion_and_rating[:, i, ...]
453
feature = ops.expand_dims(feature, 1)
454
encoded_transformer_features.append(feature)
455
encoded_transformer_features.append(encoded_target_movie)
456
encoded_transformer_features = layers.concatenate(
457
encoded_transformer_features, axis=1
458
)
459
return encoded_transformer_features, encoded_other_features
460
461
462
"""
463
## Create model inputs
464
"""
465
466
467
def create_model_inputs():
468
return {
469
"user_id": keras.Input(name="user_id", shape=(1,), dtype="int32"),
470
"sequence_movie_ids": keras.Input(
471
name="sequence_movie_ids", shape=(sequence_length - 1,), dtype="int32"
472
),
473
"target_movie_id": keras.Input(
474
name="target_movie_id", shape=(1,), dtype="int32"
475
),
476
"sequence_ratings": keras.Input(
477
name="sequence_ratings", shape=(sequence_length - 1,), dtype="float32"
478
),
479
"sex": keras.Input(name="sex", shape=(1,), dtype="int32"),
480
"age_group": keras.Input(name="age_group", shape=(1,), dtype="int32"),
481
"occupation": keras.Input(name="occupation", shape=(1,), dtype="int32"),
482
}
483
484
485
"""
486
## Create a BST model
487
"""
488
489
include_user_id = False
490
include_user_features = False
491
include_movie_features = False
492
493
hidden_units = [256, 128]
494
dropout_rate = 0.1
495
num_heads = 3
496
497
498
def create_model():
499
500
inputs = create_model_inputs()
501
transformer_features, other_features = encode_input_features(
502
inputs, include_user_id, include_user_features, include_movie_features
503
)
504
# Create a multi-headed attention layer.
505
attention_output = layers.MultiHeadAttention(
506
num_heads=num_heads, key_dim=transformer_features.shape[2], dropout=dropout_rate
507
)(transformer_features, transformer_features)
508
509
# Transformer block.
510
attention_output = layers.Dropout(dropout_rate)(attention_output)
511
x1 = layers.Add()([transformer_features, attention_output])
512
x1 = layers.LayerNormalization()(x1)
513
x2 = layers.LeakyReLU()(x1)
514
x2 = layers.Dense(units=x2.shape[-1])(x2)
515
x2 = layers.Dropout(dropout_rate)(x2)
516
transformer_features = layers.Add()([x1, x2])
517
transformer_features = layers.LayerNormalization()(transformer_features)
518
features = layers.Flatten()(transformer_features)
519
520
# Included the other_features.
521
if other_features is not None:
522
features = layers.concatenate(
523
[features, layers.Reshape([other_features.shape[-1]])(other_features)]
524
)
525
526
# Fully-connected layers.
527
for num_units in hidden_units:
528
features = layers.Dense(num_units)(features)
529
features = layers.BatchNormalization()(features)
530
features = layers.LeakyReLU()(features)
531
features = layers.Dropout(dropout_rate)(features)
532
outputs = layers.Dense(units=1)(features)
533
model = keras.Model(inputs=inputs, outputs=outputs)
534
return model
535
536
537
model = create_model()
538
539
"""
540
## Run training and evaluation experiment
541
"""
542
543
# Compile the model.
544
model.compile(
545
optimizer=keras.optimizers.Adagrad(learning_rate=0.01),
546
loss=keras.losses.MeanSquaredError(),
547
metrics=[keras.metrics.MeanAbsoluteError()],
548
)
549
550
# Read the training data.
551
552
train_dataset = get_dataset_from_csv("train_data.csv", batch_size=265, shuffle=True)
553
554
# Fit the model with the training data.
555
model.fit(train_dataset, epochs=2)
556
557
# Read the test data.
558
test_dataset = get_dataset_from_csv("test_data.csv", batch_size=265)
559
560
# Evaluate the model on the test data.
561
_, rmse = model.evaluate(test_dataset, verbose=0)
562
print(f"Test MAE: {round(rmse, 3)}")
563
564
"""
565
You should achieve a Mean Absolute Error (MAE) at or around 0.7 on the test data.
566
"""
567
568
"""
569
## Conclusion
570
571
The BST model uses the Transformer layer in its architecture to capture the sequential signals underlying
572
users’ behavior sequences for recommendation.
573
574
You can try training this model with different configurations, for example, by increasing
575
the input sequence length and training the model for a larger number of epochs. In addition,
576
you can try including other features like movie release year and customer
577
zipcode, and including cross features like sex X genre.
578
"""
579
580