开发者
资源
鲲鹏计算平台矩阵乘算子SGEMM优化实践——下篇

鲲鹏计算平台矩阵乘算子SGEMM优化实践——下篇

HPC

发表于 2026/06/18

0

鲲鹏计算平台矩阵乘算子SGEMM优化实践分为上下两篇,记录了在鲲鹏计算平台上,从最朴素的三重循环矩阵乘法出发,经过 6 个阶段逐步优化,最终实现 30+倍性能提升 的完整过程。所有代码基于 C 语言 + ARM intrinsics,矩阵存储采用列主序 (Column-Major),编译器使用华为毕昇编译器 (BiSheng)。本文为下篇,续写SGEMM优化实践内容,上篇请参考 鲲鹏计算平台矩阵乘算子SGEMM优化实践——上篇

1. sgemm实现之数据重排+SVE向量化

1.1 优化思路

前序的循环展开、Neon向量化、SVE向量化等优化手段都集中在计算侧,但瓶颈已经转移到访存侧。A 的跨步访问导致缓存利用率极差。我们无法改变矩阵 A 在内存中的原始布局,但可以在计算前将需要的数据重新排列为连续布局——这就是数据重排(Packing)。

1.2. 数据重排策略详解

A 的数据重排:列主序 → Panel 布局

原始 A 中,同一列的元素连续存储,但内层循环需要的是同一k 行、连续 i 列的元素。数据重排将 A 重排为 panel 布局:

原始 A (列主序, M=8, K=4):     数据重排后 packed_A (vl=4):
                                
A00 A01 A02 A03                Panel 0 (i=0~3):
A10 A11 A12 A13                  A00 A10 A20 A30 | A01 A11 A21 A31 | A02 A12 A22 A32 | A03 A13 A23 A33
A20 A21 A22 A23                  ← k=0 连续 ──→    ← k=1 连续 ──→
A30 A31 A32 A33
A40 A41 A42 A43                Panel 1 (i=4~7):
A50 A51 A52 A53                  A40 A50 A60 A70 | A41 A51 A61 A71 | ...
A60 A61 A62 A63                  ← 每个 k 的 vl 个元素连续 ──→
A70 A71 A72 A73

数据重排后,pan_A[k * vl] 就是 vl 个连续 float,一次 svld1_f32(svptrue_b32(), ...) 即可加载——不再需要谓词,因为 panel 内总是完整的 vl 个元素

B 的数据重排:列主序 → 行连续布局

原始 B 中,同一列的元素连续,但微内核需要的是同一 k 行、NR 列的元素:

原始 B (列主序, K=4, N=8):     数据重排后 packed_B (NR=4):

B00 B01 B02 B03 B04 B05 B06 B07   Block 0 (j=0~3):
B10 B11 B12 B13 B14 B15 B16 B17     B00 B01 B02 B03 | B10 B11 B12 B13 | B20 B21 B22 B23 | B30 B31 B32 B33
B20 B21 B22 B23 B24 B25 B26 B27     ← k=0, NR=4 连续 ──→
B30 B31 B32 B33 B34 B35 B36 B37
                                  Block 1 (j=4~7):
                                    B04 B05 B06 B07 | B14 B15 B16 B17 | ...

1.3. 微内核设计

数据重排后,微内核的内存访问模式发生了质变:

#define NR 4

static void pack_a(int M, int K, int vl, const float *A, float *packed_A) {
    for (int i = 0; i < M; i += vl) {
        int rows = (i + vl <= M) ? vl : (M - i);
        svbool_t pg = svwhilelt_b32((int64_t)0, (int64_t)rows);
        float *panel = packed_A + (i / vl) * vl * K;
        for (int k = 0; k < K; k++) {
            svfloat32_t a_vec = svld1_f32(pg, &A[i + k * M]);
            svst1_f32(svptrue_b32(), &panel[k * vl], a_vec);
        }
    }
}

