From K-NN, HNSW to Product Quantization: Approximate Nearest Neighbor Search for vector search engines


Why do you need ANN? | Background

Elastic Search, Faiss and etc. all adopted their own approximate nearest neighbor search algorithms to deal with vector search among the huge scale of corpus. During my near research for my final-year project (FYP), which is to develop a production-level distributed search engine, I have looked into vector search, inverted index and Learning-to-Rank, for my main and the most used features — searching .

My FYP intrigues me into exploring acceleration of vector search. Based on my previous knowledge, in game / navigation, in order to speed up the search process, a common solution is to use K-D Trees, which recursively partition orthognal axies to construct a tree structure so that it only takes time complexity O(log n) to search, referring to Fig. 1. However, even though partitioning sounds intuitive, it does not work well when the dimensions grows explosively into 256 or 1024, which will bring a long time to partition into all k dimensions. Therefore, I am going to break down a few techiniques that are used for speeding up ANN process step by step.

This article is going to mostly cover vector search parts so we do not discuss about inverted index and Learning-to-Rank.

K-D Tree
Figure 1: K-D Tree

Starting from K-NN

It is not a secret that K-Nearest Neighbors (K-NN) is the simplest and most naive nearest neighbor search algorithm.

The principle of KNN algorithm is very simple: try to find the nearest neighbor of the query vector in the dataset. Usually these neighbors are belonged to a cluster so that it is able to put the query vector into specific class. This brutal algorithm’s time complexity is O(ND + Nlog(K)), where N denotes the number of vectors in the dataset and D denotes the dimension of the vectors, and Nlog(K) means to maintain the order of the candidate neighbors. Then KNN algorithm also can extend to find the nearest cluster to categorize each query, as shown in Fig. 2.

K-NN
Figure 2: K-means algorithm. Find the most similar cluster for query datapoint.

However, when the vector dimension explodes or the number of vectors vastly increases, K-NN algorithm reveals weakness in terms of its time complexity to inevitably calculate the euclidean distance for each and every vector in the dataset.

Anyway, I use Claude to generate a template for me to evaluate on ms-marco dataset, which you can find them appear in everywhere whenever it comes to searching, RAG or anything similar to information retrieval. Then I simulate K-NN algorithm by numpy through sorting euclidean distance for each queries. Then I calculate Recall, MRR and time cost taken by the algorithm.

import numpy as np
import time
from sentence_transformers import SentenceTransformer
from datasets import load_dataset
from tqdm import tqdm

print("Loading model...")
model = SentenceTransformer('sentence-transformers/msmarco-MiniLM-L-6-v3')

# Load MS MARCO dataset
print("\nLoading dataset...")
dataset = load_dataset('ms_marco', 'v1.1', split='train[:10000]')

# Extract queries and build qrels
queries = []
qrels = {}
corpus = {}

print("\nProcessing data...")
for idx, item in enumerate(tqdm(dataset, desc="Processing items")):
    # Check if there are answers
    if item['answers'] and len(item['answers']) > 0:
        query_text = item['query']
        
        # Extract relevant passages
        relevant_passages = []
        passages = item['passages']
        
        # Add all passages to corpus (v1.1 doesn't have passage_id, need to create our own)
        for i in range(len(passages['passage_text'])):
            ptext = passages['passage_text'][i]
            is_selected = passages['is_selected'][i]
            
            # Create unique passage ID using query_id and passage index
            pid = f"{item['query_id']}_{i}"
            
            # Add to corpus
            if pid not in corpus:
                corpus[pid] = ptext
            
            # If it's a relevant passage, record it
            if is_selected == 1:
                relevant_passages.append(pid)
        
        # Only keep queries with relevant passages
        if relevant_passages:
            queries.append(query_text)
            qrels[len(queries) - 1] = set(relevant_passages)

print(f"\n{'='*50}")
print(f"Data Statistics:")
print(f"{'='*50}")
print(f"Queries: {len(queries)}")
print(f"Corpus: {len(corpus)}")
print(f"Average relevant docs per query: {np.mean([len(v) for v in qrels.values()]):.2f}")
print(f"{'='*50}")

# Prepare for encoding
corpus_ids = list(corpus.keys())
corpus_texts = list(corpus.values())

print("\nEncoding corpus...")
corpus_embeddings = model.encode(
    corpus_texts, 
    batch_size=128, 
    show_progress_bar=True,
    convert_to_numpy=True
)

