?gemm_compute
对pack后的A\B矩阵进行一般矩阵乘。
即:
。
op(X)可取值:
,alpha,beta为乘法系数,op(A)为m*k矩阵,op(B)为k*n矩阵,C为m*n矩阵。
接口定义
C interface:
void cblas_sgemm_compute(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE transA, const enum CBLAS_TRANSPOSE transB, const BLASINT m, const BLASINT n, const BLASINT k, const float alpha, const float *a, const BLASINT lda, const float *b, const BLASINT ldb, const float beta, float *c, const BLASINT ldc);
void cblas_dgemm_compute(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE transA, const enum CBLAS_TRANSPOSE transB, const BLASINT m, const BLASINT n, const BLASINT k, const double alpha, const double *a, const BLASINT lda, const double *b, const BLASINT ldb, const double beta, double *c, const BLASINT ldc);
void cblas_cgemm_compute(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_TRANSPOSE TransB, const BLASINT m, const BLASINT n, const BLASINT k, const void *alpha, const void *a, const BLASINT lda, const void *b, const BLASINT ldb, const void *beta, void *c, const BLASINT ldc);
void cblas_zgemm_compute(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_TRANSPOSE TransB, const BLASINT m, const BLASINT n, const BLASINT k, const void *alpha, const void *a, const BLASINT lda, const void *b, const BLASINT ldb, const void *beta, void *c, const BLASINT ldc);
void cblas_bgemm_compute(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE transA, const enum CBLAS_TRANSPOSE transB, const BLASINT m, const BLASINT n, const BLASINT k, const __bf16 alpha, const __bf16 *a, const BLASINT lda, const __bf16 *b, const BLASINT ldb, const __bf16 beta, __bf16 *c, const BLASINT ldc);
参数
参数名 |
类型 |
描述 |
输入/输出 |
|---|---|---|---|
order |
枚举类型CBLAS_ORDER |
表示矩阵是行主序或列主序。 |
输入 |
TransA |
枚举类型CBLAS_TRANSPOSE |
矩阵A为常规矩阵,转置矩阵或共轭矩阵。
|
输入 |
TransB |
枚举类型CBLAS_TRANSPOSE |
矩阵B为常规矩阵,转置矩阵或共轭矩阵。
|
输入 |
M |
整型数 |
矩阵op(A)和矩阵C的行。 |
输入 |
N |
整型数 |
矩阵op(B)和矩阵C的列。 |
输入 |
K |
整型数 |
矩阵op(A)的列和矩阵op(B)的行。 |
输入 |
alpha |
|
乘法系数。 |
输入 |
A |
|
矩阵A。 |
输入 |
lda |
整型数 |
|
输入 |
B |
|
矩阵B。 |
输入 |
ldb |
整型数 |
|
输入 |
beta |
|
乘法系数。 |
输入 |
C |
|
矩阵C。 |
输入/输出 |
ldc |
整型数 |
矩阵为列存,ldc至少max(1, m),否则max(1, n)。 |
输入 |
依赖
#include "kblas.h"
示例
C interface:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 | int m = 4, k = 3, n = 4, lda = 4, ldb = 3, ldc = 4; float alpha = 1.0, beta = 2.0; /* * A: * 0.340188, 0.411647, -0.222225, * -0.105617, -0.302449, 0.053970, * 0.283099, -0.164777, -0.022603, * 0.298440, 0.268230, 0.128871, * B: * -0.135216, 0.416195, -0.358397, -0.257113, * 0.013401, 0.135712, 0.106969, -0.362768, * 0.452230, 0.217297, -0.483699, 0.304177, * C: * -0.343321, 0.498924, 0.112640, -0.006417, * -0.099056, -0.281743, -0.203968, 0.472775, * -0.370210, 0.012932, 0.137552, -0.207483, * -0.391191, 0.339112, 0.024287, 0.271358, */ float a[12] = {0.340188, -0.105617, 0.283099, 0.298440, 0.411647, -0.302449, -0.164777, 0.268230, -0.222225, 0.053970, -0.022603, 0.128871}; float b[12] = {-0.135216, 0.013401, 0.452230, 0.416195, 0.135712, 0.217297, -0.358397, 0.106969, -0.483699, -0.257113, -0.362768, 0.304177}; float c[16] = {-0.343321, -0.099056, -0.370210, -0.391191, 0.498924, -0.281743, 0.012932, 0.339112, 0.112640, -0.203968, 0.137552, 0.024287, -0.006417, 0.472775, -0.207483, 0.271358}; size_t size_a = cblas_sgemm_pack_get_size(CblasA, m, n, k); size_t size_b = cblas_sgemm_pack_get_size(CblasB, m, n, k); float *sa = (float *)malloc(size_a); float *sb = (float *)malloc(size_b); cblas_sgemm_pack(CblasColMajor, CblasA, CblasNoTrans, m, n, k, a, ld, sa); cblas_sgemm_pack(CblasColMajor, CblasB, CblasNoTrans, m, n, k, b, ld, sb); cblas_sgemm_compute(CblasColMajor,CblasNoTrans,CblasNoTrans, m, n, k, alpha, sa, lda, sb, ldb, beta, c, ldc); free(sa); free(sb); /* * Output C: * -0.827621 1.147010 0.254881 -0.317229 * -0.163476 -0.636762 -0.428542 1.098841 * -0.791128 0.116416 0.166949 -0.434854 * -0.760862 0.866839 -0.092028 0.407877 * */ |



