鲲鹏计算平台 KuDNN 适配 oneDNN 实践
发表于 2026/05/19
0
1. 概述
KuDNN是专门为鲲鹏处理器定制的高性能深度学习算子库,使用KML BLAS高性能库和JIT动态代码生成kernel作为底层计算方法,优化linear、softmax和normal等AI算子性能。KuDNN可以通过算子库插件形式集成进开源软件oneDNN框架,也可以单独作为后端加速算子集成到TensorFlow或者Pytorch。
本文介绍如何将 KuDNN 集成到oneDNN-3.4 中。通过这种适配,上层框架(如 PyTorch、TensorFlow)可以通过标准的 oneDNN 接口无缝调用鲲鹏底层优化算子。
2. 构建系统适配(CMake)
适配的第一步是让 oneDNN 的构建系统能够识别并链接 kuDNN。
2.1. 引入KuDNN
在 oneDNN-3.4/cmake/KuDNN.cmake 中,通过 find_package 寻找 KuDNN 及其依赖(如 KBLAS)。
if(kdnn_cmake_included)
return()
endif()
set(kdnn_cmake_included true)
include("cmake/options.cmake")
if(NOT DNNL_TARGET_ARCH STREQUAL "AARCH64")
return()
endif()
if(NOT DNNL_AARCH64_USE_KUDNN)
return()
endif()
find_package(KuDNN REQUIRED)
if(KUDNN_FOUND)
list(APPEND EXTRA_SHARED_LIBS ${KUDNN_LIBRARIES})
include_directories(${KUDNN_INCLUDE_DIRS})
message(STATUS "KPL Library: ${KUDNN_LIBRARIES}")
message(STATUS "KPL Library headers: ${KUDNN_INCLUDE_DIRS}")
add_definitions(-DDNNL_AARCH64_USE_KUDNN)
set(CMAKE_CXX_STANDARD 14)
set(CMAKE_CXX_EXTENSIONS "OFF")
endif()- 关键宏定义: 定义 -DDNNL_AARCH64_USE_KUDNN 供 C++ 源码区分逻辑。
- 依赖管理: 将 KUDNN_LIBRARIES 添加到 EXTRA_SHARED_LIBS 中。
2.2. 自动发现库文件
创建 oneDNN-3.4/cmake/FindKuDNN.cmake,通过环境变量 KUDNN_ROOT_DIR 自动检索头文件和库文件路径,确保环境适配的灵活性。
find_path(KUDNN_INCLUDE_DIR
NAMES kudnn.hpp
PATHS ENV KUDNN_ROOT_DIR
PATH_SUFFIXES include
NO_DEFAULT_PATH
)
find_path(KPL_BLAS_INCLUDE_DIR
NAMES kblas.h
PATHS ENV BLAS_ROOT_DIR
PATH_SUFFIXES include
NO_DEFAULT_PATH
)
find_library(KUDNN_LIBRARY
NAMES kudnn
PATHS ENV KUDNN_ROOT_DIR
PATH_SUFFIXES lib
NO_DEFAULT_PATH
)
find_library(KPL_BLAS_LIBRARY
NAMES kblas
PATHS ENV BLAS_ROOT_DIR
PATH_SUFFIXES lib
NO_DEFAULT_PATH
)
include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(KuDNN DEFAULT_MSG
KUDNN_INCLUDE_DIR
KPL_BLAS_INCLUDE_DIR
KUDNN_LIBRARY
KPL_BLAS_LIBRARY
)
mark_as_advanced(
KUDNN_LIBRARY
KPL_BLAS_LIBRARY
KPL_FFT_LIBRARY
KPL_FFTF_LIBRARY
KPL_FFTH_LIBRARY
KUDNN_INCLUDE_DIR
KPL_BLAS_INCLUDE_DIR
KPL_FFT_INCLUDE_DIR
)
# Find the extra libraries and include dirs
if(KUDNN_FOUND)
list(APPEND KUDNN_INCLUDE_DIRS
${KUDNN_INCLUDE_DIR} ${KPL_BLAS_INCLUDE_DIR} ${KPL_FFT_INCLUDE_DIR})
list(APPEND KUDNN_LIBRARIES
${KUDNN_LIBRARY} ${KPL_BLAS_LIBRARY} ${KPL_FFT_LIBRARY} ${KPL_FFTF_LIBRARY} ${KPL_FFTH_LIBRARY})
endif()2.3. 修改主控脚本
在 oneDNN-3.4/CMakeList.txt 中,将原有的 ACL 包含逻辑替换为:
include("cmake/KuDNN.cmake") (ACL->KuDNN)3. 源码组织架构
在 src/cpu/aarch64/ 下新建 kudnn 目录,用于存放适配层代码。
- kudnn_matmul.hpp/kudnn_convolution.hpp...:具体的算子适配实现。
- kudnn_utils.hpp/cpp:负责 oneDNN 内存描述符(Memory Descriptor)与 kuDNN 参数格式的转换。
- kudnn/jit/:存放基于 SVE/NEON 指令生成的 JIT 内核代码。
oneDNN-3.4/src/cpu/aarch64/kudnn/
├── [Top Level] 算子原始定义 (Primitives)
│ ├── kudnn_matmul.hpp, kudnn_convolution.hpp, kudnn_eltwise.hpp...
│ └── kudnn_post_ops.hpp/cpp (算子融合逻辑)
│
├── [Utils] 数据与协议转换层
│ ├── kudnn_utils.hpp (总入口)
│ ├── kudnn_utils_gemm.cpp, kudnn_utils_conv.cpp (MatMul/Conv 映射)
│ ├── kudnn_utils_act.cpp, kudnn_utils_softmax.cpp (激活/归一化映射)
│ └── kudnn_utils_common.cpp, kudnn_utils_thread.cpp (通用与线程管理)
│
└── [JIT] 即时编译加速引擎
├── utils/
│ ├── kudnn_jit_generator.hpp (AArch64 汇编生成基类)
│ └── kudnn_jit_primitive_conf.hpp (JIT 参数配置)
├── eltwise/
│ ├── kudnn_jit_uni_eltwise.cpp (通用 Eltwise JIT 实现)
│ └── kudnn_jit_uni_eltwise_injector.cpp (汇编级算子注入器)
└── convolution/
├── kudnn_jit_sve_convolution.cpp (针对 SVE 指令集的卷积实现)
├── kudnn_jit_f32_convolution.cpp (FP32 卷积逻辑)
└── kudnn_jit_f32_conv_kernel.cpp (卷积核心计算 Kernel)4. 算子注册机制
oneDNN 通过一个“算子实现列表(Implementation List)”来管理所有可选方案。我们需要将 KuDNN 的实现排在列表前列,以确保高优先级。
以 MatMul 为例,修改 src/cpu/matmul/cpu_matmul_list.cpp:
- 引入头文件:使用 DNNL_AARCH64_USE_KUDNN 宏隔离。
- 添加实例:使用 CPU_INSTANCE_AARCH64_KUDNN(kdnn_matmul_t) 宏将 KuDNN 实现插入列表。宏将KuDNN 实现插入列表。
#if DNNL_X64
#include "cpu/x64/matmul/brgemm_matmul.hpp"
#include "cpu/x64/matmul/jit_uni_sparse_matmul.hpp"
using namespace dnnl::impl::cpu::x64::matmul;
using namespace dnnl::impl::cpu::x64;
#elif DNNL_AARCH64
#ifdef DNNL_AARCH64_USE_ACL
#include "cpu/aarch64/matmul/acl_matmul.hpp"
#endif
#if DNNL_AARCH64_USE_KUDNN
#include "cpu/aarch64/kudnn/kudnn_matmul.hpp"
#endif // DNNL_AARCH64_USE_KUDNN
...
constexpr impl_list_item_t impl_list[] = REG_MATMUL_P({
CPU_INSTANCE_AARCH64_KUDNN(kdnn_matmul_t)
CPU_INSTANCE_AARCH64_ACL(acl_matmul_t)
...
nullptr,
});提示: 列表中的位置决定了搜索优先级。将 KuDNN 放在 ACL 之前,可以确保在鲲鹏平台上优先触发最强性能路径。
5. 核心适配逻辑实现:以 MatMul 为例
适配代码的核心在于实现 primitive_t 和 pd_t(Primitive Descriptor)。
5.1 描述符初始化 (pd_t::init)
在 init 函数中,需要进行严格的兼容性检查:
- 检查数据类型(如是否支持 BF16/FP32)。
- 检查属性(Attributes)和后置操作(Post-ops)。
- 调用 kdnn_utils::convert_to_kdnn_gemm 将 oneDNN 描述符映射为 KuDNN 算子对象。
status_t init(engine_t *engine) {
bool ok = !has_zero_dim_memory() &&
attr()->has_default_values(dnnl_primitive_attr::skip_mask_t::post_ops |
dnnl_primitive_attr::skip_mask_t::scales_runtime) &&
attr_scales_ok() &&
(attr()->output_scales_.mask_ == 0) &&
(attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_ == 0) &&
!has_runtime_dims_or_strides();
if (!ok) return status::unimplemented;
if (weights_md_.format_kind == dnnl_format_kind_t::dnnl_format_kind_any) {
// Default to 'N T' form to reduce reorder calls in the model.
const auto wei_tag = pick(weights_md_.ndims - 2, ba, acb, abdc, abced);
memory_desc_init_by_tag(weights_md_, wei_tag);
}
VDISPATCH_MATMUL(set_default_formats(), VERBOSE_UNSUPPORTED_TAG);
const memory_desc_wrapper src_d(&src_md_);
const memory_desc_wrapper wei_d(&weights_md_);
data_type_t dst_dt = dst_md_.data_type;
if (attr_.post_ops_.len() > 1) {
return status::unimplemented; // Currently, only one post-ops is supported.
}
bool has_post_ops = (attr_.post_ops_.len() >= 1);
if (has_post_ops) {
dst_md_.data_type = data_type_t::dnnl_f32; // Temporarily change the output data type to f32
}
const memory_desc_wrapper dst_d(dst_md());
const memory_desc_wrapper bias_d(&bias_md_);
auto&& matmul = kdnn_utils::convert_to_kdnn_gemm(src_d, wei_d, dst_d, bias_d);
if (!matmul.first) {
return status::unimplemented;
} else {
kdnn_matmul_prim_.reset(matmul.second);
if (has_post_ops) {
dst_md_.data_type = dst_dt;
}
CHECK(post_ops_.init(engine, attr_.post_ops_, dst_md_));
if (has_post_ops) {
need_tmp_dst_ = true;
dst_size_ = dst_d.nelems() * types::data_type_size(data_type_t::dnnl_f32);
} else {
need_tmp_dst_ = false;
dst_size_ = 0;
}
auto scratchpad = scratchpad_registry().registrar();
book_precomputed_scales(scratchpad, attr()->scales_, N());
return status::success;
}5.2 资源管理 (resource_t)
由于 KuDNN 算子对象(如 KuDNN::Gemm)是非线程安全的或需要管理生命周期,我们使用 kdnn_matmul_resource_t 对其进行封装,并利用 oneDNN 的 resource_mapper 进行管理。
struct kdnn_matmul_resource_t : public resource_t {
kdnn_matmul_resource_t(const std::unique_ptr<KuDNN::Gemm> &kdnn_matmul_prim) noexcept
: kdnn_matmul_obj_(new KuDNN::Gemm{*(kdnn_matmul_prim.get())}) {}
KuDNN::Gemm &get_kdnn_obj() const noexcept { return *kdnn_matmul_obj_; }
DNNL_DISALLOW_COPY_AND_ASSIGN(kdnn_matmul_resource_t);
private:
std::unique_ptr<KuDNN::Gemm> kdnn_matmul_obj_;
}; // kdnn_matmul_resource_t5.3 执行运算(execute_forward)
在执行阶段,核心流程如下:
- 资源准备:由于 resource_mapper 和底层的 kdnn_obj 在执行期间可能存在非线程安全的操作,需要先锁定当前算子实例,同时动态获取与当前执行上下文绑定的 KuDNN 算子对象。std::lock_guard<std::mutex> _lock {this->mtx}; KuDNN::Gemm &kdnn_obj = (ctx.get_resource_mapper()->get<kdnn_matmul_resource_t>(this))->get_kdnn_obj();
- 处理 Scale:通过 precompute_scales 处理量化参数。const float *scales; try { DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC); DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS); auto scratchpad = ctx.get_scratchpad_grantor(); const int ndims = pd()->ndims(); const auto dst_d = ctx.memory_mdw(DNNL_ARG_DST, pd()->dst_md()); scales = precompute_scales(scratchpad, src_scales, wei_scales, dst_d.dims()[ndims - 1], pd()->attr()); } catch (const std::exception &e) { return status::runtime_error; }
- 执行运算:
- 路径 A:带 Post-ops 的混合执行
当 pd()->need_tmp_dst_ 为真时(例如存在 ReLU、Add 或降精度需求): - 路径 B:没有 Post-ops 时直接执行
kdnn_obj.Run(src, wei, dst, ...),数据直接写入最终输出内存。
if (pd()->need_tmp_dst_) {
std::unique_ptr<float, KuDNN::Service::Deallocator<float>> tmp_buffer(
static_cast<float*>(KuDNN::Service::AlignedAlloc(pd()->dst_size_)));
try {
kdnn_obj.Run(src, wei, tmp_buffer.get(), bias, scales[0], 0.0f);
} catch (const std::exception &e) {
return status::runtime_error;
}
pd()->post_ops_.execute(ctx, tmp_buffer.get(), dst);
} else {
try {
kdnn_obj.Run(src, wei, dst, bias, scales[0], 0.0f);
} catch (const std::exception &e) {
return status::runtime_error;
}
}

