MatMul

场景说明

执行两个输入张量a,b以及参数结构体M,执行两个矩阵的乘法运算,将计算结果存储于输出张量output。例如:

1
2
3
4
5
6
a: [1,2,
    3,4]
b: [5,6,
    7,8]
output: [19,22,
         43,50]

代码示例

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
#include <cmath>
#include <random>
#include <cstdint>
#include <iostream>

#include "ktfop.h"
int main()
{
    using namespace ktfop;
    MatMulParams<float> params;
    params.order = CblasRowMajor;     //矩阵存储顺序(行优先or列优先)
    params.transA = CblasNoTrans;     //A矩阵是否转置
    params.transB = CblasNoTrans;     //B矩阵是否转置
    params.m = 2;                     //A矩阵的行数
    params.n = 2;                     //B矩阵的列数
    params.k = 2;                     //A矩阵的列数或B矩阵的行数    
    params.alpha = 1.0f;              //alpha参数
    params.lda = 2;                   //A矩阵存储步长
    params.ldb = 2;                   //B矩阵存储步长
    params.beta = 0.0f;               //beta参数
    params.ldc = 2;                   //C矩阵存储步长
    float a[4] = {1, 2, 3, 4};        //初始化参与计算的矩阵
    float b[4] = {5, 6, 7, 8};
    float c[4] = {0, 0, 0, 0};        //保存计算结果的矩阵
    int ret = -1;
    ret = Matmul(a, b, c, params);    //调用Matmul算子,结果存储于c数组
    std::cout << "c: [";
    for (int i = 0; i < 4; ++i) {
        std::cout << c[i];
        if (i < 3) {
            std::cout << ", ";
        }
    }
    std::cout << "]" << std::endl;
    ret = Matmul(static_cast<float *>(nullptr), b, c, params); //输入空指针,打印日志"ERROR Parameter verification failed for the MatMul Op."
    ret = Matmul(a, b, static_cast<float *>(nullptr), params); //输入空指针,打印日志"ERROR Parameter verification failed for the MatMul Op."
}