开发者
KuDNN的矩阵乘之对接KML_BLAS GEMM接口实现详解

KuDNN的矩阵乘之对接KML_BLAS GEMM接口实现详解

HPC

发表于 2026/05/12

0

1. 概述

本文将围绕KuDNN的GEMM基本原理展开介绍,GEMM接口分别对接了KML BLAS的cblas_gemm接口和JIT code,本次主要介绍KML BLAS这一算法分支的原理和实现。

KuDNN源码可参考: https://gitcode.com/kunpengcompute/kudnn

2. 接口介绍

include/operations/kudnn_gemm.hpp中声明了Gemm接口类方法,主要是Gemm构造函数和Run执行计算两个成员函数,以及指向真正Gemm实现GemmImpl类的pImpl指针。

class KUDNN_API_PUBLIC Gemm final {
public:
    Gemm(const TensorInfo &aInfo, const TensorInfo &bInfo, const TensorInfo &cInfo, const TensorInfo &biasInfo, int numThreads = 0) noexcept(false);
    void Run(const void *a, const void *b, void *c, const void *bias, float alpha = 1.0f, float beta = 0.0f,
             int numThreads = 0) const noexcept(false);
private:
    std::unique_ptr<Detail::GemmImpl> pImpl;
};

kudnn的所有接口实现均采用Pointer to Implement机制,将类的实现细节从其对象表示中移除,通过不透明指针将它们放置在单独的类中。此技术用于构建具有稳定 ABI 的 C++ 库接口,并减少编译时依赖。因为类的私有数据成员参与其对象表示,影响大小和布局,并且因为类的私有成员函数参与重载决议(在成员访问检查之前进行),所以对这些实现细节的任何更改都需要重新编译类的所有用户。PImpl 消除了这种编译依赖;对实现的更改不会导致重新编译。因此,如果库在其 ABI 中使用 PImpl,则新版本的库可以更改实现,同时与旧版本保持 ABI 兼容。

Gemm的必要参数为A,B,C张量,计算过程为C=A*B+C。tensorInfo描述了张量信息,包括张量的维度大小(dims),数据类型,布局(layout)和步长(stride); 观察接口会发现相比BLAS的标准GEMM接口少量LDA、LDB和LDC参数,其实这些参数可以用stride表示实现。

3. 算法流程

3.1 实例化Gemm

从接口层的Gemm类只提供接口定义,其实现完全由Detail::GemmImpl来完成

// all public methods just call corresponding pImpl implemnentation
Gemm::Gemm(const TensorInfo &aInfo, const TensorInfo &bInfo, const TensorInfo &cInfo, const TensorInfo &biasInfo,
           int numThreads) noexcept(false): pImpl(new Detail::GemmImpl(aInfo, bInfo, cInfo, biasInfo, numThreads)){}

Detail::GemmImpl类中有成员GemmInfo记录gemm参数信息,其类成员就是src,wei,dst,bias等参数; 同时将impl设置为nullptr

GemmImpl(const TensorInfo &srcInfo, const TensorInfo &weiInfo, const TensorInfo &dstInfo, const TensorInfo &biaInfo,
            int numThreads) noexcept(false) : gemmImplInfo(srcInfo, weiInfo, dstInfo, biaInfo), impl(nullptr)

在GemmImpl构造函数中,①先校验输入参数的合法性;②如果是鲲鹏920,判断输入输出张量类型是否满足使用JIT的条件,如果满足会进一步调用FindSolution查找合适的解决方案, 并赋值给类成员impl指针;否则保持impl为nullptr,表示后续回退到BLAS计算gemm。

    {
        Service::ThrowOnStatus(Gemm::ValidateInput(srcInfo, weiInfo, dstInfo, biaInfo, numThreads), "GEMM");
        bool srcJIT = weiJIT = dstJIT = false;
#ifdef KUDNNL_920Pro
        srcJIT = srcJIT || (gemmImplInfo.srcInfo.GetType() == Element::TypeT::S8);
        weiJIT = weiJIT || (gemmImplInfo.weiInfo.GetType() == Element::TypeT::S8);
        ...
        dstJIT = dstJIT || (gemmImplInfo.dstInfo.GetType() == Element::TypeT::S32);
#endif // KUDNNL_920Pro
        ...
        if (srcJIT && weiJIT && dstJIT) {
            impl = FindSolution(gemmImplInfo, Threading::GetMaxNumThreads());
        }
    }

FinSolution函数会生成JITcode,具体分析将在另一篇《kudnn的JIT》文章中详细展开,这里不再赘述。至此,不用JIT的情景下的Gemm实例构造的过程结束,下面看Gemm的计算过程。

3.2. Gemm计算

Gemm计算的整体架构是

GemmImpl::Run (入口)
    ↓
BatchedExtendedGemm (回退BLAS GEMM算法)
    ↓
