鲲鹏社区首页
中文
注册
我要评分
文档获取效率
文档正确性
内容完整性
文档易理解
在线提单
论坛求助

使用示例

下方使用示例以使用sift-128-euclidean.hdf5数据集为例,数据集可通过以下方式获取:

1
wget http://ann-benchmarks.com/sift-128-euclidean.hdf5 --no-check-certificate

假设程序运行的目录为“/path/to/kscann_test”,完整的目录结构应如下所示:

1
2
3
4
├── datasets                                                // 存放数据集
      └── sift-128-euclidean.hdf5
├── main.py                                                 // 包含运行函数的文件
└── sift.json                                               // 对应数据集配置文件

运行步骤如下:

  1. 假设程序运行的目录为“/path/to/kscann_test”,检查目录下是否存在datasets/sift-128-euclidean.hdf5,main.py,sift.json。其中,main.py,sift.json将在下方提供。
  2. 确保sift.json文件中的“query_args”的第6个参数为实际运行时的CPU核数*4。
  3. 安装numactl依赖。
    1
    yum install numactl
    
  4. 安装Python依赖。
    1
    pip install pandas
    
  5. 运行main.py。需绑定NUMA运行。
    1
    numactl -N 0 -m 0 python main.py sift.json 
    

    运行结果如下:

main.py内容如下:
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import sys
import json
import time
import threading
from multiprocessing.pool import ThreadPool

import numpy as np
import h5py
import psutil
import scann

from typing import Any, Dict, Optional


class BaseANN(object):                                                                                                   # BaseANN基类
    def done(self) -> None:
        pass

    def get_memory_usage(self) -> Optional[float]:
        return psutil.Process().memory_info().rss / 1024

    def fit(self, X: np.array) -> None:
        pass

    def query(self, q: np.array, n: int) -> np.array:
        return []  # array of candidate indices

    def batch_query(self, X: np.array, n: int) -> None:
        pool = ThreadPool()
        self.res = pool.map(lambda q: self.query(q, n), X)

    def get_batch_results(self) -> np.array:
        return self.res

    def get_additional(self) -> Dict[str, Any]:
        return {}

    def __str__(self) -> str:
        return self.name


class Scann(BaseANN):                                                                                                    # ScaNN子类
    def __init__(self, n_leaves, avq_threshold, dims_per_block, dist):
        self.name = "scann n_leaves={} avq_threshold={:.02f} dims_per_block={}".format(
            n_leaves, avq_threshold, dims_per_block
        )
        self.n_leaves = n_leaves
        self.avq_threshold = avq_threshold
        self.dims_per_block = dims_per_block
        self.dist = dist

    def fit(self, X):
        if self.dist == "dot_product":
            spherical = True
            X[np.linalg.norm(X, axis=1) == 0] = 1.0 / np.sqrt(X.shape[1])
            X /= np.linalg.norm(X, axis=1)[:, np.newaxis]
        else:
            spherical = False

        self.searcher = (                                                                                                 # ScannBuilder类用于接收构建参数
            scann.scann_ops_pybind.builder(X, 10, self.dist)
            .tree(self.n_leaves, 1, training_sample_size=len(X), spherical=spherical, quantize_centroids=True)            # 添加IVF倒排索引分区的相关参数
            .score_ah(self.dims_per_block, anisotropic_quantization_threshold=self.avq_threshold)                         # 添加PQ分区量化的相关参数
            .reorder(1)                                                                                                   # 添加重排的相关参数
            .build()                                                                                                      # 开始构建索引 
        )

    def set_query_arguments(self, query_args):
        if len(query_args) == 5:
            self.leaves_to_search, self.reorder, self.thd, self.refined, self.batch_size = query_args
        elif len(query_args) == 6:
            self.leaves_to_search, self.reorder, self.thd, self.refined, self.batch_size, self.num_threads = query_args
        else :
            self.leaves_to_search, self.reorder, self.thd, self.refined = query_args

    def query(self, v, n):                                                                                                # 单次查询
        return self.searcher.search(v, n, self.reorder, self.leaves_to_search)[0]

    def batch_query(self, v, n):                                                                                          # 批量查询
        if self.dist == "dot_product":
            v[np.linalg.norm(v, axis=1) == 0] = 1.0 / np.sqrt(v.shape[1])
            v /= np.linalg.norm(v, axis=1)[:, np.newaxis]
        self.searcher.search_additional_params(self.thd, self.refined, self.leaves_to_search)                             # 提供额外的检索参数配置接口
        if (self.num_threads != 320) and (self.num_threads >= 1):
            self.searcher.set_num_threads(self.num_threads-1)                                                             # 配置搜索时启用线程数
        self.res = self.searcher.search_batched_parallel(v, n, self.reorder, self.leaves_to_search, self.batch_size)      # 多线程并行批量query搜索

def load_dataset(file_path):                                                                                              # 加载数据集
    with h5py.File(file_path, 'r') as f:
        train = np.array(f['train'])
        test = np.array(f['test'])
        neighbors = np.array(f['neighbors'])
    return train, test, neighbors

def read_query_args(json_file):                                                                                           # 读取配置文件
    with open(json_file, 'r') as file:
        json_info = json.load(file)
    return json_info.get('query_args', [])

def calculate_qps(start_time, num_queries):                                                                               # 计算qps
    elapsed_time = time.time() - start_time
    return num_queries / elapsed_time                

def run_benchmark():
    if len(sys.argv) < 2:
        print("Usage: python main.py <config_file>")
        sys.exit(1)
    json_file = sys.argv[1]
    query_args = read_query_args(json_file)

    train_data, query_data, ground_truth = load_dataset('datasets/sift-128-euclidean.hdf5')
    scann_model = Scann(n_leaves=100, avq_threshold=0.2, dims_per_block=2, dist="dot_product")
    scann_model.fit(train_data)
    scann_model.set_query_arguments(query_args)

    start_time = time.time()
    scann_model.batch_query(query_data, 10)
    results = scann_model.get_batch_results()

    qps = calculate_qps(start_time, len(query_data))

    print(f"QPS: {qps:.2f}")

if __name__ == "__main__":
    run_benchmark()
sift.json内容如下:
{
    "query_args": [31, 180, 0.27, 0, 200, 320]
}