In [None]:
import time
import numpy as np

# requires Faiss to be installed, see 
# https://github.com/facebookresearch/faiss/blob/main/INSTALL.md#installing-faiss-via-conda
# oh how to install the CPU version

import faiss

from faiss.contrib.datasets import SyntheticDataset

from matplotlib import pyplot

In [None]:
# setup that works for my machine. Adjust to yours 
faiss.omp_set_num_threads(32)

%matplotlib inline
%config InlineBackend.figure_format='retina'

In [None]:
# get some data
ds = SyntheticDataset(64, 1000_000, 10000, 100)
print(ds)

In [None]:
# get training set
xt = ds.get_train()
xt.shape

In [None]:
d = ds.d

# Run k-means 

In [None]:
# 4096 centroids 
km = faiss.Kmeans(ds.d, 4096)

In [None]:
%%time
km.train(xt)

In [None]:
centroids = km.centroids 
centroids.shape

In [None]:
MSE = km.obj[-1] / len(xt)
MSE

In [None]:
pyplot.plot(km.obj / ds.nt)
pyplot.ylabel("Mean Squared Error")
pyplot.xlabel("Iteration")
pyplot.grid()

# Hierarchical k-means 

In [None]:
def recursive_run_kmeans(xt, k, level): 
    if level == 0: 
        # all vectors encoded to the same, compute MSE
        centroid = xt.mean(axis=0)
        s = ((xt - centroid) ** 2).sum()
        return [centroid], s
    else: 
        km = faiss.Kmeans(ds.d, k)
        km.train(xt)
        _, labels = km.assign(xt)
        tot_sum = 0
        centroids = []
        for i in range(k): 
            subset = labels == i
            cent_i, sum_i = recursive_run_kmeans(xt[subset], k, level - 1)
            centroids += cent_i
            tot_sum += sum_i 
        return centroids, tot_sum        

In [None]:
%%time 
# 4096 = 8 ** 4
cents, s = recursive_run_kmeans(xt, 8, 4)
MSE = s / len(xt)
MSE

In [None]:
%%time 
# 4096 = 64 ** 2
cents, s = recursive_run_kmeans(xt, 64, 2)
MSE = s / len(xt)
MSE

In [None]:
# search from centroids directly 

D, _ = faiss.knn(xt, cents, k=1)
MSE = D.mean()
MSE

## Searching in a vector database 

In [None]:
# the database and set of query vectors are arrays
xq = ds.get_queries()
xb = ds.get_database()

In [None]:
xq.shape

In [None]:
xb.shape

### Ground truth and the knn function

In [None]:
# find ground-truth nearest neighbors 
gt_dis, gt = faiss.knn(xq, xb, k=10)

In [None]:
gt.shape

In [None]:
gt[:3]

In [None]:
gt_dis[:3]

In [None]:
((xq[1] - xb[6558])**2).sum()

# The inverted file 

In [None]:
nlist = 4096

# compute IVF entries for database = find the nearest centroid for each database vector 
_, list_nos = faiss.knn(xb, centroids, k=1)
list_nos = list_nos.flatten()

In [None]:
ivf_vectors = []
ivf_ids = []

for list_no in range(nlist): 
    ids = np.where(list_nos == list_no)[0]
    ivf_ids.append(ids)
    ivf_vectors.append(xb[ids])

In [None]:
len(ivf_ids), len(ivf_vectors)

In [None]:
max(len(l) for l in ivf_ids)

In [None]:
min(len(l) for l in ivf_ids)

In [None]:
# searching in the nearest centroid 
_, q_list_nos = faiss.knn(xq, centroids, k=1)
found_nns = []
for q in range(100): 
    query = xq[q]
    # fetch contents of cluster
    cluster_vectors = ivf_vectors[q_list_nos[q, 0]]
    cluster_ids = ivf_ids[q_list_nos[q, 0]]
    if cluster_ids.size == 0: 
        found_nns.append(-1)
        continue
    # compute distances 
    distances = ((query - cluster_vectors)**2).sum(1)
    # collect result id
    result_id = cluster_ids[distances.argmin()]
    found_nns.append(result_id)
    

In [None]:
(found_nns == gt[:, 0]).sum()

That's not much. Maybe we need to explore more clusters?

