kutacc_af2_triangle_multiplication_calc_proj
用于计算triangle_multiplication过程中生成的中间值:左右投影
接口定义
void kutacc_af2_triangle_multiplication_calc_proj(kutacc_af2_tm_act_inputs_t *tm_acts_ptr, kutacc_tensor_h mask, kutacc_af2_tm_proj_weights_t *tm_weights_ptr, bool input_prepack);
参数
表1 入参定义
参数名 |
类型 |
描述 |
输入/输出 |
|---|---|---|---|
tm_acts_ptr |
kutacc_af2_tm_act_inputs_t * |
kutacc_af2_tm_act_inputs_t类型的指针,具体数据结构定义见下方表2 kutacc_af2_tm_act_inputs_t数据结构表定义 |
输入 |
mask |
kutacc_tensor_h |
掩码张量 |
输入 |
tm_weights_ptr |
kutacc_af2_tm_proj_weights_t * |
kutacc_af2_tm_proj_weights_t类型的指针,具体数据结构定义见下方表3 kutacc_af2_tm_proj_weights_t数据结构表定义 |
输入 |
input_prepack |
bool |
是否对输入张量进行prepack操作 |
输入 |
表2 kutacc_af2_tm_act_inputs_t数据结构表定义
参数名 |
类型 |
描述 |
输入/输出 |
|---|---|---|---|
n_res |
int64_t |
残基数 |
输入 |
n_res_gather |
int64_t |
残基数 |
输入 |
proj_act |
kutacc_tensor_h |
act的左右投影 |
输出 |
input_act |
kutacc_tensor_h |
输入值act经过layernorm归一化后的值 |
输入 |
proj_act_gate |
kutacc_tensor_h |
临时门控张量 |
输入 |
表3 kutacc_af2_tm_proj_weights_t数据结构表定义
参数名 |
类型 |
描述 |
输入/输出 |
|---|---|---|---|
c_o |
int64_t |
输入特征维度 |
输入 |
c_i |
int64_t |
输出特征维度 |
输入 |
proj_w |
kutacc_tensor_h |
生成投影所需的权重 |
输入 |
proj_b |
kutacc_tensor_h |
生成投影所需的偏置 |
输入 |
gate_w |
kutacc_tensor_h |
门控权重 |
输入 |
gate_b |
kutacc_tensor_h |
门控偏置 |
输入 |
triangle_multiplication的整数参数应满足的约束关系:
n_res, n_res_gather, c_i, c_o > 0,
n_res * n_res_gather < INT64_MAX,
c_z = c_o, c_i = c_o,
单进程情况下需要注意n_res = n_res_gather,多进程则不需要满足该条件