static void pack_b(int K, int N, const float *B, float *packed_B) {
    for (int j = 0; j < N; j += NR) {
        int cols = (j + NR <= N) ? NR : (N - j);
        float *block = packed_B + (j / NR) * NR * K;
        for (int k = 0; k < K; k++) {
            for (int c = 0; c < cols; c++) {
                block[k * NR + c] = B[k + (j + c) * K];
            }
            for (int c = cols; c < NR; c++) {
                block[k * NR + c] = 0.0f;
            }
        }
    }
}
void sgemm_pack_sve(int M, int N, int K, const float *A, const float *B, float *C) {
    int64_t vl = svcntw();
    int n_panels = (M + vl - 1) / vl;
    int n_blocks = (N + NR - 1) / NR;

    float *packed_A = (float *)aligned_alloc(64, (size_t)n_panels * vl * K * sizeof(float));
    float *packed_B = (float *)aligned_alloc(64, (size_t)n_blocks * NR * K * sizeof(float));

    pack_a(M, K, (int)vl, A, packed_A);
    pack_b(K, N, B, packed_B);

    for (int j = 0; j < N; j += NR) {
        int cols = (j + NR <= N) ? NR : (N - j);
        float *blk_B = packed_B + (j / NR) * NR * K;
        for (int i = 0; i < M; i += vl) {
            int rows = (i + vl <= M) ? vl : (M - i);
            svbool_t pg = svwhilelt_b32((int64_t)0, (int64_t)rows);
            float *pan_A = packed_A + (i / vl) * vl * K;

            svfloat32_t c0_col0 = svdup_n_f32(0.0f);
            svfloat32_t c1_col0 = svdup_n_f32(0.0f);
            svfloat32_t c2_col0 = svdup_n_f32(0.0f);
            svfloat32_t c3_col0 = svdup_n_f32(0.0f);

            svfloat32_t c0_col1 = svdup_n_f32(0.0f);
            svfloat32_t c1_col1 = svdup_n_f32(0.0f);
            svfloat32_t c2_col1 = svdup_n_f32(0.0f);
            svfloat32_t c3_col1 = svdup_n_f32(0.0f);

            svfloat32_t c0_col2 = svdup_n_f32(0.0f);
            svfloat32_t c1_col2 = svdup_n_f32(0.0f);
            svfloat32_t c2_col2 = svdup_n_f32(0.0f);
            svfloat32_t c3_col2 = svdup_n_f32(0.0f);

            svfloat32_t c0_col3 = svdup_n_f32(0.0f);
            svfloat32_t c1_col3 = svdup_n_f32(0.0f);
            svfloat32_t c2_col3 = svdup_n_f32(0.0f);
            svfloat32_t c3_col3 = svdup_n_f32(0.0f);

            int k = 0;
            for (; k + 3 < K; k += 4) {
                svfloat32_t a0 = svld1_f32(svptrue_b32(), &pan_A[(k+0)*vl]);
                svfloat32_t a1 = svld1_f32(svptrue_b32(), &pan_A[(k+1)*vl]);
                svfloat32_t a2 = svld1_f32(svptrue_b32(), &pan_A[(k+2)*vl]);
                svfloat32_t a3 = svld1_f32(svptrue_b32(), &pan_A[(k+3)*vl]);

                c0_col0 = svmla_f32_m(pg, c0_col0, a0, svdup_n_f32(blk_B[(k+0)*NR+0]));
                c1_col0 = svmla_f32_m(pg, c1_col0, a1, svdup_n_f32(blk_B[(k+1)*NR+0]));
                c2_col0 = svmla_f32_m(pg, c2_col0, a2, svdup_n_f32(blk_B[(k+2)*NR+0]));
                c3_col0 = svmla_f32_m(pg, c3_col0, a3, svdup_n_f32(blk_B[(k+3)*NR+0]));

                c0_col1 = svmla_f32_m(pg, c0_col1, a0, svdup_n_f32(blk_B[(k+0)*NR+1]));
                c1_col1 = svmla_f32_m(pg, c1_col1, a1, svdup_n_f32(blk_B[(k+1)*NR+1]));
                c2_col1 = svmla_f32_m(pg, c2_col1, a2, svdup_n_f32(blk_B[(k+2)*NR+1]));
                c3_col1 = svmla_f32_m(pg, c3_col1, a3, svdup_n_f32(blk_B[(k+3)*NR+1]));

                c0_col2 = svmla_f32_m(pg, c0_col2, a0, svdup_n_f32(blk_B[(k+0)*NR+2]));
                c1_col2 = svmla_f32_m(pg, c1_col2, a1, svdup_n_f32(blk_B[(k+1)*NR+2]));
                c2_col2 = svmla_f32_m(pg, c2_col2, a2, svdup_n_f32(blk_B[(k+2)*NR+2]));
                c3_col2 = svmla_f32_m(pg, c3_col2, a3, svdup_n_f32(blk_B[(k+3)*NR+2]));

                c0_col3 = svmla_f32_m(pg, c0_col3, a0, svdup_n_f32(blk_B[(k+0)*NR+3]));
                c1_col3 = svmla_f32_m(pg, c1_col3, a1, svdup_n_f32(blk_B[(k+1)*NR+3]));
                c2_col3 = svmla_f32_m(pg, c2_col3, a2, svdup_n_f32(blk_B[(k+2)*NR+3]));
                c3_col3 = svmla_f32_m(pg, c3_col3, a3, svdup_n_f32(blk_B[(k+3)*NR+3]));
            }

            svfloat32_t sum_col0 = svadd_f32_m(pg, svadd_f32_m(pg, c0_col0, c1_col0),
                                                 svadd_f32_m(pg, c2_col0, c3_col0));
            svfloat32_t sum_col1 = svadd_f32_m(pg, svadd_f32_m(pg, c0_col1, c1_col1),
                                                 svadd_f32_m(pg, c2_col1, c3_col1));
            svfloat32_t sum_col2 = svadd_f32_m(pg, svadd_f32_m(pg, c0_col2, c1_col2),
                                                 svadd_f32_m(pg, c2_col2, c3_col2));
            svfloat32_t sum_col3 = svadd_f32_m(pg, svadd_f32_m(pg, c0_col3, c1_col3),
                                                 svadd_f32_m(pg, c2_col3, c3_col3));

            for (; k < K; k++) {
                svfloat32_t a_vec = svld1_f32(svptrue_b32(), &pan_A[k*vl]);
                sum_col0 = svmla_f32_m(pg, sum_col0, a_vec, svdup_n_f32(blk_B[k*NR+0]));
                sum_col1 = svmla_f32_m(pg, sum_col1, a_vec, svdup_n_f32(blk_B[k*NR+1]));
                sum_col2 = svmla_f32_m(pg, sum_col2, a_vec, svdup_n_f32(blk_B[k*NR+2]));
                sum_col3 = svmla_f32_m(pg, sum_col3, a_vec, svdup_n_f32(blk_B[k*NR+3]));
            }

            svst1_f32(pg, &C[i + (j+0)*M], sum_col0);
            if (cols >= 2) svst1_f32(pg, &C[i + (j+1)*M], sum_col1);
            if (cols >= 3) svst1_f32(pg, &C[i + (j+2)*M], sum_col2);
            if (cols >= 4) svst1_f32(pg, &C[i + (j+3)*M], sum_col3);
        }
    }

    free(packed_A);
    free(packed_B);
}

