开发者
KuDNN的矩阵乘之JIT实现详解

KuDNN的矩阵乘之JIT实现详解

HPC

发表于 2026/05/15

0

概述

本文讲从KuDNN::Gemm接口入手,解析JIT的实现原理,分析其如何生成代码和执行。

基本结构层次

kudnn的目录结构为

src/dnn/
├── include/                # Public API headers
│   ├── types/              # Data types, shapes, layouts, enums
│   ├── service/            # Error handling, threading, service APIs
│   ├── operations/         # Operation descriptors (GEMM, Conv, Norm, etc.)
│   └── kudnn.hpp           # Main include file
└── src/                    # Implementation
    ├── kernels/            # Compute kernels (convolution, softmax, rmsnorm, reorder, vml)
    ├── matmul_jit/         # GEMM JIT code generation
    │   ├── MatrixComputation/  # JIT Code for SME
    │   ├── VectorComputation/  # JIT Code for SVE
    │   └── kudnn_gemm_jit.hpp  # JIT Base implement
    ├── aarch64_codegen/    # ARM64 assembly code generation
    ├── vector_apis/        # SVE vector intrinsics
    └── *.cpp               # Operation implementations

inlclude/operations目录下声明了GEMM接口类,src/matmul_jit下定义了JIT实现逻辑,JIT包含SME与SVE两类,

KuDNN 的 JIT 代码生成器分为三层:


┌─────────────────────────────────────────────────────────┐
│  kudnn_jit.hpp - JIT 框架层                              │
│  - KUDNN_JIT_Generator: 代码生成基类                      │
│  - ImplBase<Task>: 任务执行基类                           │
│  - SolutionHandler: 实现缓存与选择                        │
├─────────────────────────────────────────────────────────┤
│  matmul_jit/kudnn_gemm_jit.hpp - GEMM JIT 实现          │
│  - MatMulKernelCodeGenBase: 内核代码生成器模板             │
│  - MatMulImplBase: GEMM 执行基类                         │
├─────────────────────────────────────────────────────────┤
│  aarch64_codegen/ - ARM64 指令生成器                      │
│  - cg_codegen.hpp: CodeGenerator (SVE/SME 指令支持)      │
└─────────────────────────────────────────────────────────┘

SME 矩阵累加器 核心指令是SME的fmopa外积指令,计算一个2VLx2VL即32x32的向量外积。

               权重矩阵 (z2, z3, z6, z7)
                  ↓
   源矩阵         ┌───────┬───────┐
 (z0,z1,z4,z5)   │ za0.s │ za1.s │  ← 累加结果
                 ├───────┼───────┤
                 │ za2.s │ za3.s │
                 └───────┴───────┘

2. 算法流程

2.1. Gemm实例化

在KuDNN用BLAS一章中介绍了非JIT的实例化过程,在GemmImpl构造函数中会选择是否采用JIT方案,本次将沿着FindSolution函数这一JIT分支深入解析其实现,当前KuDNN仅支持基于sme指令的JIT实现。

{
   bool srcJIT = weiJIT = dstJIT = false;
#ifdef KUDNNL_920Pro
   srcJIT = srcJIT || (gemmImplInfo.srcInfo.GetType() == Element::TypeT::S8);
   ...
#endif // KUDNNL_920Pro
   if (srcJIT && weiJIT && dstJIT) {
       impl = FindSolution(gemmImplInfo, Threading::GetMaxNumThreads());
   }
}

2.2. 主要类图

从类图可以看到,CodeArray负责管理JIT代码内存与写入,MatMulKernelDefaultCodeGenImpl定义了JIT代码的生成逻辑,CodeGenerator定义了具体计指令的生成原理。

2.3. 函数流程

以SME架构的F32类型Gemm为例:

2.4. MatMulKernelDefaultCodeGenImpl::MatMulKernelDefaultCodeGenImpl函数详解

该函数中可以看到熟悉的矩阵乘法的实现逻辑,其过程像一个矩阵乘函数的汇编代码,有callee的寄存器保存,函数参数加载,谓词寄存器设置,开关sme模式,ZA寄存器初始化,K循环计算,ZA数据写回内存,callee函数保存的寄存器恢复,函数return。

{
    Service::Unused(srcOffsetArg, weiOffsetArg);
    if ((m >= 0) && (n >= 0) && (k >= 0)) {
        Preamble();     //函数入栈
        Smstart(ARMCG::StreamMode::SMZA);   // SME 模式开始

        Mov(mVal, m);
        Mov(nVal, n);
        //参数传递
        Mov(biaPtrCopy, biaPtr);

        IntType vl = svcntw();
        // initialize predicates for loads and stores
        Mov(tmp, 0);
        // p0, p2 used for SRC, p1, p3 used for WEI
        Whilelt(ARMCG::PRegS(NUM_0), tmp, mVal);
        Whilelt(ARMCG::PRegS(NUM_1), tmp, nVal);
        if ((m > vl) || (n > vl)) {
            Incw(tmp);
            if (m > vl) {
                Whilelt(ARMCG::PRegS(NUM_2), tmp, mVal);
            }
            if (n > vl) {
                Whilelt(ARMCG::PRegS(NUM_3), tmp, nVal);
            }
        }
        // constants to perform continious stores from ZA tiles
        Mov(tmpLDStC0, NUM_0);
        Mov(tmpLDStC1, NUM_4);
        Mov(tmpLDStC2, NUM_8);
        Mov(tmpLDStC3, NUM_12);

        InitializeDstRegisters(m, n, vl);
        GenerateKLoop(m, n, k, vl);     // K方向累加计算指令
        StoreDstRegisters(m, n, vl);    // 写回累加器结果

        Smstop(ARMCG::StreamMode::SMZA);    // SME 模式结束

        Postamble();    //函数出栈
    }
    Ret();
}