ChooseImpl (类型匹配与转换)
    ↓
GEMMCaller::Call (参数准备)
    ↓
GEMMWrapper (底层BLAS调用)

3.2.1. GemmImpl::Run

调用Gemm.Run接口输入需要计算的张量数据指针a,b,c实现计算。Gemm类的Run函数同样只是简单转调GemmImpl的Run函数。GemmImpl的Run实现如下,判断impl是否为nullptr,如果非空指针则跳转到JIT模块执行。否则调用BatchedExtendedGemm函数,采用传统BLAS接口计算。

void Run(const void *a, const void *b, void *c, const void *bias, float alpha, float beta, int numThreads){
    if (impl != nullptr) {
        if (impl->GetNThreads() != Threading::GetMaxNumThreads()) {
            impl = FindSolution(impl->GetTask(), Threading::GetMaxNumThreads());
        }
        impl->Run({a, b, c, bias, alpha, beta});
    } else {
        GemmHelpers::BatchedExtendedGemm(gemmImplInfo, a, b, c, bias, alpha, beta, numThreads);
    }
}

gemmImplInfo是中保存有GEMM的维度大小和布局等元数据信息,结合Run接口输入的计算数据可组合参数调用KBLAS完成GEMM计算。

3.2.2. BatchedExtendedGemm

BatchedExtendedGemm函数完成3件事,依次介绍

①布局标准化
判断输入张量的布局是否符合满足标准ABX布局,先通过ReorderLayer转换。下面介绍一下什么是ABX以及为什么要转换。
ABX布局是矩阵乘法中的一种标准内存布局约定,它规定了输入矩阵A、B和输出矩阵X在内存中的排列方式。从ABX的代码可知其采用行优先布局,维度从高到低依次排列。

Layout GetStandardABXLayout() const {
    switch (dims.GetNumDims()) {
        case DIM_1: {
            return Layout::A;
        }
        ...
        case DIM_5: {
            return Layout::ABCDE;
        }
    }
}

   ABX布局转换的核心目的是:

  • 消除布局差异,所有矩阵统一内存排列最大化BLAS性能,使用最优的NoTrans路径简化地址计算,在batch处理中避免stride语义混乱提升cache效率,保证连续内存访问

② 5D广播

auto &&broadcastedDimsA = Service::BroadcastTo5D(
        Service::GetShapeAccordingToLayout(gemmInfoExec.srcInfo.GetDims(), gemmInfoExec.srcInfo.GetLayout()));
auto &&broadcastedDimsB = Service::BroadcastTo5D(
    Service::GetShapeAccordingToLayout(gemmInfoExec.weiInfo.GetDims(), gemmInfoExec.weiInfo.GetLayout()));

Shape broadcastedDims = {std::max(broadcastedDimsA[IDX_0], broadcastedDimsB[IDX_0]),
                            std::max(broadcastedDimsA[IDX_1], broadcastedDimsB[IDX_1]),
                            std::max(broadcastedDimsA[IDX_2], broadcastedDimsB[IDX_2]), 1, 1};
auto &&reorderedStridesA = Service::BroadcastTo5D(Service::GetShapeAccordingToLayout(gemmInfoExec.srcInfo.GetStrides(), gemmInfoExec.srcInfo.GetLayout()), false);
auto &&reorderedStridesB = Service::BroadcastTo5D(Service::GetShapeAccordingToLayout(gemmInfoExec.weiInfo.GetStrides(), gemmInfoExec.weiInfo.GetLayout()), false);
...

③ 三重循环批处理
因为BLAS的gemm一次计算一个2D的矩阵运算,对于kudnn最高支持5维张量的情景,需要个三重循环依次遍历计算GEMM。

for (SizeType i0 = 0; i0 < broadcastedDims[IDX_0]; ++i0) {
    for (SizeType i1 = 0; i1 < broadcastedDims[IDX_1]; ++i1) {
        for (SizeType i2 = 0; i2 < broadcastedDims[IDX_2]; ++i2) {
            const std::byte *aPtr = static_cast<const std::byte *>(aExec) +(i0 * reorderedStridesA[IDX_0] + i1 * reorderedStridesA[IDX_1] + i2 * reorderedStridesA[IDX_2]) * gemmInfoExec.srcInfo.GetType().GetSize();
            const std::byte *bPtr = ...
            std::byte *cPtr = ...
            const std::byte *biasPtr = static_cast<const std::byte *>(biasExec);
            if (biasExec) {
                biasPtr += ...
            }
            ChooseImpl(gemmInfoExec, aPtr, bPtr, cPtr, biasPtr, alpha, beta, numThreads);
        }
    }
}

3.2.3. ChooseImpl

实现了三级回退策略

第1级:精确类型匹配
  ↓ (失败)
第2级:扩展精度匹配
  ↓ (失败)  
第3级:全FP32转换

