鲲鹏社区首页
中文
注册
我要评分
文档获取效率
文档正确性
内容完整性
文档易理解
在线提单
论坛求助

gemm_?8?8s32

一般矩阵乘矩阵。

即:

op(X)可取值:,alpha,beta为乘法系数,op(A)为m*k矩阵,op(B)为k*n矩阵,C为m*n矩阵。

A_offset为一个标量,类型固定为int8_t,op(A)中所有元素加上A_offset;

B_offset为一个标量,类型固定为int8_t,op(B)中所有元素加上B_offset;

C_offset有三种定义(元素类型均为int32_t):

  • 一个标量,C矩阵的每一个元素都加上C_offset
  • 一个行向量,C矩阵的每一行都加上这个行向量C_offset
  • 一个列向量,C矩阵的每一列都加上这个列向量C_offset

接口定义

C interface:

void cblas_gemm_s8s8s32(

const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE transa, const CBLAS_TRANSPOSE transb, const CBLAS_OFFSET offsetc,

const BLASINT m, const BLASINT n, const BLASINT k, const float alpha,

const BLASINT8 *a, const BLASINT lda, const BLASINT8 oa,

const BLASINT8 *b, const BLASINT ldb, const BLASINT8 ob,

const float beta, int32_t *c, const BLASINT ldc, const int32_t *oc);

void cblas_gemm_u8u8s32(

const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE transa, const CBLAS_TRANSPOSE transb, const CBLAS_OFFSET offsetc,

const BLASINT m, const BLASINT n, const BLASINT k, const float alpha,

const BLASUINT8 *a, const BLASINT lda, const BLASINT8 oa,

const BLASUINT8 *b, const BLASINT ldb, const BLASINT8 ob,

const float beta, int32_t *c, const BLASINT ldc, const int32_t *oc);

void cblas_gemm_s8u8s32(

const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE transa, const CBLAS_TRANSPOSE transb, const CBLAS_OFFSET offsetc,

const BLASINT m, const BLASINT n, const BLASINT k, const float alpha,

const BLASINT8 *a, const BLASINT lda, const BLASINT8 oa,

const BLASUINT8 *b, const BLASINT ldb, const BLASINT8 ob,

const float beta, int32_t *c, const BLASINT ldc, const int32_t *oc);

void cblas_gemm_u8s8s32(

const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE transa, const CBLAS_TRANSPOSE transb, const CBLAS_OFFSET offsetc,

const BLASINT m, const BLASINT n, const BLASINT k, const float alpha,

const BLASUINT8 *a, const BLASINT lda, const BLASINT8 oa,

const BLASINT8 *b, const BLASINT ldb, const BLASINT8 ob,

const float beta, int32_t *c, const BLASINT ldc, const int32_t *oc);

Fortran interface:

int8 gemm无Fortran接口。

参数

参数名

类型

描述

输入/输出

layout

枚举类型CBLAS_LAYOUT

表示矩阵是行主序或列主序。

  • 如果layout = CblasRowMajor,则代表行主序。
  • 如果layout = CblasColMajor,则代表列主序。

输入

transa

枚举类型CBLAS_TRANSPOSE

矩阵A为常规矩阵,转置矩阵。

  • 如果TransA = CblasNoTrans,
  • 如果TransA = CblasTrans,

输入

transb

枚举类型CBLAS_TRANSPOSE

矩阵B为常规矩阵,转置矩阵。

  • 如果TransB = CblasNoTrans,
  • 如果TransB = CblasTrans,

输入

offsetc

枚举类型CBLAS_OFFSET

表示参数oc(GEMM定义中的C_offset)为行向量/列向量/标量。

  • 如果offsetc = CblasRowOffset,则代表C_offset是一个行向量。
  • 如果offsetc = CblasColOffset,则代表C_offset是一个列向量。
  • 如果offsetc = CblasFixOffset,则代表C_offset是一个标量。

输入

M

整型数

矩阵op(A)和矩阵C的行。

输入

N

整型数

矩阵op(B)和矩阵C的列。

输入

K

整型数

矩阵op(A)的列和矩阵op(B)的行。

输入

alpha

单精度浮点类型。

乘法系数。

输入

A

  • 在gemm_s8s8s32中是int8_t类型。
  • 在gemm_u8u8s32中是uint8_t类型。
  • 在gemm_s8u8s32中是int8_t类型。
  • 在gemm_u8s8s32中是uint8_t类型。

矩阵A。

输入

lda

整型数

  • 矩阵为列存,TransA = CblasNoTrans,lda至少max(1, m),否则max(1, k)。
  • 矩阵为行存,TransA = CblasNoTrans,lda至少max(1, k),否则max(1, m)。

输入

oa

int8_t类型

GEMM定义中的A_offset的值

输入

B

  • 在gemm_s8s8s32中是int8_t类型。
  • 在gemm_u8u8s32中是uint8_t类型。
  • 在gemm_s8u8s32中是uint8_t类型。
  • 在gemm_u8s8s32中是int8_t类型。

矩阵B。

输入

ldb

整型数

  • 矩阵为列存,TransB = CblasNoTrans,ldb至少max(1, k),否则max(1, n)。
  • 矩阵为行存,TransB = CblasNoTrans,ldb至少max(1, n),否则max(1, k)。

输入

ob

int8_t类型

GEMM定义中的B_offset的值

输入

beta

单精度浮点类型

乘法系数。

输入

C

int32_t类型

矩阵C。

输入/输出

ldc

整型数

矩阵为列存,ldc至少max(1, m),否则max(1, n)。

输入

oc

int32_t类型

GEMM定义中的C_offset的值

  • 当offsetc的值为CblasRowOffset时,此输入代表行向量。
  • 当offsetc的值为CblasColOffset时,此输入代表列向量。
  • 当offsetc的值为CblasFixOffset时,此输入代表标量。

输入

依赖

#include "kblas.h"

示例

C interface:
    int m = 4, k = 3, n = 4, lda = 4, ldb = 3, ldc = 4; 
    float alpha = 1.0, beta = 2.0;
 
    int8_t oa=3, ob=5;
    CBLAS_OFFSET offsetc = CblasRowOffset;
    int32_t oc[] = {1, 2, 3, 4};

     /* 
     * A: 
     *     35,      39,     27, 
     *     38,      55,     40, 
     *     46,      35,     41, 
     *     54,      55,     24, 
     * B: 
     *     35,       54,       35,       40, 
     *     38,       39,       55,       41, 
     *     46,       55,       27,       24, 
     * C: 
     *     7,      7,      7,      7, 
     *     7,      7,      7,      7, 
     *     7,      7,      7,      7, 
     *     7,      7,      7,      7, 
     */ 
    int8_t a[12] = {35, 38, 46, 54,
                    39, 55, 35, 55,
                    27, 40, 41, 24};
    int8_t b[12] = {35, 38, 46,
                    54, 39, 55,
                    35, 55, 27,
                    40, 41, 24};
    int32_t c[16] = {7, 7, 7, 7, 
                    7, 7, 7, 7, 
                    7, 7, 7, 7, 
                    7, 7, 7, 7}; 

    cblas_gemm_s8s8s32(CblasColMajor, CblasNoTrans, CblasNoTrans, offsetc,
        m, n, k, alpha, 
        a, lda, oa,
        b, ldb, ob,
        beta, c, ldc, oc);
    /* 
     * Output C: 
     *     4871 5906 5017 4530
     *     6342 7567 6513 5778
     *     5853 7219 5665 5247
     *     6166 7551 6641 6034
     * 
     */