Path: blob/master/examples/keras_rs/ipynb/data_parallel_retrieval.ipynb
3508 views
Retrieval with data parallel training
Author: Abheesht Sharma, Fabien Hertschuh
Date created: 2025/04/28
Last modified: 2025/04/28
Description: Retrieve movies using a two tower model (data parallel training).
Introduction
In this tutorial, we are going to train the exact same retrieval model as we did in our basic retrieval tutorial, but in a distributed way.
Distributed training is used to train models on multiple devices or machines simultaneously, thereby reducing training time. Here, we focus on synchronous data parallel training. Each accelerator (GPU/TPU) holds a complete replica of the model, and sees a different mini-batch of the input data. Local gradients are computed on each device, aggregated and used to compute a global gradient update.
Before we begin, let's note down a few things:
The number of accelerators should be greater than 1.
The
keras.distribution
API works only with JAX. So, make sure you select JAX as your backend!
Data Parallel
For the synchronous data parallelism strategy in distributed training, we will use the DataParallel
class present in the keras.distribution
API.
Alternatively, you can choose to create the DataParallel
object using a 1D DeviceMesh
object, like so:
Preparing the dataset
Now that we are done defining the global distribution strategy, the rest of the guide looks exactly the same as the previous basic retrieval guide.
Let's load and prepare the dataset. Here too, we use the MovieLens dataset.
Implementing the Model
We build a two-tower retrieval model. Therefore, we need to combine a query tower for users and a candidate tower for movies. Note that we don't have to change anything here from the previous basic retrieval tutorial.
Fitting and evaluating
After defining the model, we can use the standard Keras model.fit()
to train and evaluate the model.
Let's train the model. Evaluation takes a bit of time, so we only evaluate the model every 5 epochs.
Making predictions
Now that we have a model, let's run inference and make predictions.
We then simply use the Keras model.predict()
method. Under the hood, it calls the BruteForceRetrieval
layer to perform the actual retrieval.
And we're done! For data parallel training, all we had to do was add ~3-5 LoC. The rest is exactly the same.