鲲鹏社区首页
中文
注册
开发者
我要评分
获取效率
正确性
完整性
易理解
在线提单
论坛求助

kutacc_af2_triangle_multiplication_calc_proj

用于计算triangle_multiplication过程中生成的中间值:左右投影

接口定义

void kutacc_af2_triangle_multiplication_calc_proj(kutacc_af2_tm_act_inputs_t *tm_acts_ptr, kutacc_tensor_h mask, kutacc_af2_tm_proj_weights_t *tm_weights_ptr, bool input_prepack);

参数

表1 入参定义

参数名

类型

描述

输入/输出

tm_acts_ptr

kutacc_af2_tm_act_inputs_t *

kutacc_af2_tm_act_inputs_t类型的指针,具体数据结构定义见下方表2 kutacc_af2_tm_act_inputs_t数据结构表定义

输入

mask

kutacc_tensor_h

掩码张量

输入

tm_weights_ptr

kutacc_af2_tm_proj_weights_t *

kutacc_af2_tm_proj_weights_t类型的指针,具体数据结构定义见下方表3 kutacc_af2_tm_proj_weights_t数据结构表定义

输入

input_prepack

bool

是否对输入张量进行prepack操作

输入

表2 kutacc_af2_tm_act_inputs_t数据结构表定义

参数名

类型

描述

输入/输出

n_res

int64_t

残基数

输入

n_res_gather

int64_t

残基数

输入

proj_act

kutacc_tensor_h

act的左右投影

输出

input_act

kutacc_tensor_h

输入值act经过layernorm归一化后的值

输入

proj_act_gate

kutacc_tensor_h

临时门控张量

输入

表3 kutacc_af2_tm_proj_weights_t数据结构表定义

参数名

类型

描述

输入/输出

c_o

int64_t

输入特征维度

输入

c_i

int64_t

输出特征维度

输入

proj_w

kutacc_tensor_h

生成投影所需的权重

输入

proj_b

kutacc_tensor_h

生成投影所需的偏置

输入

gate_w

kutacc_tensor_h

门控权重

输入

gate_b

kutacc_tensor_h

门控偏置

输入

triangle_multiplication的整数参数应满足的约束关系:

n_res, n_res_gather, c_i, c_o > 0,

n_res * n_res_gather < INT64_MAX,

c_z = c_o, c_i = c_o,

单进程情况下需要注意n_res = n_res_gather,多进程则不需要满足该条件