In [None]:
nprobe = 13
# searching in the nearest centroid 
_, q_list_nos = faiss.knn(xq, centroids, k=nprobe)
found_nns = []
ndis = 0
for q in range(100): 
    query = xq[q]
    # fetch contents of clusters 
    cluster_vectors = np.vstack([
        ivf_vectors[i]
        for i in q_list_nos[q]
    ])
    cluster_ids = np.hstack([
        ivf_ids[i]
        for i in q_list_nos[q]
    ])
    if cluster_ids.size == 0: 
        found_nns.append(-1)
        continue
    # compute distances 
    distances = ((query - cluster_vectors)**2).sum(1)
    ndis += len(cluster_ids)
    # collect result id
    result_id = cluster_ids[distances.argmin()]
    found_nns.append(result_id)

In [None]:
(found_nns == gt[:, 0]).sum()

In [None]:
ndis / 100 

That's better, we computed just 106 distances on average per query (out of 10000)

## Inverted file in Faiss 

In [None]:
index = faiss.index_factory(d, "IVF1024,Flat") # flat means: don't encode the vectors!

In [None]:
index.train(xt)

In [None]:
index.add(xb)

In [None]:
D, I = index.search(xq, 10)

In [None]:
(I[:, 0] == gt[:, 0]).sum()

In [None]:
index.nprobe = 10
D, I = index.search(xq, 10)
(I[:, 0] == gt[:, 0]).sum()

## Tradeoff speed / accuracy 

In [None]:
results = {}
for nlist in 64, 256, 1024: 
    index = faiss.index_factory(d, f"IVF{nlist},Flat")
    index.train(xt)
    index.add(xb)
    for nprobe in 1, 2, 4, 8, 16, 32, 64, 128:
        if nprobe > nlist: 
            continue
        index.nprobe = nprobe
        t0 = time.time()
        for run in range(100):   # several runs to get stable timings
            D, I = index.search(xq, 10)
        t1 = time.time() 
        recall = (I[:, 0] == gt[:, 0]).sum()
        print(f"{nlist=:} {nprobe=:} {recall=:} time={(t1 - t0) * 1000 :.3f} ms")
        results[(nlist, nprobe)] = (recall, (t1 - t0) * 1000)
        

In [None]:
for nlist in 64, 256, 1024: 
    index = faiss.index_factory(d, f"IVF{nlist},Flat")
    index.train(xt)
    index.add(xb)
    res = [results[(nlist, nprobe)] for nprobe in [1, 2, 4, 8, 16, 32, 64, 128] if nprobe < nlist]
    recalls = [r[0] for r in res]
    times = [r[1] for r in res]
    pyplot.plot(recalls, times, label=f"{nlist=:}")

pyplot.ylabel("time (ms)")
pyplot.xlabel("R@1")
pyplot.legend()
pyplot.grid()
    

## Search cost as a function of the database size 

In [None]:
ns = 2 ** np.arange(10, 25)
nprobe = 15 # fix nprobe 
for k in 4 ** np.arange(3, 7): 
    coarse_quantization_cost = k
    ivf_scanning_cost = nprobe / k * ns
    pyplot.loglog(ns, coarse_quantization_cost + ivf_scanning_cost, label=f"{k=:}")
pyplot.xlabel("database size")
pyplot.ylabel("nb distance computations")
pyplot.title(f"search cost at {nprobe=:}")
pyplot.legend()
pyplot.grid()

# Searching in compressed vectors 

In [None]:
# work on a smaller subset because otherwise we don't see anything with such small codes 
xb_small = xb[:1000]
_, gt_small = faiss.knn(xq, xb_small, k=10)

In [None]:
# compute codes for database = find the nearest centroid for each database vector 
encoding_errors, codes = faiss.knn(xb_small, centroids, k=1)

In [None]:
codes.shape

In [None]:
codes = codes.flatten()

In [None]:
# reconstruct 
reconstructed_xb = centroids[codes]

In [None]:
MSE = ((reconstructed_xb - xb_small) ** 2).sum(1).mean()
MSE

Similar but a bit worse than the training MSE 

In [None]:
# anothe way of computing it
encoding_errors.mean()

## Asymmetric search

In [None]:
found_dis, found_indices = faiss.knn(xq, reconstructed_xb, k=10)

