kutacc_af2_global_attention
在AF2中,global_attention是一种全局的注意力机制,通过对多序列比对过程中列方向的目标残基的信息进行整合,会将来自首个序列的三维结构信息传递给其他序列,从而更好地理解该残基在不同进化背景下所起的作用,以及残基之间的协同进化关系等,进而为蛋白质结构预测提供更丰富的信息。
接口定义
void kutacc_af2_global_attention(kutacc_af2_attention_inputs_t *q_based_ptr, kutacc_tensor_h q_data, kutacc_tensor_h q_mask, kutacc_af2_attention_weights_t *weight_ptr, kutacc_tensor_h out);
参数
表1 入参定义
参数名 |
类型 |
描述 |
输入/输出 |
|---|---|---|---|
q_based_ptr |
kutacc_af2_attention_inputs_t * |
kutacc_af2_attention_inputs_t类型的指针,数据结构表定义见下方表2 kutacc_af2_attention_inputs_t数据结构表定义 |
输入 |
q_data |
kutacc_tensor_h |
Q矩阵数据 |
输入 |
q_mask |
kutacc_tensor_h |
Q掩码矩阵 |
输入 |
weight_ptr |
kutacc_af2_attention_weights_t * |
kutacc_af2_attention_weights_t类型的指针,数据结构表定义见下方表3 kutacc_af2_attention_weights_t数据结构表 |
输入 |
out |
kutacc_tensor_h |
输出数据 |
输出 |
表2 kutacc_af2_attention_inputs_t数据结构表定义
参数名 |
类型 |
描述 |
gate |
kutacc_tensor_h |
输入中间层门控张量数据 |
k |
kutacc_tensor_h |
输入中间层k数据 |
v |
kutacc_tensor_h |
输入中间层v数据 |
q |
kutacc_tensor_h |
输入中间层q数据 |
avg |
kutacc_tensor_h |
权重avg数据 |
batch |
int64_t |
输入批次 |
seq_len |
int64_t |
输入数据的序列长度 |
表3 kutacc_af2_attention_weights_t数据结构表
参数名 |
类型 |
描述 |
nchannels |
int64_t |
输入权重的总数 |
nheads |
int64_t |
输入权重的head数 |
head_size |
int64_t |
输入权重每个head中的数据数量 |
query_w |
kutacc_tensor_h |
权重query_w数据 |
key_w |
kutacc_tensor_h |
权重key_w数据 |
gating_w |
kutacc_tensor_h |
权重gating_w数据 |
gating_b |
kutacc_tensor_h |
门控偏移量数据 |
output_w |
kutacc_tensor_h |
权重output_w数据 |
output_b |
kutacc_tensor_h |
输出矩阵乘法偏移量数据 |
value_w |
kutacc_tensor_h |
权重value_w数据 |
global_attention整数参数应满足的约束关系:
nchannels = nheads * head_size,
batch, seq_len, nchannels, nheads, head_size > 0
在满足上述条件的同时应满足nchannels=64
在构建用例及使用KPEX时应满足以下算子形状约束
表4 global attention入参形状约束
tensor |
shape |
|---|---|
query_w |
[nheads, head_size, nchannels] |
key_w |
[head_size, nchannels] |
value_w |
[head_size, nchannels] |
gating_w |
[nheads, head_size, nchannles] |
gating_b |
[nheads, head_size] |
output_w |
[nchannels, nheads, head_size] |
output_b |
[nchannels] |
q_data |
[batch, seq_len, nchannels] |
m_data |
[batch, seq_len, nchannels] |
q_mask |
[batch, seq_len, 1] |
示例
C++ interface:
//test_global_attention.h
#ifndef KPEX_TPP_ALPHAFOLD_TEST_GLOBAL_ATTENTION_H
#define KPEX_TPP_ALPHAFOLD_TEST_GLOBAL_ATTENTION_H
#include <ATen/core/Tensor.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/ones.h>
#include <ATen/ops/zeros.h>
#include <ATen/ops/full.h>
#include <ATen/native/cpu/utils.h>
#include <c10/core/ScalarType.h>
#include "common_header.h"
namespace alphafold {
at::Tensor test_global_attention(int64_t batch, int64_t seq_len, int64_t nchannels, int64_t nheads, int64_t head_size);
}
#endif
//bind.h
#include <torch/extension.h>
#include "test_global_attention.h"
namespace alphafold {
inline void bind(pybind11::module &m)
{
auto submodule = m.def_submodule("alphafold");
submodule.def("test_global_attention", &test_global_attention, py::arg("batch"), py::arg("seq_len"), py::arg("nchannels"), py::arg("nheads"), py::arg("head_size"));
}
//test_global_attention.cpp
#include "kutacc.h"
#include <utils/memory.h>
#include <utils/bf16.h>
#include <utils/TensorWrapper.h>
#include "test_global_attention.h"
namespace alphafold {
at::Tensor test_global_attention(int64_t batch, int64_t seq_len, int64_t nchannels, int64_t nheads, int64_t head_size) {
float a = 0.75f;
float b = 0.25f;
float c = 0.5f;
at::Tensor q_data = at::ones({batch, seq_len, nchannels}, at::TensorOptions().device(kpex::device()).dtype(at::kBFloat16));
at::Tensor out = at::empty(q_data.sizes(), q_data.options());
at::Tensor q_mask = at::ones({batch, seq_len, 1}, q_data.options());
q_mask = q_mask.contiguous();
at::Tensor q_avg = q_data.new_empty({batch, nchannels});
at::Tensor q = q_data.new_empty({batch, nheads, head_size});
at::Tensor k = q_data.new_empty({batch, seq_len, head_size});
at::Tensor v = q_data.new_empty({head_size, batch, seq_len});
at::Tensor gate = q_data.new_empty({batch, seq_len, nheads, head_size});
at::Tensor query_w = at::full({nheads, head_size, nchannels}, a, q_data.options()).to(at::kBFloat16);
auto bf16_opt = query_w.options().device(kpex::device()).dtype(at::kBFloat16);
auto float_opt = query_w.options().device(kpex::device()).dtype(at::kFloat);
at::Tensor key_w = at::full({head_size, nchannels}, b, at::TensorOptions().device(kpex::device()).dtype(at::kBFloat16));
at::Tensor value_w = at::full({head_size, nchannels}, a, at::TensorOptions().device(kpex::device()).dtype(at::kBFloat16));
at::Tensor gating_w = at::full({nheads, head_size, nchannels}, a, at::TensorOptions().device(kpex::device()).dtype(at::kBFloat16));
at::Tensor gating_b = at::full({nheads, head_size}, c, at::TensorOptions().device(kpex::device()).dtype(at::kFloat));
at::Tensor output_w = at::full({nchannels, nheads, head_size}, b, at::TensorOptions().device(kpex::device()).dtype(at::kBFloat16));
at::Tensor output_b = at::full({nchannels}, c, at::TensorOptions().device(kpex::device()).dtype(at::kFloat));
query_w = query_w.to(bf16_opt).contiguous().view({nchannels, nchannels});
key_w = key_w.to(bf16_opt).contiguous().view({head_size, nchannels});
value_w = value_w.to(bf16_opt).contiguous().view({head_size, nchannels});
gating_w = gating_w.to(bf16_opt).contiguous().view({nchannels, nchannels});
output_w = output_w.to(bf16_opt).contiguous().view({nchannels, nchannels});
auto query_w_res = linear_weight_prepack(query_w);
auto key_w_res = linear_weight_prepack(key_w);
auto value_w_res = linear_weight_prepack(value_w);
auto gating_w_res = linear_weight_prepack(gating_w);
auto output_w_res = linear_weight_prepack(output_w);
kutacc::TensorWrapper q_avg_tw = convert_to_tensor_wrapper(q_avg);
kutacc::TensorWrapper q_tw = convert_to_tensor_wrapper(q);
kutacc::TensorWrapper k_tw = convert_to_tensor_wrapper(k);
kutacc::TensorWrapper v_tw = convert_to_tensor_wrapper(v);
kutacc::TensorWrapper gate_tw = convert_to_tensor_wrapper(gate);
kutacc::TensorWrapper q_data_tw = convert_to_tensor_wrapper(q_data);
kutacc::TensorWrapper q_mask_tw = convert_to_tensor_wrapper(q_mask);
kutacc::TensorWrapper out_tw = convert_to_tensor_wrapper(out);
kutacc::TensorWrapper query_w_tw = convert_to_tensor_wrapper(query_w_res);
kutacc::TensorWrapper key_w_tw = convert_to_tensor_wrapper(key_w_res);
kutacc::TensorWrapper value_w_tw = convert_to_tensor_wrapper(value_w_res);
kutacc::TensorWrapper gating_w_tw = convert_to_tensor_wrapper(gating_w_res);
kutacc::TensorWrapper gating_b_tw = convert_to_tensor_wrapper(gating_b);
kutacc::TensorWrapper output_w_tw = convert_to_tensor_wrapper(output_w_res);
kutacc::TensorWrapper output_b_tw = convert_to_tensor_wrapper(output_b);
kutacc_af2_attention_weights_t_wrapper *global_attention_weight_ptr = new kutacc_af2_attention_weights_t_wrapper(query_w_tw, key_w_tw, value_w_tw, gating_w_tw, gating_b_tw,
output_w_tw, output_b_tw, nchannels, nheads, head_size);
kutacc_af2_attention_inputs_t_wrapper *global_attention_q_ptr = new kutacc_af2_attention_inputs_t_wrapper(q_tw, k_tw, v_tw, gate_tw, q_avg_tw, batch, seq_len);
kutacc_af2_global_attention(global_attention_q_ptr, q_data_tw.get_tensor(), q_mask_tw.get_tensor(), global_attention_weight_ptr, out_tw.get_tensor());
delete global_attention_weight_ptr;
delete global_attention_q_ptr;
return out;
}
}
//test.py
import torch
import kpex._C as kernel
def test_global_attention(batch, seq_len, nchannels, nheads, head_size):
out = kernel.alphafold.test_global_attention(batch, seq_len, nchannels, nheads, head_size)
return out
//output:
tensor([[[768., 768., 768., 768., 768., 768., 768., 768., 768., 768., 768.,
768., 768., 768., 768., 768., 768., 768., 768., 768., 768., 768.,
768., 768., 768., 768., 768., 768., 768., 768., 768., 768., 768.,
768., 768., 768., 768., 768., 768., 768., 768., 768., 768., 768.,
768., 768., 768., 768., 768., 768., 768., 768., 768., 768., 768.,
768., 768., 768., 768., 768., 768., 768., 768., 768.],
[768., 768., 768., 768., 768., 768., 768., 768., 768., 768., 768.,
768., 768., 768., 768., 768., 768., 768., 768., 768., 768., 768.,
768., 768., 768., 768., 768., 768., 768., 768., 768., 768., 768.,
768., 768., 768., 768., 768., 768., 768., 768., 768., 768., 768.,
768., 768., 768., 768., 768., 768., 768., 768., 768., 768., 768.,
768., 768., 768., 768., 768., 768., 768., 768., 768.]],
[[768., 768., 768., 768., 768., 768., 768., 768., 768., 768., 768.,
768., 768., 768., 768., 768., 768., 768., 768., 768., 768., 768.,
768., 768., 768., 768., 768., 768., 768., 768., 768., 768., 768.,
768., 768., 768., 768., 768., 768., 768., 768., 768., 768., 768.,
768., 768., 768., 768., 768., 768., 768., 768., 768., 768., 768.,
768., 768., 768., 768., 768., 768., 768., 768., 768.],
[768., 768., 768., 768., 768., 768., 768., 768., 768., 768., 768.,
768., 768., 768., 768., 768., 768., 768., 768., 768., 768., 768.,
768., 768., 768., 768., 768., 768., 768., 768., 768., 768., 768.,
768., 768., 768., 768., 768., 768., 768., 768., 768., 768., 768.,
768., 768., 768., 768., 768., 768., 768., 768., 768., 768., 768.,
768., 768., 768., 768., 768., 768., 768., 768., 768.]]],
dtype=torch.bfloat16)