kutacc_af2_triangle_multiplication_gate_and_out_linear
triangle_multiplication的中间步骤,基于input_act通过线性变化生成gate及基于center_act线性变化生成out
接口定义
void kutacc_af2_triangle_multiplication_gate_and_out_linear(kutacc_tensor_h gate, kutacc_tensor_h out, kutacc_af2_tm_act_inputs_t *tm_acts_ptr, kutacc_tensor_h center_act,
kutacc_af2_tm_linear_weights_t *tm_weights_ptr, bool input_prepack);
参数
表1 入参定义
参数名 |
类型 |
描述 |
输入/输出 |
|---|---|---|---|
gate |
kutacc_tensor_h |
门控张量 |
输入 |
out |
kutacc_tensor_h |
输出张量 |
输出 |
tm_acts_ptr |
kutacc_af2_tm_act_inputs_t * |
kutacc_af2_tm_act_inputs_t类型的指针,具体数据结构定义见表2 kutacc_af2_tm_act_inputs_t数据结构表定义 |
输入 |
center_act |
kutacc_tensor_h |
equation步骤获得center_act经过permute及layernorm变化后获得的输入 |
输入 |
tm_weights_ptr |
kutacc_af2_tm_linear_weights_t * |
kutacc_af2_tm_linear_weights_t类型的指针,具体数据结构定义见表3 kutacc_af2_tm_linear_weights_t数据结构表定义 |
输入 |
input_prepack |
bool |
是否prepack |
输入 |
表2 kutacc_af2_tm_act_inputs_t数据结构表定义
参数名 |
类型 |
描述 |
输入/输出 |
|---|---|---|---|
n_res |
int64_t |
残基数 |
输入 |
n_res_gather |
int64_t |
残基数 |
输入 |
proj_act |
kutacc_tensor_h |
act的左右投影 |
输出 |
input_act |
kutacc_tensor_h |
输入值act经过layernorm归一化后的值 |
输入 |
proj_act_gate |
kutacc_tensor_h |
临时门控张量 |
输入 |
表3 kutacc_af2_tm_linear_weights_t数据结构表定义
参数名 |
类型 |
描述 |
输入/输出 |
|---|---|---|---|
c_o |
int64_t |
输出特征维度 |
输入 |
c_i |
int64_t |
输入特征维度 |
输入 |
gating_w |
kutacc_tensor_h |
生成gate用的权重 |
输出 |
gating_b |
kutacc_tensor_h |
生成gate用的偏置 |
输入 |
output_proj_w |
kutacc_tensor_h |
生成out用的权重 |
输入 |
output_proj_b |
kutacc_tensor_h |
生成out用的偏置 |
输入 |
triangle_multiplication的整数参数应满足的约束关系:
n_res, n_res_gather, c_o, c_i, c_z > 0,
n_res * n_res_gather < INT64_MAX,
c_z = c_o, c_i = c_o;
单进程情况下需要注意n_res = n_res_gather,多进程则不需要满足该条件