gemm_?8?8s32_compute
对pack后的A\B矩阵进行一般矩阵乘。
即:
。
op(X)可取值:
,alpha,beta为乘法系数,op(A)为m*k矩阵,op(B)为k*n矩阵,C为m*n矩阵。
接口定义
C interface:
void cblas_gemm_s8s8s32_compute(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE transa, const CBLAS_TRANSPOSE transb,
const CBLAS_OFFSET offsetc, const BLASINT m, const BLASINT n, const BLASINT k, const float alpha, const BLASINT8 *a, const BLASINT lda,
const BLASINT8 oa,const BLASINT8 *b, const BLASINT ldb, const BLASINT8 ob, const float beta, int32_t *c,const BLASINT ldc, const int32_t *oc);
void cblas_gemm_s8u8s32_compute(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE transa, const CBLAS_TRANSPOSE transb,
const CBLAS_OFFSET offsetc, const BLASINT m, const BLASINT n, const BLASINT k, const float alpha, const BLASINT8 *a, const BLASINT lda,
const BLASINT8 oa, const BLASUINT8 *b, const BLASINT ldb, const BLASINT8 ob, const float beta, int32_t *c, const BLASINT ldc, const int32_t *oc);
void cblas_gemm_u8u8s32_compute(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE transa, const CBLAS_TRANSPOSE transb,
const CBLAS_OFFSET offsetc, const BLASINT m, const BLASINT n, const BLASINT k,const float alpha, const BLASUINT8 *a, const BLASINT lda,
const BLASINT8 oa,const BLASUINT8 *b, const BLASINT ldb, const BLASINT8 ob, const float beta, int32_t *c,const BLASINT ldc, const int32_t *oc);
void cblas_gemm_u8s8s32_compute(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE transa, const CBLAS_TRANSPOSE transb,
const CBLAS_OFFSET offsetc, const BLASINT m, const BLASINT n, const BLASINT k,const float alpha, const BLASUINT8 *a, const BLASINT lda,
const BLASINT8 oa,const BLASINT8 *b, const BLASINT ldb, const BLASINT8 ob, const float beta, int32_t *c,const BLASINT ldc, const int32_t *oc);
参数
参数名 |
类型 |
描述 |
输入/输出 |
|---|---|---|---|
order |
枚举类型CBLAS_ORDER |
表示矩阵是行主序或列主序。 |
输入 |
TransA |
枚举类型CBLAS_TRANSPOSE |
矩阵A为常规矩阵,转置矩阵或共轭矩阵。
|
输入 |
TransB |
枚举类型CBLAS_TRANSPOSE |
矩阵B为常规矩阵,转置矩阵或共轭矩阵。
|
输入 |
M |
整型数 |
矩阵op(A)和矩阵C的行。 |
输入 |
N |
整型数 |
矩阵op(B)和矩阵C的列。 |
输入 |
K |
整型数 |
矩阵op(A)的列和矩阵op(B)的行。 |
输入 |
alpha |
|
乘法系数。 |
输入 |
A |
|
矩阵A。 |
输入 |
lda |
整型数 |
|
输入 |
B |
|
矩阵B。 |
输入 |
ldb |
整型数 |
|
输入 |
beta |
|
乘法系数。 |
输入 |
C |
|
矩阵C。 |
输入/输出 |
ldc |
整型数 |
矩阵为列存,ldc至少max(1, m),否则max(1, n)。 |
输入 |
依赖
#include "kblas.h"
示例
void test_igemm_compute(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE transA, const enum CBLAS_TRANSPOSE transB, BLASINT row, BLASINT col)
{
BLASUINT8 *a = (BLASUINT8 *)malloc(sizeof(BLASUINT8) * row * col);
BLASINT8 *b = (BLASINT8 *)malloc(sizeof(BLASINT8) * col * row);
int32_t *c = (int32_t *)malloc(sizeof(int32_t) * row * row);
for (int i = 0; i < row * col; i++) {
a[i] = i;
b[i] = i;
}
for (int i = 0; i < row * row; i++) {
c[i] = i;
}
BLASINT m = row;
BLASINT n = row;
BLASINT k = col;
BLASINT lda, ldb, ldc = m;
size_t size_a = cblas_gemm_s8u8s32_pack_get_size(CblasA, m, n, k);
size_t size_b = cblas_gemm_s8u8s32_pack_get_size(CblasB, m, n, k);
BLASUINT8 *sa = (BLASUINT8 *)malloc(size_a);
BLASINT8 *sb = (BLASINT8 *)malloc(size_b);
if (order == CblasColMajor) {
if (transA == CblasNoTrans) {
lda = m;
} else {
lda = k;
}
if (transB == CblasNoTrans) {
ldb = k;
} else {
ldb = n;
}
} else { // CblasRowMajor
if (transA == CblasNoTrans) {
lda = k;
} else {
lda = m;
}
if (transB == CblasNoTrans) {
ldb = n;
} else {
ldb = k;
}
}
float alpha = 2.0;
float beta = 3.0;
cblas_gemm_s8u8s32_pack(order, CblasA, transA, m, n, k, a, lda, sa);
cblas_gemm_s8u8s32_pack(order, CblasB, transB, m, n, k, b, ldb, sb);
int oc = 3;
cblas_gemm_u8s8s32_compute(order, transA, transB, CblasFixOffset, m, n, k, alpha, sa, lda, 0, sb, ldb, 0, beta, c, ldc, &oc);
free(a);
free(b);
free(c);
free(sa);
free(sb);
}


