kutacc_af2_invariant_point
利用注意力机制来保持特征点的不变性,从而提高模型在蛋白质结构预测中的性能。通过对输入特征进行加权聚合,模块能够有效地捕获特征之间的关系,并精确地调整蛋白质原子的坐标,从而提高结构预测的准确性和稳定性。
接口定义
void kutacc_af2_invariant_point(kutacc_af2_ipa_s_inputs_t *ipa_s_ptrs, kutacc_af2_ipa_o_inputs_t *ipa_o_ptrs, kutacc_tensor_h z, kutacc_tensor_h rigid_rot_mats, kutacc_tensor_h rigid_trans, kutacc_tensor_h mask, kutacc_af2_ipa_weights_t *ipa_weight_ptrs);
参数
表1 入参说明
参数名 |
类型 |
描述 |
输入/输出 |
|---|---|---|---|
ipa_s_ptrs |
kutacc_af2_ipa_s_inputs_t * |
kutacc_af2_ipa_s_inputs_t 类型的指针,具体定义见表2 kutacc_af2_ipa_s_inputs_t数据结构表定义 |
输入 |
ipa_o_ptrs |
kutacc_af2_ipa_o_inputs_t * |
kutacc_af2_ipa_o_inputs_t类型的指针,具体定义见表3 kutacc_af2_ipa_o_inputs_t数据结构表定义 |
输入 |
z |
kutacc_tensor_h |
中间变量 |
输入 |
rigd_rot_mats |
kutacc_tensor_h |
刚体变换:旋转参数 |
输入 |
rigid_trans |
kutacc_tensor_h |
刚体变换:平移参数 |
输入 |
mask |
kutacc_tensor_h |
掩码,用于忽略无效残基 |
输入 |
ipa_weight_ptrs |
kutacc_af2_ipa_weights_t * |
kutacc_af2_ipa_weights_t类型的指针,具体定义见表4 kutacc_af2_ipa_weights_t数据结构表定义 |
输入 |
表2 kutacc_af2_ipa_s_inputs_t数据结构表定义
参数名 |
类型 |
描述 |
输入/输出 |
|---|---|---|---|
n_res |
int64_t |
输入向量s的第0维长度 |
输入 |
a |
kutacc_tensor_h |
用于保存q转置与k的矩阵乘,进行softmax后的临时结果 |
输入 |
b |
kutacc_tensor_h |
用于保存z向量进行线性变换后的临时结果 |
输入 |
q |
kutacc_tensor_h |
输入s经过线性变化生成的Q矩阵 |
输入 |
k |
kutacc_tensor_h |
K矩阵 |
输入 |
v |
kutacc_tensor_h |
V矩阵 |
输入 |
q_pts |
kutacc_tensor_h |
查询矩阵的3D点 |
输入 |
k_pts |
kutacc_tensor_h |
键矩阵的3D点 |
输入 |
v_pts |
kutacc_tensor_h |
值矩阵的3D点 |
输入 |
表3 kutacc_af2_ipa_o_inputs_t数据结构表定义
参数名 |
类型 |
描述 |
输入/输出 |
|---|---|---|---|
o |
kutacc_tensor_h |
用于保存qkv,attention计算结果的向量 |
输入 |
o_pt |
kutacc_tensor_h |
用于保存pts相关向量attention计算结果的向量 |
输入 |
o_pt_norm |
kutacc_tensor_h |
用于保存需要进行layernorm变换的向量 |
输入 |
o_pair |
kutacc_tensor_h |
残基对特征 |
输入 |
表4 kutacc_af2_ipa_weights_t数据结构表定义
参数名 |
类型 |
描述 |
输入/输出 |
|---|---|---|---|
c_z |
int64_t |
单一single的长度 |
输入 |
c_hidden |
int64_t |
隐藏层长度 |
输入 |
no_heads |
int64_t |
attention头数量 |
输入 |
no_qk_points |
int64_t |
需要生成的q/k数量 |
输入 |
no_v_points |
int64_t |
需要生成的v数量 |
输入 |
head_weights |
kutacc_tensor_h |
注意力头的权重 |
输入 |
weights_head_weights |
kutacc_tensor_h |
注意力头的权重 |
输入 |
linear_b_w |
kutacc_tensor_h |
对b进行线性变化时的权重 |
输入 |
linear_b_b |
kutacc_tensor_h |
对b进行线性变化时的偏置 |
输入 |
整数参数应满足的约束关系: n_res, c_z, c_hidden, no_heads, no_qk_points, no_v_points > 0;
no_heads * c_hidden < INT64_MAX; no_heads * no_v_points * 3 <INT64_MAX;
no_heads*(c_hidden+no_v_points*3) <INT64_MAX; no_heads*(c_hidden+no_v_points*4) <INT64_MAX; no_heads * c_z < INT64_MAX;
no_heads * c_hidden * 2 * c_s < INT64_MAX; no_heads * 3 *no_qk_points * c_s < INT64_MAX;
3 * no_heads * (no_qk_points+ no_v_points) * c_s <INT64_MAX;
no_heads * 2 * c_hidden = c_s
no_qk_points * 3 < 16
构建KPEX层所需tensor参数的约束如下:
tensor |
shape |
描述 |
|---|---|---|
s |
[n_res, c_s] |
输入张量 |
z |
[n_res, n_res, c_z] |
输入张量,见入参说明参数z |
rigid_trans |
[n_res, 3] |
输入张量,见入参说明参数rigid_trans |
rigid_rot_mats |
[n_res, 3, 3] |
输入张量,见入参说明rigid_rot_mats |
mask |
[n_res] |
输入张量,见入参说明参数mask |
linear_q_w |
[no_heads * c_hidden, c_s] |
将s通过线性变化生成q的权重 |
linear_q_b |
[no_heads * c_hidden] |
将s通过线性变化生成q的偏置 |
linear_kv_w |
[no_heads * 2 * c_hidden, c_s] |
将s通过线性变化生成k, v的权重 |
linear_kv_b |
[no_heads * 2 * c_hidden] |
将s通过线性变化生成k, v的偏置 |
linear_q_points_w |
[no_heads * no_qk_points * 3, c_s] |
将s通过线性变化生成q_pts的权重 |
linear_q_points_b |
[no_heads * no_qk_points *3] |
将s通过线性变化生成q_pts的偏置 |
linear_kv_points_w |
[3 * no_heads * (no_qk_points + no_v_points), c_s] |
将s通过线性变化生成k_pts, v_pts的权重 |
linear_kv_points_b |
[3 * no_heads * (no_qk_points + no_v_points)] |
将s通过线性变化生成k_pts, v_pts的偏置 |
linear_b_w |
[no_heads, c_z] |
见表4 kutacc_af2_ipa_weights_t数据结构表定义参数linear_b_w |
linear_b_b |
[no_heads] |
见表4 kutacc_af2_ipa_weights_t数据结构表定义参数linear_b_b |
head_weights |
[no_heads] |
见表4 kutacc_af2_ipa_weights_t数据结构表定义参数head_weights |
linear_out_w |
[c_s, no_heads * (c_hidden + no_v_points * 4 + c_z)] |
见表4 kutacc_af2_ipa_weights_t数据结构表定义参数linear_out_w |
linear_out_b |
[c_s] |
见表4 kutacc_af2_ipa_weights_t数据结构表定义参数linear_out_b |
示例
C++ interface:
//test_invariant.h
#ifndef KPEX_TPP_ALPHAFOLD_TEST_INVARIAN_H
#define KPEX_TPP_ALPHAFOLD_TEST_INVARIAN_H
#include <ATen/core/Tensor.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/ones.h>
#include <ATen/ops/zeros.h>
#include <ATen/ops/full.h>
#include <ATen/native/cpu/utils.h>
#include <c10/core/ScalarType.h>
namespace alphafold {
at::Tensor test_invariant_point_attention(int64_t n_res, int64_t no_heads, int64_t c_hidden, int64_t no_qk_points, int64_t no_v_points, int64_t c_z, int64_t c_s);
}
#endif
//bind.h
#include "test_invariant_point.h"
namespace alphafold {
inline void bind(pybind11::module &m)
{
auto submodule = m.def_submodule("alphafold");
submodule.def("test_invariant_point_attention", &test_invariant_point_attention, py::arg("n_res"), py::arg("no_heads"), py::arg("c_hidden"), py::arg("no_qk_points"),
py::arg("no_v_points"), py::arg("c_z"), py::arg("c_s"));
}
}
//test_invariant.cpp
#include "test_invariant_point.h"
#include "utils/linear.h"
#include "rigid.h"
#include "kutacc.h"
#include "utils/memory.h"
#include "invariant_point.h"
namespace alphafold {
at::Tensor test_invariant_point_attention(int64_t n_res, int64_t no_heads, int64_t c_hidden, int64_t no_qk_points, int64_t no_v_points, int64_t c_z, int64_t c_s)
{
float a = 0.6f;
float b = 0.25f;
float c = 0.1f;
float d = 0.2f;
float e = 0.5f;
at::Tensor s = at::full({n_res, c_s}, a, at::TensorOptions().device(kpex::device()).dtype(at::kBFloat16));
at::Tensor out = at::empty(s.sizes(), s.options());
at::Tensor rigid_trans = at::ones({n_res, 1, 1, 3}, s.options()).to(at::kFloat);
at::Tensor rigid_rot_mats = at::ones({n_res, 1, 1, 3, 3}, s.options()).to(at::kFloat);
at::Tensor linear_q_w = at::full({no_heads, c_hidden, c_s}, b, s.options());
at::Tensor linear_q_b = at::full({no_heads, c_hidden}, c, rigid_rot_mats.options());
at::Tensor linear_k_w = at::ones({no_heads, c_hidden, c_s}, s.options());
at::Tensor linear_k_b = at::zeros({no_heads, c_hidden}, rigid_rot_mats.options());
at::Tensor linear_v_w = at::ones({no_heads, c_hidden, c_s}, s.options());
at::Tensor linear_v_b = at::zeros({no_heads, c_hidden}, rigid_rot_mats.options());
at::Tensor linear_q_points_w = at::full({no_heads, no_qk_points, 3, c_s}, c, s.options());
at::Tensor linear_q_points_b = at::ones({no_heads, no_qk_points, 3}, rigid_rot_mats.options());
at::Tensor linear_k_points_w = at::full({no_heads, no_qk_points, 3, c_s}, d, s.options());
at::Tensor linear_k_points_b = at::full({no_heads, no_qk_points, 3}, d, rigid_rot_mats.options());
at::Tensor linear_v_points_w = at::ones({no_heads, no_v_points, 3, c_s}, s.options());
at::Tensor linear_v_points_b = at::zeros({no_heads, no_v_points, 3}, rigid_rot_mats.options());
at::Tensor linear_b_w = at::ones({no_heads, c_z}, s.options());
at::Tensor linear_b_b = at::ones({no_heads}, rigid_rot_mats.options());
at::Tensor head_weights = at::ones({no_heads}, rigid_rot_mats.options());
at::Tensor linear_out_w = at::ones({c_s, no_heads * (c_hidden + no_v_points * 4 + c_z)}, s.options());
at::Tensor linear_out_b = at::full({c_s}, e, rigid_rot_mats.options());
at::Tensor mask = at::ones({n_res}, s.options());
at::Tensor z = at::ones({n_res, n_res, c_z}, s.options());
at::Tensor q = linear(s, linear_q_w, linear_q_b);
at::Tensor k = linear(s, linear_k_w, linear_k_b);
at::Tensor v = linear(s, linear_v_w, linear_v_b);
at::Tensor q_pts = linear(s, linear_q_points_w, linear_q_points_b);
q_pts = rigid_rot_vec_mul(q_pts, rigid_rot_mats, rigid_trans);
at::Tensor k_pts = linear(s, linear_k_points_w, linear_k_points_b);
k_pts = rigid_rot_vec_mul(k_pts, rigid_rot_mats, rigid_trans);
at::Tensor v_pts = linear(s, linear_v_points_w, linear_v_points_b);
v_pts = rigid_rot_vec_mul(v_pts, rigid_rot_mats, rigid_trans);
v_pts = v_pts.permute({1, 2, 3, 0}).contiguous();
at::Tensor m = at::empty({no_heads, n_res, n_res}, s.options()); // b
at::Tensor n = at::empty({no_heads, n_res, n_res}, q.options()); // a
at::Tensor head_weights_2 = at::empty(head_weights.sizes(), head_weights.options());
at::Tensor collect = at::empty({n_res, no_heads * (c_hidden + no_v_points * 4 + c_z)}, s.options());
at::Tensor o = collect.narrow(1, 0, no_heads * c_hidden).view({n_res, no_heads, c_hidden});
at::Tensor o_pt = collect.narrow(1, no_heads * c_hidden, no_heads * no_v_points * 3).view({n_res, 3, no_heads, no_v_points});
at::Tensor o_pt_norm = collect.narrow(1, no_heads * (c_hidden + no_v_points * 3), no_heads * no_v_points).view({n_res, no_heads, no_v_points});
at::Tensor o_pair = collect.narrow(1, no_heads * (c_hidden + no_v_points * 4), no_heads * c_z).view({n_res, no_heads, c_z});
auto q_tw = convert_to_tensor_wrapper(q);
auto k_tw = convert_to_tensor_wrapper(k);
auto v_tw = convert_to_tensor_wrapper(v);
auto q_pts_tw = convert_to_tensor_wrapper(q_pts);
auto k_pts_tw = convert_to_tensor_wrapper(k_pts);
auto v_pts_tw = convert_to_tensor_wrapper(v_pts);
auto m_tw = convert_to_tensor_wrapper(m);
auto n_tw = convert_to_tensor_wrapper(n);
auto head_weights_tw = convert_to_tensor_wrapper(head_weights);
auto weights_head_weights_tw = convert_to_tensor_wrapper(head_weights_2);
auto o_tw = convert_to_tensor_wrapper(o);
auto o_pt_tw = convert_to_tensor_wrapper(o_pt);
auto o_pt_norm_tw = convert_to_tensor_wrapper(o_pt_norm);
auto o_pair_tw = convert_to_tensor_wrapper(o_pair);
auto z_tw = convert_to_tensor_wrapper(z);
auto rigid_rot_mats_tw = convert_to_tensor_wrapper(rigid_rot_mats);
auto rigid_trans_tw = convert_to_tensor_wrapper(rigid_trans);
auto mask_tw = convert_to_tensor_wrapper(mask);
auto linear_b_w_tw = convert_to_tensor_wrapper(linear_b_w);
auto linear_b_b_tw = convert_to_tensor_wrapper(linear_b_b);
kutacc_af2_ipa_weights_t_wrapper *ipa_weight_ptr = new kutacc_af2_ipa_weights_t_wrapper(head_weights_tw, weights_head_weights_tw, linear_b_w_tw, linear_b_b_tw, c_z, c_hidden, no_heads,
no_qk_points, no_v_points);
kutacc_af2_ipa_s_inputs_t_wrapper *ipa_s_ptrs = new kutacc_af2_ipa_s_inputs_t_wrapper(n_tw, m_tw, q_tw, k_tw, v_tw, q_pts_tw, k_pts_tw, v_pts_tw, n_res);
kutacc_af2_ipa_o_inputs_t_wrapper *ipa_o_ptrs = new kutacc_af2_ipa_o_inputs_t_wrapper(o_tw, o_pt_tw, o_pt_norm_tw, o_pair_tw);
if (unlikely(ipa_s_ptrs == nullptr || ipa_o_ptrs == nullptr || ipa_weight_ptr == nullptr)) {
return out;
}
kutacc_af2_invariant_point(ipa_s_ptrs, ipa_o_ptrs, z_tw.get_tensor(), rigid_rot_mats_tw.get_tensor(), rigid_trans_tw.get_tensor(), mask_tw.get_tensor(), ipa_weight_ptr);
out = linear(collect, linear_out_w, linear_out_b);
delete ipa_weight_ptr;
delete ipa_s_ptrs;
delete ipa_o_ptrs;
return out;
}
}
//test.py
import torch
import kpex._C as kernel
def test_invariant_point_attention(n_res, no_heads, c_hidden, no_qk_points, no_v_points, c_z, c_s):
out = kernel.alphafold.test_invariant_point_attention(n_res, no_heads, c_hidden, no_qk_points, no_v_points, c_z, c_s)
return out
//input : 4, 2, 2, 3, 8, 2, 8
//output:
>>> kpex.tpp.alphafold.alphafold.test_invariant_point_attention(4, 2, 2, 3, 8, 2, 8)
tensor([[3296., 3296., 3296., 3296., 3296., 3296., 3296., 3296.],
[3296., 3296., 3296., 3296., 3296., 3296., 3296., 3296.],
[3296., 3296., 3296., 3296., 3296., 3296., 3296., 3296.],
[3296., 3296., 3296., 3296., 3296., 3296., 3296., 3296.]],
dtype=torch.bfloat16)