Path: blob/master/examples/keras_rs/data_parallel_retrieval.py
3507 views
"""1Title: Retrieval with data parallel training2Author: [Abheesht Sharma](https://github.com/abheesht17/), [Fabien Hertschuh](https://github.com/hertschuh/)3Date created: 2025/04/284Last modified: 2025/04/285Description: Retrieve movies using a two tower model (data parallel training).6Accelerator: TPU7"""89"""10## Introduction1112In this tutorial, we are going to train the exact same retrieval model as we13did in our14[basic retrieval](/keras_rs/examples/basic_retrieval/)15tutorial, but in a distributed way.1617Distributed training is used to train models on multiple devices or machines18simultaneously, thereby reducing training time. Here, we focus on synchronous19data parallel training. Each accelerator (GPU/TPU) holds a complete replica20of the model, and sees a different mini-batch of the input data. Local gradients21are computed on each device, aggregated and used to compute a global gradient22update.2324Before we begin, let's note down a few things:25261. The number of accelerators should be greater than 1.272. The `keras.distribution` API works only with JAX. So, make sure you select28JAX as your backend!29"""3031"""shell32pip install -q keras-rs33"""3435import os3637os.environ["KERAS_BACKEND"] = "jax"3839import random4041import jax42import keras43import tensorflow as tf # Needed only for the dataset44import tensorflow_datasets as tfds4546import keras_rs4748"""49## Data Parallel5051For the synchronous data parallelism strategy in distributed training,52we will use the `DataParallel` class present in the `keras.distribution`53API.54"""55devices = jax.devices() # Assume it has >1 local devices.56data_parallel = keras.distribution.DataParallel(devices=devices)5758"""59Alternatively, you can choose to create the `DataParallel` object60using a 1D `DeviceMesh` object, like so:6162```63mesh_1d = keras.distribution.DeviceMesh(64shape=(len(devices),), axis_names=["data"], devices=devices65)66data_parallel = keras.distribution.DataParallel(device_mesh=mesh_1d)67```68"""6970# Set the global distribution strategy.71keras.distribution.set_distribution(data_parallel)7273"""74## Preparing the dataset7576Now that we are done defining the global distribution77strategy, the rest of the guide looks exactly the same78as the previous basic retrieval guide.7980Let's load and prepare the dataset. Here too, we use the81MovieLens dataset.82"""8384# Ratings data with user and movie data.85ratings = tfds.load("movielens/100k-ratings", split="train")86# Features of all the available movies.87movies = tfds.load("movielens/100k-movies", split="train")8889# User, movie counts for defining vocabularies.90users_count = (91ratings.map(lambda x: tf.strings.to_number(x["user_id"], out_type=tf.int32))92.reduce(tf.constant(0, tf.int32), tf.maximum)93.numpy()94)95movies_count = movies.cardinality().numpy()969798# Preprocess dataset, and split it into train-test datasets.99def preprocess_rating(x):100return (101# Input is the user IDs102tf.strings.to_number(x["user_id"], out_type=tf.int32),103# Labels are movie IDs + ratings between 0 and 1.104{105"movie_id": tf.strings.to_number(x["movie_id"], out_type=tf.int32),106"rating": (x["user_rating"] - 1.0) / 4.0,107},108)109110111shuffled_ratings = ratings.map(preprocess_rating).shuffle(112100_000, seed=42, reshuffle_each_iteration=False113)114train_ratings = shuffled_ratings.take(80_000).batch(1000).cache()115test_ratings = shuffled_ratings.skip(80_000).take(20_000).batch(1000).cache()116117"""118## Implementing the Model119120We build a two-tower retrieval model. Therefore, we need to combine a121query tower for users and a candidate tower for movies. Note that we don't122have to change anything here from the previous basic retrieval tutorial.123"""124125126class RetrievalModel(keras.Model):127"""Create the retrieval model with the provided parameters.128129Args:130num_users: Number of entries in the user embedding table.131num_candidates: Number of entries in the candidate embedding table.132embedding_dimension: Output dimension for user and movie embedding tables.133"""134135def __init__(136self,137num_users,138num_candidates,139embedding_dimension=32,140**kwargs,141):142super().__init__(**kwargs)143# Our query tower, simply an embedding table.144self.user_embedding = keras.layers.Embedding(num_users, embedding_dimension)145# Our candidate tower, simply an embedding table.146self.candidate_embedding = keras.layers.Embedding(147num_candidates, embedding_dimension148)149# The layer that performs the retrieval.150self.retrieval = keras_rs.layers.BruteForceRetrieval(k=10, return_scores=False)151self.loss_fn = keras.losses.MeanSquaredError()152153def build(self, input_shape):154self.user_embedding.build(input_shape)155self.candidate_embedding.build(input_shape)156# In this case, the candidates are directly the movie embeddings.157# We take a shortcut and directly reuse the variable.158self.retrieval.candidate_embeddings = self.candidate_embedding.embeddings159self.retrieval.build(input_shape)160super().build(input_shape)161162def call(self, inputs, training=False):163user_embeddings = self.user_embedding(inputs)164result = {165"user_embeddings": user_embeddings,166}167if not training:168# Skip the retrieval of top movies during training as the169# predictions are not used.170result["predictions"] = self.retrieval(user_embeddings)171return result172173def compute_loss(self, x, y, y_pred, sample_weight, training=True):174candidate_id, rating = y["movie_id"], y["rating"]175user_embeddings = y_pred["user_embeddings"]176candidate_embeddings = self.candidate_embedding(candidate_id)177178labels = keras.ops.expand_dims(rating, -1)179# Compute the affinity score by multiplying the two embeddings.180scores = keras.ops.sum(181keras.ops.multiply(user_embeddings, candidate_embeddings),182axis=1,183keepdims=True,184)185return self.loss_fn(labels, scores, sample_weight)186187188"""189## Fitting and evaluating190191After defining the model, we can use the standard Keras `model.fit()` to train192and evaluate the model.193"""194195model = RetrievalModel(users_count + 1, movies_count + 1)196model.compile(optimizer=keras.optimizers.Adagrad(learning_rate=0.2))197198"""199Let's train the model. Evaluation takes a bit of time, so we only evaluate the200model every 5 epochs.201"""202203history = model.fit(204train_ratings, validation_data=test_ratings, validation_freq=5, epochs=50205)206207"""208## Making predictions209210Now that we have a model, let's run inference and make predictions.211"""212213movie_id_to_movie_title = {214int(x["movie_id"]): x["movie_title"] for x in movies.as_numpy_iterator()215}216movie_id_to_movie_title[0] = "" # Because id 0 is not in the dataset.217218"""219We then simply use the Keras `model.predict()` method. Under the hood, it calls220the `BruteForceRetrieval` layer to perform the actual retrieval.221"""222223user_ids = random.sample(range(1, 1001), len(devices))224predictions = model.predict(keras.ops.convert_to_tensor(user_ids))225predictions = keras.ops.convert_to_numpy(predictions["predictions"])226227for i, user_id in enumerate(user_ids):228print(f"\n==Recommended movies for user {user_id}==")229for movie_id in predictions[i]:230print(movie_id_to_movie_title[movie_id])231232"""233And we're done! For data parallel training, all we had to do was add ~3-5 LoC.234The rest is exactly the same.235"""236237238