kutacc_af2_outer_product_mean_calc_left_and_right_mul
outer_product_mean中用于计算左右投影的计算函数
接口定义
void kutacc_af2_outer_product_mean_calc_left_and_right_mul(kutacc_af2_opm_act_inputs_t *opm_acts_ptr, kutacc_af2_opm_mask_inputs_t *opm_masks_ptr, kutacc_af2_opm_weights_t *opm_weights_ptr);
参数
表1 入参定义
参数名 |
类型 |
描述 |
输入/输出 |
|---|---|---|---|
opm_acts_ptr |
kutacc_af2_opm_act_inputs_t * |
kutacc_af2_opm_act_inputs_t类型的指针,具体数据结构定义见下方表2 kutacc_af2_opm_act_inputs_t 数据结构表定义 |
输入 |
opm_masks_ptr |
kutacc_af2_opm_mask_inputs_t * |
kutacc_af2_opm_mask_inputs_t类型的指针,具体数据结构定义见下方表3 kutacc_af2_opm_mask_inputs_t 数据结构表定义 |
输入 |
opm_weights_ptr |
kutacc_af2_opm_weights_t * |
kutacc_af2_opm_weights_t类型的指针,具体数据结构定义见下方表4 kutacc_af2_opm_weights_t数据结构表定义 |
输入 |
表2 kutacc_af2_opm_act_inputs_t 数据结构表定义
参数名 |
类型 |
描述 |
输入/输出 |
|---|---|---|---|
n_seq |
int64_t |
序列数量 |
输入 |
n_res |
int64_t |
残基数量 |
输入 |
input_act |
kutacc_tensor_h |
输入激活张量 |
输入 |
left_proj |
kutacc_tensor_h |
左投影 |
输入 |
right_proj |
kutacc_tensor_h |
右投影 |
输入 |
left_proj_ |
kutacc_tensor_h |
经过掩码处理后的左投影 |
输入 |
right_proj_ |
kutacc_tensor_h |
经过掩码处理后的右投影 |
输入 |
表3 kutacc_af2_opm_mask_inputs_t 数据结构表定义
参数名 |
类型 |
描述 |
输入/输出 |
|---|---|---|---|
n_res_gather |
int64_t |
int64_t 聚合后的残基数量 |
输入 |
mask_bias |
int64_t |
掩码张量地址偏移量 |
输入 |
mask |
kutacc_tensor_h |
掩码张量 |
输入 |
norm |
kutacc_tensor_h |
归一化因子张量 |
输入 |
表4 kutacc_af2_opm_weights_t 数据结构表定义
参数名 |
类型 |
描述 |
输入/输出 |
|---|---|---|---|
c_m |
int64_t |
输入特征维度 |
输入 |
c_i |
int64_t |
投影后的特征维度 |
输入 |
c_z |
int64_t |
输出特征维度 |
输入 |
left_proj_w |
kutacc_tensor_h |
左投影权重 |
输入 |
left_proj_b |
kutacc_tensor_h |
左投影偏移量 |
输入 |
right_proj_w |
kutacc_tensor_h |
右投影权重 |
输入 |
right_proj_b |
kutacc_tensor_h |
右投影偏移量 |
输入 |
outer_w |
kutacc_tensor_h |
输出权重 |
输入 |
outer_b |
kutacc_tensor_h |
输出偏移量 |
输入 |
outer_product_mean整数参数应满足的约束关系:
mask_bias >= 0,
c_i, c_m, n_res, n_res_gather, n_seq > 0,
n_seq * n_res <INT64_MAX
单进程时n_res = n_res_gather 多进程下不满足该条件