鲲鹏社区首页
中文
注册
开发者
我要评分
获取效率
正确性
完整性
易理解
在线提单
论坛求助

?gemm_pack

对矩阵进行pack操作并将其存储到已分配的缓冲区中。

即:

op(X)可取值:,alpha,beta为乘法系数,op(A)为m*k矩阵,op(B)为k*n矩阵,C为m*n矩阵。

接口定义

C interface:

void cblas_sgemm_pack(const enum CBLAS_ORDER order, const enum CBLAS_IDENTIFIER identifier, const enum CBLAS_TRANSPOSE trans, const BLASINT m, const BLASINT n, const BLASINT k, const float *src, const BLASINT ld, float *dst);

void cblas_dgemm_pack(const enum CBLAS_ORDER order, const enum CBLAS_IDENTIFIER identifier, const enum CBLAS_TRANSPOSE trans, const BLASINT m, const BLASINT n, const BLASINT k, const double *src, const BLASINT ld, double *dst);

void cblas_cgemm_pack(const enum CBLAS_ORDER order, const enum CBLAS_IDENTIFIER identifier, const enum CBLAS_TRANSPOSE trans, const BLASINT m, const BLASINT n, const BLASINT k, const void *src, const BLASINT ld, void *dst);

void cblas_zgemm_pack(const enum CBLAS_ORDER order, const enum CBLAS_IDENTIFIER identifier, const enum CBLAS_TRANSPOSE trans, const BLASINT m, const BLASINT n, const BLASINT k, const void *src, const BLASINT ld, void *dst);

void cblas_bgemm_pack(const enum CBLAS_ORDER order, const enum CBLAS_IDENTIFIER identifier, const enum CBLAS_TRANSPOSE trans, const BLASINT m, const BLASINT n, const BLASINT k, const __bf16 *src, const BLASINT ld, __bf16 *dst);

参数

参数名

类型

描述

输入/输出

order

枚举类型CBLAS_ORDER

表示矩阵是行主序或列主序。

输入

identifier

枚举类型CBLAS_IDENTIFIER

指定要pack的矩阵。

  • 如果identifier= CblasA,pack A矩阵。
  • 如果identifier= CblasB,pack B矩阵。

输入

trans

枚举类型CBLAS_TRANSPOSE

矩阵A为常规矩阵,转置矩阵。

  • 如果trans= CblasNoTrans,
  • 如果trans= CblasTrans,

矩阵B为常规矩阵,转置矩阵。

  • 如果trans= CblasNoTrans,
  • 如果trans= CblasTrans,

输入

m

整型数

矩阵op(A)和矩阵C的行。

输入

n

整型数

矩阵op(B)和矩阵C的列。

输入

k

整型数

矩阵op(A)的列和矩阵op(B)的行。

输入

src

  • 在sgemm中是单精度浮点类型。
  • 在dgemm中是双精度浮点类型。
  • 在cgemm中是单精度复数类型。
  • 在zgemm中是双精度复数类型。
  • 在bgemm中是半精度浮点类型。

矩阵A\B。

输入

ld

整型数

  • 矩阵为列存,identifier = CblasA, trans = CblasNoTrans,ld至少max(1, m),否则max(1, k)。
  • 矩阵为行存,identifier = CblasA, trans = CblasNoTrans,ld至少max(1, k),否则max(1, m)。
  • 矩阵为列存,identifier = CblasB, trans = CblasNoTrans,ld至少max(1, k),否则max(1, n)。
  • 矩阵为行存,identifier = CblasB, trans = CblasNoTrans,ld至少max(1, n),否则max(1, k)。

输入

dst

  • 在sgemm中是单精度浮点类型。
  • 在dgemm中是双精度浮点类型。
  • 在cgemm中是单精度复数类型。
  • 在zgemm中是双精度复数类型。
  • 在bgemm中是半精度浮点类型。

pack后的矩阵A\B。

输入/输出

依赖

#include "kblas.h"

示例

C interface:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
    int m = 4, k = 3, n = 4, ld = 4; 
     /* 
     * src: 
     *     0.340188,       0.411647,       -0.222225, 
     *     -0.105617,      -0.302449,      0.053970, 
     *     0.283099,       -0.164777,      -0.022603, 
     *     0.298440,       0.268230,       0.128871, 
     */ 
    float src[12] = {0.340188, -0.105617, 0.283099, 
                    0.298440, 0.411647, -0.302449, 
                    -0.164777, 0.268230, -0.222225, 
                    0.053970, -0.022603, 0.128871}; 
    size_t size = cblas_sgemm_pack_get_size(CblasA, m, n, k);
    float *dst = (float *)malloc(size);
    cblas_sgemm_pack(CblasColMajor, CblasA, CblasTrans, m, n, k, src, ld, dst); 
    free(dst);
    /* 
     * Output dst: 
     *     0.340188       -0.105617        0.283099        0.298440 
     *     0.411647       -0.302449       -0.164777        0.268230 
     *    -0.222225        0.053970       -0.022603        0.128871
     */