鲲鹏社区首页
中文
注册
开发者
我要评分
获取效率
正确性
完整性
易理解
在线提单
论坛求助

kutacc_core_bgemm_pack

对矩阵进行pack操作并将其存储到已分配的缓冲区中。

接口定义

void kutacc_core_bgemm_pack(char matrixIdentifier, char transA, char transB, const BLASINT m, const BLASINT n, const BLASINT k, const BLASINT lda, const BLASINT ldb, const __bf16 *src, __bf16 *dst);

参数

参数名

类型

描述

输入/输出

matrixIdentifier

字符型

指定要pack的矩阵,默认为矩阵B。

  • 如果identifier= 'A',pack 矩阵A。
  • 如果identifier= 'B',pack 矩阵B。

输入

transA

字符型

矩阵A的转置情况。

  • 如果transA= 'N',矩阵A不转置。
  • 如果transA= 'T',矩阵A转置。

输入

transB

字符型

矩阵B的转置情况。

  • 如果transB= 'N',矩阵B不转置。
  • 如果transB= 'T',矩阵B转置。

输入

m

整数型

矩阵A的行。

输入

n

整数型

矩阵B的列。

输入

k

整数型

矩阵A的列和B的行。

输入

lda

整数型

矩阵A的主维度。

输入

ldb

整数型

矩阵B的主维度。

输入

src

半精度浮点类型指针

输入矩阵。

输入

dst

半精度浮点类型指针

pack后的矩阵。

输入/输出

依赖

#include "kutacc_core.h"

示例

C interface:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
    char transA = 'N', transB = 'N';
    BLASINT m = 4, k = 3, n = 4;
    BLASINT lda = m;
    BLASINT ldb = n;

    __bf16 a[12] = {
        0.1, 0.2, 0.3, 0.4,
        0.5, 0.6, 0.7, 0.8,
        0.9, 1.0, 1.1, 1.2
    };
    size_t size_a = kutacc_core_bgemm_pack_get_size('A', m, n, k);
    __bf16 *sa = (__bf16 *)malloc(size_a * sizeof(__bf16));
    kutacc_core_bgemm_pack('A', transA, transB, m, n, k, lda, ldb, a, sa);
    free(sa);
    /* 
     * Output dst: 
     *     0.100098       0.601562        0.400391        1.000000 
     *     0.500000       0.300781        0.800781        1.101562
     *     0.200195       0.699219        0.898438        1.203125
     */