gemm_?8?8s32
一般矩阵乘矩阵。
即:
。
op(X)可取值:
,alpha,beta为乘法系数,op(A)为m*k矩阵,op(B)为k*n矩阵,C为m*n矩阵。
A_offset为一个标量,类型固定为int8_t,op(A)中所有元素加上A_offset;
B_offset为一个标量,类型固定为int8_t,op(B)中所有元素加上B_offset;
C_offset有三种定义(元素类型均为int32_t):
- 一个标量,C矩阵的每一个元素都加上C_offset
- 一个行向量,C矩阵的每一行都加上这个行向量C_offset
- 一个列向量,C矩阵的每一列都加上这个列向量C_offset
接口定义
C interface:
void cblas_gemm_s8s8s32(
const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE transa, const CBLAS_TRANSPOSE transb, const CBLAS_OFFSET offsetc,
const BLASINT m, const BLASINT n, const BLASINT k, const float alpha,
const BLASINT8 *a, const BLASINT lda, const BLASINT8 oa,
const BLASINT8 *b, const BLASINT ldb, const BLASINT8 ob,
const float beta, int32_t *c, const BLASINT ldc, const int32_t *oc);
void cblas_gemm_u8u8s32(
const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE transa, const CBLAS_TRANSPOSE transb, const CBLAS_OFFSET offsetc,
const BLASINT m, const BLASINT n, const BLASINT k, const float alpha,
const BLASUINT8 *a, const BLASINT lda, const BLASINT8 oa,
const BLASUINT8 *b, const BLASINT ldb, const BLASINT8 ob,
const float beta, int32_t *c, const BLASINT ldc, const int32_t *oc);
void cblas_gemm_s8u8s32(
const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE transa, const CBLAS_TRANSPOSE transb, const CBLAS_OFFSET offsetc,
const BLASINT m, const BLASINT n, const BLASINT k, const float alpha,
const BLASINT8 *a, const BLASINT lda, const BLASINT8 oa,
const BLASUINT8 *b, const BLASINT ldb, const BLASINT8 ob,
const float beta, int32_t *c, const BLASINT ldc, const int32_t *oc);
void cblas_gemm_u8s8s32(
const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE transa, const CBLAS_TRANSPOSE transb, const CBLAS_OFFSET offsetc,
const BLASINT m, const BLASINT n, const BLASINT k, const float alpha,
const BLASUINT8 *a, const BLASINT lda, const BLASINT8 oa,
const BLASINT8 *b, const BLASINT ldb, const BLASINT8 ob,
const float beta, int32_t *c, const BLASINT ldc, const int32_t *oc);
Fortran interface:
int8 gemm无Fortran接口。
参数
参数名 |
类型 |
描述 |
输入/输出 |
|---|---|---|---|
layout |
枚举类型CBLAS_LAYOUT |
表示矩阵是行主序或列主序。
|
输入 |
transa |
枚举类型CBLAS_TRANSPOSE |
矩阵A为常规矩阵,转置矩阵。
|
输入 |
transb |
枚举类型CBLAS_TRANSPOSE |
矩阵B为常规矩阵,转置矩阵。
|
输入 |
offsetc |
枚举类型CBLAS_OFFSET |
表示参数oc(GEMM定义中的C_offset)为行向量/列向量/标量。
|
输入 |
M |
整型数 |
矩阵op(A)和矩阵C的行。 |
输入 |
N |
整型数 |
矩阵op(B)和矩阵C的列。 |
输入 |
K |
整型数 |
矩阵op(A)的列和矩阵op(B)的行。 |
输入 |
alpha |
单精度浮点类型。 |
乘法系数。 |
输入 |
A |
|
矩阵A。 |
输入 |
lda |
整型数 |
|
输入 |
oa |
int8_t类型 |
GEMM定义中的A_offset的值 |
输入 |
B |
|
矩阵B。 |
输入 |
ldb |
整型数 |
|
输入 |
ob |
int8_t类型 |
GEMM定义中的B_offset的值 |
输入 |
beta |
单精度浮点类型 |
乘法系数。 |
输入 |
C |
int32_t类型 |
矩阵C。 |
输入/输出 |
ldc |
整型数 |
矩阵为列存,ldc至少max(1, m),否则max(1, n)。 |
输入 |
oc |
int32_t类型 |
GEMM定义中的C_offset的值
|
输入 |
依赖
#include "kblas.h"
示例
int m = 4, k = 3, n = 4, lda = 4, ldb = 3, ldc = 4;
float alpha = 1.0, beta = 2.0;
int8_t oa=3, ob=5;
CBLAS_OFFSET offsetc = CblasRowOffset;
int32_t oc[] = {1, 2, 3, 4};
/*
* A:
* 35, 39, 27,
* 38, 55, 40,
* 46, 35, 41,
* 54, 55, 24,
* B:
* 35, 54, 35, 40,
* 38, 39, 55, 41,
* 46, 55, 27, 24,
* C:
* 7, 7, 7, 7,
* 7, 7, 7, 7,
* 7, 7, 7, 7,
* 7, 7, 7, 7,
*/
int8_t a[12] = {35, 38, 46, 54,
39, 55, 35, 55,
27, 40, 41, 24};
int8_t b[12] = {35, 38, 46,
54, 39, 55,
35, 55, 27,
40, 41, 24};
int32_t c[16] = {7, 7, 7, 7,
7, 7, 7, 7,
7, 7, 7, 7,
7, 7, 7, 7};
cblas_gemm_s8s8s32(CblasColMajor, CblasNoTrans, CblasNoTrans, offsetc,
m, n, k, alpha,
a, lda, oa,
b, ldb, ob,
beta, c, ldc, oc);
/*
* Output C:
* 4871 5906 5017 4530
* 6342 7567 6513 5778
* 5853 7219 5665 5247
* 6166 7551 6641 6034
*
*/



