鲲鹏社区首页
中文
注册
开发者
我要评分
获取效率
正确性
完整性
易理解
在线提单
论坛求助

kutacc_af2_invariant_point

利用注意力机制来保持特征点的不变性,从而提高模型在蛋白质结构预测中的性能。通过对输入特征进行加权聚合,模块能够有效地捕获特征之间的关系,并精确地调整蛋白质原子的坐标,从而提高结构预测的准确性和稳定性。

接口定义

void kutacc_af2_invariant_point(kutacc_af2_ipa_s_inputs_t *ipa_s_ptrs, kutacc_af2_ipa_o_inputs_t *ipa_o_ptrs, kutacc_tensor_h z, kutacc_tensor_h rigid_rot_mats, kutacc_tensor_h rigid_trans, kutacc_tensor_h mask, kutacc_af2_ipa_weights_t *ipa_weight_ptrs);

参数

表1 入参说明

参数名

类型

描述

输入/输出

ipa_s_ptrs

kutacc_af2_ipa_s_inputs_t *

kutacc_af2_ipa_s_inputs_t 类型的指针,具体定义见表2 kutacc_af2_ipa_s_inputs_t数据结构表定义

输入

ipa_o_ptrs

kutacc_af2_ipa_o_inputs_t *

kutacc_af2_ipa_o_inputs_t类型的指针,具体定义见表3 kutacc_af2_ipa_o_inputs_t数据结构表定义

输入

z

kutacc_tensor_h

中间变量

输入

rigd_rot_mats

kutacc_tensor_h

刚体变换:旋转参数

输入

rigid_trans

kutacc_tensor_h

刚体变换:平移参数

输入

mask

kutacc_tensor_h

掩码,用于忽略无效残基

输入

ipa_weight_ptrs

kutacc_af2_ipa_weights_t *

kutacc_af2_ipa_weights_t类型的指针,具体定义见表4 kutacc_af2_ipa_weights_t数据结构表定义

输入

表2 kutacc_af2_ipa_s_inputs_t数据结构表定义

参数名

类型

描述

输入/输出

n_res

int64_t

输入向量s的第0维长度

输入

a

kutacc_tensor_h

用于保存q转置与k的矩阵乘,进行softmax后的临时结果

输入

b

kutacc_tensor_h

用于保存z向量进行线性变换后的临时结果

输入

q

kutacc_tensor_h

输入s经过线性变化生成的Q矩阵

输入

k

kutacc_tensor_h

K矩阵

输入

v

kutacc_tensor_h

V矩阵

输入

q_pts

kutacc_tensor_h

查询矩阵的3D点

输入

k_pts

kutacc_tensor_h

键矩阵的3D点

输入

v_pts

kutacc_tensor_h

值矩阵的3D点

输入

表3 kutacc_af2_ipa_o_inputs_t数据结构表定义

参数名

类型

描述

输入/输出

o

kutacc_tensor_h

用于保存qkv,attention计算结果的向量

输入

o_pt

kutacc_tensor_h

用于保存pts相关向量attention计算结果的向量

输入

o_pt_norm

kutacc_tensor_h

用于保存需要进行layernorm变换的向量

输入

o_pair

kutacc_tensor_h

残基对特征

输入

表4 kutacc_af2_ipa_weights_t数据结构表定义

参数名

类型

描述

输入/输出

c_z

int64_t

单一single的长度

输入

c_hidden

int64_t

隐藏层长度

输入

no_heads

int64_t

attention头数量

输入

no_qk_points

int64_t

需要生成的q/k数量

输入

no_v_points

int64_t

需要生成的v数量

输入

head_weights

kutacc_tensor_h

注意力头的权重

输入

weights_head_weights

kutacc_tensor_h

注意力头的权重

输入

linear_b_w

kutacc_tensor_h

对b进行线性变化时的权重

输入

linear_b_b

kutacc_tensor_h

对b进行线性变化时的偏置

输入

整数参数应满足的约束关系: n_res, c_z, c_hidden, no_heads, no_qk_points, no_v_points > 0;

no_heads * c_hidden < INT64_MAX; no_heads * no_v_points * 3 <INT64_MAX;

no_heads*(c_hidden+no_v_points*3) <INT64_MAX; no_heads*(c_hidden+no_v_points*4) <INT64_MAX; no_heads * c_z < INT64_MAX;