print("\nEncoding queries...")
query_embeddings = model.encode(
    queries, 
    batch_size=128, 
    show_progress_bar=True,
    convert_to_numpy=True
)

# Normalize for cosine similarity
print("\nNormalizing embeddings...")
corpus_embeddings = corpus_embeddings / np.linalg.norm(corpus_embeddings, axis=1, keepdims=True)
query_embeddings = query_embeddings / np.linalg.norm(query_embeddings, axis=1, keepdims=True)

# KNN search
def knn_search(query_vecs, corpus_vecs, k=10):
    """Efficient KNN search using numpy"""
    scores = np.dot(query_vecs, corpus_vecs.T)
    top_k_idx = np.argsort(scores, axis=1)[:, -k:][:, ::-1]
    top_k_scores = np.take_along_axis(scores, top_k_idx, axis=1)
    return top_k_scores, top_k_idx

print("\nPerforming retrieval...")
retrieval_start = time.time()
scores, indices = knn_search(query_embeddings, corpus_embeddings, k=10)
retrieval_time = time.time() - retrieval_start
print(f"Search completed in {retrieval_time:.4f} seconds")

# Evaluation
print("\nCalculating evaluation metrics...")
mrr = 0.0
recall_at_10 = 0.0
precision_at_10 = 0.0
valid_queries = 0

for q_idx in tqdm(range(len(queries)), desc="Evaluating"):
    if q_idx not in qrels:
        continue
    
    relevant = qrels[q_idx]
    retrieved = [corpus_ids[idx] for idx in indices[q_idx]]
    
    # Calculate MRR
    for rank, doc_id in enumerate(retrieved, 1):
        if doc_id in relevant:
            mrr += 1.0 / rank
            break
    
    # Calculate Recall@10
    hits = len(set(retrieved) & relevant)
    recall_at_10 += hits / len(relevant)
    
    # Calculate Precision@10
    precision_at_10 += hits / len(retrieved)
    
    valid_queries += 1

# Average metrics
mrr /= valid_queries
recall_at_10 /= valid_queries
precision_at_10 /= valid_queries

print(f"\n{'='*50}")
print(f"Evaluation Results:")
print(f"{'='*50}")
print(f"Number of evaluated queries: {valid_queries}")
print(f"MRR@10:        {mrr:.4f}")
print(f"Recall@10:     {recall_at_10:.4f}")
print(f"Precision@10:  {precision_at_10:.4f}")
print(f"{'='*50}")

Results

Loading dataset...

Processing data...
Processing items: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [00:00<00:00, 13605.77it/s]

==================================================
Data Statistics:
==================================================
Queries: 9690
Corpus: 80128
Average relevant docs per query: 1.12
==================================================

Encoding corpus...
Batches: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 626/626 [00:39<00:00, 15.77it/s]

Encoding queries...
Batches: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 76/76 [00:00<00:00, 103.18it/s]

Normalizing embeddings...

Performing retrieval...
Search completed in 33.6314 seconds

Calculating evaluation metrics...
Evaluating: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9690/9690 [00:00<00:00, 132921.27it/s]

==================================================
Evaluation Results:
==================================================
Number of evaluated queries: 9690
MRR@10:        0.5874
Recall@10:     0.9541
Precision@10:  0.1057
==================================================

How does this work? Easy. To find the top-10 results for each query vector, we brutally search nearest 10 vectors in the whole corpus. We can tell the result is not fast enough since it takes 33 seconds (which HNSW later will speed up so many times). Here, Recall@10 means how many of the top-10 results are relevant, MRR@10 means the average rank of the first relevant result in the top-10 results.

HNSW: Hierarchical Navigable Small World

A common thumb rule for speeding up querying is to pre-construct indexes or graphs beforehands, so that later it takes less efforts to compare each item.

Therefore, here goes our today’s first topic: HNSW (Hierarchical Navigable Small World). You probably have known this algorithm when using PostgreSQL vector searching or Faiss. Indeed, we can implement them by using Faiss but here we chose to implement by raw codes and break them down one by one.

Main Idea

HNSW is similar to Skip-list, as in that both of them are hierarchically splitting data points so that it enables the query to be only conducted in a smaller scope. For Skip-list, inserting, querying, deleting are all O(log n) and Redis uses Skip-list to store Key-Value pairs in memory, instead of red-black tree.

