?gemm_batch
批量一般矩阵乘矩阵,此函数描述如下:

op(X)可取值:
,alpha,beta为乘法系数,op(A)为m*k矩阵,op(B)为k*n矩阵,C为m*n矩阵。
注:batchgemm适用于矩阵规格小且batch规模大的场景,不适用矩阵大且batch小的场景
接口定义
C interface:
void cblas_sgemm_batch(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE *transA_array,
const enum CBLAS_TRANSPOSE *transB_array, const BLASINT *m_array, const BLASINT *n_array, const BLASINT *k_array,
const float *alpha_array, const float **a_array, const BLASINT *lda_array, const float **b_array,
const BLASINT *ldb_array, const float *beta_array, float **c_array, const BLASINT *ldc_array,
const BLASINT group_count, const BLASINT *group_size);
void cblas_dgemm_batch(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE *transA_array,
const enum CBLAS_TRANSPOSE *transB_array, const BLASINT *m_array, const BLASINT *n_array, const BLASINT *k_array,
const double *alpha_array, const double **a_array, const BLASINT *lda_array, const double **b_array,
const BLASINT *ldb_array, const double *beta_array, double **c_array, const BLASINT *ldc_array,
const BLASINT group_count, const BLASINT *group_size);
void cblas_cgemm_batch(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE *transA_array,
const enum CBLAS_TRANSPOSE *transB_array, const BLASINT *m_array, const BLASINT *n_array, const BLASINT *k_array,
const float *alpha_array, const float **a_array, const BLASINT *lda_array, const float **b_array,
const BLASINT *ldb_array, const float *beta_array, float **c_array, const BLASINT *ldc_array,
const BLASINT group_count, const BLASINT *group_size);
void cblas_zgemm_batch(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE *transA_array,
const enum CBLAS_TRANSPOSE *transB_array, const BLASINT *m_array, const BLASINT *n_array, const BLASINT *k_array,
const double *alpha_array, const double **a_array, const BLASINT *lda_array, const double **b_array,
const BLASINT *ldb_array, const double *beta_array, double **c_array, const BLASINT *ldc_array,
const BLASINT group_count, const BLASINT *group_size);
void cblas_hgemm_batch(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE *transA_array,
const enum CBLAS_TRANSPOSE *transB_array, const BLASINT *m_array, const BLASINT *n_array, const BLASINT *k_array,
const __fp16 *alpha_array, const __fp16 **a_array, const BLASINT *lda_array, const __fp16 **b_array,
const BLASINT *ldb_array, const __fp16 *beta_array, __fp16 **c_array, const BLASINT *ldc_array,
const BLASINT group_count, const BLASINT *group_size);
void cblas_bgemm_batch(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE *transA_array,
const enum CBLAS_TRANSPOSE *transB_array, const BLASINT *m_array, const BLASINT *n_array, const BLASINT *k_array,
const __bf16 *alpha_array, const __bf16 **a_array, const BLASINT *lda_array, const __bf16 **b_array,
const BLASINT *ldb_array, const __bf16 *beta_array, __bf16 **c_array, const BLASINT *ldc_array,
const BLASINT group_count, const BLASINT *group_size);
参数
参数名 |
类型 |
描述 |
输入/输出 |
|---|---|---|---|
order |
枚举类型CBLAS_ORDER |
表示矩阵是行主序或列主序。 |
输入 |
transA_array |
枚举类型CBLAS_TRANSPOSE |
表示矩阵A转置情况的序列。 矩阵A为常规矩阵,转置矩阵或共轭矩阵。
|
输入 |
transB_array |
枚举类型CBLAS_TRANSPOSE |
表示矩阵B转置情况的序列。 矩阵B为常规矩阵,转置矩阵或共轭矩阵。
|
输入 |
m_array |
整型数 |
矩阵op(A)和矩阵C的行数的序列。 |
输入 |
n_array |
整型数 |
矩阵op(B)和矩阵C的列数的序列。 |
输入 |
k_array |
整型数 |
矩阵op(A)的列和矩阵op(B)的行数的序列。 |
输入 |
alpha_array |
|
乘法系数的序列。 |
输入 |
a_array |
|
矩阵A的序列。 |
输入 |
lda_array |
整型数 |
表示矩阵A的leading dimension的序列。
|
输入 |
b_array |
|
矩阵B的序列。 |
输入 |
ldb_array |
整型数 |
表示矩阵A的leading dimension的序列。
|
输入 |
beta_array |
|
乘法系数的序列。 |
输入 |
c_array |
|
矩阵C的序列。 |
输入/输出 |
ldc_array |
整型数 |
表示矩阵C的leading dimension的序列。 矩阵为列存,ldc至少max(1, m),否则max(1, n)。 |
输入 |
group_count |
整型数 |
指定group的数量。必须至少为0。 |
输入 |
group_size |
整型数 |
大小为 group_count 的数组。元素 group_size[i] 指定了第 i 组中的矩阵数量。 |
输入 |
依赖
#include "kblas.h"
示例
int m = 2, k = 3, n = 2, lda = 2, ldb = 3, ldc = 2;
int group_count = 2;
int group_size[2] = {1, 1};
int m_array[2] = {2, 2};
int k_array[2] = {3, 3};
int n_array[2] = {2, 2};
int lda_array[2] = {2, 2};
int ldb_array[2] = {3, 3};
int ldc_array[2] = {2, 2};
CBLAS_TRANSPOSE transA_array[2] = {CblasNoTrans, CblasNoTrans};
CBLAS_TRANSPOSE transB_array[2] = {CblasNoTrans, CblasNoTrans};
float alpha_array[2] = {1.0, 1.0};
float beta_array[2] = {2.0, 2.0};
/*
* A:
* 0.340188, 0.411647, -0.222225,
* -0.105617, -0.302449, 0.053970,
* 0.283099, -0.164777, -0.022603,
* 0.298440, 0.268230, 0.128871,
* B:
* -0.135216, 0.416195,
* 0.013401, 0.135712,
* 0.452230, 0.217297,
* -0.358397, -0.257113,
* 0.106969, -0.362768,
* -0.483699, 0.304177
* C:
* -0.343321, 0.498924, 0.112640, -0.006417,
* -0.099056, -0.281743, -0.203968, 0.472775
*/
float a_data[12] = { 0.340188, -0.105617, 0.283099, 0.298440, 0.411647, -0.302449,
-0.164777, 0.268230, -0.222225, 0.053970, -0.022603, 0.128871 };
float b_data[12] = { -0.135216, 0.416195, 0.013401, 0.135712, 0.452230, 0.217297,
-0.358397, -0.257113, 0.106969, -0.362768, -0.483699, 0.304177 };
float c_data[8] = { -0.343321, -0.099056, -0.370210, -0.391191, 0.498924, -0.281743, 0.012932, 0.339112 };
float* a_array[2];
float* b_array[2];
float* c_array[2];
a_array[0] = a_data;
a_array[1] = a_data + m*k*group_size[0];
b_array[0] = b_data;
b_array[1] = b_data + k*n*group_size[0];
c_array[0] = c_data;
c_array[1] = c_data + m*n*group_size[0];
cblas_sgemm_batch(CblasColMajor, transA_array, transB_array, m_array, n_array, k_array, alpha_array, a_array,
lda_array, b_array, ldb_array, beta_array, c_array, ldc_array, group_count, group_size);