no_heads * c_hidden * 2 * c_s < INT64_MAX; no_heads * 3 *no_qk_points * c_s < INT64_MAX;

3 * no_heads * (no_qk_points+ no_v_points) * c_s <INT64_MAX;

no_heads * 2 * c_hidden = c_s

no_qk_points * 3 < 16

构建KPEX层所需tensor参数的约束如下:

tensor

shape

描述

s

[n_res, c_s]

输入张量

z

[n_res, n_res, c_z]

输入张量,见入参说明参数z

rigid_trans

[n_res, 3]

输入张量,见入参说明参数rigid_trans

rigid_rot_mats

[n_res, 3, 3]

输入张量,见入参说明rigid_rot_mats

mask

[n_res]

输入张量,见入参说明参数mask

linear_q_w

[no_heads * c_hidden, c_s]

将s通过线性变化生成q的权重

linear_q_b

[no_heads * c_hidden]

将s通过线性变化生成q的偏置

linear_kv_w

[no_heads * 2 * c_hidden, c_s]

将s通过线性变化生成k, v的权重

linear_kv_b

[no_heads * 2 * c_hidden]

将s通过线性变化生成k, v的偏置

linear_q_points_w

[no_heads * no_qk_points * 3, c_s]

将s通过线性变化生成q_pts的权重

linear_q_points_b

[no_heads * no_qk_points *3]

将s通过线性变化生成q_pts的偏置

linear_kv_points_w

[3 * no_heads * (no_qk_points + no_v_points), c_s]

将s通过线性变化生成k_pts, v_pts的权重

linear_kv_points_b

[3 * no_heads * (no_qk_points + no_v_points)]

将s通过线性变化生成k_pts, v_pts的偏置

linear_b_w

[no_heads, c_z]

见表4 kutacc_af2_ipa_weights_t数据结构表定义参数linear_b_w

linear_b_b

[no_heads]

见表4 kutacc_af2_ipa_weights_t数据结构表定义参数linear_b_b

head_weights

[no_heads]

见表4 kutacc_af2_ipa_weights_t数据结构表定义参数head_weights

linear_out_w

[c_s, no_heads * (c_hidden + no_v_points * 4 + c_z)]

见表4 kutacc_af2_ipa_weights_t数据结构表定义参数linear_out_w

linear_out_b

[c_s]

见表4 kutacc_af2_ipa_weights_t数据结构表定义参数linear_out_b

示例

C++ interface:

//test_invariant.h
#ifndef KPEX_TPP_ALPHAFOLD_TEST_INVARIAN_H 

#define KPEX_TPP_ALPHAFOLD_TEST_INVARIAN_H 
#include <ATen/core/Tensor.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/ones.h>
#include <ATen/ops/zeros.h>
#include <ATen/ops/full.h>
#include <ATen/native/cpu/utils.h>
#include <c10/core/ScalarType.h>
namespace alphafold { 
at::Tensor test_invariant_point_attention(int64_t n_res, int64_t no_heads, int64_t c_hidden, int64_t no_qk_points, int64_t no_v_points, int64_t c_z, int64_t c_s); 
} 
#endif

//bind.h
#include "test_invariant_point.h"
namespace alphafold { 

inline void bind(pybind11::module &m) 
{ 
    auto submodule = m.def_submodule("alphafold");
    submodule.def("test_invariant_point_attention", &test_invariant_point_attention, py::arg("n_res"), py::arg("no_heads"), py::arg("c_hidden"), py::arg("no_qk_points"), 
        py::arg("no_v_points"), py::arg("c_z"), py::arg("c_s"));
}
}

//test_invariant.cpp
#include "test_invariant_point.h" 