实测性能提升25.0×倍

1.4. 数据重排能带来性能提升原因分析

1) A 的连续加载——消除跨步访问

数据重排前:  svld1_f32(pg, &A[i + k*M])     步长 = M×4 = 2KB (跨缓存行)
数据重排后:  svld1_f32(svptrue_b32(), &pan_A[k*vl])  步长 = vl×4 = 64B (缓存行对齐!)

每次 svld1_f32 加载恰好一个缓存行(16×4=64B),且连续 k 之间的地址偏移也是 64B,硬件预取器可以高效工作。

2) A 的多列复用——加载一次用 NR 次

a0 = svld1_f32(pan_A[(k+0)*vl])   ← 加载一次
c0_col0 = svmla(c0_col0, a0, B[k+0,j+0])  ← 用于第 j+0 列
c0_col1 = svmla(c0_col1, a0, B[k+0,j+1])  ← 用于第 j+1 列
c0_col2 = svmla(c0_col2, a0, B[k+0,j+2])  ← 用于第 j+2 列
c0_col3 = svmla(c0_col3, a0, B[k+0,j+3])  ← 用于第 j+3 列

A 的数据加载 1 次,被 4 列 FMA 复用。算术强度从 0.25 提升到约 1.0 FLOP/Byte

3) 16 个独立累加器——充分利用 SVE 寄存器文件

SVE 有 32 个 Z 寄存器。微内核使用 16 个作为累加器(4×4),4 个暂存 A 向量,剩余 12 个给 B 广播和中间结果。16 条独立的 FMA 依赖链让 CPU 的乱序执行引擎充分忙碌。

4) 使用 svptrue_b32 替代谓词——减少谓词开销

数据重排后 panel 内数据总是 vl 对齐的,内层循环使用 svptrue_b32()(全真谓词),避免了 svwhilelt_b32 的计算开销和谓词对 FMA 吞吐的潜在影响。

5) 寄存器分配分析

