Eigen的tensorContract适配KBLAS
发表于 2025/11/27
0
简介
Eigen自身虽然提供了BLAS支持,但是没有用于tensorcontract接口,TensorFlow作为主流的一款AI框架,其中Matmul算子在ARM平台上大量用到了Eigen的tensorContract接口,在鲲鹏处理器平台上,实现Eigen的tensorContract接口调用KBLAS将能显著提升TensorFlow性能。
适配方法
分析tenorflow和Eigen代码,发现tensorflow中有个XLA模块,它为Eigen接口做了其他的实现,将tensor.contract接口适配到MKLDNN的dnnl_sgemm接口,适用于X86平台。适配代码位于org_tensorflow/third_party/xla/xla/tsl/framework/contraction/eigen_contraction_kernel.h
1. TensorFlow
我们将基于TensorFlow的XLA框架里Eigen::tensorContract适配oneDNN的代码进行修改,路径位于tensorflow/third_party/xla/xla/tsl/framework/contraction/eigen_contraction_kernel.h。
首先从eigen_contraction_kernel.h中删除dnnl相关的接口调用和数据类型,避免在鲲鹏平台编译出未定义的接口错误。代码分别位于dnnl_gemm_kernel和mkldnn_gemm_s8u8s32_kernel的operator()中,图中注释即部分为要删除的代码。


XLA的有两个宏分别是TENSORFLOW_USE_MKLDNN_CONTRACTION_KERNEL和TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL,编译时定义这两宏可以让Eigen的contractkernel替换为MLKDNN。如果只想适配tensor.contract接口,无需让TENSORFLOW_USE_MKLDNN_CONTRACTION_KERNEL全局生效,可以在eigen_contraction_kernel.h里将其替换为TENSORFLOW_USE_MKLDNN_CONTRACTION_KERNEL_KML

然后在#include “dnnl.h”的位置加入#include "kblas.h",并将dnnl.h注释掉

修改结构体dnnl_gemm_kernel的operator()的定义实现,注释掉原来的dnnl_sgemm调用代码,新增KBlas的cblas_sgemm代码

cblas_sgemm的默认为colmajor,与dnnl_sgemm的默认rowmajor相反。另一方面XLA已经按照colmajor对数据做了重排,所以cblas_sgemm的参数与dnnl_sgemm相反。
2. Eigen
如果打算不经过TensorFlow使用tensorContract接口,那么可以将上面修改后的eigen_contraction_kernel.h头文件include引入到应用程序的头文件列表里。eigen_contraction_kernel.h需要放在<unsupported\Eigen\CXX11\Tensor>之后,因为<unsupported\Eigen\CXX11\Tensor>里引用了tensorContract的Eigen原生实现,让我们优化的tensorContract覆盖Eigen的原生定义。
编译运行
编译时需要指定宏TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL和TENSORFLOW_USE_MKLDNN_CONTRACTION_KERNEL_KML,并且编译链接参数添加kblas.h头文件路径 和 kblas.so库路径。


