鲲鹏社区首页
中文
注册
开发者
我要评分
获取效率
正确性
完整性
易理解
在线提单
论坛求助

?gemm_batch

批量一般矩阵乘矩阵,此函数描述如下:

op(X)可取值:,alpha,beta为乘法系数,op(A)为m*k矩阵,op(B)为k*n矩阵,C为m*n矩阵。

:batchgemm适用于矩阵规格小且batch规模大的场景,不适用矩阵大且batch小的场景

接口定义

C interface:

void cblas_sgemm_batch(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE *transA_array,

const enum CBLAS_TRANSPOSE *transB_array, const BLASINT *m_array, const BLASINT *n_array, const BLASINT *k_array,

const float *alpha_array, const float **a_array, const BLASINT *lda_array, const float **b_array,

const BLASINT *ldb_array, const float *beta_array, float **c_array, const BLASINT *ldc_array,

const BLASINT group_count, const BLASINT *group_size);

void cblas_dgemm_batch(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE *transA_array,

const enum CBLAS_TRANSPOSE *transB_array, const BLASINT *m_array, const BLASINT *n_array, const BLASINT *k_array,

const double *alpha_array, const double **a_array, const BLASINT *lda_array, const double **b_array,

const BLASINT *ldb_array, const double *beta_array, double **c_array, const BLASINT *ldc_array,

const BLASINT group_count, const BLASINT *group_size);

void cblas_cgemm_batch(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE *transA_array,

const enum CBLAS_TRANSPOSE *transB_array, const BLASINT *m_array, const BLASINT *n_array, const BLASINT *k_array,

const float *alpha_array, const float **a_array, const BLASINT *lda_array, const float **b_array,

const BLASINT *ldb_array, const float *beta_array, float **c_array, const BLASINT *ldc_array,

const BLASINT group_count, const BLASINT *group_size);

void cblas_zgemm_batch(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE *transA_array,

const enum CBLAS_TRANSPOSE *transB_array, const BLASINT *m_array, const BLASINT *n_array, const BLASINT *k_array,

const double *alpha_array, const double **a_array, const BLASINT *lda_array, const double **b_array,

const BLASINT *ldb_array, const double *beta_array, double **c_array, const BLASINT *ldc_array,

const BLASINT group_count, const BLASINT *group_size);

void cblas_hgemm_batch(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE *transA_array,

const enum CBLAS_TRANSPOSE *transB_array, const BLASINT *m_array, const BLASINT *n_array, const BLASINT *k_array,

const __fp16 *alpha_array, const __fp16 **a_array, const BLASINT *lda_array, const __fp16 **b_array,

const BLASINT *ldb_array, const __fp16 *beta_array, __fp16 **c_array, const BLASINT *ldc_array,

const BLASINT group_count, const BLASINT *group_size);

void cblas_bgemm_batch(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE *transA_array,

const enum CBLAS_TRANSPOSE *transB_array, const BLASINT *m_array, const BLASINT *n_array, const BLASINT *k_array,

const __bf16 *alpha_array, const __bf16 **a_array, const BLASINT *lda_array, const __bf16 **b_array,

const BLASINT *ldb_array, const __bf16 *beta_array, __bf16 **c_array, const BLASINT *ldc_array,

const BLASINT group_count, const BLASINT *group_size);

参数

参数名

类型

描述

输入/输出

order

枚举类型CBLAS_ORDER

表示矩阵是行主序或列主序。

输入

transA_array

枚举类型CBLAS_TRANSPOSE

表示矩阵A转置情况的序列。

矩阵A为常规矩阵,转置矩阵或共轭矩阵。

  • 如果TransA = CblasNoTrans,
  • 如果TransA = CblasTrans,
  • 如果TransA = CblasConjTrans,
  • 如果TransA = CblasConjTrans,

输入

transB_array

枚举类型CBLAS_TRANSPOSE

表示矩阵B转置情况的序列。

矩阵B为常规矩阵,转置矩阵或共轭矩阵。

  • 如果TransB = CblasNoTrans,
  • 如果TransB = CblasTrans,
  • 如果TransB = CblasConjTrans,
  • 如果TransB = CblasConjTrans,

输入

m_array

整型数

矩阵op(A)和矩阵C的行数的序列。

输入

n_array

整型数

矩阵op(B)和矩阵C的列数的序列。

输入

k_array

整型数

矩阵op(A)的列和矩阵op(B)的行数的序列。

输入

alpha_array

  • 在sgemm中是单精度浮点类型。
  • 在dgemm中是双精度浮点类型。
  • 在hgemm中是fp16类型。
  • 在bgemm中是bf16类型。

