Path: blob/master/examples/keras_rs/distributed_embedding_jax.py
3507 views
"""1Title: DistributedEmbedding using TPU SparseCore and JAX2Author: [Fabien Hertschuh](https://github.com/hertschuh/), [Abheesht Sharma](https://github.com/abheesht17/), [C. Antonio Sánchez](https://github.com/cantonios/)3Date created: 2025/06/034Last modified: 2025/09/025Description: Rank movies using a two tower model with embeddings on SparseCore.6Accelerator: TPU7"""89"""10## Introduction1112In the [basic ranking](/keras_rs/examples/basic_ranking/) tutorial, we showed13how to build a ranking model for the MovieLens dataset to suggest movies to14users.1516This tutorial implements the same model trained on the same dataset but with the17use of `keras_rs.layers.DistributedEmbedding`, which makes use of SparseCore on18TPU. This is the JAX version of the tutorial. It needs to be run on TPU v5p or19v6e.2021Let's begin by choosing JAX as the backend and importing all the necessary22libraries.23"""2425"""shell26pip install -q -U jax[tpu]>=0.7.027pip install -q jax-tpu-embedding28pip install -q tensorflow-cpu29pip install -q keras-rs30"""3132import os3334os.environ["KERAS_BACKEND"] = "jax"3536import jax37import keras38import keras_rs39import tensorflow as tf # Needed for the dataset40import tensorflow_datasets as tfds4142"""43## Dataset distribution4445While the model is replicated and the embedding tables are sharded across46SparseCores, the dataset is distributed by sharding each batch across the TPUs.47We need to make sure the batch size is a multiple of the number of TPUs.48"""4950PER_REPLICA_BATCH_SIZE = 25651BATCH_SIZE = PER_REPLICA_BATCH_SIZE * jax.local_device_count("tpu")5253distribution = keras.distribution.DataParallel(devices=jax.devices("tpu"))54keras.distribution.set_distribution(distribution)5556"""57## Preparing the dataset5859We're going to use the same MovieLens data. The ratings are the objectives we60are trying to predict.61"""6263# Ratings data.64ratings = tfds.load("movielens/100k-ratings", split="train")65# Features of all the available movies.66movies = tfds.load("movielens/100k-movies", split="train")6768"""69We need to know the number of users as we're using the user ID directly as an70index in the user embedding table.71"""7273users_count = int(74ratings.map(lambda x: tf.strings.to_number(x["user_id"], out_type=tf.int32))75.reduce(tf.constant(0, tf.int32), tf.maximum)76.numpy()77)7879"""80We also need do know the number of movies as we're using the movie ID directly81as an index in the movie embedding table.82"""8384movies_count = int(movies.cardinality().numpy())8586"""87The inputs to the model are the user IDs and movie IDs and the labels are the88ratings.89"""909192def preprocess_rating(x):93return (94# Inputs are user IDs and movie IDs95{96"user_id": tf.strings.to_number(x["user_id"], out_type=tf.int32),97"movie_id": tf.strings.to_number(x["movie_id"], out_type=tf.int32),98},99# Labels are ratings between 0 and 1.100(x["user_rating"] - 1.0) / 4.0,101)102103104"""105We'll split the data by putting 80% of the ratings in the train set, and 20% in106the test set.107"""108109shuffled_ratings = ratings.map(preprocess_rating).shuffle(110100_000, seed=42, reshuffle_each_iteration=False111)112train_ratings = (113shuffled_ratings.take(80_000).batch(BATCH_SIZE, drop_remainder=True).cache()114)115test_ratings = (116shuffled_ratings.skip(80_000)117.take(20_000)118.batch(BATCH_SIZE, drop_remainder=True)119.cache()120)121122"""123## Configuring DistributedEmbedding124125The `keras_rs.layers.DistributedEmbedding` handles multiple features and126multiple embedding tables. This is to enable the sharing of tables between127features and allow some optimizations that come from combining multiple128embedding lookups into a single invocation. In this section, we'll describe129how to configure these.130131### Configuring tables132133Tables are configured using `keras_rs.layers.TableConfig`, which has:134135- A name.136- A vocabulary size (input size).137- an embedding dimension (output size).138- A combiner to specify how to reduce multiple embeddings into a single one in139the case when we embed a sequence. Note that this doesn't apply to our example140because we're getting a single embedding for each user and each movie.141- A placement to tell whether to put the table on the SparseCore chips or not.142In this case, we want the `"sparsecore"` placement.143- An optimizer to specify how to apply gradients when training. Each table has144its own optimizer and the one passed to `model.compile()` is not used for the145embedding tables.146147### Configuring features148149Features are configured using `keras_rs.layers.FeatureConfig`, which has:150151- A name.152- A table, the embedding table to use.153- An input shape (batch size is for all TPUs).154- An output shape (batch size is for all TPUs).155156We can organize features in any structure we want, which can be nested. A dict157is often a good choice to have names for the inputs and outputs.158"""159160EMBEDDING_DIMENSION = 32161162movie_table = keras_rs.layers.TableConfig(163name="movie_table",164vocabulary_size=movies_count + 1, # +1 for movie ID 0, which is not used165embedding_dim=EMBEDDING_DIMENSION,166optimizer="adam",167placement="sparsecore",168)169user_table = keras_rs.layers.TableConfig(170name="user_table",171vocabulary_size=users_count + 1, # +1 for user ID 0, which is not used172embedding_dim=EMBEDDING_DIMENSION,173optimizer="adam",174placement="sparsecore",175)176177FEATURE_CONFIGS = {178"movie_id": keras_rs.layers.FeatureConfig(179name="movie",180table=movie_table,181input_shape=(BATCH_SIZE,),182output_shape=(BATCH_SIZE, EMBEDDING_DIMENSION),183),184"user_id": keras_rs.layers.FeatureConfig(185name="user",186table=user_table,187input_shape=(BATCH_SIZE,),188output_shape=(BATCH_SIZE, EMBEDDING_DIMENSION),189),190}191192"""193## Defining the Model194195We're now ready to create a `DistributedEmbedding` inside a model. Once we have196the configuration, we simply pass it the constructor of `DistributedEmbedding`.197Then, within the model `call` method, `DistributedEmbedding` is the first layer198we call.199200The ouputs have the exact same structure as the inputs. In our example, we201concatenate the embeddings we got as outputs and run them through a tower of202dense layers.203"""204205206class EmbeddingModel(keras.Model):207"""Create the model with the embedding configuration.208209Args:210feature_configs: the configuration for `DistributedEmbedding`.211"""212213def __init__(self, feature_configs):214super().__init__()215216self.embedding_layer = keras_rs.layers.DistributedEmbedding(217feature_configs=feature_configs218)219self.ratings = keras.Sequential(220[221# Learn multiple dense layers.222keras.layers.Dense(256, activation="relu"),223keras.layers.Dense(64, activation="relu"),224# Make rating predictions in the final layer.225keras.layers.Dense(1),226]227)228229def call(self, preprocessed_features):230# Embedding lookup. Outputs have the same structure as the inputs.231embedding = self.embedding_layer(preprocessed_features)232return self.ratings(233keras.ops.concatenate(234[embedding["user_id"], embedding["movie_id"]],235axis=1,236)237)238239240"""241Let's now instantiate the model. We then use `model.compile()` to configure the242loss, metrics and optimizer. Again, this Adagrad optimizer will only apply to243the dense layers and not the embedding tables.244"""245246model = EmbeddingModel(FEATURE_CONFIGS)247248model.compile(249loss=keras.losses.MeanSquaredError(),250metrics=[keras.metrics.RootMeanSquaredError()],251optimizer="adagrad",252)253254"""255With the JAX backend, we need to preprocess the inputs to convert them to a256hardware-dependent format required for use with SparseCores. We'll do this by257wrapping the datasets into generator functions.258"""259260261def train_dataset_generator():262for inputs, labels in iter(train_ratings):263yield model.embedding_layer.preprocess(inputs, training=True), labels264265266def test_dataset_generator():267for inputs, labels in iter(test_ratings):268yield model.embedding_layer.preprocess(inputs, training=False), labels269270271"""272## Fitting and evaluating273274We can use the standard Keras `model.fit()` to train the model. Keras will275automatically use the `TPUStrategy` to distribute the model and the data.276"""277278model.fit(train_dataset_generator(), epochs=5)279280"""281Same for `model.evaluate()`.282"""283284model.evaluate(test_dataset_generator(), return_dict=True)285286"""287That's it.288289This example shows that after configuring the `DistributedEmbedding` and setting290up the required preprocessing, you can use the standard Keras workflows.291"""292293294