使用示例
下方使用示例以使用sift-128-euclidean.hdf5数据集为例,数据集可通过以下方式获取:
1 | wget http://ann-benchmarks.com/sift-128-euclidean.hdf5 --no-check-certificate |
假设程序运行的目录为“/path/to/kscann_test”,完整的目录结构应如下所示:
1 2 3 4 | ├── datasets // 存放数据集 └── sift-128-euclidean.hdf5 ├── main.py // 包含运行函数的文件 └── sift.json // 对应数据集配置文件 |
运行步骤如下:
- 假设程序运行的目录为“/path/to/kscann_test”,检查目录下是否存在datasets/sift-128-euclidean.hdf5,main.py,sift.json。其中,main.py,sift.json将在下方提供。
- 确保sift.json文件中的“query_args”的第6个参数为实际运行时的CPU核数*4。
- 安装numactl依赖。
1
yum install numactl
- 安装Python依赖。
1
pip install pandas
- 运行main.py。需绑定NUMA运行。
1
numactl -N 0 -m 0 python main.py sift.json
运行结果如下:
main.py内容如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 | import sys import json import time import threading from multiprocessing.pool import ThreadPool import numpy as np import h5py import psutil import scann from typing import Any, Dict, Optional class BaseANN(object): # BaseANN基类 def done(self) -> None: pass def get_memory_usage(self) -> Optional[float]: return psutil.Process().memory_info().rss / 1024 def fit(self, X: np.array) -> None: pass def query(self, q: np.array, n: int) -> np.array: return [] # array of candidate indices def batch_query(self, X: np.array, n: int) -> None: pool = ThreadPool() self.res = pool.map(lambda q: self.query(q, n), X) def get_batch_results(self) -> np.array: return self.res def get_additional(self) -> Dict[str, Any]: return {} def __str__(self) -> str: return self.name class Scann(BaseANN): # ScaNN子类 def __init__(self, n_leaves, avq_threshold, dims_per_block, dist): self.name = "scann n_leaves={} avq_threshold={:.02f} dims_per_block={}".format( n_leaves, avq_threshold, dims_per_block ) self.n_leaves = n_leaves self.avq_threshold = avq_threshold self.dims_per_block = dims_per_block self.dist = dist def fit(self, X): if self.dist == "dot_product": spherical = True X[np.linalg.norm(X, axis=1) == 0] = 1.0 / np.sqrt(X.shape[1]) X /= np.linalg.norm(X, axis=1)[:, np.newaxis] else: spherical = False self.searcher = ( # ScannBuilder类用于接收构建参数 scann.scann_ops_pybind.builder(X, 10, self.dist) .tree(self.n_leaves, 1, training_sample_size=len(X), spherical=spherical, quantize_centroids=True) # 添加IVF倒排索引分区的相关参数 .score_ah(self.dims_per_block, anisotropic_quantization_threshold=self.avq_threshold) # 添加PQ分区量化的相关参数 .reorder(1) # 添加重排的相关参数 .build() # 开始构建索引 ) def set_query_arguments(self, query_args): if len(query_args) == 5: self.leaves_to_search, self.reorder, self.thd, self.refined, self.batch_size = query_args elif len(query_args) == 6: self.leaves_to_search, self.reorder, self.thd, self.refined, self.batch_size, self.num_threads = query_args else : self.leaves_to_search, self.reorder, self.thd, self.refined = query_args def query(self, v, n): # 单次查询 return self.searcher.search(v, n, self.reorder, self.leaves_to_search)[0] def batch_query(self, v, n): # 批量查询 if self.dist == "dot_product": v[np.linalg.norm(v, axis=1) == 0] = 1.0 / np.sqrt(v.shape[1]) v /= np.linalg.norm(v, axis=1)[:, np.newaxis] self.searcher.search_additional_params(self.thd, self.refined, self.leaves_to_search) # 提供额外的检索参数配置接口 if (self.num_threads != 320) and (self.num_threads >= 1): self.searcher.set_num_threads(self.num_threads-1) # 配置搜索时启用线程数 self.res = self.searcher.search_batched_parallel(v, n, self.reorder, self.leaves_to_search, self.batch_size) # 多线程并行批量query搜索 def load_dataset(file_path): # 加载数据集 with h5py.File(file_path, 'r') as f: train = np.array(f['train']) test = np.array(f['test']) neighbors = np.array(f['neighbors']) return train, test, neighbors def read_query_args(json_file): # 读取配置文件 with open(json_file, 'r') as file: json_info = json.load(file) return json_info.get('query_args', []) def calculate_qps(start_time, num_queries): # 计算qps elapsed_time = time.time() - start_time return num_queries / elapsed_time def run_benchmark(): if len(sys.argv) < 2: print("Usage: python main.py <config_file>") sys.exit(1) json_file = sys.argv[1] query_args = read_query_args(json_file) train_data, query_data, ground_truth = load_dataset('datasets/sift-128-euclidean.hdf5') scann_model = Scann(n_leaves=100, avq_threshold=0.2, dims_per_block=2, dist="dot_product") scann_model.fit(train_data) scann_model.set_query_arguments(query_args) start_time = time.time() scann_model.batch_query(query_data, 10) results = scann_model.get_batch_results() qps = calculate_qps(start_time, len(query_data)) print(f"QPS: {qps:.2f}") if __name__ == "__main__": run_benchmark() |
sift.json内容如下:
{ "query_args": [31, 180, 0.27, 0, 200, 320] }
父主题: KScaNN接口说明