kutacc_af2_rigid_rot_vec_mul
igid_rot_vec_mul实现矩阵乘向量,输入为矩阵和向量,输出为向量。
接口定义
void kutacc_af2_rigid_rot_vec_mul(kutacc_tensor_h pts, kutacc_tensor_h rot_mats, kutacc_tensor_h out, kutacc_tensor_h trans);
参数
参数名 |
类型 |
描述 |
输入/输出 |
tensor shape |
|---|---|---|---|---|
pts |
kutacc_tensor_h |
输入数据 |
输入 |
[..., 3] |
rot_mats |
kutacc_tensor_h |
旋转矩阵 |
输入 |
[..., 3, 3] |
out |
kutacc_tensor_h |
输出数据 |
输出 |
[..., 3] |
trans |
kutacc_tensor_h |
平移向量 |
输入 |
[..., 3] |
kutacc_af2_rigid_rot_vec_mul参数的约束关系: dim = pts.dim - 1; pts.strides[dim] = 1, rot_mats.strides[dim] = 3, rot_mats.strides[dim + 1] = 1 trans.strides[dim] = 1
父主题: AlphaFold2