Path: blob/master/examples/keras_rs/md/data_parallel_retrieval.md
3508 views
Retrieval with data parallel training
Author: Abheesht Sharma, Fabien Hertschuh
Date created: 2025/04/28
Last modified: 2025/04/28
Description: Retrieve movies using a two tower model (data parallel training).
Introduction
In this tutorial, we are going to train the exact same retrieval model as we did in our basic retrieval tutorial, but in a distributed way.
Distributed training is used to train models on multiple devices or machines simultaneously, thereby reducing training time. Here, we focus on synchronous data parallel training. Each accelerator (GPU/TPU) holds a complete replica of the model, and sees a different mini-batch of the input data. Local gradients are computed on each device, aggregated and used to compute a global gradient update.
Before we begin, let's note down a few things:
The number of accelerators should be greater than 1.
The
keras.distribution
API works only with JAX. So, make sure you select JAX as your backend!
Data Parallel
For the synchronous data parallelism strategy in distributed training, we will use the DataParallel
class present in the keras.distribution
API.
Alternatively, you can choose to create the DataParallel
object using a 1D DeviceMesh
object, like so:
Preparing the dataset
Now that we are done defining the global distribution strategy, the rest of the guide looks exactly the same as the previous basic retrieval guide.
Let's load and prepare the dataset. Here too, we use the MovieLens dataset.
Downloading and preparing dataset Unknown size (download: Unknown size, generated: Unknown size, total: Unknown size) to /root/tensorflow_datasets/movielens/100k-ratings/0.1.1...
Dl Completed...: 0 url [00:00, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]
Extraction completed...: 0 file [00:00, ? file/s]
Generating splits...: 0%| | 0/1 [00:00<?, ? splits/s]
Generating train examples...: 0 examples [00:00, ? examples/s]
Shuffling /root/tensorflow_datasets/movielens/100k-ratings/incomplete.2O98FR_0.1.1/movielens-train.tfrecord*..…
WARNING:absl:Variant folder /root/tensorflow_datasets/movielens/100k-movies/0.1.1 has no dataset_info.json
Dataset movielens downloaded and prepared to /root/tensorflow_datasets/movielens/100k-ratings/0.1.1. Subsequent calls will reuse this data. Downloading and preparing dataset Unknown size (download: Unknown size, generated: Unknown size, total: Unknown size) to /root/tensorflow_datasets/movielens/100k-movies/0.1.1...
Dl Completed...: 0 url [00:00, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]
Extraction completed...: 0 file [00:00, ? file/s]
Generating splits...: 0%| | 0/1 [00:00<?, ? splits/s]
Generating train examples...: 0 examples [00:00, ? examples/s]
Shuffling /root/tensorflow_datasets/movielens/100k-movies/incomplete.4QKWMO_0.1.1/movielens-train.tfrecord*...…
Dataset movielens downloaded and prepared to /root/tensorflow_datasets/movielens/100k-movies/0.1.1. Subsequent calls will reuse this data.
Fitting and evaluating
After defining the model, we can use the standard Keras model.fit()
to train and evaluate the model.
Let's train the model. Evaluation takes a bit of time, so we only evaluate the model every 5 epochs.
80/80 ━━━━━━━━━━━━━━━━━━━━ 3s 8ms/step - loss: 0.4772
Epoch 2/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4771
Epoch 3/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4770
Epoch 4/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4769
Epoch 5/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 3s 37ms/step - loss: 0.4769 - val_loss: 0.4836
Epoch 6/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: 0.4768
Epoch 7/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4767
Epoch 8/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4766
Epoch 9/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: 0.4764
Epoch 10/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 0.4763 - val_loss: 0.4833
Epoch 11/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4761
Epoch 12/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4759
Epoch 13/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4757
Epoch 14/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4754
Epoch 15/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 0.4750 - val_loss: 0.4821
Epoch 16/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: 0.4746
Epoch 17/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: 0.4740
Epoch 18/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4734
Epoch 19/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4725
Epoch 20/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 0.4715 - val_loss: 0.4784
Epoch 21/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4702
Epoch 22/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4686
Epoch 23/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4666
Epoch 24/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4641
Epoch 25/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 0.4609 - val_loss: 0.4664
Epoch 26/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4571
Epoch 27/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4524
Epoch 28/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4466
Epoch 29/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4395
Epoch 30/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 0.4311 - val_loss: 0.4326
Epoch 31/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4210
Epoch 32/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4093
Epoch 33/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.3957
Epoch 34/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.3805
Epoch 35/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 0.3636 - val_loss: 0.3597
Epoch 36/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.3455
Epoch 37/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.3265
Epoch 38/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.3072
Epoch 39/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2880
Epoch 40/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 0.2696 - val_loss: 0.2664
Epoch 41/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.2523
Epoch 42/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2363
Epoch 43/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.2218
Epoch 44/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.2087
Epoch 45/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 0.1970 - val_loss: 0.1986
Epoch 46/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.1866
Epoch 47/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.1773
Epoch 48/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1689
Epoch 49/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.1613
Epoch 50/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 0.1544 - val_loss: 0.1586
We then simply use the Keras model.predict()
method. Under the hood, it calls the BruteForceRetrieval
layer to perform the actual retrieval.
==Recommended movies for user 793== b'Star Wars (1977)' b'Godfather, The (1972)' b'Raiders of the Lost Ark (1981)' b'Fargo (1996)' b'Silence of the Lambs, The (1991)' b"Schindler's List (1993)" b'Shawshank Redemption, The (1994)' b'Titanic (1997)' b'Braveheart (1995)' b'Pulp Fiction (1994)'
==Recommended movies for user 188== b'Star Wars (1977)' b'Fargo (1996)' b'Godfather, The (1972)' b'Silence of the Lambs, The (1991)' b"Schindler's List (1993)" b'Return of the Jedi (1983)' b'Raiders of the Lost Ark (1981)' b'Pulp Fiction (1994)' b'Toy Story (1995)' b'Empire Strikes Back, The (1980)'
==Recommended movies for user 865== b'Star Wars (1977)' b'Fargo (1996)' b'Godfather, The (1972)' b'Silence of the Lambs, The (1991)' b'Raiders of the Lost Ark (1981)' b"Schindler's List (1993)" b'Return of the Jedi (1983)' b'Shawshank Redemption, The (1994)' b'Pulp Fiction (1994)' b'Empire Strikes Back, The (1980)'
==Recommended movies for user 710== b'Star Wars (1977)' b'Fargo (1996)' b'Godfather, The (1972)' b'Silence of the Lambs, The (1991)' b'Raiders of the Lost Ark (1981)' b"Schindler's List (1993)" b'Pulp Fiction (1994)' b'Return of the Jedi (1983)' b'Empire Strikes Back, The (1980)' b'Toy Story (1995)'
==Recommended movies for user 721== b'Star Wars (1977)' b'Fargo (1996)' b'Godfather, The (1972)' b'Raiders of the Lost Ark (1981)' b'Silence of the Lambs, The (1991)' b"Schindler's List (1993)" b'Return of the Jedi (1983)' b'Empire Strikes Back, The (1980)' b'Pulp Fiction (1994)' b'Casablanca (1942)'
==Recommended movies for user 451== b'Star Wars (1977)' b'Raiders of the Lost Ark (1981)' b'Godfather, The (1972)' b'Fargo (1996)' b'Silence of the Lambs, The (1991)' b'Return of the Jedi (1983)' b'Contact (1997)' b'Casablanca (1942)' b'Empire Strikes Back, The (1980)' b'Pulp Fiction (1994)'
==Recommended movies for user 228== b'Star Wars (1977)' b'Fargo (1996)' b'Godfather, The (1972)' b'Raiders of the Lost Ark (1981)' b'Silence of the Lambs, The (1991)' b"Schindler's List (1993)" b'Return of the Jedi (1983)' b'Pulp Fiction (1994)' b'Empire Strikes Back, The (1980)' b'Shawshank Redemption, The (1994)'
==Recommended movies for user 175== b'Star Wars (1977)' b'Fargo (1996)' b'Silence of the Lambs, The (1991)' b'Raiders of the Lost Ark (1981)' b'Return of the Jedi (1983)' b'Casablanca (1942)' b"Schindler's List (1993)" b'Empire Strikes Back, The (1980)' b'Godfather, The (1972)' b"One Flew Over the Cuckoo's Nest (1975)"