kutacc_core_bgemm_ex
一般矩阵乘矩阵。
即:
。
接口定义
void kutacc_core_bgemm_ex(char transA, char transB, const BLASINT m, const BLASINT n, const BLASINT k, const __bf16 alpha, const __bf16 *a, const BLASINT lda, const __bf16 *b, const BLASINT ldb, const __bf16 beta, __bf16 *c, const BLASINT ldc, const BlasExtendParam *blas_extend_param);
参数
参数名 |
类型 |
描述 |
输入/输出 |
|---|---|---|---|
transA |
字符型 |
矩阵A的转置情况。
|
输入 |
transB |
字符型 |
矩阵B的转置情况。
|
输入 |
m |
整数型 |
矩阵A的行。 |
输入 |
n |
整数型 |
矩阵B的列。 |
输入 |
k |
整数型 |
矩阵A的列和B的行。 |
输入 |
alpha |
半精度浮点类型 |
矩阵A、B乘法系数。 |
输入 |
a |
半精度浮点类型指针 |
矩阵A。 |
输入 |
lda |
整数型 |
矩阵A的主维度。 |
输入 |
b |
半精度浮点类型指针 |
矩阵B。 |
输入 |
ldb |
整数型 |
矩阵B的主维度。 |
输入 |
beta |
半精度浮点类型 |
矩阵C乘法系数。 |
输入 |
c |
半精度浮点类型指针 |
矩阵C。 |
输入 |
ldc |
半精度浮点类型 |
矩阵C的主维度。 |
输入/输出 |
blas_extend_param |
结构体指针类型 |
设置拓展操作,结构体中type为操作类型,extra为操作数据,next为指向下一个拓展操作结构体的指针。
|
输入 |
依赖
#include "kutacc_core.h"
示例
C interface:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 | char transA = 'N', transB = 'N'; BLASINT m = 4, k = 3, n = 4; BLASINT lda = m; BLASINT ldb = k; BLASINT ldc = m; __bf16 alpha = vcvth_bf16_f32(1.0); __bf16 beta = vcvth_bf16_f32(3.0); __bf16 a[12] = { 3.4019, -1.0562, 2.8310, 2.9844, 4.1165, -3.0245, -1.6478, 2.6829, -2.2223, 0.5397, -0.2260, 1.2887 }; __bf16 b[12] = { -1.3522, 0.1340, 4.5223, 4.1612, 1.3571, 2.1730, -3.5840, 1.0697, -4.8370, -2.5711, -3.6277, 3.0418 }; __bf16 c[12] = {0}; BlasExtendParam *extend_param = (BlasExtendParam *)malloc(sizeof(BlasExtendParam)); if (extend_param == NULL) { return; } extend_param->type = BLAS_EXTEND_TYPE_ACTIVATION; extend_param->extra = BLAS_EXTEND_ACTIVATION_RELU; extend_param->next = NULL; kutacc_core_bgemm_ex(transA, transB, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, extend_param); free(extend_param); /* * Output c: * 0.000000 43.250000 2.968750 0.000000 * 10.437500 0.000000 0.000000 15.750000 * 0.000000 32.500000 0.000000 11.625000 * 6.531250 43.750000 0.000000 0.000000 */ |
父主题: GEMM算子