Faster retrieval with Scalable Nearest Neighbours (ScANN)
Author: Abheesht Sharma, Fabien Hertschuh
Date created: 2025/04/28
Last modified: 2025/04/28
Description: Using ScANN for faster retrieval.
Introduction
Retrieval models are designed to quickly identify a small set of highly relevant candidates from vast pools of data, often comprising millions or even hundreds of millions of items. To effectively respond to the user's context and behavior in real time, these models must perform this task in just milliseconds.
Approximate nearest neighbor (ANN) search is the key technology that enables this level of efficiency. In this tutorial, we'll demonstrate how to leverage ScANN—a cutting-edge nearest neighbor retrieval library—to effortlessly scale retrieval for millions of items.
ScANN, developed by Google Research, is a high-performance library designed for dense vector similarity search at scale. It efficiently indexes a database of candidate embeddings, enabling rapid search during inference. By leveraging advanced vector compression techniques and finely tuned algorithms, ScaNN strikes an optimal balance between speed and accuracy. As a result, it can significantly outperform brute-force search methods, delivering fast retrieval with minimal loss in accuracy.
We will start with the same code as the basic retrieval example. Data processing, model building, and training remain exactly the same. Feel free to skip this part if you have gone over the basic retrieval example before.
Note: ScANN does not have its own separate layer in KerasRS because the ScANN library is TensorFlow-only. Here, in this example, we directly use the ScANN library and demonstrate its usage with KerasRS.
Imports
Let's install the scann
library and import all necessary packages. We will also set the backend to JAX.
Preparing the dataset
Implementing the Model
Training the model
80/80 ━━━━━━━━━━━━━━━━━━━━ 3s 9ms/step - loss: 0.4772
Epoch 2/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - loss: 0.4772
Epoch 3/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4771
Epoch 4/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4771
Epoch 5/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 2s 25ms/step - loss: 0.4771 - val_loss: 0.4835
Epoch 6/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4770
Epoch 7/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4770
Epoch 8/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4769
Epoch 9/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4769
Epoch 10/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4769 - val_loss: 0.4835
Epoch 11/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4768
Epoch 12/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4768
Epoch 13/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4767
Epoch 14/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4767
Epoch 15/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4766 - val_loss: 0.4834
Epoch 16/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4766
Epoch 17/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4765
Epoch 18/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4765
Epoch 19/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4764
Epoch 20/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4763 - val_loss: 0.4833
Epoch 21/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4762
Epoch 22/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4761
Epoch 23/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4760
Epoch 24/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4759
Epoch 25/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4758 - val_loss: 0.4829
Epoch 26/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4757
Epoch 27/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4756
Epoch 28/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4754
Epoch 29/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4752
Epoch 30/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4750 - val_loss: 0.4823
Epoch 31/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4748
Epoch 32/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4746
Epoch 33/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4744
Epoch 34/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4741
Epoch 35/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4738 - val_loss: 0.4810
Epoch 36/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4734
Epoch 37/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4730
Epoch 38/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4726
Epoch 39/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4721
Epoch 40/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4715 - val_loss: 0.4788
Epoch 41/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4709
Epoch 42/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4702
Epoch 43/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4695
Epoch 44/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4686
Epoch 45/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4677 - val_loss: 0.4749
Epoch 46/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4666
Epoch 47/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4654
Epoch 48/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4641
Epoch 49/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4627
Epoch 50/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4610 - val_loss: 0.4679
Now, let's do a forward pass on the layer. Note that in previous tutorials, we have the above layer as an attribute of the model class, and we then call .predict()
. This will obviously be faster (since it's compiled XLA code), but since we cannot do the same for ScANN, we just do a normal forward pass here without compilation to ensure a fair comparison.
Now, let's retrieve movies using ScANN. We will use the ScANN library from Google Research to build the layer and then call it. To fully understand all the arguments, please refer to the ScANN README file.
You can clearly see the performance improvement in terms of latency. ScANN (0.003 seconds) takes one-fiftieth the time it takes for the brute force layer (0.15 seconds) to run!