Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/keras_rs/sas_rec.py
3507 views
1
"""
2
Title: Sequential retrieval using SASRec
3
Author: [Abheesht Sharma](https://github.com/abheesht17/), [Fabien Hertschuh](https://github.com/hertschuh/)
4
Date created: 2025/04/28
5
Last modified: 2025/04/28
6
Description: Recommend movies using a Transformer-based retrieval model (SASRec).
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
Sequential recommendation is a popular model that looks at a sequence of items
14
that users have interacted with previously and then predicts the next item.
15
Here, the order of the items within each sequence matters. Previously, in the
16
[Recommending movies: retrieval using a sequential model](/keras_rs/examples/sequential_retrieval/)
17
example, we built a GRU-based sequential retrieval model. In this example, we
18
will build a popular Transformer decoder-based model named
19
[Self-Attentive Sequential Recommendation (SASRec)](https://arxiv.org/abs/1808.09781)
20
for the same sequential recommendation task.
21
22
Let's begin by importing all the necessary libraries.
23
"""
24
25
"""shell
26
pip install -q keras-rs
27
"""
28
29
import os
30
31
os.environ["KERAS_BACKEND"] = "jax" # `"tensorflow"`/`"torch"`
32
33
import collections
34
import os
35
36
import keras
37
import keras_hub
38
import numpy as np
39
import pandas as pd
40
import tensorflow as tf # Needed only for the dataset
41
from keras import ops
42
43
import keras_rs
44
45
"""
46
Let's also define all important variables/hyperparameters below.
47
"""
48
49
DATA_DIR = "./raw/data/"
50
51
# MovieLens-specific variables
52
MOVIELENS_1M_URL = "https://files.grouplens.org/datasets/movielens/ml-1m.zip"
53
MOVIELENS_ZIP_HASH = "a6898adb50b9ca05aa231689da44c217cb524e7ebd39d264c56e2832f2c54e20"
54
55
RATINGS_FILE_NAME = "ratings.dat"
56
MOVIES_FILE_NAME = "movies.dat"
57
58
# Data processing args
59
MAX_CONTEXT_LENGTH = 200
60
MIN_SEQUENCE_LENGTH = 3
61
PAD_ITEM_ID = 0
62
63
RATINGS_DATA_COLUMNS = ["UserID", "MovieID", "Rating", "Timestamp"]
64
MOVIES_DATA_COLUMNS = ["MovieID", "Title", "Genres"]
65
MIN_RATING = 2
66
67
# Training/model args picked from SASRec paper
68
BATCH_SIZE = 128
69
NUM_EPOCHS = 10
70
LEARNING_RATE = 0.001
71
72
NUM_LAYERS = 2
73
NUM_HEADS = 1
74
HIDDEN_DIM = 50
75
DROPOUT = 0.2
76
77
"""
78
## Dataset
79
80
Next, we need to prepare our dataset. Like we did in the
81
[sequential retrieval](/keras_rs/examples/sequential_retrieval/)
82
example, we are going to use the MovieLens dataset.
83
84
The dataset preparation step is fairly involved. The original ratings dataset
85
contains `(user, movie ID, rating, timestamp)` tuples (among other columns,
86
which are not important for this example). Since we are dealing with sequential
87
retrieval, we need to create movie sequences for every user, where the sequences
88
are ordered by timestamp.
89
90
Let's start by downloading and reading the dataset.
91
"""
92
93
# Download the MovieLens dataset.
94
if not os.path.exists(DATA_DIR):
95
os.makedirs(DATA_DIR)
96
97
path_to_zip = keras.utils.get_file(
98
fname="ml-1m.zip",
99
origin=MOVIELENS_1M_URL,
100
file_hash=MOVIELENS_ZIP_HASH,
101
hash_algorithm="sha256",
102
extract=True,
103
cache_dir=DATA_DIR,
104
)
105
movielens_extracted_dir = os.path.join(
106
os.path.dirname(path_to_zip),
107
"ml-1m_extracted",
108
"ml-1m",
109
)
110
111
112
# Read the dataset.
113
def read_data(data_directory, min_rating=None):
114
"""Read movielens ratings.dat and movies.dat file
115
into dataframe.
116
"""
117
118
ratings_df = pd.read_csv(
119
os.path.join(data_directory, RATINGS_FILE_NAME),
120
sep="::",
121
names=RATINGS_DATA_COLUMNS,
122
encoding="unicode_escape",
123
)
124
ratings_df["Timestamp"] = ratings_df["Timestamp"].apply(int)
125
126
# Remove movies with `rating < min_rating`.
127
if min_rating is not None:
128
ratings_df = ratings_df[ratings_df["Rating"] >= min_rating]
129
130
movies_df = pd.read_csv(
131
os.path.join(data_directory, MOVIES_FILE_NAME),
132
sep="::",
133
names=MOVIES_DATA_COLUMNS,
134
encoding="unicode_escape",
135
)
136
return ratings_df, movies_df
137
138
139
ratings_df, movies_df = read_data(
140
data_directory=movielens_extracted_dir, min_rating=MIN_RATING
141
)
142
143
# Need to know #movies so as to define embedding layers.
144
movies_count = movies_df["MovieID"].max()
145
146
"""
147
Now that we have read the dataset, let's create sequences of movies
148
for every user. Here is the function for doing just that.
149
"""
150
151
152
def get_movie_sequence_per_user(ratings_df):
153
"""Get movieID sequences for every user."""
154
sequences = collections.defaultdict(list)
155
156
for user_id, movie_id, rating, timestamp in ratings_df.values:
157
sequences[user_id].append(
158
{
159
"movie_id": movie_id,
160
"timestamp": timestamp,
161
"rating": rating,
162
}
163
)
164
165
# Sort movie sequences by timestamp for every user.
166
for user_id, context in sequences.items():
167
context.sort(key=lambda x: x["timestamp"])
168
sequences[user_id] = context
169
170
return sequences
171
172
173
sequences = get_movie_sequence_per_user(ratings_df)
174
175
"""
176
So far, we have essentially replicated what we did in the sequential retrieval
177
example. We have a sequence of movies for every user.
178
179
SASRec is trained contrastively, which means the model learns to distinguish
180
between sequences of movies a user has actually interacted with (positive
181
examples) and sequences they have not interacted with (negative examples).
182
183
The following function, `format_data`, prepares the data in this specific
184
format. For each user's movie sequence, it generates a corresponding
185
"negative sequence". This negative sequence consists of randomly
186
selected movies that the user has *not* interacted with, but are of the same
187
length as the original sequence.
188
"""
189
190
191
def format_data(sequences):
192
examples = {
193
"sequence": [],
194
"negative_sequence": [],
195
}
196
197
for user_id in sequences:
198
sequence = [int(d["movie_id"]) for d in sequences[user_id]]
199
200
# Get negative sequence.
201
def random_negative_item_id(low, high, positive_lst):
202
sampled = np.random.randint(low=low, high=high)
203
while sampled in positive_lst:
204
sampled = np.random.randint(low=low, high=high)
205
return sampled
206
207
negative_sequence = [
208
random_negative_item_id(1, movies_count + 1, sequence)
209
for _ in range(len(sequence))
210
]
211
212
examples["sequence"].append(np.array(sequence))
213
examples["negative_sequence"].append(np.array(negative_sequence))
214
215
examples["sequence"] = tf.ragged.constant(examples["sequence"])
216
examples["negative_sequence"] = tf.ragged.constant(examples["negative_sequence"])
217
218
return examples
219
220
221
examples = format_data(sequences)
222
ds = tf.data.Dataset.from_tensor_slices(examples).batch(BATCH_SIZE)
223
224
"""
225
Now that we have the original movie interaction sequences for each user (from
226
`format_data`, stored in `examples["sequence"]`) and their corresponding
227
random negative sequences (in `examples["negative_sequence"]`), the next step is
228
to prepare this data for input to the model. The primary goals of this
229
preprocessing are:
230
231
1. Creating Input Features and Target Labels: For sequential
232
recommendation, the model learns to predict the next item in a sequence
233
given the preceding items. This is achieved by:
234
- taking the original `example["sequence"]` and creating the model's
235
input features (`item_ids`) from all items *except the last one*
236
(`example["sequence"][..., :-1]`);
237
- creating the target "positive sequence" (what the model tries to predict
238
as the actual next items) by taking the original `example["sequence"]`
239
and shifting it, using all items *except the first one*
240
(`example["sequence"][..., 1:]`);
241
- shifting `example["negative_sequence"]` (from `format_data`) is
242
to create the target "negative sequence" for the contrastive loss
243
(`example["negative_sequence"][..., 1:]`).
244
245
2. Handling Variable Length Sequences: Neural networks typically require
246
fixed-size inputs. Therefore, both the input feature sequences and the
247
target sequences are padded (with a special `PAD_ITEM_ID`) or truncated
248
to a predefined `MAX_CONTEXT_LENGTH`. A `padding_mask` is also generated
249
from the input features to ensure the model ignores these padded tokens
250
during attention calculations, i.e, these tokens will be masked.
251
252
3. Differentiating Training and Validation/Testing:
253
- During training:
254
- Input features (`item_ids`) and context for negative sequences
255
are prepared as described above (all but the last item of the
256
original sequences).
257
- Target positive and negative sequences are the shifted versions of
258
the original sequences.
259
- `sample_weight` is created based on the input features to ensure
260
that loss is calculated only on actual items, not on padding tokens
261
in the targets.
262
- During validation/testing:
263
- Input features are prepared similarly.
264
- The model's performance is typically evaluated on its ability to
265
predict the actual last item of the original sequence. Thus,
266
`sample_weight` is configured to focus the loss calculation
267
only on this final prediction in the target sequences.
268
269
Note: SASRec does the same thing we've done above, except that they take the
270
`item_ids[:-2]` for the validation set and `item_ids[:-1]` for the test set.
271
We skip that here for brevity.
272
"""
273
274
275
def _preprocess(example, train=False):
276
sequence = example["sequence"]
277
negative_sequence = example["negative_sequence"]
278
279
if train:
280
sequence = example["sequence"][..., :-1]
281
negative_sequence = example["negative_sequence"][..., :-1]
282
283
batch_size = tf.shape(sequence)[0]
284
285
if not train:
286
# Loss computed only on last token.
287
sample_weight = tf.zeros_like(sequence, dtype="float32")[..., :-1]
288
sample_weight = tf.concat(
289
[sample_weight, tf.ones((batch_size, 1), dtype="float32")], axis=1
290
)
291
292
# Truncate/pad sequence. +1 to account for truncation later.
293
sequence = sequence.to_tensor(
294
shape=[batch_size, MAX_CONTEXT_LENGTH + 1], default_value=PAD_ITEM_ID
295
)
296
negative_sequence = negative_sequence.to_tensor(
297
shape=[batch_size, MAX_CONTEXT_LENGTH + 1], default_value=PAD_ITEM_ID
298
)
299
if train:
300
sample_weight = tf.cast(sequence != PAD_ITEM_ID, dtype="float32")
301
else:
302
sample_weight = sample_weight.to_tensor(
303
shape=[batch_size, MAX_CONTEXT_LENGTH + 1], default_value=0
304
)
305
306
example = (
307
{
308
# last token does not have a next token
309
"item_ids": sequence[..., :-1],
310
# padding mask for controlling attention mask
311
"padding_mask": (sequence != PAD_ITEM_ID)[..., :-1],
312
},
313
{
314
"positive_sequence": sequence[
315
..., 1:
316
], # 0th token's label will be 1st token, and so on
317
"negative_sequence": negative_sequence[..., 1:],
318
},
319
sample_weight[..., 1:], # loss will not be computed on pad tokens
320
)
321
return example
322
323
324
def preprocess_train(examples):
325
return _preprocess(examples, train=True)
326
327
328
def preprocess_val(examples):
329
return _preprocess(examples, train=False)
330
331
332
train_ds = ds.map(preprocess_train)
333
val_ds = ds.map(preprocess_val)
334
335
"""
336
We can see a batch for each.
337
"""
338
339
for batch in train_ds.take(1):
340
print(batch)
341
342
for batch in val_ds.take(1):
343
print(batch)
344
345
346
"""
347
## Model
348
349
To encode the input sequence, we use a Transformer decoder-based model. This
350
part of the model is very similar to the GPT-2 architecture. Refer to the
351
[GPT text generation from scratch with KerasHub](/examples/generative/text_generation_gpt/#build-the-model)
352
guide for more details on this part.
353
354
One part to note is that when we are "predicting", i.e., `training` is `False`,
355
we get the embedding corresponding to the last movie in the sequence. This makes
356
sense, because at inference time, we want to predict the movie the user will
357
likely watch after watching the last movie.
358
359
Also, it's worth discussing the `compute_loss` method. We embed the positive
360
and negative sequences using the input embedding matrix. We compute the
361
similarity of (positive sequence, input sequence) and (negative sequence,
362
input sequence) pair embeddings by computing the dot product. The goal now is
363
to maximize the similarity of the former and minimize the similarity of
364
the latter. Let's see this mathematically. Binary Cross Entropy is written
365
as follows:
366
367
```
368
loss = - (y_true * np.log(y_pred) + (1 - y_true) * np.log(1 - y_pred))
369
```
370
371
Here, we assign the positive pairs a label of 1 and the negative pairs a label
372
of 0. So, for a positive pair, the loss reduces to:
373
374
```
375
loss = -np.log(positive_logits)
376
```
377
378
Minimising the loss means we want to maximize the log term, which in turn,
379
implies maximising `positive_logits`. Similarly, we want to minimize
380
`negative_logits`.
381
"""
382
383
384
class SasRec(keras.Model):
385
def __init__(
386
self,
387
vocabulary_size,
388
num_layers,
389
num_heads,
390
hidden_dim,
391
dropout=0.0,
392
max_sequence_length=100,
393
dtype=None,
394
**kwargs,
395
):
396
super().__init__(dtype=dtype, **kwargs)
397
398
# ======== Layers ========
399
400
# === Embeddings ===
401
self.item_embedding = keras_hub.layers.ReversibleEmbedding(
402
input_dim=vocabulary_size,
403
output_dim=hidden_dim,
404
embeddings_initializer="glorot_uniform",
405
embeddings_regularizer=keras.regularizers.l2(0.001),
406
dtype=dtype,
407
name="item_embedding",
408
)
409
self.position_embedding = keras_hub.layers.PositionEmbedding(
410
initializer="glorot_uniform",
411
sequence_length=max_sequence_length,
412
dtype=dtype,
413
name="position_embedding",
414
)
415
self.embeddings_add = keras.layers.Add(
416
dtype=dtype,
417
name="embeddings_add",
418
)
419
self.embeddings_dropout = keras.layers.Dropout(
420
dropout,
421
dtype=dtype,
422
name="embeddings_dropout",
423
)
424
425
# === Decoder layers ===
426
self.transformer_layers = []
427
for i in range(num_layers):
428
self.transformer_layers.append(
429
keras_hub.layers.TransformerDecoder(
430
intermediate_dim=hidden_dim,
431
num_heads=num_heads,
432
dropout=dropout,
433
layer_norm_epsilon=1e-05,
434
# SASRec uses ReLU, although GeLU might be a better option
435
activation="relu",
436
kernel_initializer="glorot_uniform",
437
normalize_first=True,
438
dtype=dtype,
439
name=f"transformer_layer_{i}",
440
)
441
)
442
443
# === Final layer norm ===
444
self.layer_norm = keras.layers.LayerNormalization(
445
axis=-1,
446
epsilon=1e-8,
447
dtype=dtype,
448
name="layer_norm",
449
)
450
451
# === Retrieval ===
452
# The layer that performs the retrieval.
453
self.retrieval = keras_rs.layers.BruteForceRetrieval(k=10, return_scores=False)
454
455
# === Loss ===
456
self.loss_fn = keras.losses.BinaryCrossentropy(from_logits=True, reduction=None)
457
458
# === Attributes ===
459
self.vocabulary_size = vocabulary_size
460
self.num_layers = num_layers
461
self.num_heads = num_heads
462
self.hidden_dim = hidden_dim
463
self.dropout = dropout
464
self.max_sequence_length = max_sequence_length
465
466
def _get_last_non_padding_token(self, tensor, padding_mask):
467
valid_token_mask = ops.logical_not(padding_mask)
468
seq_lengths = ops.sum(ops.cast(valid_token_mask, "int32"), axis=1)
469
last_token_indices = ops.maximum(seq_lengths - 1, 0)
470
471
indices = ops.expand_dims(last_token_indices, axis=(-2, -1))
472
gathered_tokens = ops.take_along_axis(tensor, indices, axis=1)
473
last_token_embedding = ops.squeeze(gathered_tokens, axis=1)
474
475
return last_token_embedding
476
477
def build(self, input_shape):
478
embedding_shape = list(input_shape) + [self.hidden_dim]
479
480
# Model
481
self.item_embedding.build(input_shape)
482
self.position_embedding.build(embedding_shape)
483
484
self.embeddings_add.build((embedding_shape, embedding_shape))
485
self.embeddings_dropout.build(embedding_shape)
486
487
for transformer_layer in self.transformer_layers:
488
transformer_layer.build(decoder_sequence_shape=embedding_shape)
489
490
self.layer_norm.build(embedding_shape)
491
492
# Retrieval
493
self.retrieval.candidate_embeddings = self.item_embedding.embeddings
494
self.retrieval.build(input_shape)
495
496
# Chain to super
497
super().build(input_shape)
498
499
def call(self, inputs, training=False):
500
item_ids, padding_mask = inputs["item_ids"], inputs["padding_mask"]
501
502
x = self.item_embedding(item_ids)
503
position_embedding = self.position_embedding(x)
504
x = self.embeddings_add((x, position_embedding))
505
x = self.embeddings_dropout(x)
506
507
for transformer_layer in self.transformer_layers:
508
x = transformer_layer(x, decoder_padding_mask=padding_mask)
509
510
item_sequence_embedding = self.layer_norm(x)
511
result = {"item_sequence_embedding": item_sequence_embedding}
512
513
# At inference, perform top-k retrieval.
514
if not training:
515
# need to extract last non-padding token.
516
last_item_embedding = self._get_last_non_padding_token(
517
item_sequence_embedding, padding_mask
518
)
519
result["predictions"] = self.retrieval(last_item_embedding)
520
521
return result
522
523
def compute_loss(self, x, y, y_pred, sample_weight, training=False):
524
item_sequence_embedding = y_pred["item_sequence_embedding"]
525
y_positive_sequence = y["positive_sequence"]
526
y_negative_sequence = y["negative_sequence"]
527
528
# Embed positive, negative sequences.
529
positive_sequence_embedding = self.item_embedding(y_positive_sequence)
530
negative_sequence_embedding = self.item_embedding(y_negative_sequence)
531
532
# Logits
533
positive_logits = ops.sum(
534
ops.multiply(positive_sequence_embedding, item_sequence_embedding),
535
axis=-1,
536
)
537
negative_logits = ops.sum(
538
ops.multiply(negative_sequence_embedding, item_sequence_embedding),
539
axis=-1,
540
)
541
logits = ops.concatenate([positive_logits, negative_logits], axis=1)
542
543
# Labels
544
labels = ops.concatenate(
545
[
546
ops.ones_like(positive_logits),
547
ops.zeros_like(negative_logits),
548
],
549
axis=1,
550
)
551
552
# sample weights
553
sample_weight = ops.concatenate(
554
[sample_weight, sample_weight],
555
axis=1,
556
)
557
558
loss = self.loss_fn(
559
y_true=ops.expand_dims(labels, axis=-1),
560
y_pred=ops.expand_dims(logits, axis=-1),
561
sample_weight=sample_weight,
562
)
563
loss = ops.divide_no_nan(ops.sum(loss), ops.sum(sample_weight))
564
565
return loss
566
567
def compute_output_shape(self, inputs_shape):
568
return list(inputs_shape) + [self.hidden_dim]
569
570
571
"""
572
Let's instantiate our model and do some sanity checks.
573
"""
574
575
model = SasRec(
576
vocabulary_size=movies_count + 1,
577
num_layers=NUM_LAYERS,
578
num_heads=NUM_HEADS,
579
hidden_dim=HIDDEN_DIM,
580
dropout=DROPOUT,
581
max_sequence_length=MAX_CONTEXT_LENGTH,
582
)
583
584
# Training
585
output = model(
586
inputs={
587
"item_ids": ops.ones((2, MAX_CONTEXT_LENGTH), dtype="int32"),
588
"padding_mask": ops.ones((2, MAX_CONTEXT_LENGTH), dtype="bool"),
589
},
590
training=True,
591
)
592
print(output["item_sequence_embedding"].shape)
593
594
# Inference
595
output = model(
596
inputs={
597
"item_ids": ops.ones((2, MAX_CONTEXT_LENGTH), dtype="int32"),
598
"padding_mask": ops.ones((2, MAX_CONTEXT_LENGTH), dtype="bool"),
599
},
600
training=False,
601
)
602
print(output["predictions"].shape)
603
604
"""
605
Now, let's compile and train our model.
606
"""
607
608
model.compile(
609
optimizer=keras.optimizers.Adam(learning_rate=LEARNING_RATE, beta_2=0.98),
610
)
611
model.fit(
612
x=train_ds,
613
validation_data=val_ds,
614
epochs=NUM_EPOCHS,
615
)
616
617
"""
618
## Making predictions
619
620
Now that we have a model, we would like to be able to make predictions.
621
622
So far, we have only handled movies by id. Now is the time to create a mapping
623
keyed by movie IDs to be able to surface the titles.
624
"""
625
626
movie_id_to_movie_title = dict(zip(movies_df["MovieID"], movies_df["Title"]))
627
movie_id_to_movie_title[0] = "" # Because id 0 is not in the dataset.
628
629
"""
630
We then simply use the Keras `model.predict()` method. Under the hood, it calls
631
the `BruteForceRetrieval` layer to perform the actual retrieval.
632
633
Note that this model can retrieve movies already watched by the user. We could
634
easily add logic to remove them if that is desirable.
635
"""
636
637
for ele in val_ds.unbatch().take(1):
638
test_sample = ele[0]
639
test_sample["item_ids"] = tf.expand_dims(test_sample["item_ids"], axis=0)
640
test_sample["padding_mask"] = tf.expand_dims(test_sample["padding_mask"], axis=0)
641
642
movie_sequence = np.array(test_sample["item_ids"])[0]
643
for movie_id in movie_sequence:
644
if movie_id == 0:
645
continue
646
print(movie_id_to_movie_title[movie_id], end="; ")
647
print()
648
649
predictions = model.predict(test_sample)["predictions"]
650
predictions = keras.ops.convert_to_numpy(predictions)
651
652
for movie_id in predictions[0]:
653
print(movie_id_to_movie_title[movie_id])
654
655
"""
656
And that's all!
657
"""
658
659