In [None]:
(gt_small[:, 0] == found_indices[:, 0]).sum() 

We loose 73% of nearest neighbors because the vectors are compressed a lot (12 bits). But note chance is at 1/1000 = 0.1%

## Symmetric search 

In [None]:
# let's encode and decode the queries as well 
_, xq_codes = faiss.knn(xq, centroids, k=1)
xq_codes = xq_codes.flatten()
reconstructed_xq = centroids[xq_codes]

In [None]:
found_dis, found_indices = faiss.knn(reconstructed_xq, reconstructed_xb, k=10)

In [None]:
(gt_small[:, 0] == found_indices[:, 0]).sum() 

Wow that's even worse

## Asymmetric search with look-up tables 

In [None]:
# recall reference results
found_dis, found_indices = faiss.knn(xq, reconstructed_xb, k=10)

In [None]:
# make look-up tables for all queries
def pairwise_distances(A, B): 
    return (A ** 2).sum(1)[:, None] + (B ** 2).sum(1) - 2 * A @ B.T 

In [None]:
LUT = pairwise_distances(xq, centroids)

In [None]:
LUT.shape

In [None]:
codes.shape

In [None]:
distances = LUT[:, codes]

In [None]:
distances.shape

In [None]:
found_indices_2 = distances.argmin(axis=1)

In [None]:
np.all(found_indices[:, 0] == found_indices_2)

In [None]:
found_indices_2

In [None]:
found_indices[:, 0]

In [None]:
np.where(found_indices[:, 0] != found_indices_2)

# Product Quantization

In [None]:
# 4 sub-vectors, encode each in 2^8 elements
pq = faiss.ProductQuantizer(d, 4, 8)

In [None]:
pq.code_size   # in bytes, bits/8 rounded up to next integer

In [None]:
pq.train(xt)

In [None]:
xb_codes = pq.compute_codes(xb)

In [None]:
pq_reconstruction = pq.decode(xb_codes)

In [None]:
# compute the MSE
((pq_reconstruction - xb) ** 2).sum(1).mean()

Better MSE than the 12-bit k-means one

## Manual reconstruction

In [None]:
from faiss.contrib.inspect_tools import get_pq_centroids, get_additive_quantizer_codebooks

In [None]:
pq_centroids = get_pq_centroids(pq)

In [None]:
pq_centroids.shape

Layout: number of subvectors, K, subvector dimension

In [None]:
xb_codes[:2]

In [None]:
# reconstruct vector no 123 -- TODO implement the re-construction! 
xb123_recons = 

In [None]:
np.all(pq_reconstruction[123] == xb123_recons)

## Compare options for fixed code_size
fix number of quantizers 

In [None]:
budget = 6  # budget 6 bytes per vector
for M in 4, 8, 16: 
    nbits = budget * 8 // M
    print(f"PQ {M}x{nbits}")
    pq = faiss.ProductQuantizer(d, M, nbits)
    print(f"Sub-vector size {pq.dsub} K={pq.ksub} code size {pq.code_size}")
    pq.train(xt)
    t0 = time.time()
    pq_reconstruction = pq.decode(pq.compute_codes(xb))
    t1 = time.time()
    MSE = ((pq_reconstruction - xb) ** 2).sum(1).mean()
    print(f"{MSE=:.2f} encode-decode time: {(t1 - t0)*1000:.3f} ms")

## Optimized product quantization

In [None]:
from faiss.contrib.inspect_tools import get_LinearTransform_matrix

In [None]:
opq = faiss.OPQMatrix(d, 4)
pq = faiss.ProductQuantizer(d, 4, 8)

In [None]:
opq.train(xt)

In [None]:
pq.train(opq.apply(xt))

In [None]:
xb_t = opq.apply(xb)

In [None]:
xb_t_recons = pq.decode(pq.compute_codes(xb_t))

In [None]:
((xb_t - xb_t_recons) ** 2).sum(1).mean()

The MSE for regular PQ was 13 --> improves

In [None]:
A, bias = get_LinearTransform_matrix(opq)  # how to get the OPQ matrix

In [None]:
A.shape

## PQ in an index

A product quantizer with a search function (uses look-up tables)