I have been told that the reason Redis uses Skip-list is because it is just simpler to implement. Not sure if it is true.

Skip-list
Figure 3: A simple illustration of how Skip-list works. As it shows, 55 is first found to be between 40 to +inf, then goes to L1 to be determined between 40 to 60, and finally goes to L0 to get 55 by iterating on 40 and 50.

I will summarize the idea as Play it smart, not hard, since it does not go through EVERY item to find out the result. This is definitely a game changer for retrieval tasks. We also can tell that higher layer is always less dense than the layer below.

The way that HNSW stores hierarchical structure is to make a skip-list where the number of nodes in layer LL is always 1 / mLm_L of the layer L1L-1. And this will be guaranteed by the formula of randomness below:

L=ln(rand(0,1))×mL,L = \left\lfloor -\ln(\mathrm{rand}(0,1)) \times m_L \right\rfloor ,

where rand(0,1)\mathrm{rand}(0,1) is a random number between 0 and 1. This $L% is calculated every time when there is a new vector inserting in HNSW graph.

Code Implemntation 1 - Initializing indexes

As mentioned above, we need to calculate the layer of the graph every time when there is a new vector inserting in HNSW graph.

def get_layer(self):
        return int(-np.log(np.random.random()) * self.mL)

This method makes sure that the probability of vectors are inserted into right places. You can see the effect by calling and outputing from this function. It will show most nodes are allocated in 0, some are in 1 then less are in 2 and etc.

There goes to our first function: search_in_one_layer, as the name suggests, it is used to search vector along with ef neighbors of entry_point in one layer of the graph. It ends up returnning a List[Tuple[float, int]] where the first element is the distance and the second element is the index of the vector.

The index of vector and its actual matrix is saved in self.corpus_vecs. And the graph of HNSW is indicated by self.graph.

Some initializations:

def __init__(self, corpus_vecs):
    self.entry_point = None
    self.corpus_vecs = corpus_vecs  
    self.M = 16
    self.ef_construction = 200
    self.graph: List[dict[int, Set[int]]] = [] # the graph that saves NSW
    self.id = 0
    self.mL = 1 / np.log(self.M)
    def search_in_one_layer(self, vector, entry_point: Set[int], ef=1, L=0) -> List[Tuple[float, int]]:
        # search in Lth layer of the graph
        # return: for layer L, whats the neighbor of vectors 【Float, Int]
        # ef: number of neighbors to consider
        # find the nearest distance of neighbors
        visited = set(entry_point)
        candidates = [] #min heap
        results = [] #max heap, top of heap maintains the furtherest vector to be kicked out by pop

        for ep in entry_point: # get distance between query vector to entry points
            dist = self.get_distance(vector, self.corpus_vecs[ep])
            heapq.heappush(candidates, (dist, ep))
            heapq.heappush(results, (-dist, ep))

        while candidates:
            dist_c, c = heapq.heappop(candidates)
            dist_further = -results[0][0]
            if dist_c > dist_further and len(results) >=ef: # prune
                break

            neighbors = self.graph[L].get(c, set()) # the Lth layer to get c's neighbors

            for n in neighbors: # only search in neighbors of entry points
                if n not in visited:
                    visited.add(n)
                    dist = self.get_distance(vector, self.corpus_vecs[n])
                    if len(results) < ef or dist < -results[0][0]:
                        heapq.heappush(candidates, (dist, n))
                        heapq.heappush(results, (-dist, n))
                        if len(results) > ef:
                            heapq.heappop(results)
        result_sorted = sorted([(-d,n) for d,n in results])
        return result_sorted

Why keeping minimum heap and maximum heap at the same time? candidates is a min heap that help select the top-ef nearest neighbors, while results is used when requiring to kick out some vectors. Since we kick out vectors by the furtherest distance, we use maximum heap to maintain the furtherest vector to be kicked out by pop.

select_neighbors_heuristic is supposed to be herutistic but this blog tend to follow minimalist implementation. So it simply sorted the distance and just get the top-M neighbors.

def select_neighbors_heuristic(self, candidates: List[Tuple[float, int]], M: int) -> Set[int]:
    """Select M neighbors using a heuristic to maintain graph connectivity"""
    if len(candidates) <= M:
        return {node_id for _, node_id in candidates}
    
    # Sort by distance (ascending)
    candidates = sorted(candidates)
    selected = set()
    
    for dist, node_id in candidates:
        if len(selected) >= M:
            break
        selected.add(node_id)
    
    return selected

