Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/keras_rs/scann.py
3507 views
1
"""
2
Title: Faster retrieval with Scalable Nearest Neighbours (ScANN)
3
Author: [Abheesht Sharma](https://github.com/abheesht17/), [Fabien Hertschuh](https://github.com/hertschuh/)
4
Date created: 2025/04/28
5
Last modified: 2025/04/28
6
Description: Using ScANN for faster retrieval.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
Retrieval models are designed to quickly identify a small set of highly relevant
14
candidates from vast pools of data, often comprising millions or even hundreds
15
of millions of items. To effectively respond to the user's context and behavior
16
in real time, these models must perform this task in just milliseconds.
17
18
Approximate nearest neighbor (ANN) search is the key technology that enables
19
this level of efficiency. In this tutorial, we'll demonstrate how to leverage
20
ScANN—a cutting-edge nearest neighbor retrieval library—to effortlessly scale
21
retrieval for millions of items.
22
23
[ScANN](https://research.google/blog/announcing-scann-efficient-vector-similarity-search/),
24
developed by Google Research, is a high-performance library designed for
25
dense vector similarity search at scale. It efficiently indexes a database of
26
candidate embeddings, enabling rapid search during inference. By leveraging
27
advanced vector compression techniques and finely tuned algorithms, ScaNN
28
strikes an optimal balance between speed and accuracy. As a result, it can
29
significantly outperform brute-force search methods, delivering fast retrieval
30
with minimal loss in accuracy.
31
32
We will start with the same code as the
33
[basic retrieval example](/keras_rs/examples/basic_retrieval/).
34
Data processing, model building, and training remain exactly the same. Feel free
35
to skip this part if you have gone over the basic retrieval example before.
36
37
Note: ScANN does not have its own separate layer in KerasRS because the ScANN
38
library is TensorFlow-only. Here, in this example, we directly use the ScANN
39
library and demonstrate its usage with KerasRS.
40
41
## Imports
42
43
Let's install the `scann` library and import all necessary packages. We will
44
also set the backend to JAX.
45
"""
46
47
"""shell
48
pip install -q keras-rs
49
pip install -q scann
50
"""
51
52
import os
53
54
os.environ["KERAS_BACKEND"] = "jax" # `"tensorflow"`/`"torch"`
55
56
import time
57
import uuid
58
59
import keras
60
import tensorflow as tf # Needed for the dataset
61
import tensorflow_datasets as tfds
62
from scann import scann_ops
63
64
import keras_rs
65
66
"""
67
## Preparing the dataset
68
"""
69
70
# Ratings data with user and movie data.
71
ratings = tfds.load("movielens/100k-ratings", split="train")
72
# Features of all the available movies.
73
movies = tfds.load("movielens/100k-movies", split="train")
74
75
# Get user and movie counts so that we can define embedding layers for both.
76
users_count = (
77
ratings.map(lambda x: tf.strings.to_number(x["user_id"], out_type=tf.int32))
78
.reduce(tf.constant(0, tf.int32), tf.maximum)
79
.numpy()
80
)
81
82
movies_count = movies.cardinality().numpy()
83
84
85
# Preprocess the dataset, by selecting only the relevant columns.
86
def preprocess_rating(x):
87
return (
88
# Input is the user IDs
89
tf.strings.to_number(x["user_id"], out_type=tf.int32),
90
# Labels are movie IDs + ratings between 0 and 1.
91
{
92
"movie_id": tf.strings.to_number(x["movie_id"], out_type=tf.int32),
93
"rating": (x["user_rating"] - 1.0) / 4.0,
94
},
95
)
96
97
98
shuffled_ratings = ratings.map(preprocess_rating).shuffle(
99
100_000, seed=42, reshuffle_each_iteration=False
100
)
101
# Train-test split.
102
train_ratings = shuffled_ratings.take(80_000).batch(1000).cache()
103
test_ratings = shuffled_ratings.skip(80_000).take(20_000).batch(1000).cache()
104
105
"""
106
## Implementing the Model
107
"""
108
109
110
class RetrievalModel(keras.Model):
111
def __init__(
112
self,
113
num_users,
114
num_candidates,
115
embedding_dimension=32,
116
**kwargs,
117
):
118
super().__init__(**kwargs)
119
# Our query tower, simply an embedding table.
120
self.user_embedding = keras.layers.Embedding(num_users, embedding_dimension)
121
# Our candidate tower, simply an embedding table.
122
self.candidate_embedding = keras.layers.Embedding(
123
num_candidates, embedding_dimension
124
)
125
126
self.loss_fn = keras.losses.MeanSquaredError()
127
128
def build(self, input_shape):
129
self.user_embedding.build(input_shape)
130
self.candidate_embedding.build(input_shape)
131
132
super().build(input_shape)
133
134
def call(self, inputs, training=False):
135
user_embeddings = self.user_embedding(inputs)
136
result = {
137
"user_embeddings": user_embeddings,
138
}
139
return result
140
141
def compute_loss(self, x, y, y_pred, sample_weight, training=True):
142
candidate_id, rating = y["movie_id"], y["rating"]
143
user_embeddings = y_pred["user_embeddings"]
144
candidate_embeddings = self.candidate_embedding(candidate_id)
145
146
labels = keras.ops.expand_dims(rating, -1)
147
# Compute the affinity score by multiplying the two embeddings.
148
scores = keras.ops.sum(
149
keras.ops.multiply(user_embeddings, candidate_embeddings),
150
axis=1,
151
keepdims=True,
152
)
153
return self.loss_fn(labels, scores, sample_weight)
154
155
156
"""
157
## Training the model
158
"""
159
160
model = RetrievalModel(users_count + 1000, movies_count + 1000)
161
model.compile(optimizer=keras.optimizers.Adagrad(learning_rate=0.1))
162
163
history = model.fit(
164
train_ratings, validation_data=test_ratings, validation_freq=5, epochs=50
165
)
166
167
"""
168
## Making predictions
169
170
Before we try out ScANN, let's go with the brute force method, i.e., for a given
171
user, scores are computed for all movies, sorted and then the top-k
172
movies are picked. This is, of course, not very scalable when we have a huge
173
number of movies.
174
"""
175
176
candidate_embeddings = keras.ops.array(model.candidate_embedding.embeddings.numpy())
177
# Artificially duplicate candidate embeddings to simulate a large number of
178
# movies.
179
candidate_embeddings = keras.ops.concatenate(
180
[candidate_embeddings]
181
+ [
182
candidate_embeddings
183
* keras.random.uniform(keras.ops.shape(candidate_embeddings))
184
for _ in range(100)
185
],
186
axis=0,
187
)
188
189
user_embedding = model.user_embedding(keras.ops.array([10, 5, 42, 345]))
190
191
# Define the brute force retrieval layer.
192
brute_force_layer = keras_rs.layers.BruteForceRetrieval(
193
candidate_embeddings=candidate_embeddings,
194
k=10,
195
return_scores=False,
196
)
197
198
"""
199
Now, let's do a forward pass on the layer. Note that in previous tutorials, we
200
have the above layer as an attribute of the model class, and we then call
201
`.predict()`. This will obviously be faster (since it's compiled XLA code), but
202
since we cannot do the same for ScANN, we just do a normal forward pass here
203
without compilation to ensure a fair comparison.
204
"""
205
206
t0 = time.time()
207
pred_movie_ids = brute_force_layer(user_embedding)
208
print("Time taken by brute force layer (sec):", time.time() - t0)
209
210
"""
211
Now, let's retrieve movies using ScANN. We will use the ScANN library from
212
Google Research to build the layer and then call it. To fully understand all the
213
arguments, please refer to the
214
[ScANN README file](https://github.com/google-research/google-research/tree/master/scann#readme).
215
"""
216
217
218
def build_scann(
219
candidates,
220
k=10,
221
distance_measure="dot_product",
222
dimensions_per_block=2,
223
num_reordering_candidates=500,
224
num_leaves=100,
225
num_leaves_to_search=30,
226
training_iterations=12,
227
):
228
builder = scann_ops.builder(
229
db=candidates,
230
num_neighbors=k,
231
distance_measure=distance_measure,
232
)
233
234
builder = builder.tree(
235
num_leaves=num_leaves,
236
num_leaves_to_search=num_leaves_to_search,
237
training_iterations=training_iterations,
238
)
239
builder = builder.score_ah(dimensions_per_block=dimensions_per_block)
240
241
if num_reordering_candidates is not None:
242
builder = builder.reorder(num_reordering_candidates)
243
244
# Set a unique name to prevent unintentional sharing between
245
# ScaNN instances.
246
searcher = builder.build(shared_name=str(uuid.uuid4()))
247
return searcher
248
249
250
def run_scann(searcher):
251
pred_movie_ids = searcher.search_batched_parallel(
252
user_embedding,
253
final_num_neighbors=10,
254
).indices
255
return pred_movie_ids
256
257
258
searcher = build_scann(candidates=candidate_embeddings)
259
260
t0 = time.time()
261
pred_movie_ids = run_scann(searcher)
262
print("Time taken by ScANN (sec):", time.time() - t0)
263
264
"""
265
You can clearly see the performance improvement in terms of latency. ScANN
266
(0.003 seconds) takes one-fiftieth the time it takes for the brute force layer
267
(0.15 seconds) to run!
268
"""
269
270