kutacc_core_bgemm_pack
对矩阵进行pack操作并将其存储到已分配的缓冲区中。
接口定义
void kutacc_core_bgemm_pack(char matrixIdentifier, char transA, char transB, const BLASINT m, const BLASINT n, const BLASINT k, const BLASINT lda, const BLASINT ldb, const __bf16 *src, __bf16 *dst);
参数
参数名 |
类型 |
描述 |
输入/输出 |
|---|---|---|---|
matrixIdentifier |
字符型 |
指定要pack的矩阵,默认为矩阵B。
|
输入 |
transA |
字符型 |
矩阵A的转置情况。
|
输入 |
transB |
字符型 |
矩阵B的转置情况。
|
输入 |
m |
整数型 |
矩阵A的行。 |
输入 |
n |
整数型 |
矩阵B的列。 |
输入 |
k |
整数型 |
矩阵A的列和B的行。 |
输入 |
lda |
整数型 |
矩阵A的主维度。 |
输入 |
ldb |
整数型 |
矩阵B的主维度。 |
输入 |
src |
半精度浮点类型指针 |
输入矩阵。 |
输入 |
dst |
半精度浮点类型指针 |
pack后的矩阵。 |
输入/输出 |
依赖
#include "kutacc_core.h"
示例
C interface:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 | char transA = 'N', transB = 'N'; BLASINT m = 4, k = 3, n = 4; BLASINT lda = m; BLASINT ldb = n; __bf16 a[12] = { 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2 }; size_t size_a = kutacc_core_bgemm_pack_get_size('A', m, n, k); __bf16 *sa = (__bf16 *)malloc(size_a * sizeof(__bf16)); kutacc_core_bgemm_pack('A', transA, transB, m, n, k, lda, ldb, a, sa); free(sa); /* * Output dst: * 0.100098 0.601562 0.400391 1.000000 * 0.500000 0.300781 0.800781 1.101562 * 0.200195 0.699219 0.898438 1.203125 */ |
父主题: GEMM算子