Apetizers are ready, let’s move to main course: vector insertion.

The philosophy of inserting a vector can be described into two situations:

  1. When there is a first vector into HNSW, by the time HNSW has not yet got any levels, we will append the number of level_current levels into graph variables, which is initiated from get_layer function.
  2. When it comes to new vectors, it searches through top layers down to the current layer + 1, to greedy search the entrypoint for the next layer.

Why would using a greedy search for such a hassle situation? Because entry point passed from parameter is usually the toppest layer of vector, which you can still refer back to the Skip-list illustration. Every time going down one layer, it would be hopping from the previous layer’s entry point to the next layer’s entry point. If we choose the same vector, it is no use because we already know the current vector is not the answer we want otherwise we should have stopped it here. This is just a simple trick that sets ef as 1 here to just get the nearest neighbor as next entry point.


    def insert(self, query_vec):
        node_id = self.id
        self.id += 1
        level_current = self.get_layer() # get the level for this vector
        
        # Initialize graph layers if needed
        while len(self.graph) <= level_current:
            self.graph.append({})
        
        # First node - just add it
        if node_id == 0:
            self.entry_point = node_id
            for l in range(level_current + 1):
                self.graph[l][node_id] = set()
            return

        # Search phase 1: find entry points from top to target layer
        ep = {self.entry_point}
        max_level = len(self.graph) - 1
        
        # Navigate from top layer down to level_current+1
        for l in range(max_level, level_current, -1):
            results = self.search_in_one_layer(query_vec, ep, ef=1, L=l)
            ep = {results[0][1]} if results else ep
        
        # Search phase 2: find neighbors and insert at each layer from level_current down to 0
        for l in range(level_current, -1, -1):
            # Find ef_construction nearest neighbors at this layer
            results = self.search_in_one_layer(query_vec, ep, ef=self.ef_construction, L=l)
            
            # Select M neighbors for this layer
            M = self.M if l > 0 else self.M * 2  # Layer 0 has more connections
            neighbors = self.select_neighbors_heuristic(results, M)
            
            # Add bidirectional links
            if node_id not in self.graph[l]:
                self.graph[l][node_id] = set()
            
            for neighbor_id in neighbors:
                # Add edge from new node to neighbor
                self.graph[l][node_id].add(neighbor_id)
                
                # Add edge from neighbor to new node
                if neighbor_id not in self.graph[l]:
                    self.graph[l][neighbor_id] = set()
                self.graph[l][neighbor_id].add(node_id)
           
            # Update entry points for next layer
            ep = neighbors
        
        # Update entry point if new node is at a higher level
        if level_current > max_level:
            self.entry_point = node_id

Here we search through one layer and get neighbors from it. We select M neighbors to connect them within this layer. So when constructing the edge between vectors, it means insert into graph. If this layer only has one vector, then only connects to this vector. But entrypoint is always added to result first as written in the beginning ofsearch_in_one_layer function.

            # Find ef_construction nearest neighbors at this layer
            results = self.search_in_one_layer(query_vec, ep, ef=self.ef_construction, L=l)
            
            # Select M neighbors for this layer
            M = self.M if l > 0 else self.M * 2  # Layer 0 has more connections
            neighbors = self.select_neighbors_heuristic(results, M)
            
            # Add bidirectional links
            if node_id not in self.graph[l]:
                self.graph[l][node_id] = set()

Implementation 2: Searching

After constructing indexes, then we go to search within HNSW graph.

Same as index insertion and Skip-list, we search vectors from top to down, and it is an approximate search, which means it will lose some accuracy since it is always approximate.

Searching is rather simpler than insertion: We also adopt greedy search to explore each layers’ entry points. When it comes to Layer 0, where all nodes exist, we use these entry points to search fo ef nearest neighbors. Then the final results are returned.

def search(self, query_vec, k=10, ef=None):
        """Search for k nearest neighbors
        Args:
            query_vec: query vector (already normalized)
            k: number of nearest neighbors to return
            ef: size of the dynamic candidate list (default: max(ef_construction, k))
        Returns:
            List of (distance, node_id) tuples
        """
        if ef is None:
            ef = max(self.ef_construction, k)
        
        if self.entry_point is None:
            return []
        
        # Start from entry point at the top layer
        ep = {self.entry_point}
        max_level = len(self.graph) - 1
        
        # Phase 1: Navigate from top to layer 1
        for l in range(max_level, 0, -1):
            results = self.search_in_one_layer(query_vec, ep, ef=1, L=l)
            ep = {results[0][1]} if results else ep
        
        # Phase 2: Search at layer 0 with ef candidates
        results = self.search_in_one_layer(query_vec, ep, ef=ef, L=0)
        
        # Return top k results
        return results[:k]