2.5. 指令生成

这里以smestart指令为例,Concat通过组合不同的字段得到32位指令,WriteData将指令写到内存。top_是jit code的顶部,地址递增地写入新指令。

void CodeGenerator::Smstart(StreamMode mod)
{
    MatrixComputationStart(mod);
}

void CodeGenerator::MatrixComputationStart(StreamMode mod)
{
    uint32_t code =
        Concat({LeftShift(0x1AA068, BITFIELD_11), LeftShift(mask, BITFIELD_9), LeftShift(0x17F, BITFIELD_0)});
    WriteData(code);
}

void CodeArray::WriteData(uint32_t code)
{
    if (size_ >= maxSize_) {
        if (type_ == Type::MEM_AUTO_GROW) {
            GrowMemory();
        } else {
            throw Error(Err::CODE_IS_TOO_BIG);
        }
    }
    top_[size_++] = code;
}

2.6. CodeGenerator::Ready函数详解

Ready函数的定义位于CodeGenerator类里,计算JIT Code被跳转地址,设置JIT code所在内存属性,确保缓存写入到内存中。

void CodeGenerator::Ready(ProtectMode mode)
{
    if (HasUndefinedLabel()) {
        throw Error(Err::LABEL_IS_NOT_FOUND);
    }
    if (IsAutoGrow()) {
        CalcJmpAddress();
    }
    if (UseProtect()) {
        SetProtectMode(mode);
    }
    ClearCache(static_cast<void *>(GetCode()), static_cast<void *>(GetCurr()));
}

2.7. Gemm Run计算

在GemmImpl::Run的定义中,下面的条件分支会因为impl!=nullptr跳转到JIT分支执行

if (impl != nullptr) {
    if (impl->GetNThreads() != Threading::GetMaxNumThreads()) {
        impl = FindSolution(impl->GetTask(), Threading::GetMaxNumThreads());
    }
    impl->Run({a, b, c, bias, alpha, beta});
} else {
    GemmHelpers::BatchedExtendedGemm(gemmImplInfo, a, b, c, bias, alpha, beta, numThreads);
}

这里有个额外判断,如果线程数发生变化需要重新执行FindSolution函数,因为线程不同其任务划分会发生变化。

函数调用流程如下:

2.8. MatMulImplBase::Run函数

启动多线并行域,计算线程并行调用MatMulImplBase::RunSequential,

2.9. RunSequential函数

RunSequential是单线程内逻辑,根据线程id遍历本线程所分配的batchDims维度,也就是5D tensor的高3维的一部分,依次处理

for (SizeType i0 = 0; i0 < perThreadParams[threadId].broadcastedDims[IDX_0]; ++i0) {
    for (SizeType i1 = 0; i1 < perThreadParams[threadId].broadcastedDims[IDX_1]; ++i1) {
        for (SizeType i2 = 0; i2 < perThreadParams[threadId].broadcastedDims[IDX_2]; ++i2) {
            // 计算src, wei, dst, bias数据指针偏移
            // ...
            RunSequential2D(threadId, srcPtr, weiPtr, dstPtr, biaExec, buffer, alpha, beta);
        }
    }
}

2.10. RunSequential2D函数

RunSequential2D函数的主要过程如下,是一个5层循环,外三层依次按块大小遍历M,K,N;内2层遍历分块后的M和N,最终跳转到实例化阶段生成的JIT Code执行。

for (IntType mBlock = 0; mBlock < m; mBlock += CodeGen::GetMBlockSize(m, n, k)) {
    for (IntType kBlock = 0; kBlock < k; kBlock += CodeGen::GetKBlockSize(m, n, k)) {
        for (IntType nBlock = 0; nBlock < n; nBlock += CodeGen::GetNBlockSize(m, n, k)) {
            for (IntType mSubBlock = 0; mSubBlock < mBlockExec; mSubBlock += CodeGen::kernelMSize * vl) {
                for (IntType nSubBlock = 0; nSubBlock < nBlockExec; nSubBlock += CodeGen::kernelNSize * vl) {
                    auto kernel = perThreadImpls[threadId][mIdx * NUM_4 + nIdx * NUM_2 + kIdx].template GetCode<void (*)(const SrcType *, const WeiType *, DstType *, const BiaType *, DstLoadFlags, float)>();
                    kernel(srcBuffer + mSubBlock * kBlockExec, weiBuffer + nSubBlock * kBlockExec, dstExec, biaExec, biaFlagsExec, beta);
                }
            }
        }
    }
}

.templete GetCode会调用CodeGenWrapper类的模板类型MatMulKernelDefaultCodeGen的基类CodeArray的GetCode()函数,返回JIT 代码的入口地址

template <class F>
const F GetCode() const
{
    return reinterpret_cast<F>(top_);
}

得到的kernel是个函数指针,指向JIT Code入口地址,kernel只负责计算如下大小的矩阵

nrmrkc


本页内容