kutacc_af2_gating_attention
gating_attention函数通过在标准多头自注意力(multi-self-attention)输出上施加可学习的门控(gating),来动态地筛选和控制信息流,从而提升模型对关键残基或序列位置的敏感性与表达能力。
接口定义
void kutacc_af2_gating_attention(kutacc_tensor_h input, kutacc_af2_attention_inputs_t *q_based_ptr, kutacc_tensor_h bias, kutacc_tensor_h nonbatched_bias, kutacc_af2_attention_weights_t *weight_ptr, kutacc_tensor_h out, int64_t block_size);
参数
表1 入参定义
参数名 |
类型 |
描述 |
输入/输出 |
|---|---|---|---|
input |
kutacc_tensor_h |
输入数据 |
输入 |
q_based_ptr |
kutacc_af2_attention_inputs_t * |
kutacc_af2_attention_inputs_t类型的指针,它指向多重输入张量数据的集合,具体见下方kutacc_af2_attention_inputs_t数据结构表 |
输入 |
bias |
kutacc_tensor_h |
神经网络中的attention偏移量数据 |
输入 |
nonbatched_bias |
kutacc_tensor_h |
神经网络中的通用偏移量数据 |
输入 |
weight_ptr |
kutacc_af2_attention_weights_t * |
kutacc_af2_attention_weights_t类型的指针,它指向多重权重数据的结合,具体见下方kutacc_af2_attention_weights_t数据结构表 |
输入 |
out |
kutacc_tensor_h |
输出数据 |
输出 |
block_size |
int64_t |
并行操作中的数据块切分大小 |
输入 |
表2 kutacc_af2_attention_inputs_t数据结构表定义
参数名 |
类型 |
描述 |
gate |
kutacc_tensor_h |
输入中间层gate数据 |
k |
kutacc_tensor_h |
输入中间层k数据 |
v |
kutacc_tensor_h |
输入中间层v数据 |
q |
kutacc_tensor_h |
输入中间层q数据 |
avg |
kutacc_tensor_h |
权重avg数据 |
batch |
int64_t |
输入批次 |
seq_len |
int64_t |
输入数据的序列长度 |
表3 kutacc_af2_attention_weights_t数据结构表
参数名 |
类型 |
描述 |
nchannels |
int64_t |
输入权重的总数 |
nheads |
int64_t |
输入权重的head数 |
head_size |
int64_t |
输入权重每个head中的数据数量 |
query_w |
kutacc_tensor_h |
权重query_w数据 |
key_w |
kutacc_tensor_h |
权重key_w数据 |
gating_w |
kutacc_tensor_h |
权重gating_w数据 |
gating_b |
kutacc_tensor_h |
门控偏移量数据 |
output_w |
kutacc_tensor_h |
权重output_w数据 |
output_b |
kutacc_tensor_h |
输出矩阵乘法偏移量数据 |
value_w |
kutacc_tensor_h |
权重value_w数据 |
整数参数的约束关系 应满足:
nchannels = nheads * head_size;
batch, seq_len, nchannels, nheads, head_size, block_size > 0,
nchannels < INT64_MAX
在调用kpex接口时构造tensor应满足以下形状约束
表4 kpex gating_attention入参形状约束
参数 |
形状 |
|---|---|
query_w |
[nheads, head_size, nchannels] |
key_w |
[nheads, head_size, nchannels] |
value_w |
[nheads, head_size, nchannels] |
gating_w |
[nchannels, nheads, head_size] |
gating_b |
[nheads, head_size] |
output_w |
[nchannels, nheads, head_size] |
output_b |
[nchannels] |
q_data |
[batch, seq_len, nchannels] |
m_data |
[batch, seq_len, nchannels] |
bias |
[batch, 1, 1, seq_len] |
nonbatched_bias |
[nheads, seq_len, seq_len] or [0] |
示例
//test_gating_attention.h
#include <ATen/core/Tensor.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/ones.h>
#include <ATen/ops/zeros.h>
#include <ATen/ops/full.h>
#include <ATen/native/cpu/utils.h>
#include <c10/core/ScalarType.h>
#include "common_header.h"
namespace alphafold {
at::Tensor test_global_attention(int64_t batch, int64_t seq_len, int64_t nchannels, int64_t nheads, int64_t head_size);
}
//bind.h
#include "test_gating_attention.h"
#include <torch/extension.h>
namespace alphafold {
inline void bind(pybind11::module &m)
{
auto submodule = m.def_submodule("alphafold");
submodule.("test_gating_attention", &test_gating_attention,py::("batch"), py::("seq_len"), py::("nchannels"), py::("nheads"), py::("head_size"), py::("block_size");
}
//test_gating_attention.cpp
#include "kutacc.h"
#include <utils/memory.h>
#include <utils/bf16.h>
#include <utils/TensorWrapper.h>
#include "test_gating_attention.h"
namespace alphafold {
at::Tensor test_gating_attention(int64_t batch, int64_t seq_len, int64_t nchannels, int64_t nheads, int64_t head_size, int64_t block_size) {
float a = 3.0f;
float b = 2.5f;
at::Tensor q_data = at::ones({batch, seq_len, nchannels}, at::TensorOptions().device(kpex::device()).dtype(at::kBFloat16));
auto q = q_data.new_empty({batch, seq_len, nheads, head_size});
auto k = q_data.new_empty({batch, seq_len, nheads, head_size});
auto v = q_data.new_empty({nheads, head_size, batch, seq_len});
auto gate = q_data.new_empty({batch, seq_len, nheads, head_size});
auto weighted_avg = q_data.new_empty({batch, seq_len, nheads, head_size});
at::Tensor input = at::ones({batch * seq_len, nchannels}, q_data.options());
at::Tensor bias = at::ones({batch, 1, 1, seq_len}, q_data.options());
at::Tensor nonbatched_bias = at::zeros({nheads, seq_len, seq_len}, q_data.options());
at::Tensor query_w = at::full({nheads, head_size, nchannels}, a, q_data.options()).to(at::kBFloat16);
auto bf16_opt = query_w.options().device(kpex::device()).dtype(at::kBFloat16);
auto float_opt = query_w.options().device(kpex::device()).dtype(at::kFloat);
at::Tensor key_w = at::full({nheads, head_size, nchannels}, b, q_data.options()).to(bf16_opt);
at::Tensor value_w = at::full({nheads, head_size, nchannels}, a, q_data.options()).to(bf16_opt);
at::Tensor gating_w = at::full({nheads, head_size, nchannels}, a, q_data.options()).to(bf16_opt);
at::Tensor gating_b = at::full({nheads, head_size}, b, q_data.options()).to(float_opt);
at::Tensor output_w = at::full({nchannels, nheads, head_size}, b, q_data.options()).to(bf16_opt);
at::Tensor output_b = at::full({nchannels}, a, q_data.options()).to(float_opt);
query_w = query_w.to(bf16_opt).contiguous().view({nchannels, nchannels});
key_w = key_w.to(bf16_opt).contiguous().view({nchannels, nchannels});
value_w = value_w.to(bf16_opt).contiguous().view({nchannels, nchannels});
gating_w = gating_w.to(bf16_opt).contiguous().view({nchannels, nchannels});
output_w = output_w.to(bf16_opt).contiguous().view({nchannels, nchannels});
auto query_w_res = linear_weight_prepack(query_w);
auto key_w_res = linear_weight_prepack(key_w);
auto value_w_res = linear_weight_prepack(value_w);
auto gating_w_res = linear_weight_prepack(gating_w);
auto output_w_res = linear_weight_prepack(output_w);
at::Tensor out = at::empty(q_data.sizes(), q_data.options());
auto input_tw = convert_to_tensor_wrapper(input);
auto q_tw = convert_to_tensor_wrapper(q);
auto k_tw = convert_to_tensor_wrapper(k);
auto v_tw = convert_to_tensor_wrapper(v);
auto gate_tw = convert_to_tensor_wrapper(gate);
auto weighted_avg_tw = convert_to_tensor_wrapper(weighted_avg);
auto bias_tw = convert_to_tensor_wrapper(bias);
auto nonbatched_bias_tw = convert_to_tensor_wrapper(nonbatched_bias);
auto query_w_tw = convert_to_tensor_wrapper(query_w_res);
auto key_w_tw = convert_to_tensor_wrapper(key_w_res);
auto value_w_tw = convert_to_tensor_wrapper(value_w_res);
auto gating_w_tw = convert_to_tensor_wrapper(gating_w);
auto gating_b_tw = convert_to_tensor_wrapper(gating_w_res);
auto output_w_tw = convert_to_tensor_wrapper(output_w_res);
auto output_b_tw = convert_to_tensor_wrapper(output_b);
auto out_tw = convert_to_tensor_wrapper(out);
kutacc_af2_attention_weights_t_wrapper *gating_attention_weight_ptr = new kutacc_af2_attention_weights_t_wrapper(query_w_tw, key_w_tw, value_w_tw, gating_w_tw, gating_b_tw,
output_w_tw, output_b_tw, nchannels, nheads, head_size);
kutacc_af2_attention_inputs_t_wrapper *gating_attention_q_ptr = new kutacc_af2_attention_inputs_t_wrapper(q_tw, k_tw, v_tw, gate_tw, weighted_avg_tw, batch, seq_len);
kutacc_af2_gating_attention(input_tw.get_tensor(), gating_attention_q_ptr,bias_tw.get_tensor(), nonbatched_bias_tw.get_tensor(), gating_attention_weight_ptr, out_tw.get_tensor(), block_size);
delete gating_attention_weight_ptr;
delete gating_attention_q_ptr;
return out;
}
}
//test.py
import torch
import kpex._C as kernel
def test_gating_attention(batch, seq_len, nchannels, nheads, head_size, block_size):
out = kernel.alphafold.test_gating_attention(batch, seq_len, nchannels, nheads, head_size, block_size)
print(out)
return out
/**
* input:2, 2, 8, 4, 2, 2
* output:
tensor([[[484., 484., 484., 484., 484., 484., 484., 484.],
[484., 484., 484., 484., 484., 484., 484., 484.]],
[[484., 484., 484., 484., 484., 484., 484., 484.],
[484., 484., 484., 484., 484., 484., 484., 484.]]],
dtype=torch.bfloat16)
*
*/