You might have noticed that entry_point each time we used is based on the toppest layer of vector that has been inserted last time.

Speeding up by changing the way of distance calculation

Honestly in the beginning, to query vectors without optimization is so degraded that it almost kept the same speed as K-NN. This also blew my mind because theoretically it should be much faster. Then One of most throttling part is actually my distance calculation method.

In my previous implementation, get_distance function was always called to directly calculate the euclidean distance between two vectors.

def get_distance(v1, v2):
    return np.linalg.norm(v1 - v2)

But this is very slow since it compares each row and column of the matrix. A simple yet effective way to speed up is to normalize all vectors first, and then use cosine similarity to calculate the distance.

def get_distance(v1, v2):
    return -np.dot(v1, v2)

Negativity is to make sure the larger the distance, the smaller the similarity, matching the theroem of minimum heap to find the top-k elements.

Speeding up by pruning / early-quit when inserting

This speeds up querying 2X in my case. I did not write this part in insertion because I did not want to complex the process.

In HNSW, it’s crucial to maintain a maximum degree (M) for each node to ensure the graph remains a “small world” network with logarithmic search complexity. If a node has too many edges, the search becomes inefficient. Therefore, when we add a back-link to a neighbor, we must check if that neighbor now has > M connections. If so, we re-evaluate all its connections and keep only the best M ones using the heuristic.

So after adding an edge into graph, we need to have a check:

if len(self.graph[l][neighbor_id]) > M:
    # Collect all existing connections of this neighbor
    current_connections = list(self.graph[l][neighbor_id])
    
    # Calculate distances from these connections to the neighbor_id
    conn_candidates = []
    neighbor_vec = self.corpus_vecs[neighbor_id]
    
    # Note: For simplicity, we calculate distances one by one. 
    # In a production environment, vectorized calculation would be more efficient.
    for conn_id in current_connections:
        dist = self.get_distance(neighbor_vec, self.corpus_vecs[conn_id])
        conn_candidates.append((dist, conn_id))
    
    # Re-execute heuristic selection to retain only the best M connections
    kept_neighbors = self.select_neighbors_heuristic(conn_candidates, M)
    
    # Update the graph structure
    self.graph[l][neighbor_id] = kept_neighbors   

Final Result

Loading model...

Loading dataset...

Processing data...
Processing items: 100%|█████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [00:00<00:00, 22211.02it/s]

==================================================
Data Statistics:
==================================================
Queries: 9690
Corpus: 80128
Average relevant docs per query: 1.12
==================================================

Encoding corpus...
Batches: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 626/626 [00:36<00:00, 17.05it/s]

Encoding queries...
Batches: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 76/76 [00:00<00:00, 117.31it/s]

Normalizing embeddings...

Performing retrieval...
Building HNSW index...
Inserting vectors: 100%|██████████████████████████████████████████████████████████████████████████████████████| 80128/80128 [05:26<00:00, 245.23it/s]
HNSW index built with 80128 vectors

Searching queries...
Searching: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 9690/9690 [00:05<00:00, 1728.07it/s]
Search completed in 332.3713 seconds

Calculating evaluation metrics...
Evaluating: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 9690/9690 [00:00<00:00, 296768.96it/s]

==================================================
Evaluation Results:
==================================================
Number of evaluated queries: 9690
MRR@10:        0.4968
Recall@10:     0.8122
Precision@10:  0.0902
==================================================

Querying is ultimately fast: within 5 seconds, it finishes where the BM25 K-NN search has to take 33 seconds to achieve.

But inserting indexes takes over 4 minutes, which is the trade-off of using this method. The truth is in real-world scenario, what we always care is online real-time results instead of the offline one.

Experimental Setup

All experiments are conducted on RTX 2080 ti, Intel Core i5-12450H, 16GB RAM, Manjaro. We use python=3.10 and all latest library by the time of writing.

For the references

Then following searching, Elastic search uses HNSW algorithm