鲲鹏社区首页
中文
注册
开发者
我要评分
文档获取效率
文档正确性
内容完整性
文档易理解
在线提单
论坛求助

添加额外参数

为充分利用多核CPU的性能优势,可以为测试命令添加额外的参数“--threads”,用于手动设置线程数。

实际修改过程请参照batch_query的调用链修改,确保“/data/ann-benchmarks-main/ann_benchmarks/main.py”中的main()方法能够正确传参即可。

  1. 进入module.py文件中查看源码。
    1
    vim /data/ann-benchmarks-main/ann_benchmarks/algorithms/base/module.py
    

    源码batch_query方法中,启动的线程数量默认与CPU数相同。通过增加外部传参,即可手动设置线程数。

  2. 进入main.py文件中修改代码。
    1
    vim /data/ann-benchmarks-main/ann_benchmarks/main.py
    
    • 第125行添加“--threads”参数。
      1
      parser.add_argument("--threads", type=int, help="Number of threads. If not set, the default value will be used, and the number of threads enabled is equal to the number of CPU.", default=-1)
      

    • 修改第71行run方法调用传参。
      1
      run(definition, args.dataset, args.count, args.runs, args.batch, args.threads)
      

  3. 进入runner.py文件中修改代码。
    1
    vim /data/ann-benchmarks-main/ann_benchmarks/runner.py
    
    • 修改第197行run方法参数列表。
      1
      def run(definition: Definition, dataset_name: str, count: int, run_count: int, batch: bool, threads: int) -> None:
      

    • 修改第230行run_individual_query方法调用传参。
      1
      descriptor, results = run_individual_query(algo, X_train, X_test, distance, count, run_count, batch, threads)
      

    • 修改第23行run_individual_query方法参数列表。
      1
      2
      def run_individual_query(algo: BaseANN, X_train: numpy.array, X_test: numpy.array, distance: str, count: int,
                                   run_count: int, batch: bool, threads: int) -> Tuple[dict, list]:
      

    • 修改第86行batch_query方法调用传参。
      1
      def batch_query(X: numpy.array, threads: int) -> List[Tuple[float, List[Tuple[int, float]]]]:
      

    • 修改第105行batch_query方法调用传参。
      1
      algo.batch_query(X, count, threads)
      

    • 修改第124行batch_query方法调用传参。
      1
      results = batch_query(X_test, threads)
      

  4. 进入module.py文件中修改代码。
    1
    vim /data/ann-benchmarks-main/ann_benchmarks/algorithms/base/module.py
    
    • 修改batch_query方法参数列表。
      1
      def batch_query(self, X: numpy.array, n: int, threads: int) -> None:
      
    • 修改batch_query方法内部实现。
      1
      2
      3
      4
      if threads <= 0:
         pool = ThreadPool()
      else:
         pool = ThreadPool(threads)