Z0~Z3:   a0, a1, a2, a3        (A 的 4 路 K 加载)
Z4~Z7:   c0_col0~c3_col0       (第 0 列的 4 路 K 累加器)
Z8~Z11:  c0_col1~c3_col1       (第 1 列)
Z12~Z15: c0_col2~c3_col2       (第 2 列)
Z16~Z19: c0_col3~c3_col3       (第 3 列)
Z20~Z31: 临时寄存器             (B 广播、归约中间值)

32 个 Z 寄存器恰好用满,没有溢出到栈。

1.5. 仍然存在的瓶颈

全量数据重排的缓存问题:当前优化方法对整个矩阵做全量数据重排。512×512 矩阵的 packed_A 大小为 512×512×4 = 1MB,远超 L2 缓存容量。当计算后半部分时,前半部分已被驱逐,需要重新从内存加载原始数据再数据重排。

更严重的问题:当矩阵规模增大到 4096×4096 时,packed_A 达到 64MB,packed_B 达到 64MB,完全无法放入任何缓存层次。这就是接下来要解决的问题。

2. sgemm实现之缓存分块+数据重排+SVE向量化

2.1. 优化思路

数据重排解决了单次访存的连续性问题,但没有解决工作集过大导致的缓存驱逐问题。核心问题是不能一次性数据重排整个矩阵,而是按缓存大小分块,只数据重排当前块需要的数据。

这就是经典的缓存分块(Cache Tiling / Blocking)技术,参考OpenBLAS 的算法框架。

2.2. 分块参数选择

分块参数的选择需要精确匹配 CPU 的缓存层次:

参数计算过程适配缓存
MC128packed_A_tile = MC × KC × 4B = 128×256×4 = 128KBL2 (通常 256KB~1MB)
KC256packed_B_tile = KC × NR × 4B = 256×4×4 = 4KBL1D (通常 64KB)
NR4C_tile = MC × NR × 4B = 128×4×4 = 2KB寄存器 + L1D

为什么是这些值?

  • KC=256:packed_B_tile 只有 4KB,确保驻留 L1。K 维度的分块使 B_tile 足够小,在 jb 循环中被反复命中。
  • MC=128:packed_A_tile 为 128KB,适配 L2。A_tile 在 jb 循环中被 NR 列复用,必须驻留在 L2 以避免重复从内存加载。
  • MC 必须是 vl 的倍数:128/16=8,确保 panel 内无需谓词处理。

2.3. 三层循环结构

for ib = 0 to M step MC:              // M 分块 (L2 粒度)
    for kb = 0 to K step KC:          // K 分块 (L1 粒度)
        pack_A_tile(ib, kb)           // 只数据重排 MC×KC 的小块
        for jb = 0 to N step NR:      // N 微内核 (寄存器粒度)
            pack_B_tile(kb, jb)       // 只数据重排 KC×NR 的小块
            micro_kernel()            // SVE 计算

2.4. 循环顺序的抉择

循环顺序 (ib, kb, jb) 是经过深思熟虑的。让我们对比两种方案:

方案 A:(jb, kb, ib) —— 非最优选择

for jb:  for kb:  pack_B  for ib:  pack_A  compute
         ↑ 每个 (jb, kb) 对,所有 ib 都要重新 pack_A
         ↑ N/NR = 128 次 ib 循环,A 被数据重排 128 次!

方案 B:(ib, kb, jb) —— 最优选择

for ib:  for kb:  pack_A  for jb:  pack_B  compute
         ↑ 每个 (ib, kb) 对,A 只数据重排一次
         ↑ 然后在所有 jb 上复用(N/NR = 128 次复用)

量化对比(512×512 矩阵):

方案A 数据重排次数B 数据重排次数
(jb, kb, ib)(N/NR) × (K/KC) × (M/MC) = 128×2×4 = 1024同左
(ib, kb, jb)(M/MC) × (K/KC) = 4×2 = 8(M/MC) × (K/KC) × (N/NR) = 1024

A 的数据重排从 1024 次降到 8 次——128 倍减少!因为 A_tile (128KB) 远大于 B_tile (4KB),减少 A 的数据重排开销收益巨大。

2.5. 首块优化

K 被分块后,C 矩阵的每个元素需要跨多个 kb 块累加:

C[i,j] = Σ_k A[i,k]×B[k,j]
       = Σ_{kb} Σ_{k in kb} A[i,k]×B[k,j]
       = C_tile_{kb=0} + C_tile_{kb=1} + ...

第一个 kb 块时,C 还未被写入,可以直接 store(省去一次 load):

