鲲鹏社区首页
中文
注册
开发者
我要评分
获取效率
正确性
完整性
易理解
在线提单
论坛求助

kutacc_core_bgemm_ex

一般矩阵乘矩阵。

即:

接口定义

void kutacc_core_bgemm_ex(char transA, char transB, const BLASINT m, const BLASINT n, const BLASINT k, const __bf16 alpha, const __bf16 *a, const BLASINT lda, const __bf16 *b, const BLASINT ldb, const __bf16 beta, __bf16 *c, const BLASINT ldc, const BlasExtendParam *blas_extend_param);

参数

参数名

类型

描述

输入/输出

transA

字符型

矩阵A的转置情况。

  • 如果transA= 'N',矩阵A不转置。
  • 如果transA= 'T',矩阵A转置。

输入

transB

字符型

矩阵B的转置情况。

  • 如果transB= 'N',矩阵B不转置。
  • 如果transB= 'T',矩阵B转置。

输入

m

整数型

矩阵A的行。

输入

n

整数型

矩阵B的列。

输入

k

整数型

矩阵A的列和B的行。

输入

alpha

半精度浮点类型

矩阵A、B乘法系数。

输入

a

半精度浮点类型指针

矩阵A。

输入

lda

整数型

矩阵A的主维度。

输入

b

半精度浮点类型指针

矩阵B。

输入

ldb

整数型

矩阵B的主维度。

输入

beta

半精度浮点类型

矩阵C乘法系数。

输入

c

半精度浮点类型指针

矩阵C。

输入

ldc

半精度浮点类型

矩阵C的主维度。

输入/输出

blas_extend_param

结构体指针类型

设置拓展操作,结构体中type为操作类型,extra为操作数据,next为指向下一个拓展操作结构体的指针。

  • 如果type= BLAS_EXTEND_TYPE_NUM_THREADS,操作为设置线程数。
  • 如果transB= BLAS_EXTEND_TYPE_PREPACK,操作为对prepack后的矩阵直接计算。
  • 如果transB= BLAS_EXTEND_TYPE_BIAS,操作为对结果矩阵添加bias。
  • 如果transB= BLAS_EXTEND_TYPE_ACTIVATION,操作为对结果函数进行激活函数运算。

输入

依赖

#include "kutacc_core.h"

示例

C interface:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
    char transA = 'N', transB = 'N';
    BLASINT m = 4, k = 3, n = 4;
    BLASINT lda = m;
    BLASINT ldb = k;
    BLASINT ldc = m;
    __bf16 alpha = vcvth_bf16_f32(1.0);
    __bf16 beta = vcvth_bf16_f32(3.0);

    __bf16 a[12] = {
        3.4019, -1.0562, 2.8310, 2.9844,
        4.1165, -3.0245, -1.6478, 2.6829,
        -2.2223, 0.5397, -0.2260, 1.2887
    };

    __bf16 b[12] = {
        -1.3522, 0.1340, 4.5223,
        4.1612, 1.3571, 2.1730,
        -3.5840, 1.0697, -4.8370,
        -2.5711, -3.6277, 3.0418
    };

    __bf16 c[12] = {0};

    BlasExtendParam *extend_param = (BlasExtendParam *)malloc(sizeof(BlasExtendParam));
    if (extend_param == NULL) {
        return;
    }

    extend_param->type = BLAS_EXTEND_TYPE_ACTIVATION;
    extend_param->extra = BLAS_EXTEND_ACTIVATION_RELU;
    extend_param->next = NULL;

    kutacc_core_bgemm_ex(transA, transB, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, extend_param);
    free(extend_param);

    /* 
     * Output c: 
     *     0.000000       43.250000       2.968750        0.000000 
     *     10.437500      0.000000        0.000000        15.750000
     *     0.000000       32.500000       0.000000        11.625000
     *     6.531250       43.750000       0.000000        0.000000
     */