gemm_batch_?8?8s32
批量一般矩阵乘矩阵,此函数描述如下:

op(X)可取值:
,alpha,beta为乘法系数,op(A)为m*k矩阵,op(B)为k*n矩阵,C为m*n矩阵。
A_offset为一个标量,类型固定为int8_t,op(A)中所有元素加上A_offset;
B_offset为一个标量,类型固定为int8_t,op(B)中所有元素加上B_offset;
C_offset有三种定义(元素类型均为int32_t):
- 一个标量,C矩阵的每一个元素都加上C_offset
- 一个行向量,C矩阵的每一行都加上这个行向量C_offset
- 一个列向量,C矩阵的每一列都加上这个列向量C_offset
注:batchgemm适用于矩阵规格小且batch规模大的场景,不适用矩阵大且batch小的场景
接口定义
C interface:
void cblas_gemm_batch_s8s8s32(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE *transA_array,
const CBLAS_TRANSPOSE *transB_array, const CBLAS_OFFSET *offsetc_array, const BLASINT *m_array,
const BLASINT *n_array, const BLASINT *k_array, const float *alpha_array, const BLASINT8 **a_array,
const BLASINT *lda_array, const BLASINT8 *oa_array, const BLASINT8 **b_array, const BLASINT *ldb_array,
const BLASINT8 *ob_array, const float *beta, int32_t **c_array, const BLASINT *ldc_array, const int32_t **oc_array,
const BLASINT group_count, const BLASINT *group_size);
void cblas_gemm_batch_u8u8s32(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE *transA_array,
const CBLAS_TRANSPOSE *transB_array, const CBLAS_OFFSET *offsetc_array, const BLASINT *m_array,
const BLASINT *n_array, const BLASINT *k_array, const float *alpha_array, const BLASUINT8 **a_array,
const BLASINT *lda_array, const BLASINT8 *oa_array, const BLASUINT8 **b_array, const BLASINT *ldb_array,
const BLASINT8 *ob_array, const float *beta, int32_t **c_array, const BLASINT *ldc_array, const int32_t **oc_array,
const BLASINT group_count, const BLASINT *group_size);
void cblas_gemm_batch_s8u8s32(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE *transA_array,
const CBLAS_TRANSPOSE *transB_array, const CBLAS_OFFSET *offsetc_array, const BLASINT *m_array,
const BLASINT *n_array, const BLASINT *k_array, const float *alpha_array, const BLASINT8 **a_array,
const BLASINT *lda_array, const BLASINT8 *oa_array, const BLASUINT8 **b_array, const BLASINT *ldb_array,
const BLASINT8 *ob_array, const float *beta, int32_t **c_array, const BLASINT *ldc_array, const int32_t **oc_array,
const BLASINT group_count, const BLASINT *group_size);
void cblas_gemm_batch_u8s8s32(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE *transA_array,
const CBLAS_TRANSPOSE *transB_array, const CBLAS_OFFSET *offsetc_array, const BLASINT *m_array,
const BLASINT *n_array, const BLASINT *k_array, const float *alpha_array, const BLASUINT8 **a_array,
const BLASINT *lda_array, const BLASINT8 *oa_array, const BLASINT8 **b_array, const BLASINT *ldb_array,
const BLASINT8 *ob_array, const float *beta, int32_t **c_array, const BLASINT *ldc_array, const int32_t **oc_array,
const BLASINT group_count, const BLASINT *group_size);
参数
参数名 |
类型 |
描述 |
输入/输出 |
|---|---|---|---|
layout |
枚举类型CBLAS_ORDER |
表示矩阵是行主序或列主序。 |
输入 |
transA_array |
枚举类型CBLAS_TRANSPOSE |
表示矩阵A转置情况的序列。 矩阵A为常规矩阵,转置矩阵或共轭矩阵。
|
输入 |
transB_array |
枚举类型CBLAS_TRANSPOSE |
表示矩阵B转置情况的序列。 矩阵B为常规矩阵,转置矩阵或共轭矩阵。
|
输入 |
offsetc_array |
枚举类型CBLAS_OFFSET |
表示参数oc(GEMM定义中的C_offset)为行向量/列向量/标量的序列。
|
输入 |
m_array |
整型数 |
矩阵op(A)和矩阵C的行数的序列。 |
输入 |
n_array |
整型数 |
矩阵op(B)和矩阵C的列数的序列。 |
输入 |
k_array |
整型数 |
矩阵op(A)的列和矩阵op(B)的行数的序列。 |
输入 |
alpha_array |
单精度浮点类型。 |
乘法系数的序列。 |
输入 |
a_array |
|
矩阵A的序列。 |
输入 |
lda_array |
整型数 |
表示矩阵A的leading dimension的序列。
|
输入 |
oa_array |
整型数 |
GEMM定义中的A_offset的值的序列 |
输入 |
b_array |
|
矩阵B的序列。 |
输入 |
ldb_array |
整型数 |
表示矩阵A的leading dimension的序列。
|
输入 |
ob_array |
整型数 |
GEMM定义中的B_offset的值的序列 |
输入 |
beta_array |
|
乘法系数的序列。 |
输入 |
c_array |
整型数 |
矩阵C的序列。 |
输入/输出 |
ldc_array |
整型数 |
表示矩阵C的leading dimension的序列。 矩阵为列存,ldc至少max(1, m),否则max(1, n)。 |
输入 |
oc_array |
整型数 |
GEMM定义中的C_offset的值的序列 |
输入 |
group_count |
整型数 |
指定group的数量。必须至少为0。 |
输入 |
group_size |
整型数 |
大小为 group_count 的数组。元素 group_size[i] 指定了第 i 组中的矩阵数量。 |
输入 |
依赖
#include "kblas.h"
示例
int m = 2, k = 3, n = 2, lda = 2, ldb = 3, ldc = 2;
int group_count = 2;
int group_size[2] = {1, 1};
int m_array[2] = {2, 2};
int k_array[2] = {3, 3};
int n_array[2] = {2, 2};
int lda_array[2] = {2, 2};
int ldb_array[2] = {3, 3};
int ldc_array[2] = {2, 2};
CBLAS_TRANSPOSE transA_array[2] = {CblasNoTrans, CblasNoTrans};
CBLAS_TRANSPOSE transB_array[2] = {CblasNoTrans, CblasNoTrans};
float alpha_array[2] = {1.0, 1.0};
float beta_array[2] = {1.0, 1.0};
CBLAS_OFFSET offestc_array[2] = {CblasFixOffset, CblasFixOffset};
int32_t oc_array[2] = {0, 0};
int8_t oa_array[2] = {0, 0};
int8_t ob_array[2] = {0, 0};
int8_t a_data[12] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
int8_t b_data[12] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
int32_t c_data[8] = {1, 1, 1, 1, 1, 1, 1, 1};
int8_t* a_array[2];
int8_t* b_array[2];
int32_t* c_array[2];
a_array[0] = a_data;
a_array[1] = a_data + m*k*group_size[0];
b_array[0] = b_data;
b_array[1] = b_data + k*n*group_size[0];
c_array[0] = c_data;
c_array[1] = c_data + m*n*group_size[0];
cblas_gemm_batch_s8s8s32(CblasColMajor, transA_array, transB_array, offestc_array, m_array, n_array, k_array, alpha_array, a_array,
lda_array, oa_array, b_array, ldb_array, ob_array, beta_array, c_array, ldc_array, oc_array, group_count, group_size);