int first_k = (kb == 0);
if (first_k) {
    // C 还没写入,直接存储,省去 load
    svst1_f32(pg, &C[...], sum_col0);
} else {
    // C 已有部分累加值,需要 load-add-store
    svfloat32_t old = svld1_f32(pg, &C[...]);
    svst1_f32(pg, &C[...], svadd_f32_m(pg, old, sum_col0));
}

这减少了 K/KC - 1 = 1 次 C 的 load(对于 K=512, KC=256),效果不大,但对大 K 值更显著。

完整代码结构

#define TILED_MC 128
#define TILED_KC 256
#define TILED_NR 4

void sgemm_tiled_sve(int M, int N, int K, const float *A, const float *B, float *C) {
    int64_t vl = svcntw();
    int vl_int = (int)vl;
    int MC = (TILED_MC / vl_int) * vl_int;
    if (MC == 0) MC = vl_int;
    int KC = TILED_KC;
    int n_panels_mc = MC / vl_int;

    float *packed_A = (float *)aligned_alloc(64, (size_t)n_panels_mc * vl_int * KC * sizeof(float));
    float *packed_B = (float *)aligned_alloc(64, (size_t)KC * TILED_NR * sizeof(float));

    memset(C, 0, (size_t)M * N * sizeof(float));
    for (int ib = 0; ib < M; ib += MC) {
        int ib_len = (ib + MC <= M) ? MC : (M - ib);

        for (int kb = 0; kb < K; kb += KC) {
            int kb_len = (kb + KC <= K) ? KC : (K - kb);
            int first_k = (kb == 0);

            for (int i = 0; i < ib_len; i += vl_int) {
                int rows = (i + vl_int <= ib_len) ? vl_int : (ib_len - i);
                svbool_t pg = svwhilelt_b32((int64_t)0, (int64_t)rows);
                float *panel = packed_A + (i / vl_int) * vl_int * kb_len;
                for (int k = 0; k < kb_len; k++) {
                    svfloat32_t a_vec = svld1_f32(pg, &A[(ib + i) + (kb + k) * M]);
                    svst1_f32(svptrue_b32(), &panel[k * vl_int], a_vec);
                }
            }

            for (int jb = 0; jb < N; jb += TILED_NR) {
                int cols = (jb + TILED_NR <= N) ? TILED_NR : (N - jb);

                for (int k = 0; k < kb_len; k++) {
                    for (int c = 0; c < cols; c++) {
                        packed_B[k * TILED_NR + c] = B[(kb + k) + (jb + c) * K];
                    }
                    for (int c = cols; c < TILED_NR; c++) {
                        packed_B[k * TILED_NR + c] = 0.0f;
                    }
                }

                for (int i = 0; i < ib_len; i += vl_int) {
                    int rows = (i + vl_int <= ib_len) ? vl_int : (ib_len - i);
                    svbool_t pg = svwhilelt_b32((int64_t)0, (int64_t)rows);
                    float *pan_A = packed_A + (i / vl_int) * vl_int * kb_len;

                    svfloat32_t c0_col0 = svdup_n_f32(0.0f);
                    svfloat32_t c1_col0 = svdup_n_f32(0.0f);
                    svfloat32_t c2_col0 = svdup_n_f32(0.0f);
                    svfloat32_t c3_col0 = svdup_n_f32(0.0f);

                    svfloat32_t c0_col1 = svdup_n_f32(0.0f);
                    svfloat32_t c1_col1 = svdup_n_f32(0.0f);
                    svfloat32_t c2_col1 = svdup_n_f32(0.0f);
                    svfloat32_t c3_col1 = svdup_n_f32(0.0f);

                    svfloat32_t c0_col2 = svdup_n_f32(0.0f);
                    svfloat32_t c1_col2 = svdup_n_f32(0.0f);
                    svfloat32_t c2_col2 = svdup_n_f32(0.0f);
                    svfloat32_t c3_col2 = svdup_n_f32(0.0f);

                    svfloat32_t c0_col3 = svdup_n_f32(0.0f);
                    svfloat32_t c1_col3 = svdup_n_f32(0.0f);
                    svfloat32_t c2_col3 = svdup_n_f32(0.0f);
                    svfloat32_t c3_col3 = svdup_n_f32(0.0f);

                    int k = 0;
                    for (; k + 3 < kb_len; k += 4) {
                        svfloat32_t a0 = svld1_f32(svptrue_b32(), &pan_A[(k+0)*vl_int]);
                        svfloat32_t a1 = svld1_f32(svptrue_b32(), &pan_A[(k+1)*vl_int]);
                        svfloat32_t a2 = svld1_f32(svptrue_b32(), &pan_A[(k+2)*vl_int]);
                        svfloat32_t a3 = svld1_f32(svptrue_b32(), &pan_A[(k+3)*vl_int]);

                        c0_col0 = svmla_f32_m(pg, c0_col0, a0, svdup_n_f32(packed_B[(k+0)*TILED_NR+0]));
                        c1_col0 = svmla_f32_m(pg, c1_col0, a1, svdup_n_f32(packed_B[(k+1)*TILED_NR+0]));
                        c2_col0 = svmla_f32_m(pg, c2_col0, a2, svdup_n_f32(packed_B[(k+2)*TILED_NR+0]));
                        c3_col0 = svmla_f32_m(pg, c3_col0, a3, svdup_n_f32(packed_B[(k+3)*TILED_NR+0]));

                        c0_col1 = svmla_f32_m(pg, c0_col1, a0, svdup_n_f32(packed_B[(k+0)*TILED_NR+1]));
                        c1_col1 = svmla_f32_m(pg, c1_col1, a1, svdup_n_f32(packed_B[(k+1)*TILED_NR+1]));
                        c2_col1 = svmla_f32_m(pg, c2_col1, a2, svdup_n_f32(packed_B[(k+2)*TILED_NR+1]));
                        c3_col1 = svmla_f32_m(pg, c3_col1, a3, svdup_n_f32(packed_B[(k+3)*TILED_NR+1]));

                        c0_col2 = svmla_f32_m(pg, c0_col2, a0, svdup_n_f32(packed_B[(k+0)*TILED_NR+2]));
                        c1_col2 = svmla_f32_m(pg, c1_col2, a1, svdup_n_f32(packed_B[(k+1)*TILED_NR+2]));
                        c2_col2 = svmla_f32_m(pg, c2_col2, a2, svdup_n_f32(packed_B[(k+2)*TILED_NR+2]));
                        c3_col2 = svmla_f32_m(pg, c3_col2, a3, svdup_n_f32(packed_B[(k+3)*TILED_NR+2]));

                        c0_col3 = svmla_f32_m(pg, c0_col3, a0, svdup_n_f32(packed_B[(k+0)*TILED_NR+3]));
                        c1_col3 = svmla_f32_m(pg, c1_col3, a1, svdup_n_f32(packed_B[(k+1)*TILED_NR+3]));
                        c2_col3 = svmla_f32_m(pg, c2_col3, a2, svdup_n_f32(packed_B[(k+2)*TILED_NR+3]));
                        c3_col3 = svmla_f32_m(pg, c3_col3, a3, svdup_n_f32(packed_B[(k+3)*TILED_NR+3]));
                    }

                    svfloat32_t sum_col0 = svadd_f32_m(pg, svadd_f32_m(pg, c0_col0, c1_col0),
                                                         svadd_f32_m(pg, c2_col0, c3_col0));
                    svfloat32_t sum_col1 = svadd_f32_m(pg, svadd_f32_m(pg, c0_col1, c1_col1),
                                                         svadd_f32_m(pg, c2_col1, c3_col1));
                    svfloat32_t sum_col2 = svadd_f32_m(pg, svadd_f32_m(pg, c0_col2, c1_col2),
                                                         svadd_f32_m(pg, c2_col2, c3_col2));
                    svfloat32_t sum_col3 = svadd_f32_m(pg, svadd_f32_m(pg, c0_col3, c1_col3),
                                                         svadd_f32_m(pg, c2_col3, c3_col3));

                    for (; k < kb_len; k++) {
                        svfloat32_t a_vec = svld1_f32(svptrue_b32(), &pan_A[k*vl_int]);
                        sum_col0 = svmla_f32_m(pg, sum_col0, a_vec, svdup_n_f32(packed_B[k*TILED_NR+0]));
                        sum_col1 = svmla_f32_m(pg, sum_col1, a_vec, svdup_n_f32(packed_B[k*TILED_NR+1]));
                        sum_col2 = svmla_f32_m(pg, sum_col2, a_vec, svdup_n_f32(packed_B[k*TILED_NR+2]));
                        sum_col3 = svmla_f32_m(pg, sum_col3, a_vec, svdup_n_f32(packed_B[k*TILED_NR+3]));
                    }

                    if (first_k) {
                        svst1_f32(pg, &C[(ib + i) + (jb + 0) * M], sum_col0);
                        if (cols >= 2) svst1_f32(pg, &C[(ib + i) + (jb + 1) * M], sum_col1);
                        if (cols >= 3) svst1_f32(pg, &C[(ib + i) + (jb + 2) * M], sum_col2);
                        if (cols >= 4) svst1_f32(pg, &C[(ib + i) + (jb + 3) * M], sum_col3);
                    } else {
                        svfloat32_t old0 = svld1_f32(pg, &C[(ib + i) + (jb + 0) * M]);
                        svst1_f32(pg, &C[(ib + i) + (jb + 0) * M], svadd_f32_m(pg, old0, sum_col0));
                        if (cols >= 2) {
                            svfloat32_t old1 = svld1_f32(pg, &C[(ib + i) + (jb + 1) * M]);
                            svst1_f32(pg, &C[(ib + i) + (jb + 1) * M], svadd_f32_m(pg, old1, sum_col1));
                        }
                        if (cols >= 3) {
                            svfloat32_t old2 = svld1_f32(pg, &C[(ib + i) + (jb + 2) * M]);
                            svst1_f32(pg, &C[(ib + i) + (jb + 2) * M], svadd_f32_m(pg, old2, sum_col2));
                        }
                        if (cols >= 4) {
                            svfloat32_t old3 = svld1_f32(pg, &C[(ib + i) + (jb + 3) * M]);
                            svst1_f32(pg, &C[(ib + i) + (jb + 3) * M], svadd_f32_m(pg, old3, sum_col3));
                        }
                    }
                }
            }
        }
    }
    free(packed_A);
    free(packed_B);
}

