开发者
资源
我要评分
获取效率
正确性
完整性
易理解
在线提单
论坛求助

gemm_?8?8s32_compute

对pack后的A\B矩阵进行一般矩阵乘。

即:

op(X)可取值:,alpha,beta为乘法系数,op(A)为m*k矩阵,op(B)为k*n矩阵,C为m*n矩阵。

接口定义

C interface:

void cblas_gemm_s8s8s32_compute(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE transa, const CBLAS_TRANSPOSE transb,

const CBLAS_OFFSET offsetc, const BLASINT m, const BLASINT n, const BLASINT k, const float alpha, const BLASINT8 *a, const BLASINT lda,

const BLASINT8 oa,const BLASINT8 *b, const BLASINT ldb, const BLASINT8 ob, const float beta, int32_t *c,const BLASINT ldc, const int32_t *oc);

void cblas_gemm_s8u8s32_compute(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE transa, const CBLAS_TRANSPOSE transb,

const CBLAS_OFFSET offsetc, const BLASINT m, const BLASINT n, const BLASINT k, const float alpha, const BLASINT8 *a, const BLASINT lda,

const BLASINT8 oa, const BLASUINT8 *b, const BLASINT ldb, const BLASINT8 ob, const float beta, int32_t *c, const BLASINT ldc, const int32_t *oc);

void cblas_gemm_u8u8s32_compute(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE transa, const CBLAS_TRANSPOSE transb,

const CBLAS_OFFSET offsetc, const BLASINT m, const BLASINT n, const BLASINT k,const float alpha, const BLASUINT8 *a, const BLASINT lda,

const BLASINT8 oa,const BLASUINT8 *b, const BLASINT ldb, const BLASINT8 ob, const float beta, int32_t *c,const BLASINT ldc, const int32_t *oc);

void cblas_gemm_u8s8s32_compute(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE transa, const CBLAS_TRANSPOSE transb,

const CBLAS_OFFSET offsetc, const BLASINT m, const BLASINT n, const BLASINT k,const float alpha, const BLASUINT8 *a, const BLASINT lda,

const BLASINT8 oa,const BLASINT8 *b, const BLASINT ldb, const BLASINT8 ob, const float beta, int32_t *c,const BLASINT ldc, const int32_t *oc);

参数

参数名

类型

描述

输入/输出

order

枚举类型CBLAS_ORDER

表示矩阵是行主序或列主序。

输入

TransA

枚举类型CBLAS_TRANSPOSE

矩阵A为常规矩阵,转置矩阵或共轭矩阵。

  • 如果trans= CblasNoTrans,
  • 如果trans= CblasTrans,

输入

TransB

枚举类型CBLAS_TRANSPOSE

矩阵B为常规矩阵,转置矩阵或共轭矩阵。

  • 如果trans= CblasNoTrans,
  • 如果trans= CblasTrans,

输入

M

整型数

矩阵op(A)和矩阵C的行。

输入

N

整型数

矩阵op(B)和矩阵C的列。

输入

K

整型数

矩阵op(A)的列和矩阵op(B)的行。

输入

alpha

  • 单精度浮点类型。

乘法系数。

输入

A

  • 在gemm_s8?8s32是int8类型。
  • 在gemm_u8?8s32是uint8类型。

矩阵A。

输入

lda

整型数

  • 矩阵为列存,TransA = CblasNoTrans,lda至少max(1, m),否则max(1, k)。
  • 矩阵为行存,TransA = CblasNoTrans,lda至少max(1, k),否则max(1, m)。

输入

B

  • 在gemm_?8s8s32是int8类型。
  • 在gemm_?8u8s32是uint8类型。

矩阵B。

输入

ldb

整型数

  • 矩阵为列存,TransB = CblasNoTrans,ldb至少max(1, k),否则max(1, n)。
  • 矩阵为行存,TransB = CblasNoTrans,ldb至少max(1, n),否则max(1, k)。

输入

beta

  • 单精度浮点类型。

乘法系数。

输入

C

  • 是int类型。

矩阵C。

输入/输出

ldc

整型数

矩阵为列存,ldc至少max(1, m),否则max(1, n)。

输入

依赖

#include "kblas.h"

示例

void test_igemm_compute(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE transA, const enum CBLAS_TRANSPOSE transB, BLASINT row, BLASINT col)
{
    BLASUINT8 *a = (BLASUINT8 *)malloc(sizeof(BLASUINT8) * row * col);
    BLASINT8 *b = (BLASINT8 *)malloc(sizeof(BLASINT8) * col * row);
    int32_t *c = (int32_t *)malloc(sizeof(int32_t) * row * row);
    for (int i = 0; i < row * col; i++) {
        a[i] = i;
        b[i] = i;
    }
    for (int i = 0; i < row * row; i++) {
        c[i] = i;
    }
    BLASINT m = row;
    BLASINT n = row;
    BLASINT k = col;
    BLASINT lda, ldb, ldc = m;
    size_t size_a = cblas_gemm_s8u8s32_pack_get_size(CblasA, m, n, k);
    size_t size_b = cblas_gemm_s8u8s32_pack_get_size(CblasB, m, n, k);
    BLASUINT8 *sa = (BLASUINT8 *)malloc(size_a);
    BLASINT8 *sb = (BLASINT8 *)malloc(size_b);
    if (order == CblasColMajor) {
        if (transA == CblasNoTrans) {
            lda = m;
        } else {
            lda = k;
        }
        if (transB == CblasNoTrans) {
            ldb = k;
        } else {
            ldb = n;
        }
    } else { // CblasRowMajor
        if (transA == CblasNoTrans) {
            lda = k;
        } else {
            lda = m;
        }
        if (transB == CblasNoTrans) {
            ldb = n;
        } else {
            ldb = k;
        }
    }
    float alpha = 2.0;
    float beta = 3.0;
    cblas_gemm_s8u8s32_pack(order, CblasA, transA, m, n, k, a, lda, sa);
    cblas_gemm_s8u8s32_pack(order, CblasB, transB, m, n, k, b, ldb, sb);
    int oc = 3;
    cblas_gemm_u8s8s32_compute(order, transA, transB, CblasFixOffset, m, n, k, alpha, sa, lda, 0, sb, ldb, 0, beta, c, ldc, &oc);
 
    free(a);
    free(b);
    free(c);
    free(sa);
    free(sb);
}