鲲鹏社区首页
中文
注册
我要评分
文档获取效率
文档正确性
内容完整性
文档易理解
在线提单
论坛求助

运行和验证

ScaNN编译安装完成后获取到测试数据集进行ScaNN的运行和验证。

  1. 进入ScaNN功能验证规划路径。
    1
    cd /path/to/scann_test
    
  2. 下载测试数据集。
    1
    wget http://ann-benchmarks.com/glove-100-angular.hdf5 --no-check-certificate
    
  3. 执行以下命令创建并编写ScaNN测试脚本scann_test.py。
    1. 创建“scann_test.py”文件。
      1
      vi scann_test.py
      
    2. 按“i”进入编辑模式,编写“scann_test.py”文件,添加如下部分。
      import numpy as np
      import h5py
      import time
      
      import scann
      
      def compute_recall(neighbors, true_neighbors):
          total = 0
          for gt_row, row in zip(true_neighbors, neighbors):
              total += np.intersect1d(gt_row, row).shape[0]
          return total / true_neighbors.size
      
      def main():
          print("Load dataset: glove-100-angular.hdf5")
          glove_h5py = h5py.File("glove-100-angular.hdf5", "r")
      
          print("Dataset keys:", list(glove_h5py.keys()))
          dataset = glove_h5py['train']
          queries = glove_h5py['test']
          print("Train size: ", dataset.shape)
          print("Queries size:", queries.shape)
      
          print("\nCreate ScaNN searcher")
          start = time.time()
          normalized_dataset = dataset / np.linalg.norm(dataset, axis=1)[:, np.newaxis]
          searcher = scann.scann_ops_pybind.builder(normalized_dataset, 10, "dot_product").tree(
              num_leaves=2000, num_leaves_to_search=100, training_sample_size=250000).score_ah(
              2, anisotropic_quantization_threshold=0.2).reorder(100).build()
          end = time.time()
          print("Time (s):", end - start)
      
          print("\n1.Batched-query: queries")
          start = time.time()
          neighbors, distances = searcher.search_batched(queries)
          end = time.time()
          print("Recall:", compute_recall(neighbors, glove_h5py['neighbors'][:, :10]))
          print("Time (s):", end - start)
      
          print("\n2.Single-query: queries[0]")
          start = time.time()
          neighbors, distances = searcher.search(queries[0], final_num_neighbors=5)
          end = time.time()
      
          print("neighbors:", neighbors)
          print("distances:", distances)
          print("Time (ms):", 1000*(end - start))
      
      if __name__ == "__main__":
          main()
      
    3. 按“Esc”键,输入:wq!,按“Enter”保存并退出编辑。
  4. 运行测试。
    1
    python3 scann_test.py
    

    回显信息显示,测试程序先加载数据集glove-100-angular(该数据集为100维,训练集规模约100万条,查询数据集为10000条),之后创建一个ScaNN的searcher搜索器,最后采用2种模式查询数据,如下:

    1. Batched模式:批量查询全部查询数据集,在该参数配置下,召回率Recall为0.89965。
    2. Single模式:查询数据集中索引为0的数据,结果返回与其近似最近的5个邻居的位置(neighbors)和对应距离信息(distances)。

    若测试程序运行无报错,Batched模式下的召回率与上图回显信息相近,Single模式查询的数据信息与上图回显信息一致,则代表ScaNN功能正常。