实测性能提升35.4×倍

2.6. 性能提升原因分析

1) packed_A_tile 驻留 L2

无分块: packed_A 全量 1MB → 超出 L2,计算后半部分时前半部分已被驱逐
有分块: packed_A_tile 128KB → 完全放入 L2,在 jb 循环中 100% 命中

jb 循环迭代 N/NR = 128 次,A_tile 在这 128 次迭代中始终驻留 L2,每次访问都是 cache hit。

2) packed_B_tile 驻留 L1

4KB 的 B_tile 完全放入 L1D,微内核内对 B 的访问接近 100% L1 命中率。

3) 缓冲区大小大幅缩小

Stage 5: packed_A = M×K×4 = 1MB,  packed_B = K×N×4 = 1MB  → 总计 2MB
Stage 6: packed_A = MC×KC×4 = 128KB, packed_B = KC×NR×4 = 4KB → 总计 132KB

内存占用从 2MB 降到 132KB,减少 15 倍,也减少了 TLB 压力。

2.7. 缓存命中率对比

内存重排+SVE向量化 (无分块, 512×512):
  packed_A 1MB >> L2 → A 的复用需要从内存重新加载
  估计 L2 命中率: ~30%

缓存分块+内存重排+SVE向量化 (分块 MC=128, KC=256):
  packed_A_tile 128KB < L2 → A 在 jb 循环中完全命中
  packed_B_tile 4KB << L1D → B 在微内核中完全命中
  估计 L2 命中率: ~95%+

