Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/keras_rs/dcn.py
3507 views
1
"""
2
Title: Ranking with Deep and Cross Networks
3
Author: [Abheesht Sharma](https://github.com/abheesht17/), [Fabien Hertschuh](https://github.com/hertschuh/)
4
Date created: 2025/04/28
5
Last modified: 2025/04/28
6
Description: Rank movies using Deep and Cross Networks (DCN).
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
This tutorial demonstrates how to use Deep & Cross Networks (DCN) to effectively
14
learn feature crosses. Before diving into the example, let's briefly discuss
15
feature crosses.
16
17
Imagine that we are building a recommender system for blenders. Individual
18
features might include a customer's past purchase history (e.g.,
19
`purchased_bananas`, `purchased_cooking_books`) or geographic location. However,
20
a customer who has purchased both bananas and cooking books is more likely to be
21
interested in a blender than someone who purchased only one or the other. The
22
combination of `purchased_bananas` and `purchased_cooking_books` is a feature
23
cross. Feature crosses capture interaction information between individual
24
features, providing richer context than the individual features alone.
25
26
![Why are feature crosses important?](https://i.imgur.com/qDK6UZh.gif)
27
28
Learning effective feature crosses presents several challenges. In web-scale
29
applications, data is often categorical, resulting in high-dimensional and
30
sparse feature spaces. Identifying impactful feature crosses in such
31
environments typically relies on manual feature engineering or computationally
32
expensive exhaustive searches. While traditional feed-forward multilayer
33
perceptrons (MLPs) are universal function approximators, they often struggle to
34
efficiently learn even second- or third-order feature interactions.
35
36
The Deep & Cross Network (DCN) architecture is designed for more effective
37
learning of explicit and bounded-degree feature crosses. It comprises three main
38
components: an input layer (typically an embedding layer), a cross network for
39
modeling explicit feature interactions, and a deep network for capturing
40
implicit interactions.
41
42
The cross network is the core of the DCN. It explicitly performs feature
43
crossing at each layer, with the highest polynomial degree of feature
44
interaction increasing with depth. The following figure shows the `(i+1)`-th
45
cross layer.
46
47
![Feature Cross Layer](https://i.imgur.com/ip5uRsl.png)
48
49
The deep network is a standard feedforward multilayer perceptron
50
(MLP). These two networks are then combined to form the DCN. Two common
51
combination strategies exist: a stacked structure, where the deep network is
52
placed on top of the cross network, and a parallel structure, where they
53
operate in parallel.
54
55
<table>
56
<tr>
57
<td>
58
<figure>
59
<img src="https://i.imgur.com/rNn0zxS.png" alt="Parallel layers" width="1000" height="500">
60
<figcaption>Parallel layers</figcaption>
61
</figure>
62
</td>
63
<td>
64
<figure>
65
<img src="https://i.imgur.com/g32nzCl.png" alt="Stacked layers" width="1000" height="500">
66
<figcaption>Stacked layers</figcaption>
67
</figure>
68
</td>
69
</tr>
70
</table>
71
72
Now that we know a little bit about DCN, let's start writing some code. We will
73
first train a DCN on a toy dataset, and demonstrate that the model has indeed
74
learnt important feature crosses.
75
76
Let's set the backend to JAX, and get our imports sorted.
77
"""
78
79
"""shell
80
pip install -q keras-rs
81
"""
82
83
import os
84
85
os.environ["KERAS_BACKEND"] = "jax" # `"tensorflow"`/`"torch"`
86
87
import keras
88
import matplotlib.pyplot as plt
89
import numpy as np
90
import tensorflow as tf
91
import tensorflow_datasets as tfds
92
from mpl_toolkits.axes_grid1 import make_axes_locatable
93
94
import keras_rs
95
96
"""
97
Let's also define variables which will be reused throughout the example.
98
"""
99
100
TOY_CONFIG = {
101
"learning_rate": 0.01,
102
"num_epochs": 100,
103
"batch_size": 1024,
104
}
105
106
MOVIELENS_CONFIG = {
107
# features
108
"int_features": [
109
"movie_id",
110
"user_id",
111
"user_gender",
112
"bucketized_user_age",
113
],
114
"str_features": [
115
"user_zip_code",
116
"user_occupation_text",
117
],
118
# model
119
"embedding_dim": 8,
120
"deep_net_num_units": [192, 192, 192],
121
"projection_dim": 8,
122
"dcn_num_units": [192, 192],
123
# training
124
"learning_rate": 1e-2,
125
"num_epochs": 8,
126
"batch_size": 8192,
127
}
128
129
130
"""
131
Here, we define a helper function for visualising weights of the cross layer in
132
order to better understand its functioning. Also, we define a function for
133
compiling, training and evaluating a given model.
134
"""
135
136
137
def visualize_layer(matrix, features):
138
plt.figure(figsize=(9, 9))
139
140
im = plt.matshow(np.abs(matrix), cmap=plt.cm.Blues)
141
142
ax = plt.gca()
143
divider = make_axes_locatable(plt.gca())
144
cax = divider.append_axes("right", size="5%", pad=0.05)
145
plt.colorbar(im, cax=cax)
146
cax.tick_params(labelsize=10)
147
ax.set_xticklabels([""] + features, rotation=45, fontsize=5)
148
ax.set_yticklabels([""] + features, fontsize=5)
149
150
151
def train_and_evaluate(
152
learning_rate,
153
epochs,
154
train_data,
155
test_data,
156
model,
157
):
158
optimizer = keras.optimizers.AdamW(learning_rate=learning_rate)
159
loss = keras.losses.MeanSquaredError()
160
rmse = keras.metrics.RootMeanSquaredError()
161
162
model.compile(
163
optimizer=optimizer,
164
loss=loss,
165
metrics=[rmse],
166
)
167
168
model.fit(
169
train_data,
170
epochs=epochs,
171
verbose=0,
172
)
173
174
results = model.evaluate(test_data, return_dict=True, verbose=0)
175
rmse_value = results["root_mean_squared_error"]
176
177
return rmse_value, model.count_params()
178
179
180
def print_stats(rmse_list, num_params, model_name):
181
# Report metrics.
182
num_trials = len(rmse_list)
183
avg_rmse = np.mean(rmse_list)
184
std_rmse = np.std(rmse_list)
185
186
if num_trials == 1:
187
print(f"{model_name}: RMSE = {avg_rmse}; #params = {num_params}")
188
else:
189
print(f"{model_name}: RMSE = {avg_rmse} ± {std_rmse}; #params = {num_params}")
190
191
192
"""
193
## Toy Example
194
195
To illustrate the benefits of DCNs, let's consider a simple example. Suppose we
196
have a dataset for modeling the likelihood of a customer clicking on a blender
197
advertisement. The features and label are defined as follows:
198
199
| **Features / Label** | **Description** | **Range**|
200
|:--------------------:|:------------------------------:|:--------:|
201
| `x1` = country | Customer's resident country | [0, 199] |
202
| `x2` = bananas | # bananas purchased | [0, 23] |
203
| `x3` = cookbooks | # cooking books purchased | [0, 5] |
204
| `y` | Blender ad click likelihood | - |
205
206
Then, we let the data follow the following underlying distribution:
207
`y = f(x1, x2, x3) = 0.1x1 + 0.4x2 + 0.7x3 + 0.1x1x2 +`
208
`3.1x2x3 + 0.1x3^2`.
209
210
This distribution shows that the click likelihood (`y`) depends linearly on
211
individual features (`xi`) and on multiplicative interactions between them. In
212
this scenario, the likelihood of purchasing a blender (`y`) is influenced not
213
only by purchasing bananas (`x2`) or cookbooks (`x3`) individually, but also
214
significantly by the interaction of purchasing both bananas and cookbooks
215
(`x2x3`).
216
217
### Preparing the dataset
218
219
Let's create synthetic data based on the above equation, and form the train-test
220
splits.
221
"""
222
223
224
def get_mixer_data(data_size=100_000):
225
country = np.random.randint(200, size=[data_size, 1]) / 200.0
226
bananas = np.random.randint(24, size=[data_size, 1]) / 24.0
227
cookbooks = np.random.randint(6, size=[data_size, 1]) / 6.0
228
229
x = np.concatenate([country, bananas, cookbooks], axis=1)
230
231
# Create 1st-order terms.
232
y = 0.1 * country + 0.4 * bananas + 0.7 * cookbooks
233
234
# Create 2nd-order cross terms.
235
y += (
236
0.1 * country * bananas
237
+ 3.1 * bananas * cookbooks
238
+ (0.1 * cookbooks * cookbooks)
239
)
240
241
return x, y
242
243
244
x, y = get_mixer_data(data_size=100_000)
245
num_train = 90_000
246
train_x = x[:num_train]
247
train_y = y[:num_train]
248
test_x = x[num_train:]
249
test_y = y[num_train:]
250
251
"""
252
### Building the model
253
254
To demonstrate the advantages of a cross network in recommender systems, we'll
255
compare its performance with a deep network. Since our example data only
256
contains second-order feature interactions, a single-layered cross network will
257
suffice. For datasets with higher-order interactions, multiple cross layers can
258
be stacked to form a multi-layered cross network. We will build two models:
259
260
1. A cross network with a single cross layer.
261
2. A deep network with wider and deeper feedforward layers.
262
"""
263
264
cross_network = keras.Sequential(
265
[
266
keras_rs.layers.FeatureCross(),
267
keras.layers.Dense(1),
268
]
269
)
270
271
deep_network = keras.Sequential(
272
[
273
keras.layers.Dense(512, activation="relu"),
274
keras.layers.Dense(256, activation="relu"),
275
keras.layers.Dense(128, activation="relu"),
276
keras.layers.Dense(1),
277
]
278
)
279
280
"""
281
### Model training
282
283
Before we train the model, we need to batch our datasets.
284
"""
285
286
train_ds = tf.data.Dataset.from_tensor_slices((train_x, train_y)).batch(
287
TOY_CONFIG["batch_size"]
288
)
289
test_ds = tf.data.Dataset.from_tensor_slices((test_x, test_y)).batch(
290
TOY_CONFIG["batch_size"]
291
)
292
293
"""
294
Let's train both models. Remember we have set `verbose=0` for brevity's
295
sake, so do not be alarmed if you do not see any output for a while.
296
297
After training, we evaluate the models on the unseen dataset. We will report
298
the Root Mean Squared Error (RMSE) here.
299
300
We observe that the cross network achieved significantly lower RMSE compared to
301
a ReLU-based DNN, while also using fewer parameters. This points to the
302
efficiency of the cross network in learning feature interactions.
303
"""
304
305
cross_network_rmse, cross_network_num_params = train_and_evaluate(
306
learning_rate=TOY_CONFIG["learning_rate"],
307
epochs=TOY_CONFIG["num_epochs"],
308
train_data=train_ds,
309
test_data=test_ds,
310
model=cross_network,
311
)
312
print_stats(
313
rmse_list=[cross_network_rmse],
314
num_params=cross_network_num_params,
315
model_name="Cross Network",
316
)
317
318
deep_network_rmse, deep_network_num_params = train_and_evaluate(
319
learning_rate=TOY_CONFIG["learning_rate"],
320
epochs=TOY_CONFIG["num_epochs"],
321
train_data=train_ds,
322
test_data=test_ds,
323
model=deep_network,
324
)
325
print_stats(
326
rmse_list=[deep_network_rmse],
327
num_params=deep_network_num_params,
328
model_name="Deep Network",
329
)
330
331
"""
332
### Visualizing feature interactions
333
334
Since we already know which feature crosses are important in our data, it would
335
be interesting to verify whether our model has indeed learned these key feature
336
interactions. This can be done by visualizing the learned weight matrix in the
337
cross network, where the weight `Wij` represents the learned importance of
338
the interaction between features `xi` and `xj`.
339
"""
340
341
visualize_layer(
342
matrix=cross_network.weights[0].numpy(),
343
features=["country", "purchased_bananas", "purchased_cookbooks"],
344
)
345
346
"""
347
## Real-world example
348
349
Let's use the MovieLens 100K dataset. This dataset is used to train models to
350
predict users' movie ratings, based on user-related features and movie-related
351
features.
352
353
### Preparing the dataset
354
355
The dataset processing steps here are similar to what's given in the
356
[basic ranking](/keras_rs/examples/basic_ranking/)
357
tutorial. Let's load the dataset, and keep only the useful columns.
358
"""
359
360
ratings_ds = tfds.load("movielens/100k-ratings", split="train")
361
ratings_ds = ratings_ds.map(
362
lambda x: (
363
{
364
"movie_id": int(x["movie_id"]),
365
"user_id": int(x["user_id"]),
366
"user_gender": int(x["user_gender"]),
367
"user_zip_code": x["user_zip_code"],
368
"user_occupation_text": x["user_occupation_text"],
369
"bucketized_user_age": int(x["bucketized_user_age"]),
370
},
371
x["user_rating"], # label
372
)
373
)
374
375
"""
376
For every feature, let's get the list of unique values, i.e., vocabulary, so
377
that we can use that for the embedding layer.
378
"""
379
380
vocabularies = {}
381
for feature_name in MOVIELENS_CONFIG["int_features"] + MOVIELENS_CONFIG["str_features"]:
382
vocabulary = ratings_ds.batch(10_000).map(lambda x, y: x[feature_name])
383
vocabularies[feature_name] = np.unique(np.concatenate(list(vocabulary)))
384
385
"""
386
One thing we need to do is to use `keras.layers.StringLookup` and
387
`keras.layers.IntegerLookup` to convert all features into indices, which can
388
then be fed into embedding layers.
389
"""
390
391
lookup_layers = {}
392
lookup_layers.update(
393
{
394
feature: keras.layers.IntegerLookup(vocabulary=vocabularies[feature])
395
for feature in MOVIELENS_CONFIG["int_features"]
396
}
397
)
398
lookup_layers.update(
399
{
400
feature: keras.layers.StringLookup(vocabulary=vocabularies[feature])
401
for feature in MOVIELENS_CONFIG["str_features"]
402
}
403
)
404
405
ratings_ds = ratings_ds.map(
406
lambda x, y: (
407
{
408
feature_name: lookup_layers[feature_name](x[feature_name])
409
for feature_name in vocabularies
410
},
411
y,
412
)
413
)
414
415
"""
416
Let's split our data into train and test sets. We also use `cache()` and
417
`prefetch()` for better performance.
418
"""
419
420
ratings_ds = ratings_ds.shuffle(100_000)
421
422
train_ds = (
423
ratings_ds.take(80_000)
424
.batch(MOVIELENS_CONFIG["batch_size"])
425
.cache()
426
.prefetch(tf.data.AUTOTUNE)
427
)
428
test_ds = (
429
ratings_ds.skip(80_000)
430
.batch(MOVIELENS_CONFIG["batch_size"])
431
.take(20_000)
432
.cache()
433
.prefetch(tf.data.AUTOTUNE)
434
)
435
436
"""
437
### Building the model
438
439
The model will have embedding layers, followed by cross and/or feedforward
440
layers.
441
"""
442
443
444
class DCN(keras.Model):
445
def __init__(
446
self,
447
dense_num_units_lst,
448
embedding_dim=MOVIELENS_CONFIG["embedding_dim"],
449
use_cross_layer=False,
450
projection_dim=None,
451
**kwargs,
452
):
453
super().__init__(**kwargs)
454
455
# Layers.
456
457
self.embedding_layers = []
458
for feature_name, vocabulary in vocabularies.items():
459
self.embedding_layers.append(
460
keras.layers.Embedding(
461
input_dim=len(vocabulary) + 1,
462
output_dim=embedding_dim,
463
)
464
)
465
466
if use_cross_layer:
467
self.cross_layer = keras_rs.layers.FeatureCross(
468
projection_dim=projection_dim
469
)
470
471
self.dense_layers = []
472
for num_units in dense_num_units_lst:
473
self.dense_layers.append(keras.layers.Dense(num_units, activation="relu"))
474
475
self.output_layer = keras.layers.Dense(1)
476
477
# Attributes.
478
self.dense_num_units_lst = dense_num_units_lst
479
self.embedding_dim = embedding_dim
480
self.use_cross_layer = use_cross_layer
481
self.projection_dim = projection_dim
482
483
def call(self, inputs):
484
embeddings = []
485
for feature_name, embedding_layer in zip(vocabularies, self.embedding_layers):
486
embeddings.append(embedding_layer(inputs[feature_name]))
487
488
x = keras.ops.concatenate(embeddings, axis=1)
489
490
if self.use_cross_layer:
491
x = self.cross_layer(x)
492
493
for dense_layer in self.dense_layers:
494
x = dense_layer(x)
495
496
x = self.output_layer(x)
497
498
return x
499
500
501
"""
502
We have three models - a deep cross network, an optimised deep cross
503
network with a low-rank matrix (to reduce training and serving costs) and a
504
normal deep network without cross layers. The deep cross network is a stacked
505
DCN model, i.e., the inputs are fed to cross layers, followed by feedforward
506
layers. Let's run each model 10 times, and report the average/standard
507
deviation of the RMSE.
508
"""
509
510
cross_network_rmse_list = []
511
opt_cross_network_rmse_list = []
512
deep_network_rmse_list = []
513
514
for _ in range(20):
515
cross_network = DCN(
516
dense_num_units_lst=MOVIELENS_CONFIG["dcn_num_units"],
517
embedding_dim=MOVIELENS_CONFIG["embedding_dim"],
518
use_cross_layer=True,
519
)
520
rmse, cross_network_num_params = train_and_evaluate(
521
learning_rate=MOVIELENS_CONFIG["learning_rate"],
522
epochs=MOVIELENS_CONFIG["num_epochs"],
523
train_data=train_ds,
524
test_data=test_ds,
525
model=cross_network,
526
)
527
cross_network_rmse_list.append(rmse)
528
529
opt_cross_network = DCN(
530
dense_num_units_lst=MOVIELENS_CONFIG["dcn_num_units"],
531
embedding_dim=MOVIELENS_CONFIG["embedding_dim"],
532
use_cross_layer=True,
533
projection_dim=MOVIELENS_CONFIG["projection_dim"],
534
)
535
rmse, opt_cross_network_num_params = train_and_evaluate(
536
learning_rate=MOVIELENS_CONFIG["learning_rate"],
537
epochs=MOVIELENS_CONFIG["num_epochs"],
538
train_data=train_ds,
539
test_data=test_ds,
540
model=opt_cross_network,
541
)
542
opt_cross_network_rmse_list.append(rmse)
543
544
deep_network = DCN(dense_num_units_lst=MOVIELENS_CONFIG["deep_net_num_units"])
545
rmse, deep_network_num_params = train_and_evaluate(
546
learning_rate=MOVIELENS_CONFIG["learning_rate"],
547
epochs=MOVIELENS_CONFIG["num_epochs"],
548
train_data=train_ds,
549
test_data=test_ds,
550
model=deep_network,
551
)
552
deep_network_rmse_list.append(rmse)
553
554
print_stats(
555
rmse_list=cross_network_rmse_list,
556
num_params=cross_network_num_params,
557
model_name="Cross Network",
558
)
559
print_stats(
560
rmse_list=opt_cross_network_rmse_list,
561
num_params=opt_cross_network_num_params,
562
model_name="Optimised Cross Network",
563
)
564
print_stats(
565
rmse_list=deep_network_rmse_list,
566
num_params=deep_network_num_params,
567
model_name="Deep Network",
568
)
569
570
"""
571
DCN slightly outperforms a larger DNN with ReLU layers, demonstrating
572
superior performance. Furthermore, the low-rank DCN effectively reduces the
573
number of parameters without compromising accuracy.
574
"""
575
576
"""
577
### Visualizing feature interactions
578
579
Like we did for the toy example, we will plot the weight matrix of the cross
580
layer to see which feature crosses are important. In the previous example,
581
the importance of interactions between the `i`-th and `j-th` features is
582
captured by the `(i, j)`-{th} element of the weight matrix.
583
584
In this case, the feature embeddings are of size 32 rather than 1. Therefore,
585
the importance of feature interactions is represented by the `(i, j)`-th
586
block of the weight matrix, which has dimensions `32 x 32`. To quantify the
587
significance of these interactions, we use the Frobenius norm of each block. A
588
larger value implies higher importance.
589
"""
590
591
features = list(vocabularies.keys())
592
mat = cross_network.weights[len(features)].numpy()
593
embedding_dim = MOVIELENS_CONFIG["embedding_dim"]
594
595
block_norm = np.zeros([len(features), len(features)])
596
597
# Compute the norms of the blocks.
598
for i in range(len(features)):
599
for j in range(len(features)):
600
block = mat[
601
i * embedding_dim : (i + 1) * embedding_dim,
602
j * embedding_dim : (j + 1) * embedding_dim,
603
]
604
block_norm[i, j] = np.linalg.norm(block, ord="fro")
605
606
visualize_layer(
607
matrix=block_norm,
608
features=features,
609
)
610
611
"""
612
And we are all done!
613
"""
614
615