gemm_?8?8s32
一般矩阵乘矩阵。
即:。
op(X)可取值:,alpha,beta为乘法系数,op(A)为m*k矩阵,op(B)为k*n矩阵,C为m*n矩阵。
A_offset为一个标量,类型固定为int8_t,op(A)中所有元素加上A_offset;
B_offset为一个标量,类型固定为int8_t,op(B)中所有元素加上B_offset;
C_offset有三种定义(元素类型均为int32_t):
- 一个标量,C矩阵的每一个元素都加上C_offset
- 一个行向量,C矩阵的每一行都加上这个行向量C_offset
- 一个列向量,C矩阵的每一列都加上这个列向量C_offset
接口定义
C interface:
void cblas_gemm_s8s8s32(
const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE transa, const CBLAS_TRANSPOSE transb, const CBLAS_OFFSET offsetc,
const BLASINT m, const BLASINT n, const BLASINT k, const float alpha,
const BLASINT8 *a, const BLASINT lda, const BLASINT8 oa,
const BLASINT8 *b, const BLASINT ldb, const BLASINT8 ob,
const float beta, int32_t *c, const BLASINT ldc, const int32_t *oc);
void cblas_gemm_u8u8s32(
const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE transa, const CBLAS_TRANSPOSE transb, const CBLAS_OFFSET offsetc,
const BLASINT m, const BLASINT n, const BLASINT k, const float alpha,
const BLASUINT8 *a, const BLASINT lda, const BLASINT8 oa,
const BLASUINT8 *b, const BLASINT ldb, const BLASINT8 ob,
const float beta, int32_t *c, const BLASINT ldc, const int32_t *oc);
void cblas_gemm_s8u8s32(
const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE transa, const CBLAS_TRANSPOSE transb, const CBLAS_OFFSET offsetc,
const BLASINT m, const BLASINT n, const BLASINT k, const float alpha,
const BLASINT8 *a, const BLASINT lda, const BLASINT8 oa,
const BLASUINT8 *b, const BLASINT ldb, const BLASINT8 ob,
const float beta, int32_t *c, const BLASINT ldc, const int32_t *oc);
void cblas_gemm_u8s8s32(
const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE transa, const CBLAS_TRANSPOSE transb, const CBLAS_OFFSET offsetc,
const BLASINT m, const BLASINT n, const BLASINT k, const float alpha,
const BLASUINT8 *a, const BLASINT lda, const BLASINT8 oa,
const BLASINT8 *b, const BLASINT ldb, const BLASINT8 ob,
const float beta, int32_t *c, const BLASINT ldc, const int32_t *oc);
Fortran interface:
int8 gemm无Fortran接口。
参数
参数名 |
类型 |
描述 |
输入/输出 |
---|---|---|---|
layout |
枚举类型CBLAS_LAYOUT |
表示矩阵是行主序或列主序。
|
输入 |
transa |
枚举类型CBLAS_TRANSPOSE |
矩阵A为常规矩阵,转置矩阵。
|
输入 |
transb |
枚举类型CBLAS_TRANSPOSE |
矩阵B为常规矩阵,转置矩阵。
|
输入 |
offsetc |
枚举类型CBLAS_OFFSET |
表示参数oc(GEMM定义中的C_offset)为行向量/列向量/标量。
|
输入 |
M |
整型数 |
矩阵op(A)和矩阵C的行。 |
输入 |
N |
整型数 |
矩阵op(B)和矩阵C的列。 |
输入 |
K |
整型数 |
矩阵op(A)的列和矩阵op(B)的行。 |
输入 |
alpha |
单精度浮点类型。 |
乘法系数。 |
输入 |
A |
|
矩阵A。 |
输入 |
lda |
整型数 |
|
输入 |
oa |
int8_t类型 |
GEMM定义中的A_offset的值 |
输入 |
B |
|
矩阵B。 |
输入 |
ldb |
整型数 |
|
输入 |
ob |
int8_t类型 |
GEMM定义中的B_offset的值 |
输入 |
beta |
单精度浮点类型 |
乘法系数。 |
输入 |
C |
int32_t类型 |
矩阵C。 |
输入/输出 |
ldc |
整型数 |
矩阵为列存,ldc至少max(1, m),否则max(1, n)。 |
输入 |
oc |
int32_t类型 |
GEMM定义中的C_offset的值
|
输入 |
依赖
#include "kblas.h"
示例
int m = 4, k = 3, n = 4, lda = 4, ldb = 3, ldc = 4; float alpha = 1.0, beta = 2.0; int8_t oa=3, ob=5; CBLAS_OFFSET offsetc = CblasRowOffset; int32_t oc[] = {1, 2, 3, 4}; /* * A: * 35, 39, 27, * 38, 55, 40, * 46, 35, 41, * 54, 55, 24, * B: * 35, 54, 35, 40, * 38, 39, 55, 41, * 46, 55, 27, 24, * C: * 7, 7, 7, 7, * 7, 7, 7, 7, * 7, 7, 7, 7, * 7, 7, 7, 7, */ int8_t a[12] = {35, 38, 46, 54, 39, 55, 35, 55, 27, 40, 41, 24}; int8_t b[12] = {35, 38, 46, 54, 39, 55, 35, 55, 27, 40, 41, 24}; int32_t c[16] = {7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7}; cblas_gemm_s8s8s32(CblasColMajor, CblasNoTrans, CblasNoTrans, offsetc, m, n, k, alpha, a, lda, oa, b, ldb, ob, beta, c, ldc, oc); /* * Output C: * 4871 5906 5017 4530 * 6342 7567 6513 5778 * 5853 7219 5665 5247 * 6166 7551 6641 6034 * */