Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/keras_rs/data_parallel_retrieval.py
3507 views
1
"""
2
Title: Retrieval with data parallel training
3
Author: [Abheesht Sharma](https://github.com/abheesht17/), [Fabien Hertschuh](https://github.com/hertschuh/)
4
Date created: 2025/04/28
5
Last modified: 2025/04/28
6
Description: Retrieve movies using a two tower model (data parallel training).
7
Accelerator: TPU
8
"""
9
10
"""
11
## Introduction
12
13
In this tutorial, we are going to train the exact same retrieval model as we
14
did in our
15
[basic retrieval](/keras_rs/examples/basic_retrieval/)
16
tutorial, but in a distributed way.
17
18
Distributed training is used to train models on multiple devices or machines
19
simultaneously, thereby reducing training time. Here, we focus on synchronous
20
data parallel training. Each accelerator (GPU/TPU) holds a complete replica
21
of the model, and sees a different mini-batch of the input data. Local gradients
22
are computed on each device, aggregated and used to compute a global gradient
23
update.
24
25
Before we begin, let's note down a few things:
26
27
1. The number of accelerators should be greater than 1.
28
2. The `keras.distribution` API works only with JAX. So, make sure you select
29
JAX as your backend!
30
"""
31
32
"""shell
33
pip install -q keras-rs
34
"""
35
36
import os
37
38
os.environ["KERAS_BACKEND"] = "jax"
39
40
import random
41
42
import jax
43
import keras
44
import tensorflow as tf # Needed only for the dataset
45
import tensorflow_datasets as tfds
46
47
import keras_rs
48
49
"""
50
## Data Parallel
51
52
For the synchronous data parallelism strategy in distributed training,
53
we will use the `DataParallel` class present in the `keras.distribution`
54
API.
55
"""
56
devices = jax.devices() # Assume it has >1 local devices.
57
data_parallel = keras.distribution.DataParallel(devices=devices)
58
59
"""
60
Alternatively, you can choose to create the `DataParallel` object
61
using a 1D `DeviceMesh` object, like so:
62
63
```
64
mesh_1d = keras.distribution.DeviceMesh(
65
shape=(len(devices),), axis_names=["data"], devices=devices
66
)
67
data_parallel = keras.distribution.DataParallel(device_mesh=mesh_1d)
68
```
69
"""
70
71
# Set the global distribution strategy.
72
keras.distribution.set_distribution(data_parallel)
73
74
"""
75
## Preparing the dataset
76
77
Now that we are done defining the global distribution
78
strategy, the rest of the guide looks exactly the same
79
as the previous basic retrieval guide.
80
81
Let's load and prepare the dataset. Here too, we use the
82
MovieLens dataset.
83
"""
84
85
# Ratings data with user and movie data.
86
ratings = tfds.load("movielens/100k-ratings", split="train")
87
# Features of all the available movies.
88
movies = tfds.load("movielens/100k-movies", split="train")
89
90
# User, movie counts for defining vocabularies.
91
users_count = (
92
ratings.map(lambda x: tf.strings.to_number(x["user_id"], out_type=tf.int32))
93
.reduce(tf.constant(0, tf.int32), tf.maximum)
94
.numpy()
95
)
96
movies_count = movies.cardinality().numpy()
97
98
99
# Preprocess dataset, and split it into train-test datasets.
100
def preprocess_rating(x):
101
return (
102
# Input is the user IDs
103
tf.strings.to_number(x["user_id"], out_type=tf.int32),
104
# Labels are movie IDs + ratings between 0 and 1.
105
{
106
"movie_id": tf.strings.to_number(x["movie_id"], out_type=tf.int32),
107
"rating": (x["user_rating"] - 1.0) / 4.0,
108
},
109
)
110
111
112
shuffled_ratings = ratings.map(preprocess_rating).shuffle(
113
100_000, seed=42, reshuffle_each_iteration=False
114
)
115
train_ratings = shuffled_ratings.take(80_000).batch(1000).cache()
116
test_ratings = shuffled_ratings.skip(80_000).take(20_000).batch(1000).cache()
117
118
"""
119
## Implementing the Model
120
121
We build a two-tower retrieval model. Therefore, we need to combine a
122
query tower for users and a candidate tower for movies. Note that we don't
123
have to change anything here from the previous basic retrieval tutorial.
124
"""
125
126
127
class RetrievalModel(keras.Model):
128
"""Create the retrieval model with the provided parameters.
129
130
Args:
131
num_users: Number of entries in the user embedding table.
132
num_candidates: Number of entries in the candidate embedding table.
133
embedding_dimension: Output dimension for user and movie embedding tables.
134
"""
135
136
def __init__(
137
self,
138
num_users,
139
num_candidates,
140
embedding_dimension=32,
141
**kwargs,
142
):
143
super().__init__(**kwargs)
144
# Our query tower, simply an embedding table.
145
self.user_embedding = keras.layers.Embedding(num_users, embedding_dimension)
146
# Our candidate tower, simply an embedding table.
147
self.candidate_embedding = keras.layers.Embedding(
148
num_candidates, embedding_dimension
149
)
150
# The layer that performs the retrieval.
151
self.retrieval = keras_rs.layers.BruteForceRetrieval(k=10, return_scores=False)
152
self.loss_fn = keras.losses.MeanSquaredError()
153
154
def build(self, input_shape):
155
self.user_embedding.build(input_shape)
156
self.candidate_embedding.build(input_shape)
157
# In this case, the candidates are directly the movie embeddings.
158
# We take a shortcut and directly reuse the variable.
159
self.retrieval.candidate_embeddings = self.candidate_embedding.embeddings
160
self.retrieval.build(input_shape)
161
super().build(input_shape)
162
163
def call(self, inputs, training=False):
164
user_embeddings = self.user_embedding(inputs)
165
result = {
166
"user_embeddings": user_embeddings,
167
}
168
if not training:
169
# Skip the retrieval of top movies during training as the
170
# predictions are not used.
171
result["predictions"] = self.retrieval(user_embeddings)
172
return result
173
174
def compute_loss(self, x, y, y_pred, sample_weight, training=True):
175
candidate_id, rating = y["movie_id"], y["rating"]
176
user_embeddings = y_pred["user_embeddings"]
177
candidate_embeddings = self.candidate_embedding(candidate_id)
178
179
labels = keras.ops.expand_dims(rating, -1)
180
# Compute the affinity score by multiplying the two embeddings.
181
scores = keras.ops.sum(
182
keras.ops.multiply(user_embeddings, candidate_embeddings),
183
axis=1,
184
keepdims=True,
185
)
186
return self.loss_fn(labels, scores, sample_weight)
187
188
189
"""
190
## Fitting and evaluating
191
192
After defining the model, we can use the standard Keras `model.fit()` to train
193
and evaluate the model.
194
"""
195
196
model = RetrievalModel(users_count + 1, movies_count + 1)
197
model.compile(optimizer=keras.optimizers.Adagrad(learning_rate=0.2))
198
199
"""
200
Let's train the model. Evaluation takes a bit of time, so we only evaluate the
201
model every 5 epochs.
202
"""
203
204
history = model.fit(
205
train_ratings, validation_data=test_ratings, validation_freq=5, epochs=50
206
)
207
208
"""
209
## Making predictions
210
211
Now that we have a model, let's run inference and make predictions.
212
"""
213
214
movie_id_to_movie_title = {
215
int(x["movie_id"]): x["movie_title"] for x in movies.as_numpy_iterator()
216
}
217
movie_id_to_movie_title[0] = "" # Because id 0 is not in the dataset.
218
219
"""
220
We then simply use the Keras `model.predict()` method. Under the hood, it calls
221
the `BruteForceRetrieval` layer to perform the actual retrieval.
222
"""
223
224
user_ids = random.sample(range(1, 1001), len(devices))
225
predictions = model.predict(keras.ops.convert_to_tensor(user_ids))
226
predictions = keras.ops.convert_to_numpy(predictions["predictions"])
227
228
for i, user_id in enumerate(user_ids):
229
print(f"\n==Recommended movies for user {user_id}==")
230
for movie_id in predictions[i]:
231
print(movie_id_to_movie_title[movie_id])
232
233
"""
234
And we're done! For data parallel training, all we had to do was add ~3-5 LoC.
235
The rest is exactly the same.
236
"""
237
238