鲲鹏社区首页
中文
注册
开发者
我要评分
获取效率
正确性
完整性
易理解
在线提单
论坛求助

kutacc_af2_triangle_multiplication_gate_and_out_linear

triangle_multiplication的中间步骤,基于input_act通过线性变化生成gate及基于center_act线性变化生成out

接口定义

void kutacc_af2_triangle_multiplication_gate_and_out_linear(kutacc_tensor_h gate, kutacc_tensor_h out, kutacc_af2_tm_act_inputs_t *tm_acts_ptr, kutacc_tensor_h center_act,

kutacc_af2_tm_linear_weights_t *tm_weights_ptr, bool input_prepack);

参数

表1 入参定义

参数名

类型

描述

输入/输出

gate

kutacc_tensor_h

门控张量

输入

out

kutacc_tensor_h

输出张量

输出

tm_acts_ptr

kutacc_af2_tm_act_inputs_t *

kutacc_af2_tm_act_inputs_t类型的指针,具体数据结构定义见表2 kutacc_af2_tm_act_inputs_t数据结构表定义

输入

center_act

kutacc_tensor_h

equation步骤获得center_act经过permute及layernorm变化后获得的输入

输入

tm_weights_ptr

kutacc_af2_tm_linear_weights_t *

kutacc_af2_tm_linear_weights_t类型的指针,具体数据结构定义见表3 kutacc_af2_tm_linear_weights_t数据结构表定义

输入

input_prepack

bool

是否prepack

输入

表2 kutacc_af2_tm_act_inputs_t数据结构表定义

参数名

类型

描述

输入/输出

n_res

int64_t

残基数

输入

n_res_gather

int64_t

残基数

输入

proj_act

kutacc_tensor_h

act的左右投影

输出

input_act

kutacc_tensor_h

输入值act经过layernorm归一化后的值

输入

proj_act_gate

kutacc_tensor_h

临时门控张量

输入

表3 kutacc_af2_tm_linear_weights_t数据结构表定义

参数名

类型

描述

输入/输出

c_o

int64_t

输出特征维度

输入

c_i

int64_t

输入特征维度

输入

gating_w

kutacc_tensor_h

生成gate用的权重

输出

gating_b

kutacc_tensor_h

生成gate用的偏置

输入

output_proj_w

kutacc_tensor_h

生成out用的权重

输入

output_proj_b

kutacc_tensor_h

生成out用的偏置

输入

triangle_multiplication的整数参数应满足的约束关系:

n_res, n_res_gather, c_o, c_i, c_z > 0,

n_res * n_res_gather < INT64_MAX,

c_z = c_o, c_i = c_o;

单进程情况下需要注意n_res = n_res_gather,多进程则不需要满足该条件