鲲鹏社区首页
中文
注册
开发者
我要评分
获取效率
正确性
完整性
易理解
在线提单
论坛求助

kutacc_af2_outer_product_mean_calc_left_and_right_mul

outer_product_mean中用于计算左右投影的计算函数

接口定义

void kutacc_af2_outer_product_mean_calc_left_and_right_mul(kutacc_af2_opm_act_inputs_t *opm_acts_ptr, kutacc_af2_opm_mask_inputs_t *opm_masks_ptr, kutacc_af2_opm_weights_t *opm_weights_ptr);

参数

表1 入参定义

参数名

类型

描述

输入/输出

opm_acts_ptr

kutacc_af2_opm_act_inputs_t *

kutacc_af2_opm_act_inputs_t类型的指针,具体数据结构定义见下方表2 kutacc_af2_opm_act_inputs_t 数据结构表定义

输入

opm_masks_ptr

kutacc_af2_opm_mask_inputs_t *

kutacc_af2_opm_mask_inputs_t类型的指针,具体数据结构定义见下方表3

kutacc_af2_opm_mask_inputs_t 数据结构表定义

输入

opm_weights_ptr

kutacc_af2_opm_weights_t *

kutacc_af2_opm_weights_t类型的指针,具体数据结构定义见下方表4

kutacc_af2_opm_weights_t数据结构表定义

输入

表2 kutacc_af2_opm_act_inputs_t 数据结构表定义

参数名

类型

描述

输入/输出

n_seq

int64_t

序列数量

输入

n_res

int64_t

残基数量

输入

input_act

kutacc_tensor_h

输入激活张量

输入

left_proj

kutacc_tensor_h

左投影

输入

right_proj

kutacc_tensor_h

右投影

输入

left_proj_

kutacc_tensor_h

经过掩码处理后的左投影

输入

right_proj_

kutacc_tensor_h

经过掩码处理后的右投影

输入

表3 kutacc_af2_opm_mask_inputs_t 数据结构表定义

参数名

类型

描述

输入/输出

n_res_gather

int64_t

int64_t 聚合后的残基数量

输入

mask_bias

int64_t

掩码张量地址偏移量

输入

mask

kutacc_tensor_h

掩码张量

输入

norm

kutacc_tensor_h

归一化因子张量

输入

表4 kutacc_af2_opm_weights_t 数据结构表定义

参数名

类型

描述

输入/输出

c_m

int64_t

输入特征维度

输入

c_i

int64_t

投影后的特征维度

输入

c_z

int64_t

输出特征维度

输入

left_proj_w

kutacc_tensor_h

左投影权重

输入

left_proj_b

kutacc_tensor_h

左投影偏移量

输入

right_proj_w

kutacc_tensor_h

右投影权重

输入

right_proj_b

kutacc_tensor_h

右投影偏移量

输入

outer_w

kutacc_tensor_h

输出权重

输入

outer_b

kutacc_tensor_h

输出偏移量

输入

outer_product_mean整数参数应满足的约束关系:

mask_bias >= 0,

c_i, c_m, n_res, n_res_gather, n_seq > 0,

n_seq * n_res <INT64_MAX

单进程时n_res = n_res_gather 多进程下不满足该条件