开发者
我要评分
获取效率
正确性
完整性
易理解
在线提单
论坛求助

gemm_batch_?8?8s32

批量一般矩阵乘矩阵,此函数描述如下:

op(X)可取值:,alpha,beta为乘法系数,op(A)为m*k矩阵,op(B)为k*n矩阵,C为m*n矩阵。

A_offset为一个标量,类型固定为int8_t,op(A)中所有元素加上A_offset;

B_offset为一个标量,类型固定为int8_t,op(B)中所有元素加上B_offset;

C_offset有三种定义(元素类型均为int32_t):

  • 一个标量,C矩阵的每一个元素都加上C_offset
  • 一个行向量,C矩阵的每一行都加上这个行向量C_offset
  • 一个列向量,C矩阵的每一列都加上这个列向量C_offset

:batchgemm适用于矩阵规格小且batch规模大的场景,不适用矩阵大且batch小的场景

接口定义

C interface:

void cblas_gemm_batch_s8s8s32(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE *transA_array,

const CBLAS_TRANSPOSE *transB_array, const CBLAS_OFFSET *offsetc_array, const BLASINT *m_array,

const BLASINT *n_array, const BLASINT *k_array, const float *alpha_array, const BLASINT8 **a_array,

const BLASINT *lda_array, const BLASINT8 *oa_array, const BLASINT8 **b_array, const BLASINT *ldb_array,

const BLASINT8 *ob_array, const float *beta, int32_t **c_array, const BLASINT *ldc_array, const int32_t **oc_array,

const BLASINT group_count, const BLASINT *group_size);

void cblas_gemm_batch_u8u8s32(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE *transA_array,

const CBLAS_TRANSPOSE *transB_array, const CBLAS_OFFSET *offsetc_array, const BLASINT *m_array,

const BLASINT *n_array, const BLASINT *k_array, const float *alpha_array, const BLASUINT8 **a_array,

const BLASINT *lda_array, const BLASINT8 *oa_array, const BLASUINT8 **b_array, const BLASINT *ldb_array,

const BLASINT8 *ob_array, const float *beta, int32_t **c_array, const BLASINT *ldc_array, const int32_t **oc_array,

const BLASINT group_count, const BLASINT *group_size);

void cblas_gemm_batch_s8u8s32(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE *transA_array,

const CBLAS_TRANSPOSE *transB_array, const CBLAS_OFFSET *offsetc_array, const BLASINT *m_array,

const BLASINT *n_array, const BLASINT *k_array, const float *alpha_array, const BLASINT8 **a_array,

const BLASINT *lda_array, const BLASINT8 *oa_array, const BLASUINT8 **b_array, const BLASINT *ldb_array,

const BLASINT8 *ob_array, const float *beta, int32_t **c_array, const BLASINT *ldc_array, const int32_t **oc_array,

const BLASINT group_count, const BLASINT *group_size);

void cblas_gemm_batch_u8s8s32(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE *transA_array,

const CBLAS_TRANSPOSE *transB_array, const CBLAS_OFFSET *offsetc_array, const BLASINT *m_array,

const BLASINT *n_array, const BLASINT *k_array, const float *alpha_array, const BLASUINT8 **a_array,

const BLASINT *lda_array, const BLASINT8 *oa_array, const BLASINT8 **b_array, const BLASINT *ldb_array,

const BLASINT8 *ob_array, const float *beta, int32_t **c_array, const BLASINT *ldc_array, const int32_t **oc_array,

const BLASINT group_count, const BLASINT *group_size);

参数

参数名

类型

描述

输入/输出

layout

枚举类型CBLAS_ORDER

表示矩阵是行主序或列主序。

输入

transA_array

枚举类型CBLAS_TRANSPOSE

表示矩阵A转置情况的序列。

矩阵A为常规矩阵,转置矩阵或共轭矩阵。

  • 如果TransA = CblasNoTrans,
  • 如果TransA = CblasTrans,
  • 如果TransA = CblasConjTrans,
  • 如果TransA = CblasConjTrans,

输入

transB_array

枚举类型CBLAS_TRANSPOSE

表示矩阵B转置情况的序列。

矩阵B为常规矩阵,转置矩阵或共轭矩阵。

  • 如果TransB = CblasNoTrans,
  • 如果TransB = CblasTrans,
  • 如果TransB = CblasConjTrans,
  • 如果TransB = CblasConjTrans,

