鲲鹏社区首页
中文
注册
开发者
我要评分
获取效率
正确性
完整性
易理解
在线提单
论坛求助

?gemm_compute

对pack后的A\B矩阵进行一般矩阵乘。

即:

op(X)可取值:,alpha,beta为乘法系数,op(A)为m*k矩阵,op(B)为k*n矩阵,C为m*n矩阵。

接口定义

C interface:

void cblas_sgemm_compute(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE transA, const enum CBLAS_TRANSPOSE transB, const BLASINT m, const BLASINT n, const BLASINT k, const float alpha, const float *a, const BLASINT lda, const float *b, const BLASINT ldb, const float beta, float *c, const BLASINT ldc);

void cblas_dgemm_compute(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE transA, const enum CBLAS_TRANSPOSE transB, const BLASINT m, const BLASINT n, const BLASINT k, const double alpha, const double *a, const BLASINT lda, const double *b, const BLASINT ldb, const double beta, double *c, const BLASINT ldc);

void cblas_cgemm_compute(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_TRANSPOSE TransB, const BLASINT m, const BLASINT n, const BLASINT k, const void *alpha, const void *a, const BLASINT lda, const void *b, const BLASINT ldb, const void *beta, void *c, const BLASINT ldc);

void cblas_zgemm_compute(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_TRANSPOSE TransB, const BLASINT m, const BLASINT n, const BLASINT k, const void *alpha, const void *a, const BLASINT lda, const void *b, const BLASINT ldb, const void *beta, void *c, const BLASINT ldc);

void cblas_bgemm_compute(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE transA, const enum CBLAS_TRANSPOSE transB, const BLASINT m, const BLASINT n, const BLASINT k, const __bf16 alpha, const __bf16 *a, const BLASINT lda, const __bf16 *b, const BLASINT ldb, const __bf16 beta, __bf16 *c, const BLASINT ldc);

参数

参数名

类型

描述

输入/输出

order

枚举类型CBLAS_ORDER

表示矩阵是行主序或列主序。

输入

TransA

枚举类型CBLAS_TRANSPOSE

矩阵A为常规矩阵,转置矩阵或共轭矩阵。

  • 如果trans= CblasNoTrans,
  • 如果trans= CblasTrans,

输入

TransB

枚举类型CBLAS_TRANSPOSE

矩阵B为常规矩阵,转置矩阵或共轭矩阵。

  • 如果trans= CblasNoTrans,
  • 如果trans= CblasTrans,

输入

M

整型数

矩阵op(A)和矩阵C的行。

输入

N

整型数

矩阵op(B)和矩阵C的列。

输入

K

整型数

矩阵op(A)的列和矩阵op(B)的行。

输入

alpha

  • 在sgemm中是单精度浮点类型。
  • 在dgemm中是双精度浮点类型。
  • 在cgemm中是单精度复数类型。
  • 在zgemm中是双精度复数类型。
  • 在bgemm中是半精度浮点类型。

乘法系数。

输入

A

  • 在sgemm中是单精度浮点类型。
  • 在dgemm中是双精度浮点类型。
  • 在cgemm中是单精度复数类型。
  • 在zgemm中是双精度复数类型。
  • 在bgemm中是半精度浮点类型。

矩阵A。

输入

lda

整型数

  • 矩阵为列存,TransA = CblasNoTrans,lda至少max(1, m),否则max(1, k)。
  • 矩阵为行存,TransA = CblasNoTrans,lda至少max(1, k),否则max(1, m)。

输入

B

  • 在sgemm中是单精度浮点类型。
  • 在dgemm中是双精度浮点类型。
  • 在cgemm中是单精度复数类型。
  • 在zgemm中是双精度复数类型。
  • 在bgemm中是半精度浮点类型。

矩阵B。

输入

ldb

整型数

  • 矩阵为列存,TransB = CblasNoTrans,ldb至少max(1, k),否则max(1, n)。
  • 矩阵为行存,TransB = CblasNoTrans,ldb至少max(1, n),否则max(1, k)。

输入

beta

  • 在sgemm中是单精度浮点类型。
  • 在dgemm中是双精度浮点类型。
  • 在cgemm中是单精度复数类型。
  • 在zgemm中是双精度复数类型。
  • 在bgemm中是半精度浮点类型。

乘法系数。

输入

C

  • 在sgemm中是单精度浮点类型。
  • 在dgemm中是双精度浮点类型。
  • 在cgemm中是单精度复数类型。
  • 在zgemm中是双精度复数类型。
  • 在bgemm中是半精度浮点类型。

矩阵C。

输入/输出

ldc

整型数

矩阵为列存,ldc至少max(1, m),否则max(1, n)。

输入

依赖

#include "kblas.h"

示例

C interface:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
    int m = 4, k = 3, n = 4, lda = 4, ldb = 3, ldc = 4; 
    float alpha = 1.0, beta = 2.0; 
     /* 
     * A: 
     *     0.340188,       0.411647,       -0.222225, 
     *     -0.105617,      -0.302449,      0.053970, 
     *     0.283099,       -0.164777,      -0.022603, 
     *     0.298440,       0.268230,       0.128871, 
     * B: 
     *     -0.135216,      0.416195,       -0.358397,      -0.257113, 
     *     0.013401,       0.135712,       0.106969,       -0.362768, 
     *     0.452230,       0.217297,       -0.483699,      0.304177, 
     * C: 
     *     -0.343321,      0.498924,       0.112640,       -0.006417, 
     *     -0.099056,      -0.281743,      -0.203968,      0.472775, 
     *     -0.370210,      0.012932,       0.137552,       -0.207483, 
     *     -0.391191,      0.339112,       0.024287,       0.271358, 
     */ 
    float a[12] = {0.340188, -0.105617, 0.283099, 
                    0.298440, 0.411647, -0.302449, 
                    -0.164777, 0.268230, -0.222225, 
                    0.053970, -0.022603, 0.128871}; 
    float b[12] = {-0.135216, 0.013401, 0.452230, 0.416195, 
                    0.135712, 0.217297, -0.358397, 0.106969, 
                    -0.483699, -0.257113, -0.362768, 0.304177}; 
    float c[16] = {-0.343321, -0.099056, -0.370210, -0.391191, 
                    0.498924, -0.281743, 0.012932, 0.339112, 
                    0.112640, -0.203968, 0.137552, 0.024287, 
                    -0.006417, 0.472775, -0.207483, 0.271358}; 
    size_t size_a = cblas_sgemm_pack_get_size(CblasA, m, n, k);
    size_t size_b = cblas_sgemm_pack_get_size(CblasB, m, n, k);
    float *sa = (float *)malloc(size_a);
    float *sb = (float *)malloc(size_b);
    cblas_sgemm_pack(CblasColMajor, CblasA, CblasNoTrans, m, n, k, a, ld, sa);
    cblas_sgemm_pack(CblasColMajor, CblasB, CblasNoTrans, m, n, k, b, ld, sb);
    cblas_sgemm_compute(CblasColMajor,CblasNoTrans,CblasNoTrans, m, n, k, alpha, sa, lda, sb, ldb, beta, c, ldc); 
    free(sa);
    free(sb);
    /* 
     * Output C: 
     *     -0.827621       1.147010        0.254881        -0.317229 
     *     -0.163476       -0.636762       -0.428542       1.098841 
     *     -0.791128       0.116416        0.166949        -0.434854 
     *     -0.760862       0.866839        -0.092028       0.407877 
     * 
     */