Path: blob/master/examples/keras_rs/distributed_embedding_tf.py
3507 views
"""1Title: DistributedEmbedding using TPU SparseCore and TensorFlow2Author: [Fabien Hertschuh](https://github.com/hertschuh/), [Abheesht Sharma](https://github.com/abheesht17/)3Date created: 2025/09/024Last modified: 2025/09/025Description: Rank movies using a two tower model with embeddings on SparseCore.6Accelerator: TPU7"""89"""10## Introduction1112In the [basic ranking](/keras_rs/examples/basic_ranking/) tutorial, we showed13how to build a ranking model for the MovieLens dataset to suggest movies to14users.1516This tutorial implements the same model trained on the same dataset but with the17use of `keras_rs.layers.DistributedEmbedding`, which makes use of SparseCore on18TPU. This is the TensorFlow version of the tutorial. It needs to be run on TPU19v5p or v6e.2021Let's begin by installing the necessary libraries. Note that we need22`tensorflow-tpu` version 2.19. We'll also install `keras-rs`.23"""2425"""shell26pip install -U -q tensorflow-tpu==2.19.127pip install -q keras-rs28"""2930"""31We're using the PJRT version of the runtime for TensorFlow. We're also enabling32the MLIR bridge. This requires setting a few flags before importing TensorFlow.33"""3435import os36import libtpu3738os.environ["PJRT_DEVICE"] = "TPU"39os.environ["NEXT_PLUGGABLE_DEVICE_USE_C_API"] = "true"40os.environ["TF_PLUGGABLE_DEVICE_LIBRARY_PATH"] = libtpu.get_library_path()41os.environ["TF_XLA_FLAGS"] = (42"--tf_mlir_enable_mlir_bridge=true "43"--tf_mlir_enable_convert_control_to_data_outputs_pass=true "44"--tf_mlir_enable_merge_control_flow_pass=true"45)4647import tensorflow as tf4849"""50We now set the Keras backend to TensorFlow and import the necessary libraries.51"""5253os.environ["KERAS_BACKEND"] = "tensorflow"5455import keras56import keras_rs57import tensorflow_datasets as tfds5859"""60## Creating a `TPUStrategy`6162To run TensorFlow on TPU, you need to use a63[`tf.distribute.TPUStrategy`](https://www.tensorflow.org/api_docs/python/tf/distribute/TPUStrategy)64to handle the distribution of the model.6566The core of the model is replicated across TPU instances, which is done by the67`TPUStrategy`. Note that on GPU you would use68[`tf.distribute.MirroredStrategy`](https://www.tensorflow.org/api_docs/python/tf/distribute/MirroredStrategy)69instead, but this strategy is not for TPU.7071Only the embedding tables handled by `DistributedEmbedding` are sharded across72the SparseCore chips of all the available TPUs.73"""7475resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="local")76topology = tf.tpu.experimental.initialize_tpu_system(resolver)77tpu_metadata = resolver.get_tpu_system_metadata()7879device_assignment = tf.tpu.experimental.DeviceAssignment.build(80topology, num_replicas=tpu_metadata.num_cores81)82strategy = tf.distribute.TPUStrategy(83resolver, experimental_device_assignment=device_assignment84)8586"""87## Dataset distribution8889While the model is replicated and the embedding tables are sharded across90SparseCores, the dataset is distributed by sharding each batch across the TPUs.91We need to make sure the batch size is a multiple of the number of TPUs.92"""9394PER_REPLICA_BATCH_SIZE = 25695BATCH_SIZE = PER_REPLICA_BATCH_SIZE * strategy.num_replicas_in_sync9697"""98## Preparing the dataset99100We're going to use the same MovieLens data. The ratings are the objectives we101are trying to predict.102"""103104# Ratings data.105ratings = tfds.load("movielens/100k-ratings", split="train")106# Features of all the available movies.107movies = tfds.load("movielens/100k-movies", split="train")108109"""110We need to know the number of users as we're using the user ID directly as an111index in the user embedding table.112"""113114users_count = int(115ratings.map(lambda x: tf.strings.to_number(x["user_id"], out_type=tf.int32))116.reduce(tf.constant(0, tf.int32), tf.maximum)117.numpy()118)119120"""121We also need do know the number of movies as we're using the movie ID directly122as an index in the movie embedding table.123"""124125movies_count = int(movies.cardinality().numpy())126127"""128The inputs to the model are the user IDs and movie IDs and the labels are the129ratings.130"""131132133def preprocess_rating(x):134return (135# Inputs are user IDs and movie IDs136{137"user_id": tf.strings.to_number(x["user_id"], out_type=tf.int32),138"movie_id": tf.strings.to_number(x["movie_id"], out_type=tf.int32),139},140# Labels are ratings between 0 and 1.141(x["user_rating"] - 1.0) / 4.0,142)143144145"""146We'll split the data by putting 80% of the ratings in the train set, and 20% in147the test set.148"""149150shuffled_ratings = ratings.map(preprocess_rating).shuffle(151100_000, seed=42, reshuffle_each_iteration=False152)153train_ratings = (154shuffled_ratings.take(80_000).batch(BATCH_SIZE, drop_remainder=True).cache()155)156test_ratings = (157shuffled_ratings.skip(80_000)158.take(20_000)159.batch(BATCH_SIZE, drop_remainder=True)160.cache()161)162163"""164## Configuring DistributedEmbedding165166The `keras_rs.layers.DistributedEmbedding` handles multiple features and167multiple embedding tables. This is to enable the sharing of tables between168features and allow some optimizations that come from combining multiple169embedding lookups into a single invocation. In this section, we'll describe170how to configure these.171172### Configuring tables173174Tables are configured using `keras_rs.layers.TableConfig`, which has:175176- A name.177- A vocabulary size (input size).178- an embedding dimension (output size).179- A combiner to specify how to reduce multiple embeddings into a single one in180the case when we embed a sequence. Note that this doesn't apply to our example181because we're getting a single embedding for each user and each movie.182- A placement to tell whether to put the table on the SparseCore chips or not.183In this case, we want the `"sparsecore"` placement.184- An optimizer to specify how to apply gradients when training. Each table has185its own optimizer and the one passed to `model.compile()` is not used for the186embedding tables.187188### Configuring features189190Features are configured using `keras_rs.layers.FeatureConfig`, which has:191192- A name.193- A table, the embedding table to use.194- An input shape (batch size is for all TPUs).195- An output shape (batch size is for all TPUs).196197We can organize features in any structure we want, which can be nested. A dict198is often a good choice to have names for the inputs and outputs.199"""200201EMBEDDING_DIMENSION = 32202203movie_table = keras_rs.layers.TableConfig(204name="movie_table",205vocabulary_size=movies_count + 1, # +1 for movie ID 0, which is not used206embedding_dim=EMBEDDING_DIMENSION,207optimizer="adam",208placement="sparsecore",209)210user_table = keras_rs.layers.TableConfig(211name="user_table",212vocabulary_size=users_count + 1, # +1 for user ID 0, which is not used213embedding_dim=EMBEDDING_DIMENSION,214optimizer="adam",215placement="sparsecore",216)217218FEATURE_CONFIGS = {219"movie_id": keras_rs.layers.FeatureConfig(220name="movie",221table=movie_table,222input_shape=(BATCH_SIZE,),223output_shape=(BATCH_SIZE, EMBEDDING_DIMENSION),224),225"user_id": keras_rs.layers.FeatureConfig(226name="user",227table=user_table,228input_shape=(BATCH_SIZE,),229output_shape=(BATCH_SIZE, EMBEDDING_DIMENSION),230),231}232233"""234## Defining the Model235236We're now ready to create a `DistributedEmbedding` inside a model. Once we have237the configuration, we simply pass it the constructor of `DistributedEmbedding`.238Then, within the model `call` method, `DistributedEmbedding` is the first layer239we call.240241The ouputs have the exact same structure as the inputs. In our example, we242concatenate the embeddings we got as outputs and run them through a tower of243dense layers.244"""245246247class EmbeddingModel(keras.Model):248"""Create the model with the embedding configuration.249250Args:251feature_configs: the configuration for `DistributedEmbedding`.252"""253254def __init__(self, feature_configs):255super().__init__()256257self.embedding_layer = keras_rs.layers.DistributedEmbedding(258feature_configs=feature_configs259)260self.ratings = keras.Sequential(261[262# Learn multiple dense layers.263keras.layers.Dense(256, activation="relu"),264keras.layers.Dense(64, activation="relu"),265# Make rating predictions in the final layer.266keras.layers.Dense(1),267]268)269270def call(self, features):271# Embedding lookup. Outputs have the same structure as the inputs.272embedding = self.embedding_layer(features)273return self.ratings(274keras.ops.concatenate(275[embedding["user_id"], embedding["movie_id"]],276axis=1,277)278)279280281"""282Let's now instantiate the model. We then use `model.compile()` to configure the283loss, metrics and optimizer. Again, this Adagrad optimizer will only apply to284the dense layers and not the embedding tables.285"""286287with strategy.scope():288model = EmbeddingModel(FEATURE_CONFIGS)289290model.compile(291loss=keras.losses.MeanSquaredError(),292metrics=[keras.metrics.RootMeanSquaredError()],293optimizer="adagrad",294)295296"""297## Fitting and evaluating298299We can use the standard Keras `model.fit()` to train the model. Keras will300automatically use the `TPUStrategy` to distribute the model and the data.301"""302303with strategy.scope():304model.fit(train_ratings, epochs=5)305306"""307Same for `model.evaluate()`.308"""309310with strategy.scope():311model.evaluate(test_ratings, return_dict=True)312313"""314That's it.315316This example shows that after setting up the `TPUStrategy` and configuring the317`DistributedEmbedding`, you can use the standard Keras workflows.318"""319320321