输入

offsetc_array

枚举类型CBLAS_OFFSET

表示参数oc(GEMM定义中的C_offset)为行向量/列向量/标量的序列。

  • 如果offsetc = CblasRowOffset,则代表C_offset是一个行向量。
  • 如果offsetc = CblasColOffset,则代表C_offset是一个列向量。
  • 如果offsetc = CblasFixOffset,则代表C_offset是一个标量。

输入

m_array

整型数

矩阵op(A)和矩阵C的行数的序列。

输入

n_array

整型数

矩阵op(B)和矩阵C的列数的序列。

输入

k_array

整型数

矩阵op(A)的列和矩阵op(B)的行数的序列。

输入

alpha_array

单精度浮点类型。

乘法系数的序列。

输入

a_array

  • 在*_s8*8s32中是int8_t类型。
  • 在*_u8*8s32中是uint8_t类型。

矩阵A的序列。

输入

lda_array

整型数

表示矩阵A的leading dimension的序列。

  • 矩阵为列存,TransA = CblasNoTrans,lda至少max(1, m),否则max(1, k)。
  • 矩阵为行存,TransA = CblasNoTrans,lda至少max(1, k),否则max(1, m)。

输入

oa_array

整型数

GEMM定义中的A_offset的值的序列

输入

b_array

  • 在*_*8s8s32中是int8_t类型。
  • 在*_*8u8s32中是uint8_t类型。

矩阵B的序列。

输入

ldb_array

整型数

表示矩阵A的leading dimension的序列。

  • 矩阵为列存,TransB = CblasNoTrans,ldb至少max(1, k),否则max(1, n)。
  • 矩阵为行存,TransB = CblasNoTrans,ldb至少max(1, n),否则max(1, k)。

输入

ob_array

整型数

GEMM定义中的B_offset的值的序列

输入

beta_array

  • 单精度浮点类型。

乘法系数的序列。

输入

c_array

整型数

矩阵C的序列。

输入/输出

ldc_array

整型数

表示矩阵C的leading dimension的序列。

矩阵为列存,ldc至少max(1, m),否则max(1, n)。

输入

oc_array

整型数

GEMM定义中的C_offset的值的序列

输入

group_count

整型数

指定group的数量。必须至少为0。

输入

group_size

整型数

大小为 group_count 的数组。元素 group_size[i] 指定了第 i 组中的矩阵数量。

输入

依赖

#include "kblas.h"

示例

int m = 2, k = 3, n = 2, lda = 2, ldb = 3, ldc = 2; 
    int group_count = 2;
    int group_size[2] = {1, 1};
    int m_array[2] = {2, 2};
    int k_array[2] = {3, 3};
    int n_array[2] = {2, 2};
    int lda_array[2] = {2, 2};
    int ldb_array[2] = {3, 3};
    int ldc_array[2] = {2, 2};
    CBLAS_TRANSPOSE transA_array[2] = {CblasNoTrans, CblasNoTrans};
    CBLAS_TRANSPOSE transB_array[2] = {CblasNoTrans, CblasNoTrans};
    float alpha_array[2] = {1.0, 1.0};
    float beta_array[2] = {1.0, 1.0}; 
    CBLAS_OFFSET offestc_array[2] = {CblasFixOffset, CblasFixOffset};
    int32_t oc_array[2] = {0, 0};
    int8_t oa_array[2] = {0, 0};
    int8_t ob_array[2] = {0, 0};
    int8_t a_data[12] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
    int8_t b_data[12] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
    int32_t c_data[8] = {1, 1, 1, 1, 1, 1, 1, 1};
    int8_t* a_array[2];
    int8_t* b_array[2];
    int32_t* c_array[2];
    a_array[0] = a_data;
    a_array[1] = a_data + m*k*group_size[0];
    b_array[0] = b_data;
    b_array[1] = b_data + k*n*group_size[0];
    c_array[0] = c_data;
    c_array[1] = c_data + m*n*group_size[0];
    cblas_gemm_batch_s8s8s32(CblasColMajor, transA_array, transB_array, offestc_array, m_array, n_array, k_array, alpha_array, a_array,
                      lda_array, oa_array, b_array, ldb_array, ob_array, beta_array, c_array, ldc_array, oc_array, group_count, group_size);