Examples
This section uses the sift-128-euclidean.hdf5 dataset as an example. Run the following command to obtain the dataset:
1 | wget http://ann-benchmarks.com/sift-128-euclidean.hdf5 --no-check-certificate |
Assume that the directory where the program runs is /path/to/kscann_test. The complete directory structure is as follows:
1 2 3 4 | ├── datasets // Store the dataset. └── sift-128-euclidean.hdf5 ├── main.py // The file that contains the running functions. └── sift.json // The corresponding dataset configuration file. |
Procedure:
- Assume that the program runs in the /path/to/kscann_test directory. Check whether the datasets/sift-128-euclidean.hdf5, main.py and sift.json files exist in the directory. main.py and sift.json are provided at the end of this section.
- Ensure that the sixth parameter of query_args in the sift.json file is set to the actual number of CPU cores multiplied by 4.
- Install the numactl dependency.
1yum install numactl
- Install the Python dependency.
1pip install pandas
- Run main.py. Binding to the NUMA node is required.
1numactl -N 0 -m 0 python main.py sift.json
The execution result is as follows:

The content of main.py is as follows:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 | import sys import json import time import threading from multiprocessing.pool import ThreadPool import numpy as np import h5py import psutil import scann from typing import Any, Dict, Optional class BaseANN(object): # BaseANN base class. def done(self) -> None: pass def get_memory_usage(self) -> Optional[float]: return psutil.Process().memory_info().rss / 1024 def fit(self, X: np.array) -> None: pass def query(self, q: np.array, n: int) -> np.array: return [] # array of candidate indices def batch_query(self, X: np.array, n: int) -> None: pool = ThreadPool() self.res = pool.map(lambda q: self.query(q, n), X) def get_batch_results(self) -> np.array: return self.res def get_additional(self) -> Dict[str, Any]: return {} def __str__(self) -> str: return self.name class Scann(BaseANN): # ScaNN subclass. def __init__(self, n_leaves, avq_threshold, dims_per_block, dist): self.name = "scann n_leaves={} avq_threshold={:.02f} dims_per_block={}".format( n_leaves, avq_threshold, dims_per_block ) self.n_leaves = n_leaves self.avq_threshold = avq_threshold self.dims_per_block = dims_per_block self.dist = dist def fit(self, X): if self.dist == "dot_product": spherical = True X[np.linalg.norm(X, axis=1) == 0] = 1.0 / np.sqrt(X.shape[1]) X /= np.linalg.norm(X, axis=1)[:, np.newaxis] else: spherical = False self.searcher = ( # ScannBuilder class is used to receive build parameters. scann.scann_ops_pybind.builder(X, 10, self.dist) .tree(self.n_leaves, 1, training_sample_size=len(X), spherical=spherical, quantize_centroids=True) # Add parameters related to IVF index partition. .score_ah(self.dims_per_block, anisotropic_quantization_threshold=self.avq_threshold) # Add parameters related to partitioned PQ. .reorder(1) # Add parameters related to reordering. .build() # Start to build the index. ) def set_query_arguments(self, query_args): if len(query_args) == 5: self.leaves_to_search, self.reorder, self.thd, self.refined, self.batch_size = query_args elif len(query_args) == 6: self.leaves_to_search, self.reorder, self.thd, self.refined, self.batch_size, self.num_threads = query_args else : self.leaves_to_search, self.reorder, self.thd, self.refined = query_args def query(self, v, n): # Single query. return self.searcher.search(v, n, self.reorder, self.leaves_to_search)[0] def batch_query(self, v, n): # Batch query. if self.dist == "dot_product": v[np.linalg.norm(v, axis=1) == 0] = 1.0 / np.sqrt(v.shape[1]) v /= np.linalg.norm(v, axis=1)[:, np.newaxis] self.searcher.search_additional_params(self.thd, self.refined, self.leaves_to_search) # Provide an additional interface for configuring search parameters. if (self.num_threads != 320) and (self.num_threads >= 1): self.searcher.set_num_threads(self.num_threads-1) # Configure the number of threads enabled during search. self.res = self.searcher.search_batched_parallel(v, n, self.reorder, self.leaves_to_search, self.batch_size) # Multi-thread parallel batch-query search. def load_dataset(file_path): # Load the dataset. with h5py.File(file_path, 'r') as f: train = np.array(f['train']) test = np.array(f['test']) neighbors = np.array(f['neighbors']) return train, test, neighbors def read_query_args(json_file): # Read the configuration file. with open(json_file, 'r') as file: json_info = json.load(file) return json_info.get('query_args', []) def calculate_qps(start_time, num_queries): # Calculate the QPS. elapsed_time = time.time() - start_time return num_queries / elapsed_time def run_benchmark(): if len(sys.argv) < 2: print("Usage: python main.py <config_file>") sys.exit(1) json_file = sys.argv[1] query_args = read_query_args(json_file) train_data, query_data, ground_truth = load_dataset('datasets/sift-128-euclidean.hdf5') scann_model = Scann(n_leaves=100, avq_threshold=0.2, dims_per_block=2, dist="dot_product") scann_model.fit(train_data) scann_model.set_query_arguments(query_args) start_time = time.time() scann_model.batch_query(query_data, 10) results = scann_model.get_batch_results() qps = calculate_qps(start_time, len(query_data)) print(f"QPS: {qps:.2f}") if __name__ == "__main__": run_benchmark() |
The content of sift.json is as follows:
{
"query_args": [31, 180, 0.27, 0, 200, 320]
}
Parent topic: KScaNN APIs