3. sgemm性能优化总结

鲲鹏计算平台矩阵乘算子SGEMM优化实践上下两篇介绍了sgemm从朴素实现->循环展开->Neon向量化->SVE向量化->数据重排+SVE向量化->缓存分块 + 数据重排 + SVE向量化共5种优化手段,不同阶段性能表现情况如下(以下数据为笔者在实验环境下测得,仅供参考,读者可自行验证不同阶段的性能):

优化方法相对加速核心突破
Naive 三重循环1.0×基线
循环展开2.3×指令级并行 + 数据复用
NEON 向量化2.5×128-bit SIMD
SVE 向量化5.5×512-bit SIMD + 谓词
数据重排 + SVE25.0×连续访存 + 寄存器复用
缓存分块 + 数据重排 + SVE35.4×工作集适配缓存层次

测试程序中的main方法参考:

int main(int argc, char *argv[]) {
    int M = 512, N = 512, K = 512;
    int stage = 0;
    if (argc >= 4) {
        M = atoi(argv[1]);
        N = atoi(argv[2]);
        K = atoi(argv[3]);
    }
    if (argc >= 5) {
        stage = atoi(argv[4]);
    }

    printf("SGEMM: C(%dx%d) = A(%dx%d) x B(%dx%d)\n", M, N, M, K, K, N);
    printf("SVE vector length: %d bits (%d floats)\n\n", (int)(svcntw() * 32), (int)svcntw());

    size_t size_A = (size_t)M * K;
    size_t size_B = (size_t)K * N;
    size_t size_C = (size_t)M * N;

    float *A = (float *)aligned_alloc(64, size_A * sizeof(float));
    float *B = (float *)aligned_alloc(64, size_B * sizeof(float));
    float *C = (float *)aligned_alloc(64, size_C * sizeof(float));
    float *C_ref = (float *)aligned_alloc(64, size_C * sizeof(float));

    srand(42);
    init_matrix(A, size_A);
    init_matrix(B, size_B);

    double gflops = 2.0 * M * N * K / 1e9;

    if (stage == 0 || stage == 1) {
        printf("--- Stage 1: Naive ---\n");
        memset(C_ref, 0, size_C * sizeof(float));
        double t0 = get_time();
        sgemm_naive(M, N, K, A, B, C_ref);
        double t1 = get_time();
        printf("  Time: %.3f s, GFLOPS: %.2f\n\n", t1 - t0, gflops / (t1 - t0));
    }

    if (stage == 1) {
        free(A); free(B); free(C); free(C_ref);
        return 0;
    }

    if (stage == 0) {
    } else if (stage != 1) {
        printf("--- Stage 1: Naive (computing reference) ---\n");
        memset(C_ref, 0, size_C * sizeof(float));
        sgemm_naive(M, N, K, A, B, C_ref);
        printf("  Reference computed.\n\n");
    }

    if (stage == 0 || stage == 2) {
        printf("--- Stage 2: Loop Unroll ---\n");
        memset(C, 0, size_C * sizeof(float));
        double t0 = get_time();
        sgemm_unroll(M, N, K, A, B, C);
        double t1 = get_time();
        verify(C_ref, C, size_C, 1e-3f);
        printf("  Time: %.3f s, GFLOPS: %.2f\n\n", t1 - t0, gflops / (t1 - t0));
    }

    if (stage == 0 || stage == 3) {
        printf("--- Stage 3: NEON ---\n");
        memset(C, 0, size_C * sizeof(float));
        double t0 = get_time();
        sgemm_neon(M, N, K, A, B, C);
        double t1 = get_time();
        verify(C_ref, C, size_C, 1e-3f);
        printf("  Time: %.3f s, GFLOPS: %.2f\n\n", t1 - t0, gflops / (t1 - t0));
    }

    if (stage == 0 || stage == 4) {
        printf("--- Stage 4: SVE ---\n");
        memset(C, 0, size_C * sizeof(float));
        double t0 = get_time();
        sgemm_sve(M, N, K, A, B, C);
        double t1 = get_time();
        verify(C_ref, C, size_C, 1e-3f);
        printf("  Time: %.3f s, GFLOPS: %.2f\n\n", t1 - t0, gflops / (t1 - t0));
    }

    if (stage == 0 || stage == 5) {
        printf("--- Stage 5: Pack + SVE ---\n");
        memset(C, 0, size_C * sizeof(float));
        double t0 = get_time();
        sgemm_pack_sve(M, N, K, A, B, C);
        double t1 = get_time();
        verify(C_ref, C, size_C, 1e-3f);
        printf("  Time: %.3f s, GFLOPS: %.2f\n\n", t1 - t0, gflops / (t1 - t0));
    }

    if (stage == 0 || stage == 6) {
        printf("--- Stage 6: Cache Tiling + Pack + SVE ---\n");
        memset(C, 0, size_C * sizeof(float));
        double t0 = get_time();
        sgemm_tiled_sve(M, N, K, A, B, C);
        double t1 = get_time();
        verify(C_ref, C, size_C, 1e-3f);
        printf("  Time: %.3f s, GFLOPS: %.2f\n\n", t1 - t0, gflops / (t1 - t0));
    }

    free(A);
    free(B);
    free(C);
    free(C_ref);

    return 0;
}

编译运行方法:

# 加载毕昇编译器环境(HPCkit安装可参见:https://www.hikunpeng.com/document/detail/zh/kunpenghpcs/hpckit/instg/KunpengHPCKit_install_024.html)
source /opt/HPCKit/latest/setvars.sh

# 编译
clang -O3 -march=armv8-a+sve -o sgemm sgemm.c -lm

# 运行全部阶段(矩阵 512×512)
./sgemm 512 512 512

# 只运行指定阶段
./sgemm 512 512 512 6


本页内容