KuDNN的矩阵乘之对接KML_BLAS GEMM接口实现详解
发表于 2026/05/12
0
1. 概述
本文将围绕KuDNN的GEMM基本原理展开介绍,GEMM接口分别对接了KML BLAS的cblas_gemm接口和JIT code,本次主要介绍KML BLAS这一算法分支的原理和实现。
KuDNN源码可参考: https://gitcode.com/kunpengcompute/kudnn
2. 接口介绍
include/operations/kudnn_gemm.hpp中声明了Gemm接口类方法,主要是Gemm构造函数和Run执行计算两个成员函数,以及指向真正Gemm实现GemmImpl类的pImpl指针。
class KUDNN_API_PUBLIC Gemm final {
public:
Gemm(const TensorInfo &aInfo, const TensorInfo &bInfo, const TensorInfo &cInfo, const TensorInfo &biasInfo, int numThreads = 0) noexcept(false);
void Run(const void *a, const void *b, void *c, const void *bias, float alpha = 1.0f, float beta = 0.0f,
int numThreads = 0) const noexcept(false);
private:
std::unique_ptr<Detail::GemmImpl> pImpl;
};
kudnn的所有接口实现均采用Pointer to Implement机制,将类的实现细节从其对象表示中移除,通过不透明指针将它们放置在单独的类中。此技术用于构建具有稳定 ABI 的 C++ 库接口,并减少编译时依赖。因为类的私有数据成员参与其对象表示,影响大小和布局,并且因为类的私有成员函数参与重载决议(在成员访问检查之前进行),所以对这些实现细节的任何更改都需要重新编译类的所有用户。PImpl 消除了这种编译依赖;对实现的更改不会导致重新编译。因此,如果库在其 ABI 中使用 PImpl,则新版本的库可以更改实现,同时与旧版本保持 ABI 兼容。
Gemm的必要参数为A,B,C张量,计算过程为C=A*B+C。tensorInfo描述了张量信息,包括张量的维度大小(dims),数据类型,布局(layout)和步长(stride); 观察接口会发现相比BLAS的标准GEMM接口少量LDA、LDB和LDC参数,其实这些参数可以用stride表示实现。
3. 算法流程
3.1 实例化Gemm
从接口层的Gemm类只提供接口定义,其实现完全由Detail::GemmImpl来完成
// all public methods just call corresponding pImpl implemnentation
Gemm::Gemm(const TensorInfo &aInfo, const TensorInfo &bInfo, const TensorInfo &cInfo, const TensorInfo &biasInfo,
int numThreads) noexcept(false): pImpl(new Detail::GemmImpl(aInfo, bInfo, cInfo, biasInfo, numThreads)){}
Detail::GemmImpl类中有成员GemmInfo记录gemm参数信息,其类成员就是src,wei,dst,bias等参数; 同时将impl设置为nullptr
GemmImpl(const TensorInfo &srcInfo, const TensorInfo &weiInfo, const TensorInfo &dstInfo, const TensorInfo &biaInfo,
int numThreads) noexcept(false) : gemmImplInfo(srcInfo, weiInfo, dstInfo, biaInfo), impl(nullptr)
在GemmImpl构造函数中,①先校验输入参数的合法性;②如果是鲲鹏920,判断输入输出张量类型是否满足使用JIT的条件,如果满足会进一步调用FindSolution查找合适的解决方案, 并赋值给类成员impl指针;否则保持impl为nullptr,表示后续回退到BLAS计算gemm。
{
Service::ThrowOnStatus(Gemm::ValidateInput(srcInfo, weiInfo, dstInfo, biaInfo, numThreads), "GEMM");
bool srcJIT = weiJIT = dstJIT = false;
#ifdef KUDNNL_920Pro
srcJIT = srcJIT || (gemmImplInfo.srcInfo.GetType() == Element::TypeT::S8);
weiJIT = weiJIT || (gemmImplInfo.weiInfo.GetType() == Element::TypeT::S8);
...
dstJIT = dstJIT || (gemmImplInfo.dstInfo.GetType() == Element::TypeT::S32);
#endif // KUDNNL_920Pro
...
if (srcJIT && weiJIT && dstJIT) {
impl = FindSolution(gemmImplInfo, Threading::GetMaxNumThreads());
}
}
FinSolution函数会生成JITcode,具体分析将在另一篇《kudnn的JIT》文章中详细展开,这里不再赘述。至此,不用JIT的情景下的Gemm实例构造的过程结束,下面看Gemm的计算过程。
3.2. Gemm计算
Gemm计算的整体架构是
GemmImpl::Run (入口)
↓
BatchedExtendedGemm (回退BLAS GEMM算法)
↓
ChooseImpl (类型匹配与转换)
↓
GEMMCaller::Call (参数准备)
↓
GEMMWrapper (底层BLAS调用)
3.2.1. GemmImpl::Run
调用Gemm.Run接口输入需要计算的张量数据指针a,b,c实现计算。Gemm类的Run函数同样只是简单转调GemmImpl的Run函数。GemmImpl的Run实现如下,判断impl是否为nullptr,如果非空指针则跳转到JIT模块执行。否则调用BatchedExtendedGemm函数,采用传统BLAS接口计算。
void Run(const void *a, const void *b, void *c, const void *bias, float alpha, float beta, int numThreads){
if (impl != nullptr) {
if (impl->GetNThreads() != Threading::GetMaxNumThreads()) {
impl = FindSolution(impl->GetTask(), Threading::GetMaxNumThreads());
}
impl->Run({a, b, c, bias, alpha, beta});
} else {
GemmHelpers::BatchedExtendedGemm(gemmImplInfo, a, b, c, bias, alpha, beta, numThreads);
}
}
gemmImplInfo是中保存有GEMM的维度大小和布局等元数据信息,结合Run接口输入的计算数据可组合参数调用KBLAS完成GEMM计算。
3.2.2. BatchedExtendedGemm
BatchedExtendedGemm函数完成3件事,依次介绍
①布局标准化
判断输入张量的布局是否符合满足标准ABX布局,先通过ReorderLayer转换。下面介绍一下什么是ABX以及为什么要转换。
ABX布局是矩阵乘法中的一种标准内存布局约定,它规定了输入矩阵A、B和输出矩阵X在内存中的排列方式。从ABX的代码可知其采用行优先布局,维度从高到低依次排列。
Layout GetStandardABXLayout() const {
switch (dims.GetNumDims()) {
case DIM_1: {
return Layout::A;
}
...
case DIM_5: {
return Layout::ABCDE;
}
}
}
ABX布局转换的核心目的是:
- 消除布局差异,所有矩阵统一内存排列最大化BLAS性能,使用最优的NoTrans路径简化地址计算,在batch处理中避免stride语义混乱提升cache效率,保证连续内存访问
② 5D广播
auto &&broadcastedDimsA = Service::BroadcastTo5D(
Service::GetShapeAccordingToLayout(gemmInfoExec.srcInfo.GetDims(), gemmInfoExec.srcInfo.GetLayout()));
auto &&broadcastedDimsB = Service::BroadcastTo5D(
Service::GetShapeAccordingToLayout(gemmInfoExec.weiInfo.GetDims(), gemmInfoExec.weiInfo.GetLayout()));
Shape broadcastedDims = {std::max(broadcastedDimsA[IDX_0], broadcastedDimsB[IDX_0]),
std::max(broadcastedDimsA[IDX_1], broadcastedDimsB[IDX_1]),
std::max(broadcastedDimsA[IDX_2], broadcastedDimsB[IDX_2]), 1, 1};
auto &&reorderedStridesA = Service::BroadcastTo5D(Service::GetShapeAccordingToLayout(gemmInfoExec.srcInfo.GetStrides(), gemmInfoExec.srcInfo.GetLayout()), false);
auto &&reorderedStridesB = Service::BroadcastTo5D(Service::GetShapeAccordingToLayout(gemmInfoExec.weiInfo.GetStrides(), gemmInfoExec.weiInfo.GetLayout()), false);
...
③ 三重循环批处理
因为BLAS的gemm一次计算一个2D的矩阵运算,对于kudnn最高支持5维张量的情景,需要个三重循环依次遍历计算GEMM。
for (SizeType i0 = 0; i0 < broadcastedDims[IDX_0]; ++i0) {
for (SizeType i1 = 0; i1 < broadcastedDims[IDX_1]; ++i1) {
for (SizeType i2 = 0; i2 < broadcastedDims[IDX_2]; ++i2) {
const std::byte *aPtr = static_cast<const std::byte *>(aExec) +(i0 * reorderedStridesA[IDX_0] + i1 * reorderedStridesA[IDX_1] + i2 * reorderedStridesA[IDX_2]) * gemmInfoExec.srcInfo.GetType().GetSize();
const std::byte *bPtr = ...
std::byte *cPtr = ...
const std::byte *biasPtr = static_cast<const std::byte *>(biasExec);
if (biasExec) {
biasPtr += ...
}
ChooseImpl(gemmInfoExec, aPtr, bPtr, cPtr, biasPtr, alpha, beta, numThreads);
}
}
}3.2.3. ChooseImpl
实现了三级回退策略
第1级:精确类型匹配
↓ (失败)
第2级:扩展精度匹配
↓ (失败)
第3级:全FP32转换
所有的Gemm精度类型组合列表如下,由float,fp16,bf16,int8,uint8之间组合而成,每一项都是一个特化的GEMMCaller模板类,类中实现了对应数据类型的cblas_gemm函数调用,共有11种类型组合。
static std::vector<std::shared_ptr<GEMMCallerBase>> g_gemmImpls = {
std::make_shared<GEMMCaller<float, float, float, float>>(),
std::make_shared<GEMMCaller<__fp16, __fp16, __fp16, __fp16>>(),
std::make_shared<GEMMCaller<__fp16, __fp16, float, float>>(),
...
std::make_shared<GEMMCaller<std::int8_t, std::int8_t, std::int32_t, std::int32_t>>(),
...
};
第1级 - 精确匹配,如果输入的src、wei、dst和bias的类型与既定的g_gemmImpls里的数据类型完全匹配,就直接调用对应的GEMMCaller方法。
// check for exact match
for (auto &&gemmImpl : g_gemmImpls) {
bool isTypeMatch = (gemmImpl->GetSrcDt() == srcType) && (gemmImpl->GetWeiDt() == weiType) &&
(gemmImpl->GetDstDt() == dstType) && (gemmImpl->GetBiaDt() == biaType);
if (isTypeMatch) {
gemmImpl->Call(gemmInfo, src, wei, dst, bia, alpha, beta, numThreads);
return;
}
}
第2级 - 扩展精度匹配:
通过GetWiderType部分获取扩展的目标数据类型,然后选择dst为widerType类型的GEMMCaller作为目标函数,将其余参数扩展转换为目标函数的参数类型,然后调用Call函数继续执行。如果其余参数有比目标函数的参数更宽的数据类型,则继续回退到全FP32类型。
// try to convert to wider type and check if there's implementation with wider type
auto &&widerType = GetWiderType(srcType, weiType, dstType, biaType);
for (auto &&gemmImpl : g_gemmImpls) {
bool isWiderType = ((gemmImpl->GetDstDt() == widerType) && (dstType == widerType)) &&
((srcType.GetSize() <= gemmImpl->GetSrcDt().GetSize()) &&
(srcType.IsSigned() == gemmImpl->GetSrcDt().IsSigned())) &&
...
if (isWiderType) {
MatrixConverter s(srcType, gemmImpl->GetSrcDt(), src, m * k);
MatrixConverter w(weiType, gemmImpl->GetWeiDt(), wei, k * n);
MatrixConverter b(biaType, gemmImpl->GetBiaDt(), bia, biasM * biasN);
gemmImpl->Call(gemmInfo, s.Get(), w.Get(), dst, b.Get(), alpha, beta, numThreads);
return;
}
}
GetWiderType函数实现了类型扩展,如果存在BF16则返回BF16;如果输入全是FP或integer类型,就返回位数最宽的那个参数的类型;其他情况返回FP32类型,至于为什么是FP32是因为FP32是kudnn支持的最宽数据类型,对其他数据类型兼容性最好。
第3级 - 全FP32回退:
将所有输入类型全部转换为FP32类型,调用FP32的GEMMCaller计算,最后将计算输出dst转换为原来转换前的数据类型,作为最终输出。
// some integer combinations can't be supported when all the types are converted to wider type
MatrixConverter s(srcType, Element::TypeT::F32, src, m * k);
...
GEMMCaller<float, float, float, float> {}.Call(gemmInfo, s.Get(), w.Get(), d.Get(), b.Get(), alpha, beta, numThreads);
if (dstType != Element::TypeT::F32) {
switch (dstType) {
case Element::TypeT::F16: {
Service::ConvertFp32ToFp16(static_cast<const float *>(d.Get()), static_cast<__fp16 *>(dst), m * n);
break;
}
...
}
}
3.2.4. GEMMCaller::Call
GEMMCaller负责从gemmInfo中获取矩阵乘所需的M、N、K、lda、ldb、ldc、transa、transb等参数,结合src、wei、dst、bias和offsetC得到cblas_gemm所需的全部参数;再通过调用BlasSetNumThreadsLocal设置BLAS线程数,最后调用GEMMWrapper完成计算
3.2.5. GEMMWrapper
GEMMWrapper特化模板函数调用具体的cblas_?gemm接口完成计算。
template <typename SrcDt, typename WeiDt, typename DstDt, typename BiaDt>
void GEMMWrapper(...)
以void GEMMWrapper<float, float, float, float>为例,将矩阵乘法分为三种情况优化:
① n=1: GEMV(矩阵-向量乘),选取合适的布局后调用 cblas_sgemv
② m=1: GEMV,向量-矩阵乘
③ 其他: GEMM,直接调用 cblas_sgemm
最后通过 AddBias 添加偏置项