In [None]:
index = faiss.index_factory(d, "PQ8x6np")
index.train(xt)
index.add(xb)
D, I = index.search(xq, 10)
(I[:, 0] == gt[:, 0]).sum()

In [None]:
index = faiss.index_factory(d, "OPQ4,PQ8x6np")
index.train(xt)
index.add(xb)
D, I = index.search(xq, 10)
(I[:, 0] == gt[:, 0]).sum()

OPQ a bit better, but free at search time.

# Residual quantization

In [None]:
rq = faiss.ResidualQuantizer(d, 4, 8)

In [None]:
rq.max_beam_size 

In [None]:
%%time 
rq.train(xt[:50_000])

In [None]:
xb_recons = rq.decode(rq.compute_codes(xb))
((xb - xb_recons) ** 2).sum(1).mean()

A bit better than OPQ

In [None]:
rq.max_beam_size = 50

In [None]:
%%time
xb_recons = rq.decode(rq.compute_codes(xb))
((xb - xb_recons) ** 2).sum(1).mean()

Improves (slowly)

# Search with additive quantizers

In [None]:
index = faiss.index_factory(d, "RQ8x6")
index.code_size

In [None]:
index.train(xt[:50_000])

In [None]:
index.add(xb)
D, I = index.search(xq, 10)
(I[:, 0] == gt[:, 0]).sum()

Better than PQ & OPQ

In [None]:
%timeit index.search(xq, 10)

This is a search timing with decoding 

In [None]:
index = faiss.index_factory(d, "RQ8x6_Nqint8")
index.code_size

In [None]:
index.train(xt[:50_000])
index.add(xb)
D, I = index.search(xq, 10)
(I[:, 0] == gt[:, 0]).sum()

In [None]:
%timeit index.search(xq, 10)

Same result but much faster (uses encoded norm) 

# Scalar quantizers

In [None]:
for key in "Flat", "SQfp16", "SQ8", "SQ6", "SQ4", "LSHrt": 
    index = faiss.index_factory(d, key)
    index.train(xt[:50_000])
    index.add(xb)
    D, I = index.search(xq, 10)
    nfound = (I[:, 0] == gt[:, 0]).sum()
    
    print(f"{key} {index.code_size=:} {nfound=:}")

# Polysemous codes 

In [None]:
index = faiss.index_factory(d, "PQ8x8") # omit the np

In [None]:
index.code_size

In [None]:
index.train(xt)
index.add(xb)

In [None]:
index.polysemous_ht  # threshold of binary code comparison -- default does not filter 

In [None]:
D, I = index.search(xq, 10)
(I[:, 0] == gt[:, 0]).sum()

In [None]:
%timeit index.search(xq, 10)

In [None]:
index.search_type = faiss.IndexPQ.ST_polysemous
index.polysemous_ht = 24
D, I = index.search(xq, 10)
(I[:, 0] == gt[:, 0]).sum()

In [None]:
%timeit index.search(xq, 10)

About twice faster, same accuracy

# IVFPQ index

In [None]:
index = faiss.index_factory(d, "IVF200,PQ16x8np") 

In [None]:
index.train(xt)

In [None]:
index.add(xb)

In [None]:
D, I = index.search(xq, 10)

In [None]:
(I[:, 0] == gt[:, 0]).sum()

In [None]:
index.nprobe 

In [None]:
for nprobe in 2, 5, 10, 20, 50: 
    index.nprobe = nprobe 
    t0 = time.time()
    for _ in range(50): 
        D, I = index.search(xq, 10)
    t1 = time.time()
    nok = (I[:, 0] == gt[:, 0]).sum()
    print(f"{nprobe=:} {nok=:} {(t1 - t0)*1000:.3f} ms")

## Fast-scan SIMD implementation

In [None]:
index = faiss.index_factory(d, "IVF200,PQ32x4fsr") 
index.train(xt)
index.add(xb)

In [None]:
for nprobe in 2, 5, 10, 20, 50: 
    index.nprobe = nprobe 
    t0 = time.time()
    for _ in range(50): 
        D, I = index.search(xq, 10)
    t1 = time.time()
    nok = (I[:, 0] == gt[:, 0]).sum()
    print(f"{nprobe=:} {nok=:} {(t1 - t0)*1000:.3f} ms")