Path: blob/master/examples/keras_rs/ipynb/scann.ipynb
3508 views
Faster retrieval with Scalable Nearest Neighbours (ScANN)
Author: Abheesht Sharma, Fabien Hertschuh
Date created: 2025/04/28
Last modified: 2025/04/28
Description: Using ScANN for faster retrieval.
Introduction
Retrieval models are designed to quickly identify a small set of highly relevant candidates from vast pools of data, often comprising millions or even hundreds of millions of items. To effectively respond to the user's context and behavior in real time, these models must perform this task in just milliseconds.
Approximate nearest neighbor (ANN) search is the key technology that enables this level of efficiency. In this tutorial, we'll demonstrate how to leverage ScANN—a cutting-edge nearest neighbor retrieval library—to effortlessly scale retrieval for millions of items.
ScANN, developed by Google Research, is a high-performance library designed for dense vector similarity search at scale. It efficiently indexes a database of candidate embeddings, enabling rapid search during inference. By leveraging advanced vector compression techniques and finely tuned algorithms, ScaNN strikes an optimal balance between speed and accuracy. As a result, it can significantly outperform brute-force search methods, delivering fast retrieval with minimal loss in accuracy.
We will start with the same code as the basic retrieval example. Data processing, model building, and training remain exactly the same. Feel free to skip this part if you have gone over the basic retrieval example before.
Note: ScANN does not have its own separate layer in KerasRS because the ScANN library is TensorFlow-only. Here, in this example, we directly use the ScANN library and demonstrate its usage with KerasRS.
Imports
Let's install the scann
library and import all necessary packages. We will also set the backend to JAX.
Preparing the dataset
Implementing the Model
Training the model
Making predictions
Before we try out ScANN, let's go with the brute force method, i.e., for a given user, scores are computed for all movies, sorted and then the top-k movies are picked. This is, of course, not very scalable when we have a huge number of movies.
Now, let's do a forward pass on the layer. Note that in previous tutorials, we have the above layer as an attribute of the model class, and we then call .predict()
. This will obviously be faster (since it's compiled XLA code), but since we cannot do the same for ScANN, we just do a normal forward pass here without compilation to ensure a fair comparison.
Now, let's retrieve movies using ScANN. We will use the ScANN library from Google Research to build the layer and then call it. To fully understand all the arguments, please refer to the ScANN README file.
You can clearly see the performance improvement in terms of latency. ScANN (0.003 seconds) takes one-fiftieth the time it takes for the brute force layer (0.15 seconds) to run!