Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/keras_rs/dlrm.py
3507 views
1
"""
2
Title: Ranking with Deep Learning Recommendation Model
3
Author: [Harshith Kulkarni](https://github.com/kharshith-k)
4
Date created: 2025/06/02
5
Last modified: 2025/09/04
6
Description: Rank movies with DLRM using KerasRS.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
This tutorial demonstrates how to use the Deep Learning Recommendation Model (DLRM) to
14
effectively learn the relationships between items and user preferences using a
15
dot-product interaction mechanism. For more details, please refer to the
16
[DLRM](https://arxiv.org/abs/1906.00091) paper.
17
18
DLRM is designed to excel at capturing explicit, bounded-degree feature interactions and
19
is particularly effective at processing both categorical and continuous (sparse/dense)
20
input features. The architecture consists of three main components: dedicated input
21
layers to handle diverse features (typically embedding layers for categorical features),
22
a dot-product interaction layer to explicitly model feature interactions, and a
23
Multi-Layer Perceptron (MLP) to capture implicit feature relationships.
24
25
The dot-product interaction layer lies at the heart of DLRM, efficiently computing
26
pairwise interactions between different feature embeddings. This contrasts with models
27
like Deep & Cross Network (DCN), which can treat elements within a feature vector as
28
independent units, potentially leading to a higher-dimensional space and increased
29
computational cost. The MLP is a standard feedforward network. The DLRM is formed by
30
combining the interaction layer and MLP.
31
32
The following image illustrates the DLRM architecture:
33
34
![DLRM Architecture](/img/examples/keras_rs/dlrm/dlrm_architecture.gif)
35
36
37
Now that we have a foundational understanding of DLRM's architecture and key
38
characteristics, let's dive into the code. We will train a DLRM on a real-world dataset
39
to demonstrate its capability to learn meaningful feature interactions. Let's begin by
40
setting the backend to JAX and organizing our imports.
41
"""
42
43
"""shell
44
!pip install -q keras-rs
45
"""
46
47
import os
48
49
os.environ["KERAS_BACKEND"] = "tensorflow" # `"tensorflow"`/`"torch"`
50
51
import keras
52
import matplotlib.pyplot as plt
53
import numpy as np
54
import tensorflow as tf
55
import tensorflow_datasets as tfds
56
from mpl_toolkits.axes_grid1 import make_axes_locatable
57
58
import keras_rs
59
60
"""
61
Let's also define variables which will be reused throughout the example.
62
"""
63
64
MOVIELENS_CONFIG = {
65
# features
66
"continuous_features": [
67
"raw_user_age",
68
"hour_of_day_sin",
69
"hour_of_day_cos",
70
"hour_of_week_sin",
71
"hour_of_week_cos",
72
],
73
"categorical_int_features": [
74
"user_gender",
75
],
76
"categorical_str_features": [
77
"user_zip_code",
78
"user_occupation_text",
79
"movie_id",
80
"user_id",
81
],
82
# model
83
"embedding_dim": 8,
84
"mlp_dim": 8,
85
"deep_net_num_units": [192, 192, 192],
86
# training
87
"learning_rate": 1e-4,
88
"num_epochs": 30,
89
"batch_size": 8192,
90
}
91
92
"""
93
Here, we define a helper function for visualising weights of the cross layer in
94
order to better understand its functioning. Also, we define a function for
95
compiling, training and evaluating a given model.
96
"""
97
98
99
def plot_training_metrics(history):
100
"""Graphs all metrics tracked in the history object."""
101
plt.figure(figsize=(12, 6))
102
103
for metric_name, metric_values in history.history.items():
104
plt.plot(metric_values, label=metric_name.replace("_", " ").title())
105
106
plt.title("Metrics over Epochs")
107
plt.xlabel("Epoch")
108
plt.ylabel("Metric Value")
109
plt.legend()
110
plt.grid(True)
111
112
113
def visualize_layer(matrix, features, cmap=plt.cm.Blues):
114
115
im = plt.matshow(
116
matrix, cmap=cmap, extent=[-0.5, len(features) - 0.5, len(features) - 0.5, -0.5]
117
)
118
119
ax = plt.gca()
120
divider = make_axes_locatable(plt.gca())
121
cax = divider.append_axes("right", size="5%", pad=0.05)
122
plt.colorbar(im, cax=cax)
123
cax.tick_params(labelsize=10)
124
125
# Set tick locations explicitly before setting labels
126
ax.set_xticks(np.arange(len(features)))
127
ax.set_yticks(np.arange(len(features)))
128
129
ax.set_xticklabels(features, rotation=45, fontsize=5)
130
ax.set_yticklabels(features, fontsize=5)
131
132
plt.show()
133
134
135
def train_and_evaluate(
136
learning_rate,
137
epochs,
138
train_data,
139
test_data,
140
model,
141
plot_metrics=False,
142
):
143
optimizer = keras.optimizers.AdamW(learning_rate=learning_rate, clipnorm=1.0)
144
loss = keras.losses.MeanSquaredError()
145
rmse = keras.metrics.RootMeanSquaredError()
146
147
model.compile(
148
optimizer=optimizer,
149
loss=loss,
150
metrics=[rmse],
151
)
152
153
history = model.fit(
154
train_data,
155
epochs=epochs,
156
verbose=1,
157
)
158
if plot_metrics:
159
plot_training_metrics(history)
160
161
results = model.evaluate(test_data, return_dict=True, verbose=1)
162
rmse_value = results["root_mean_squared_error"]
163
164
return rmse_value, model.count_params()
165
166
167
def print_stats(rmse_list, num_params, model_name):
168
# Report metrics.
169
num_trials = len(rmse_list)
170
avg_rmse = np.mean(rmse_list)
171
std_rmse = np.std(rmse_list)
172
173
if num_trials == 1:
174
print(f"{model_name}: RMSE = {avg_rmse}; #params = {num_params}")
175
else:
176
print(f"{model_name}: RMSE = {avg_rmse} ± {std_rmse}; #params = {num_params}")
177
178
179
"""
180
## Real-world example
181
182
Let's use the MovieLens 100K dataset. This dataset is used to train models to
183
predict users' movie ratings, based on user-related features and movie-related
184
features.
185
186
### Preparing the dataset
187
188
The dataset processing steps here are similar to what's given in the
189
[basic ranking](/keras_rs/examples/basic_ranking/)
190
tutorial. Let's load the dataset, and keep only the useful columns.
191
"""
192
193
ratings_ds = tfds.load("movielens/100k-ratings", split="train")
194
195
196
def preprocess_features(x):
197
"""Extracts and cyclically encodes timestamp features."""
198
features = {
199
"movie_id": x["movie_id"],
200
"user_id": x["user_id"],
201
"user_gender": tf.cast(x["user_gender"], dtype=tf.int32),
202
"user_zip_code": x["user_zip_code"],
203
"user_occupation_text": x["user_occupation_text"],
204
"raw_user_age": tf.cast(x["raw_user_age"], dtype=tf.float32),
205
}
206
label = tf.cast(x["user_rating"], dtype=tf.float32)
207
208
# The timestamp is in seconds since the epoch.
209
timestamp = tf.cast(x["timestamp"], dtype=tf.float32)
210
211
# Constants for time periods
212
SECONDS_IN_HOUR = 3600.0
213
HOURS_IN_DAY = 24.0
214
HOURS_IN_WEEK = 168.0
215
216
# Calculate hour of day and encode it
217
hour_of_day = (timestamp / SECONDS_IN_HOUR) % HOURS_IN_DAY
218
features["hour_of_day_sin"] = tf.sin(2 * np.pi * hour_of_day / HOURS_IN_DAY)
219
features["hour_of_day_cos"] = tf.cos(2 * np.pi * hour_of_day / HOURS_IN_DAY)
220
221
# Calculate hour of week and encode it
222
hour_of_week = (timestamp / SECONDS_IN_HOUR) % HOURS_IN_WEEK
223
features["hour_of_week_sin"] = tf.sin(2 * np.pi * hour_of_week / HOURS_IN_WEEK)
224
features["hour_of_week_cos"] = tf.cos(2 * np.pi * hour_of_week / HOURS_IN_WEEK)
225
226
return features, label
227
228
229
# Apply the new preprocessing function
230
ratings_ds = ratings_ds.map(preprocess_features)
231
232
"""
233
For every categorical feature, let's get the list of unique values, i.e., vocabulary, so
234
that we can use that for the embedding layer.
235
"""
236
237
vocabularies = {}
238
for feature_name in (
239
MOVIELENS_CONFIG["categorical_int_features"]
240
+ MOVIELENS_CONFIG["categorical_str_features"]
241
):
242
vocabulary = ratings_ds.batch(10_000).map(lambda x, y: x[feature_name])
243
vocabularies[feature_name] = np.unique(np.concatenate(list(vocabulary)))
244
245
"""
246
One thing we need to do is to use `keras.layers.StringLookup` and
247
`keras.layers.IntegerLookup` to convert all the categorical features into indices, which
248
can
249
then be fed into embedding layers.
250
"""
251
252
lookup_layers = {}
253
lookup_layers.update(
254
{
255
feature: keras.layers.IntegerLookup(vocabulary=vocabularies[feature])
256
for feature in MOVIELENS_CONFIG["categorical_int_features"]
257
}
258
)
259
lookup_layers.update(
260
{
261
feature: keras.layers.StringLookup(vocabulary=vocabularies[feature])
262
for feature in MOVIELENS_CONFIG["categorical_str_features"]
263
}
264
)
265
266
"""
267
Let's normalize all the continuous features, so that we can use that for the MLP layers.
268
"""
269
270
normalization_layers = {}
271
for feature_name in MOVIELENS_CONFIG["continuous_features"]:
272
normalization_layers[feature_name] = keras.layers.Normalization(axis=-1)
273
274
training_data_for_adaptation = ratings_ds.take(80_000).map(lambda x, y: x)
275
276
for feature_name in MOVIELENS_CONFIG["continuous_features"]:
277
feature_ds = training_data_for_adaptation.map(
278
lambda x: tf.expand_dims(x[feature_name], axis=-1)
279
)
280
normalization_layers[feature_name].adapt(feature_ds)
281
282
ratings_ds = ratings_ds.map(
283
lambda x, y: (
284
{
285
**{
286
feature_name: lookup_layers[feature_name](x[feature_name])
287
for feature_name in vocabularies
288
},
289
# Apply the adapted normalization layers to the continuous features.
290
**{
291
feature_name: tf.squeeze(
292
normalization_layers[feature_name](
293
tf.expand_dims(x[feature_name], axis=-1)
294
),
295
axis=-1,
296
)
297
for feature_name in MOVIELENS_CONFIG["continuous_features"]
298
},
299
},
300
y,
301
)
302
)
303
304
"""
305
Let's split our data into train and test sets. We also use `cache()` and
306
`prefetch()` for better performance.
307
"""
308
309
ratings_ds = ratings_ds.shuffle(100_000)
310
311
train_ds = (
312
ratings_ds.take(80_000)
313
.batch(MOVIELENS_CONFIG["batch_size"])
314
.cache()
315
.prefetch(tf.data.AUTOTUNE)
316
)
317
test_ds = (
318
ratings_ds.skip(80_000)
319
.batch(MOVIELENS_CONFIG["batch_size"])
320
.take(20_000)
321
.cache()
322
.prefetch(tf.data.AUTOTUNE)
323
)
324
325
"""
326
### Building the model
327
328
The model will have embedding layers, followed by DotInteraction and feedforward
329
layers.
330
"""
331
332
333
class DLRM(keras.Model):
334
def __init__(
335
self,
336
dense_num_units_lst,
337
embedding_dim=MOVIELENS_CONFIG["embedding_dim"],
338
mlp_dim=MOVIELENS_CONFIG["mlp_dim"],
339
**kwargs,
340
):
341
super().__init__(**kwargs)
342
343
self.embedding_layers = {}
344
for feature_name in (
345
MOVIELENS_CONFIG["categorical_int_features"]
346
+ MOVIELENS_CONFIG["categorical_str_features"]
347
):
348
vocab_size = len(vocabularies[feature_name]) + 1 # +1 for OOV token
349
self.embedding_layers[feature_name] = keras.layers.Embedding(
350
input_dim=vocab_size,
351
output_dim=embedding_dim,
352
)
353
354
self.bottom_mlp = keras.Sequential(
355
[
356
keras.layers.Dense(mlp_dim, activation="relu"),
357
keras.layers.Dense(embedding_dim), # Output must match embedding_dim
358
]
359
)
360
361
self.dot_layer = keras_rs.layers.DotInteraction()
362
363
self.top_mlp = []
364
for num_units in dense_num_units_lst:
365
self.top_mlp.append(keras.layers.Dense(num_units, activation="relu"))
366
367
self.output_layer = keras.layers.Dense(1)
368
369
self.dense_num_units_lst = dense_num_units_lst
370
self.embedding_dim = embedding_dim
371
372
def call(self, inputs):
373
embeddings = []
374
for feature_name in (
375
MOVIELENS_CONFIG["categorical_int_features"]
376
+ MOVIELENS_CONFIG["categorical_str_features"]
377
):
378
embedding = self.embedding_layers[feature_name](inputs[feature_name])
379
embeddings.append(embedding)
380
381
# Process all continuous features together.
382
continuous_inputs = []
383
for feature_name in MOVIELENS_CONFIG["continuous_features"]:
384
# Reshape each feature to (batch_size, 1)
385
feature = keras.ops.reshape(
386
keras.ops.cast(inputs[feature_name], dtype="float32"), (-1, 1)
387
)
388
continuous_inputs.append(feature)
389
390
# Concatenate into a single tensor: (batch_size, num_continuous_features)
391
concatenated_continuous = keras.ops.concatenate(continuous_inputs, axis=1)
392
393
# Pass through the Bottom MLP to get one combined vector.
394
processed_continuous = self.bottom_mlp(concatenated_continuous)
395
396
# Combine with categorical embeddings. Note: we add a list containing the
397
# single tensor.
398
combined_features = embeddings + [processed_continuous]
399
400
# Pass the list of features to the DotInteraction layer.
401
x = self.dot_layer(combined_features)
402
403
for layer in self.top_mlp:
404
x = layer(x)
405
406
x = self.output_layer(x)
407
408
return x
409
410
411
dot_network = DLRM(
412
dense_num_units_lst=MOVIELENS_CONFIG["deep_net_num_units"],
413
embedding_dim=MOVIELENS_CONFIG["embedding_dim"],
414
mlp_dim=MOVIELENS_CONFIG["mlp_dim"],
415
)
416
417
rmse, dot_network_num_params = train_and_evaluate(
418
learning_rate=MOVIELENS_CONFIG["learning_rate"],
419
epochs=MOVIELENS_CONFIG["num_epochs"],
420
train_data=train_ds,
421
test_data=test_ds,
422
model=dot_network,
423
plot_metrics=True,
424
)
425
print_stats(
426
rmse_list=[rmse],
427
num_params=dot_network_num_params,
428
model_name="Dot Network",
429
)
430
431
"""
432
### Visualizing feature interactions
433
434
The DotInteraction layer itself doesn't have a conventional "weight" matrix like a Dense
435
layer. Instead, its function is to compute the dot product between the embedding vectors
436
of your features.
437
438
To visualize the strength of these interactions, we can calculate a matrix representing
439
the pairwise interaction strength between all feature embeddings. A common way to do this
440
is to take the dot product of the embedding matrices for each pair of features and then
441
aggregate the result into a single value (like the mean of the absolute values) that
442
represents the overall interaction strength.
443
"""
444
445
446
def get_dot_interaction_matrix(model, categorical_features, continuous_features):
447
# The new feature list for the plot labels
448
all_feature_names = categorical_features + ["all_continuous_features"]
449
num_features = len(all_feature_names)
450
451
# Store all feature outputs in the correct order.
452
all_feature_outputs = []
453
454
# Get outputs for categorical features from embedding layers (unchanged).
455
for feature_name in categorical_features:
456
embedding = model.embedding_layers[feature_name](keras.ops.array([0]))
457
all_feature_outputs.append(embedding)
458
459
# Get a single output for ALL continuous features from the shared MLP.
460
num_continuous_features = len(continuous_features)
461
# Create a dummy input of zeros for the MLP
462
dummy_continuous_input = keras.ops.zeros((1, num_continuous_features))
463
processed_continuous = model.bottom_mlp(dummy_continuous_input)
464
all_feature_outputs.append(processed_continuous)
465
466
interaction_matrix = np.zeros((num_features, num_features))
467
468
# Iterate through each pair to calculate interaction strength.
469
for i in range(num_features):
470
for j in range(num_features):
471
interaction = keras.ops.dot(
472
all_feature_outputs[i], keras.ops.transpose(all_feature_outputs[j])
473
)
474
interaction_strength = keras.ops.convert_to_numpy(np.abs(interaction))[0][0]
475
interaction_matrix[i, j] = interaction_strength
476
477
return interaction_matrix, all_feature_names
478
479
480
# Get the list of categorical feature names.
481
categorical_feature_names = (
482
MOVIELENS_CONFIG["categorical_int_features"]
483
+ MOVIELENS_CONFIG["categorical_str_features"]
484
)
485
486
# Calculate the interaction matrix with the corrected function.
487
interaction_matrix, feature_names = get_dot_interaction_matrix(
488
model=dot_network,
489
categorical_features=categorical_feature_names,
490
continuous_features=MOVIELENS_CONFIG["continuous_features"],
491
)
492
493
# Visualize the matrix as a heatmap.
494
print("\nVisualizing the feature interaction strengths:")
495
visualize_layer(interaction_matrix, feature_names)
496
497