开发者
我要评分
获取效率
正确性
完整性
易理解
在线提单
论坛求助

对接KuDNN算子库

HPCKit开发套件中主要用于AI框架的加速算子库KuDNN也提供了已使能矩阵算力的接口,如矩阵乘法计算的KuDNN::Gemm接口,该接口能够通过JIT动态代码技术根据输入规格生成相应的矩阵算力kernel,可以达到更好的性能。KuDNN::Gemm支持INT8、BF16、FP16和FP32四种类型。更多接口说明及使用指导可以参考鲲鹏社区的《KuDNN开发指南》

以使用OpenBLAS的单精度浮点数矩阵乘接口sgemm为例,说明如何进行替换

替换前源文件OpenBLAS_sgemm.c

#include <cblas.h>  
void matrix_multiply(float *A, float *B, float *C, int m, int n, int k) {
    cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, k, 1.0, A, k, B, n, 0.0, C, n); 
}

替换后源文件KuDNN_sgemm.cpp

#include "kudnn.hpp"
void matrix_multiply(float *A, float *B, float *C, int m, int n, int k)
{
    // layout为AB
    using SizeType = KuDNN::SizeType;
    using Shape = KuDNN::Shape;
    using Type KuDNN::Element::TypeT;
    Shape srcShape(m, k);
    Shape weiShape(k, n);
    Shape dstShape(m, n);
    // Tensor初始化
    const KuDNN::TensorInfo srcTensor = {srcShape, Type::F32, KuDNN::Layout::AB};
    const KuDNN::TensorInfo weiTensor = {weiShape, Type::F32, KuDNN::Layout::AB};
    const KuDNN::TensorInfo dstTensor = {dstShape, Type::F32, KuDNN::Layout::AB};
    int numThreads = 1;
    // 构造算子
    KuDNN::Gemm gemmLayer(srcTensor, weiTensor, dstTensor, numThreads);
    // 执行算子
    gemmLayer.Run(A, B, C, 1.0, 0.0, numThreads);
}

替换后编译指令:

clang++ -O3 -o KuDNN_sgemm KuDNN_sgemm.cpp -lkudnn

除此之外,KuDNN算子库也可以通过插件的形式集成到PyTorch等AI框架中,具体方式可参考鲲鹏社区的《KuDNN适配PyTorch接口说明》