Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/keras_rs/distributed_embedding_jax.py
3507 views
1
"""
2
Title: DistributedEmbedding using TPU SparseCore and JAX
3
Author: [Fabien Hertschuh](https://github.com/hertschuh/), [Abheesht Sharma](https://github.com/abheesht17/), [C. Antonio Sánchez](https://github.com/cantonios/)
4
Date created: 2025/06/03
5
Last modified: 2025/09/02
6
Description: Rank movies using a two tower model with embeddings on SparseCore.
7
Accelerator: TPU
8
"""
9
10
"""
11
## Introduction
12
13
In the [basic ranking](/keras_rs/examples/basic_ranking/) tutorial, we showed
14
how to build a ranking model for the MovieLens dataset to suggest movies to
15
users.
16
17
This tutorial implements the same model trained on the same dataset but with the
18
use of `keras_rs.layers.DistributedEmbedding`, which makes use of SparseCore on
19
TPU. This is the JAX version of the tutorial. It needs to be run on TPU v5p or
20
v6e.
21
22
Let's begin by choosing JAX as the backend and importing all the necessary
23
libraries.
24
"""
25
26
"""shell
27
pip install -q -U jax[tpu]>=0.7.0
28
pip install -q jax-tpu-embedding
29
pip install -q tensorflow-cpu
30
pip install -q keras-rs
31
"""
32
33
import os
34
35
os.environ["KERAS_BACKEND"] = "jax"
36
37
import jax
38
import keras
39
import keras_rs
40
import tensorflow as tf # Needed for the dataset
41
import tensorflow_datasets as tfds
42
43
"""
44
## Dataset distribution
45
46
While the model is replicated and the embedding tables are sharded across
47
SparseCores, the dataset is distributed by sharding each batch across the TPUs.
48
We need to make sure the batch size is a multiple of the number of TPUs.
49
"""
50
51
PER_REPLICA_BATCH_SIZE = 256
52
BATCH_SIZE = PER_REPLICA_BATCH_SIZE * jax.local_device_count("tpu")
53
54
distribution = keras.distribution.DataParallel(devices=jax.devices("tpu"))
55
keras.distribution.set_distribution(distribution)
56
57
"""
58
## Preparing the dataset
59
60
We're going to use the same MovieLens data. The ratings are the objectives we
61
are trying to predict.
62
"""
63
64
# Ratings data.
65
ratings = tfds.load("movielens/100k-ratings", split="train")
66
# Features of all the available movies.
67
movies = tfds.load("movielens/100k-movies", split="train")
68
69
"""
70
We need to know the number of users as we're using the user ID directly as an
71
index in the user embedding table.
72
"""
73
74
users_count = int(
75
ratings.map(lambda x: tf.strings.to_number(x["user_id"], out_type=tf.int32))
76
.reduce(tf.constant(0, tf.int32), tf.maximum)
77
.numpy()
78
)
79
80
"""
81
We also need do know the number of movies as we're using the movie ID directly
82
as an index in the movie embedding table.
83
"""
84
85
movies_count = int(movies.cardinality().numpy())
86
87
"""
88
The inputs to the model are the user IDs and movie IDs and the labels are the
89
ratings.
90
"""
91
92
93
def preprocess_rating(x):
94
return (
95
# Inputs are user IDs and movie IDs
96
{
97
"user_id": tf.strings.to_number(x["user_id"], out_type=tf.int32),
98
"movie_id": tf.strings.to_number(x["movie_id"], out_type=tf.int32),
99
},
100
# Labels are ratings between 0 and 1.
101
(x["user_rating"] - 1.0) / 4.0,
102
)
103
104
105
"""
106
We'll split the data by putting 80% of the ratings in the train set, and 20% in
107
the test set.
108
"""
109
110
shuffled_ratings = ratings.map(preprocess_rating).shuffle(
111
100_000, seed=42, reshuffle_each_iteration=False
112
)
113
train_ratings = (
114
shuffled_ratings.take(80_000).batch(BATCH_SIZE, drop_remainder=True).cache()
115
)
116
test_ratings = (
117
shuffled_ratings.skip(80_000)
118
.take(20_000)
119
.batch(BATCH_SIZE, drop_remainder=True)
120
.cache()
121
)
122
123
"""
124
## Configuring DistributedEmbedding
125
126
The `keras_rs.layers.DistributedEmbedding` handles multiple features and
127
multiple embedding tables. This is to enable the sharing of tables between
128
features and allow some optimizations that come from combining multiple
129
embedding lookups into a single invocation. In this section, we'll describe
130
how to configure these.
131
132
### Configuring tables
133
134
Tables are configured using `keras_rs.layers.TableConfig`, which has:
135
136
- A name.
137
- A vocabulary size (input size).
138
- an embedding dimension (output size).
139
- A combiner to specify how to reduce multiple embeddings into a single one in
140
the case when we embed a sequence. Note that this doesn't apply to our example
141
because we're getting a single embedding for each user and each movie.
142
- A placement to tell whether to put the table on the SparseCore chips or not.
143
In this case, we want the `"sparsecore"` placement.
144
- An optimizer to specify how to apply gradients when training. Each table has
145
its own optimizer and the one passed to `model.compile()` is not used for the
146
embedding tables.
147
148
### Configuring features
149
150
Features are configured using `keras_rs.layers.FeatureConfig`, which has:
151
152
- A name.
153
- A table, the embedding table to use.
154
- An input shape (batch size is for all TPUs).
155
- An output shape (batch size is for all TPUs).
156
157
We can organize features in any structure we want, which can be nested. A dict
158
is often a good choice to have names for the inputs and outputs.
159
"""
160
161
EMBEDDING_DIMENSION = 32
162
163
movie_table = keras_rs.layers.TableConfig(
164
name="movie_table",
165
vocabulary_size=movies_count + 1, # +1 for movie ID 0, which is not used
166
embedding_dim=EMBEDDING_DIMENSION,
167
optimizer="adam",
168
placement="sparsecore",
169
)
170
user_table = keras_rs.layers.TableConfig(
171
name="user_table",
172
vocabulary_size=users_count + 1, # +1 for user ID 0, which is not used
173
embedding_dim=EMBEDDING_DIMENSION,
174
optimizer="adam",
175
placement="sparsecore",
176
)
177
178
FEATURE_CONFIGS = {
179
"movie_id": keras_rs.layers.FeatureConfig(
180
name="movie",
181
table=movie_table,
182
input_shape=(BATCH_SIZE,),
183
output_shape=(BATCH_SIZE, EMBEDDING_DIMENSION),
184
),
185
"user_id": keras_rs.layers.FeatureConfig(
186
name="user",
187
table=user_table,
188
input_shape=(BATCH_SIZE,),
189
output_shape=(BATCH_SIZE, EMBEDDING_DIMENSION),
190
),
191
}
192
193
"""
194
## Defining the Model
195
196
We're now ready to create a `DistributedEmbedding` inside a model. Once we have
197
the configuration, we simply pass it the constructor of `DistributedEmbedding`.
198
Then, within the model `call` method, `DistributedEmbedding` is the first layer
199
we call.
200
201
The ouputs have the exact same structure as the inputs. In our example, we
202
concatenate the embeddings we got as outputs and run them through a tower of
203
dense layers.
204
"""
205
206
207
class EmbeddingModel(keras.Model):
208
"""Create the model with the embedding configuration.
209
210
Args:
211
feature_configs: the configuration for `DistributedEmbedding`.
212
"""
213
214
def __init__(self, feature_configs):
215
super().__init__()
216
217
self.embedding_layer = keras_rs.layers.DistributedEmbedding(
218
feature_configs=feature_configs
219
)
220
self.ratings = keras.Sequential(
221
[
222
# Learn multiple dense layers.
223
keras.layers.Dense(256, activation="relu"),
224
keras.layers.Dense(64, activation="relu"),
225
# Make rating predictions in the final layer.
226
keras.layers.Dense(1),
227
]
228
)
229
230
def call(self, preprocessed_features):
231
# Embedding lookup. Outputs have the same structure as the inputs.
232
embedding = self.embedding_layer(preprocessed_features)
233
return self.ratings(
234
keras.ops.concatenate(
235
[embedding["user_id"], embedding["movie_id"]],
236
axis=1,
237
)
238
)
239
240
241
"""
242
Let's now instantiate the model. We then use `model.compile()` to configure the
243
loss, metrics and optimizer. Again, this Adagrad optimizer will only apply to
244
the dense layers and not the embedding tables.
245
"""
246
247
model = EmbeddingModel(FEATURE_CONFIGS)
248
249
model.compile(
250
loss=keras.losses.MeanSquaredError(),
251
metrics=[keras.metrics.RootMeanSquaredError()],
252
optimizer="adagrad",
253
)
254
255
"""
256
With the JAX backend, we need to preprocess the inputs to convert them to a
257
hardware-dependent format required for use with SparseCores. We'll do this by
258
wrapping the datasets into generator functions.
259
"""
260
261
262
def train_dataset_generator():
263
for inputs, labels in iter(train_ratings):
264
yield model.embedding_layer.preprocess(inputs, training=True), labels
265
266
267
def test_dataset_generator():
268
for inputs, labels in iter(test_ratings):
269
yield model.embedding_layer.preprocess(inputs, training=False), labels
270
271
272
"""
273
## Fitting and evaluating
274
275
We can use the standard Keras `model.fit()` to train the model. Keras will
276
automatically use the `TPUStrategy` to distribute the model and the data.
277
"""
278
279
model.fit(train_dataset_generator(), epochs=5)
280
281
"""
282
Same for `model.evaluate()`.
283
"""
284
285
model.evaluate(test_dataset_generator(), return_dict=True)
286
287
"""
288
That's it.
289
290
This example shows that after configuring the `DistributedEmbedding` and setting
291
up the required preprocessing, you can use the standard Keras workflows.
292
"""
293
294