所有的Gemm精度类型组合列表如下,由float,fp16,bf16,int8,uint8之间组合而成,每一项都是一个特化的GEMMCaller模板类,类中实现了对应数据类型的cblas_gemm函数调用,共有11种类型组合。

static std::vector<std::shared_ptr<GEMMCallerBase>> g_gemmImpls = {
    std::make_shared<GEMMCaller<float, float, float, float>>(),
    std::make_shared<GEMMCaller<__fp16, __fp16, __fp16, __fp16>>(),
    std::make_shared<GEMMCaller<__fp16, __fp16, float, float>>(),
    ...
    std::make_shared<GEMMCaller<std::int8_t, std::int8_t, std::int32_t, std::int32_t>>(),
    ...
};

第1级 - 精确匹配,如果输入的src、wei、dst和bias的类型与既定的g_gemmImpls里的数据类型完全匹配,就直接调用对应的GEMMCaller方法。

// check for exact match
for (auto &&gemmImpl : g_gemmImpls) {
    bool isTypeMatch = (gemmImpl->GetSrcDt() == srcType) && (gemmImpl->GetWeiDt() == weiType) &&
        (gemmImpl->GetDstDt() == dstType) && (gemmImpl->GetBiaDt() == biaType);
    if (isTypeMatch) {
        gemmImpl->Call(gemmInfo, src, wei, dst, bia, alpha, beta, numThreads);
        return;
    }
}

第2级 - 扩展精度匹配:
通过GetWiderType部分获取扩展的目标数据类型,然后选择dst为widerType类型的GEMMCaller作为目标函数,将其余参数扩展转换为目标函数的参数类型,然后调用Call函数继续执行。如果其余参数有比目标函数的参数更宽的数据类型,则继续回退到全FP32类型。

// try to convert to wider type and check if there's implementation with wider type
auto &&widerType = GetWiderType(srcType, weiType, dstType, biaType);
for (auto &&gemmImpl : g_gemmImpls) {
    bool isWiderType = ((gemmImpl->GetDstDt() == widerType) && (dstType == widerType)) &&
        ((srcType.GetSize() <= gemmImpl->GetSrcDt().GetSize()) &&
        (srcType.IsSigned() == gemmImpl->GetSrcDt().IsSigned())) &&
        ...
    if (isWiderType) {
        MatrixConverter s(srcType, gemmImpl->GetSrcDt(), src, m * k);
        MatrixConverter w(weiType, gemmImpl->GetWeiDt(), wei, k * n);
        MatrixConverter b(biaType, gemmImpl->GetBiaDt(), bia, biasM * biasN);
        gemmImpl->Call(gemmInfo, s.Get(), w.Get(), dst, b.Get(), alpha, beta, numThreads);
        return;
    }
}

GetWiderType函数实现了类型扩展,如果存在BF16则返回BF16;如果输入全是FP或integer类型,就返回位数最宽的那个参数的类型;其他情况返回FP32类型,至于为什么是FP32是因为FP32是kudnn支持的最宽数据类型,对其他数据类型兼容性最好。

第3级 - 全FP32回退:
将所有输入类型全部转换为FP32类型,调用FP32的GEMMCaller计算,最后将计算输出dst转换为原来转换前的数据类型,作为最终输出。

// some integer combinations can't be supported when all the types are converted to wider type
MatrixConverter s(srcType, Element::TypeT::F32, src, m * k);
...
GEMMCaller<float, float, float, float> {}.Call(gemmInfo, s.Get(), w.Get(), d.Get(), b.Get(), alpha, beta, numThreads);

if (dstType != Element::TypeT::F32) {
    switch (dstType) {
        case Element::TypeT::F16: {
            Service::ConvertFp32ToFp16(static_cast<const float *>(d.Get()), static_cast<__fp16 *>(dst), m * n);
            break;
        }
        ...
    }
}

3.2.4. GEMMCaller::Call

GEMMCaller负责从gemmInfo中获取矩阵乘所需的M、N、K、lda、ldb、ldc、transa、transb等参数,结合src、wei、dst、bias和offsetC得到cblas_gemm所需的全部参数;再通过调用BlasSetNumThreadsLocal设置BLAS线程数,最后调用GEMMWrapper完成计算

3.2.5. GEMMWrapper

GEMMWrapper特化模板函数调用具体的cblas_?gemm接口完成计算。

template <typename SrcDt, typename WeiDt, typename DstDt, typename BiaDt>
void GEMMWrapper(...)

以void GEMMWrapper<float, float, float, float>为例,将矩阵乘法分为三种情况优化:
① n=1: GEMV(矩阵-向量乘),选取合适的布局后调用 cblas_sgemv
② m=1: GEMV,向量-矩阵乘
③ 其他: GEMM,直接调用 cblas_sgemm

最后通过 AddBias 添加偏置项


本页内容