#include "utils/linear.h" 
#include "rigid.h" 
#include "kutacc.h" 
#include "utils/memory.h" 
#include "invariant_point.h" 
namespace alphafold { 
at::Tensor test_invariant_point_attention(int64_t n_res, int64_t no_heads, int64_t c_hidden, int64_t no_qk_points, int64_t no_v_points, int64_t c_z, int64_t c_s) 
{ 
    float a = 0.6f; 
    float b = 0.25f; 
    float c = 0.1f; 
    float d = 0.2f; 
    float e = 0.5f; 
    at::Tensor s = at::full({n_res, c_s}, a, at::TensorOptions().device(kpex::device()).dtype(at::kBFloat16)); 
    at::Tensor out = at::empty(s.sizes(), s.options()); 
    at::Tensor rigid_trans = at::ones({n_res, 1, 1, 3}, s.options()).to(at::kFloat); 
    at::Tensor rigid_rot_mats = at::ones({n_res, 1, 1, 3, 3}, s.options()).to(at::kFloat); 
    at::Tensor linear_q_w = at::full({no_heads, c_hidden, c_s}, b, s.options()); 
    at::Tensor linear_q_b = at::full({no_heads, c_hidden}, c, rigid_rot_mats.options()); 
    at::Tensor linear_k_w = at::ones({no_heads, c_hidden, c_s}, s.options()); 
    at::Tensor linear_k_b = at::zeros({no_heads, c_hidden}, rigid_rot_mats.options()); 
    at::Tensor linear_v_w = at::ones({no_heads, c_hidden, c_s}, s.options()); 
    at::Tensor linear_v_b = at::zeros({no_heads, c_hidden}, rigid_rot_mats.options()); 
    at::Tensor linear_q_points_w = at::full({no_heads, no_qk_points, 3, c_s}, c, s.options()); 
    at::Tensor linear_q_points_b = at::ones({no_heads, no_qk_points, 3}, rigid_rot_mats.options()); 
    at::Tensor linear_k_points_w = at::full({no_heads, no_qk_points, 3, c_s}, d, s.options()); 
    at::Tensor linear_k_points_b = at::full({no_heads, no_qk_points, 3}, d, rigid_rot_mats.options()); 
    at::Tensor linear_v_points_w = at::ones({no_heads, no_v_points, 3, c_s}, s.options()); 
    at::Tensor linear_v_points_b = at::zeros({no_heads, no_v_points, 3}, rigid_rot_mats.options()); 
    at::Tensor linear_b_w = at::ones({no_heads, c_z}, s.options()); 
    at::Tensor linear_b_b = at::ones({no_heads}, rigid_rot_mats.options()); 
    at::Tensor head_weights = at::ones({no_heads}, rigid_rot_mats.options()); 
    at::Tensor linear_out_w = at::ones({c_s, no_heads * (c_hidden + no_v_points * 4 + c_z)}, s.options()); 
    at::Tensor linear_out_b = at::full({c_s}, e, rigid_rot_mats.options()); 
    at::Tensor mask = at::ones({n_res}, s.options()); 
    at::Tensor z = at::ones({n_res, n_res, c_z}, s.options()); 
    at::Tensor q = linear(s, linear_q_w, linear_q_b); 
    at::Tensor k = linear(s, linear_k_w, linear_k_b); 
    at::Tensor v = linear(s, linear_v_w, linear_v_b); 
    at::Tensor q_pts = linear(s, linear_q_points_w, linear_q_points_b); 
    q_pts = rigid_rot_vec_mul(q_pts, rigid_rot_mats, rigid_trans); 
    at::Tensor k_pts = linear(s, linear_k_points_w, linear_k_points_b); 
    k_pts = rigid_rot_vec_mul(k_pts, rigid_rot_mats, rigid_trans); 
    at::Tensor v_pts = linear(s, linear_v_points_w, linear_v_points_b); 
    v_pts = rigid_rot_vec_mul(v_pts, rigid_rot_mats, rigid_trans); 
    v_pts = v_pts.permute({1, 2, 3, 0}).contiguous(); 
    at::Tensor m = at::empty({no_heads, n_res, n_res}, s.options()); // b 
    at::Tensor n = at::empty({no_heads, n_res, n_res}, q.options()); // a 
    at::Tensor head_weights_2 = at::empty(head_weights.sizes(), head_weights.options()); 
    at::Tensor collect = at::empty({n_res, no_heads * (c_hidden + no_v_points * 4 + c_z)}, s.options()); 
    at::Tensor o = collect.narrow(1, 0, no_heads * c_hidden).view({n_res, no_heads, c_hidden}); 
    at::Tensor o_pt = collect.narrow(1, no_heads * c_hidden, no_heads * no_v_points * 3).view({n_res, 3, no_heads, no_v_points}); 
    at::Tensor o_pt_norm = collect.narrow(1, no_heads * (c_hidden + no_v_points * 3), no_heads * no_v_points).view({n_res, no_heads, no_v_points}); 
    at::Tensor o_pair = collect.narrow(1, no_heads * (c_hidden + no_v_points * 4), no_heads * c_z).view({n_res, no_heads, c_z}); 
    auto q_tw = convert_to_tensor_wrapper(q); 
    auto k_tw = convert_to_tensor_wrapper(k); 
    auto v_tw = convert_to_tensor_wrapper(v); 
    auto q_pts_tw = convert_to_tensor_wrapper(q_pts); 
    auto k_pts_tw = convert_to_tensor_wrapper(k_pts); 
    auto v_pts_tw = convert_to_tensor_wrapper(v_pts); 
    auto m_tw = convert_to_tensor_wrapper(m); 
    auto n_tw = convert_to_tensor_wrapper(n); 
    auto head_weights_tw = convert_to_tensor_wrapper(head_weights); 
    auto weights_head_weights_tw = convert_to_tensor_wrapper(head_weights_2); 
    auto o_tw = convert_to_tensor_wrapper(o); 
    auto o_pt_tw = convert_to_tensor_wrapper(o_pt); 
    auto o_pt_norm_tw = convert_to_tensor_wrapper(o_pt_norm); 
    auto o_pair_tw = convert_to_tensor_wrapper(o_pair); 
    auto z_tw = convert_to_tensor_wrapper(z); 
    auto rigid_rot_mats_tw = convert_to_tensor_wrapper(rigid_rot_mats); 
    auto rigid_trans_tw = convert_to_tensor_wrapper(rigid_trans); 
    auto mask_tw = convert_to_tensor_wrapper(mask); 
    auto linear_b_w_tw = convert_to_tensor_wrapper(linear_b_w); 
    auto linear_b_b_tw = convert_to_tensor_wrapper(linear_b_b); 
    kutacc_af2_ipa_weights_t_wrapper *ipa_weight_ptr = new kutacc_af2_ipa_weights_t_wrapper(head_weights_tw, weights_head_weights_tw, linear_b_w_tw, linear_b_b_tw, c_z, c_hidden, no_heads, 
        no_qk_points, no_v_points); 
    kutacc_af2_ipa_s_inputs_t_wrapper *ipa_s_ptrs = new kutacc_af2_ipa_s_inputs_t_wrapper(n_tw, m_tw, q_tw, k_tw, v_tw, q_pts_tw, k_pts_tw, v_pts_tw, n_res); 
    kutacc_af2_ipa_o_inputs_t_wrapper *ipa_o_ptrs = new kutacc_af2_ipa_o_inputs_t_wrapper(o_tw, o_pt_tw, o_pt_norm_tw, o_pair_tw); 
    if (unlikely(ipa_s_ptrs == nullptr || ipa_o_ptrs == nullptr || ipa_weight_ptr == nullptr)) { 
        return out; 
    } 
    kutacc_af2_invariant_point(ipa_s_ptrs, ipa_o_ptrs, z_tw.get_tensor(), rigid_rot_mats_tw.get_tensor(), rigid_trans_tw.get_tensor(), mask_tw.get_tensor(), ipa_weight_ptr); 
    out = linear(collect, linear_out_w, linear_out_b); 
    delete ipa_weight_ptr; 
    delete ipa_s_ptrs; 
    delete ipa_o_ptrs; 
    return out; 
} 
}

//test.py
import torch
import kpex._C as kernel
def test_invariant_point_attention(n_res, no_heads, c_hidden, no_qk_points, no_v_points, c_z, c_s): 
    out = kernel.alphafold.test_invariant_point_attention(n_res, no_heads, c_hidden, no_qk_points, no_v_points, c_z, c_s) 
    return out

//input : 4, 2, 2, 3, 8, 2, 8
//output:
>>> kpex.tpp.alphafold.alphafold.test_invariant_point_attention(4, 2, 2, 3, 8, 2, 8)
tensor([[3296., 3296., 3296., 3296., 3296., 3296., 3296., 3296.],
        [3296., 3296., 3296., 3296., 3296., 3296., 3296., 3296.],
        [3296., 3296., 3296., 3296., 3296., 3296., 3296., 3296.],
        [3296., 3296., 3296., 3296., 3296., 3296., 3296., 3296.]],
       dtype=torch.bfloat16)