Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/keras_rs/distributed_embedding_tf.py
3507 views
1
"""
2
Title: DistributedEmbedding using TPU SparseCore and TensorFlow
3
Author: [Fabien Hertschuh](https://github.com/hertschuh/), [Abheesht Sharma](https://github.com/abheesht17/)
4
Date created: 2025/09/02
5
Last modified: 2025/09/02
6
Description: Rank movies using a two tower model with embeddings on SparseCore.
7
Accelerator: TPU
8
"""
9
10
"""
11
## Introduction
12
13
In the [basic ranking](/keras_rs/examples/basic_ranking/) tutorial, we showed
14
how to build a ranking model for the MovieLens dataset to suggest movies to
15
users.
16
17
This tutorial implements the same model trained on the same dataset but with the
18
use of `keras_rs.layers.DistributedEmbedding`, which makes use of SparseCore on
19
TPU. This is the TensorFlow version of the tutorial. It needs to be run on TPU
20
v5p or v6e.
21
22
Let's begin by installing the necessary libraries. Note that we need
23
`tensorflow-tpu` version 2.19. We'll also install `keras-rs`.
24
"""
25
26
"""shell
27
pip install -U -q tensorflow-tpu==2.19.1
28
pip install -q keras-rs
29
"""
30
31
"""
32
We're using the PJRT version of the runtime for TensorFlow. We're also enabling
33
the MLIR bridge. This requires setting a few flags before importing TensorFlow.
34
"""
35
36
import os
37
import libtpu
38
39
os.environ["PJRT_DEVICE"] = "TPU"
40
os.environ["NEXT_PLUGGABLE_DEVICE_USE_C_API"] = "true"
41
os.environ["TF_PLUGGABLE_DEVICE_LIBRARY_PATH"] = libtpu.get_library_path()
42
os.environ["TF_XLA_FLAGS"] = (
43
"--tf_mlir_enable_mlir_bridge=true "
44
"--tf_mlir_enable_convert_control_to_data_outputs_pass=true "
45
"--tf_mlir_enable_merge_control_flow_pass=true"
46
)
47
48
import tensorflow as tf
49
50
"""
51
We now set the Keras backend to TensorFlow and import the necessary libraries.
52
"""
53
54
os.environ["KERAS_BACKEND"] = "tensorflow"
55
56
import keras
57
import keras_rs
58
import tensorflow_datasets as tfds
59
60
"""
61
## Creating a `TPUStrategy`
62
63
To run TensorFlow on TPU, you need to use a
64
[`tf.distribute.TPUStrategy`](https://www.tensorflow.org/api_docs/python/tf/distribute/TPUStrategy)
65
to handle the distribution of the model.
66
67
The core of the model is replicated across TPU instances, which is done by the
68
`TPUStrategy`. Note that on GPU you would use
69
[`tf.distribute.MirroredStrategy`](https://www.tensorflow.org/api_docs/python/tf/distribute/MirroredStrategy)
70
instead, but this strategy is not for TPU.
71
72
Only the embedding tables handled by `DistributedEmbedding` are sharded across
73
the SparseCore chips of all the available TPUs.
74
"""
75
76
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="local")
77
topology = tf.tpu.experimental.initialize_tpu_system(resolver)
78
tpu_metadata = resolver.get_tpu_system_metadata()
79
80
device_assignment = tf.tpu.experimental.DeviceAssignment.build(
81
topology, num_replicas=tpu_metadata.num_cores
82
)
83
strategy = tf.distribute.TPUStrategy(
84
resolver, experimental_device_assignment=device_assignment
85
)
86
87
"""
88
## Dataset distribution
89
90
While the model is replicated and the embedding tables are sharded across
91
SparseCores, the dataset is distributed by sharding each batch across the TPUs.
92
We need to make sure the batch size is a multiple of the number of TPUs.
93
"""
94
95
PER_REPLICA_BATCH_SIZE = 256
96
BATCH_SIZE = PER_REPLICA_BATCH_SIZE * strategy.num_replicas_in_sync
97
98
"""
99
## Preparing the dataset
100
101
We're going to use the same MovieLens data. The ratings are the objectives we
102
are trying to predict.
103
"""
104
105
# Ratings data.
106
ratings = tfds.load("movielens/100k-ratings", split="train")
107
# Features of all the available movies.
108
movies = tfds.load("movielens/100k-movies", split="train")
109
110
"""
111
We need to know the number of users as we're using the user ID directly as an
112
index in the user embedding table.
113
"""
114
115
users_count = int(
116
ratings.map(lambda x: tf.strings.to_number(x["user_id"], out_type=tf.int32))
117
.reduce(tf.constant(0, tf.int32), tf.maximum)
118
.numpy()
119
)
120
121
"""
122
We also need do know the number of movies as we're using the movie ID directly
123
as an index in the movie embedding table.
124
"""
125
126
movies_count = int(movies.cardinality().numpy())
127
128
"""
129
The inputs to the model are the user IDs and movie IDs and the labels are the
130
ratings.
131
"""
132
133
134
def preprocess_rating(x):
135
return (
136
# Inputs are user IDs and movie IDs
137
{
138
"user_id": tf.strings.to_number(x["user_id"], out_type=tf.int32),
139
"movie_id": tf.strings.to_number(x["movie_id"], out_type=tf.int32),
140
},
141
# Labels are ratings between 0 and 1.
142
(x["user_rating"] - 1.0) / 4.0,
143
)
144
145
146
"""
147
We'll split the data by putting 80% of the ratings in the train set, and 20% in
148
the test set.
149
"""
150
151
shuffled_ratings = ratings.map(preprocess_rating).shuffle(
152
100_000, seed=42, reshuffle_each_iteration=False
153
)
154
train_ratings = (
155
shuffled_ratings.take(80_000).batch(BATCH_SIZE, drop_remainder=True).cache()
156
)
157
test_ratings = (
158
shuffled_ratings.skip(80_000)
159
.take(20_000)
160
.batch(BATCH_SIZE, drop_remainder=True)
161
.cache()
162
)
163
164
"""
165
## Configuring DistributedEmbedding
166
167
The `keras_rs.layers.DistributedEmbedding` handles multiple features and
168
multiple embedding tables. This is to enable the sharing of tables between
169
features and allow some optimizations that come from combining multiple
170
embedding lookups into a single invocation. In this section, we'll describe
171
how to configure these.
172
173
### Configuring tables
174
175
Tables are configured using `keras_rs.layers.TableConfig`, which has:
176
177
- A name.
178
- A vocabulary size (input size).
179
- an embedding dimension (output size).
180
- A combiner to specify how to reduce multiple embeddings into a single one in
181
the case when we embed a sequence. Note that this doesn't apply to our example
182
because we're getting a single embedding for each user and each movie.
183
- A placement to tell whether to put the table on the SparseCore chips or not.
184
In this case, we want the `"sparsecore"` placement.
185
- An optimizer to specify how to apply gradients when training. Each table has
186
its own optimizer and the one passed to `model.compile()` is not used for the
187
embedding tables.
188
189
### Configuring features
190
191
Features are configured using `keras_rs.layers.FeatureConfig`, which has:
192
193
- A name.
194
- A table, the embedding table to use.
195
- An input shape (batch size is for all TPUs).
196
- An output shape (batch size is for all TPUs).
197
198
We can organize features in any structure we want, which can be nested. A dict
199
is often a good choice to have names for the inputs and outputs.
200
"""
201
202
EMBEDDING_DIMENSION = 32
203
204
movie_table = keras_rs.layers.TableConfig(
205
name="movie_table",
206
vocabulary_size=movies_count + 1, # +1 for movie ID 0, which is not used
207
embedding_dim=EMBEDDING_DIMENSION,
208
optimizer="adam",
209
placement="sparsecore",
210
)
211
user_table = keras_rs.layers.TableConfig(
212
name="user_table",
213
vocabulary_size=users_count + 1, # +1 for user ID 0, which is not used
214
embedding_dim=EMBEDDING_DIMENSION,
215
optimizer="adam",
216
placement="sparsecore",
217
)
218
219
FEATURE_CONFIGS = {
220
"movie_id": keras_rs.layers.FeatureConfig(
221
name="movie",
222
table=movie_table,
223
input_shape=(BATCH_SIZE,),
224
output_shape=(BATCH_SIZE, EMBEDDING_DIMENSION),
225
),
226
"user_id": keras_rs.layers.FeatureConfig(
227
name="user",
228
table=user_table,
229
input_shape=(BATCH_SIZE,),
230
output_shape=(BATCH_SIZE, EMBEDDING_DIMENSION),
231
),
232
}
233
234
"""
235
## Defining the Model
236
237
We're now ready to create a `DistributedEmbedding` inside a model. Once we have
238
the configuration, we simply pass it the constructor of `DistributedEmbedding`.
239
Then, within the model `call` method, `DistributedEmbedding` is the first layer
240
we call.
241
242
The ouputs have the exact same structure as the inputs. In our example, we
243
concatenate the embeddings we got as outputs and run them through a tower of
244
dense layers.
245
"""
246
247
248
class EmbeddingModel(keras.Model):
249
"""Create the model with the embedding configuration.
250
251
Args:
252
feature_configs: the configuration for `DistributedEmbedding`.
253
"""
254
255
def __init__(self, feature_configs):
256
super().__init__()
257
258
self.embedding_layer = keras_rs.layers.DistributedEmbedding(
259
feature_configs=feature_configs
260
)
261
self.ratings = keras.Sequential(
262
[
263
# Learn multiple dense layers.
264
keras.layers.Dense(256, activation="relu"),
265
keras.layers.Dense(64, activation="relu"),
266
# Make rating predictions in the final layer.
267
keras.layers.Dense(1),
268
]
269
)
270
271
def call(self, features):
272
# Embedding lookup. Outputs have the same structure as the inputs.
273
embedding = self.embedding_layer(features)
274
return self.ratings(
275
keras.ops.concatenate(
276
[embedding["user_id"], embedding["movie_id"]],
277
axis=1,
278
)
279
)
280
281
282
"""
283
Let's now instantiate the model. We then use `model.compile()` to configure the
284
loss, metrics and optimizer. Again, this Adagrad optimizer will only apply to
285
the dense layers and not the embedding tables.
286
"""
287
288
with strategy.scope():
289
model = EmbeddingModel(FEATURE_CONFIGS)
290
291
model.compile(
292
loss=keras.losses.MeanSquaredError(),
293
metrics=[keras.metrics.RootMeanSquaredError()],
294
optimizer="adagrad",
295
)
296
297
"""
298
## Fitting and evaluating
299
300
We can use the standard Keras `model.fit()` to train the model. Keras will
301
automatically use the `TPUStrategy` to distribute the model and the data.
302
"""
303
304
with strategy.scope():
305
model.fit(train_ratings, epochs=5)
306
307
"""
308
Same for `model.evaluate()`.
309
"""
310
311
with strategy.scope():
312
model.evaluate(test_ratings, return_dict=True)
313
314
"""
315
That's it.
316
317
This example shows that after setting up the `TPUStrategy` and configuring the
318
`DistributedEmbedding`, you can use the standard Keras workflows.
319
"""
320
321