乘法系数的序列。

输入

a_array

  • 在sgemm中是单精度浮点类型。
  • 在dgemm中是双精度浮点类型。
  • 在hgemm中是fp16类型。
  • 在bgemm中是bf16类型。

矩阵A的序列。

输入

lda_array

整型数

表示矩阵A的leading dimension的序列。

  • 矩阵为列存,TransA = CblasNoTrans,lda至少max(1, m),否则max(1, k)。
  • 矩阵为行存,TransA = CblasNoTrans,lda至少max(1, k),否则max(1, m)。

输入

b_array

  • 在sgemm中是单精度浮点类型。
  • 在dgemm中是双精度浮点类型。
  • 在hgemm中是fp16类型。
  • 在bgemm中是bf16类型。

矩阵B的序列。

输入

ldb_array

整型数

表示矩阵A的leading dimension的序列。

  • 矩阵为列存,TransB = CblasNoTrans,ldb至少max(1, k),否则max(1, n)。
  • 矩阵为行存,TransB = CblasNoTrans,ldb至少max(1, n),否则max(1, k)。

输入

beta_array

  • 在sgemm中是单精度浮点类型。
  • 在dgemm中是双精度浮点类型。
  • 在hgemm中是fp16类型。
  • 在bgemm中是bf16类型。

乘法系数的序列。

输入

c_array

  • 在sgemm中是单精度浮点类型。
  • 在dgemm中是双精度浮点类型。
  • 在hgemm中是fp16类型。
  • 在bgemm中是bf16类型。

矩阵C的序列。

输入/输出

ldc_array

整型数

表示矩阵C的leading dimension的序列。

矩阵为列存,ldc至少max(1, m),否则max(1, n)。

输入

group_count

整型数

指定group的数量。必须至少为0。

输入

group_size

整型数

大小为 group_count 的数组。元素 group_size[i] 指定了第 i 组中的矩阵数量。

输入

依赖

#include "kblas.h"

示例

    int m = 2, k = 3, n = 2, lda = 2, ldb = 3, ldc = 2; 
    int group_count = 2;
    int group_size[2] = {1, 1};
    int m_array[2] = {2, 2};
    int k_array[2] = {3, 3};
    int n_array[2] = {2, 2};
    int lda_array[2] = {2, 2};
    int ldb_array[2] = {3, 3};
    int ldc_array[2] = {2, 2};
    CBLAS_TRANSPOSE transA_array[2] = {CblasNoTrans, CblasNoTrans};
    CBLAS_TRANSPOSE transB_array[2] = {CblasNoTrans, CblasNoTrans};
    float alpha_array[2] = {1.0, 1.0};
    float beta_array[2] = {2.0, 2.0}; 
     /* 
     * A: 
     *     0.340188,       0.411647,       -0.222225, 
     *     -0.105617,      -0.302449,      0.053970, 
     *     0.283099,       -0.164777,      -0.022603, 
     *     0.298440,       0.268230,       0.128871, 
     * B: 
     *     -0.135216,      0.416195,
     *     0.013401,       0.135712,
     *     0.452230,       0.217297,
     *     -0.358397,      -0.257113, 
     *      0.106969,      -0.362768,
     *     -0.483699,      0.304177
     * C: 
     *     -0.343321,      0.498924,       0.112640,       -0.006417, 
     *     -0.099056,      -0.281743,      -0.203968,      0.472775
     */
    float a_data[12] = { 0.340188,  -0.105617, 0.283099,  0.298440, 0.411647,  -0.302449,
                         -0.164777, 0.268230,  -0.222225, 0.053970, -0.022603, 0.128871 };
    float b_data[12] = { -0.135216, 0.416195,  0.013401, 0.135712,  0.452230,  0.217297,
                         -0.358397, -0.257113, 0.106969, -0.362768, -0.483699, 0.304177 };
    float c_data[8] = { -0.343321, -0.099056, -0.370210, -0.391191, 0.498924, -0.281743, 0.012932, 0.339112 };
    float* a_array[2];
    float* b_array[2];
    float* c_array[2];
    a_array[0] = a_data;
    a_array[1] = a_data + m*k*group_size[0];
    b_array[0] = b_data;
    b_array[1] = b_data + k*n*group_size[0];
    c_array[0] = c_data;
    c_array[1] = c_data + m*n*group_size[0];
    cblas_sgemm_batch(CblasColMajor, transA_array, transB_array, m_array, n_array, k_array, alpha_array, a_array,
                      lda_array, b_array, ldb_array, beta_array, c_array, ldc_array, group_count, group_size);