鲲鹏社区首页
中文
注册
开发者
我要评分
获取效率
正确性
完整性
易理解
在线提单
论坛求助

kutacc_af2_rigid_rot_matmul

rigid_rot_mat_mul实现矩阵乘矩阵,输入为两个矩阵,输出为矩阵。

接口定义

void kutacc_af2_rigid_rot_matmul(kutacc_tensor_h a, kutacc_tensor_h b, kutacc_tensor_h out);

参数

参数名

类型

描述

输入/输出

tensor shape

a

kutacc_tensor_h

第一个旋转矩阵

输入

[..., 3, 3]

b

kutacc_tensor_h

第二个旋转矩阵

输入

[..., 3, 3]

out

kutacc_tensor_h

输出数据

输出

[..., 3, 3]

kutacc_af2_rigid_rot_matmul参数的约束关系:dim = a.dim - 2; a.strides[dim] = 3, a.strides[dim +1] = 1, b.strides[dim] = 3 b.strides[dim + 1] = 1

示例

rigid算子均用在invariant_point_attention算子内部,具体用例应参考8.1.2.2.2.1.1.9环节的invariant_point_attention算子的示例