鲲鹏社区首页
中文
注册
开发者
我要评分
获取效率
正确性
完整性
易理解
在线提单
论坛求助

kutacc_af2_gating_attention

gating_attention函数通过在标准多头自注意力(multi-self-attention)输出上施加可学习的门控(gating),来动态地筛选和控制信息流,从而提升模型对关键残基或序列位置的敏感性与表达能力。

接口定义

void kutacc_af2_gating_attention(kutacc_tensor_h input, kutacc_af2_attention_inputs_t *q_based_ptr, kutacc_tensor_h bias, kutacc_tensor_h nonbatched_bias, kutacc_af2_attention_weights_t *weight_ptr, kutacc_tensor_h out, int64_t block_size);

参数

表1 入参定义

参数名

类型

描述

输入/输出

input

kutacc_tensor_h

输入数据

输入

q_based_ptr

kutacc_af2_attention_inputs_t *

kutacc_af2_attention_inputs_t类型的指针,它指向多重输入张量数据的集合,具体见下方kutacc_af2_attention_inputs_t数据结构表

输入

bias

kutacc_tensor_h

神经网络中的attention偏移量数据

输入

nonbatched_bias

kutacc_tensor_h

神经网络中的通用偏移量数据

输入

weight_ptr

kutacc_af2_attention_weights_t *

kutacc_af2_attention_weights_t类型的指针,它指向多重权重数据的结合,具体见下方kutacc_af2_attention_weights_t数据结构表

输入

out

kutacc_tensor_h

输出数据

输出

block_size

int64_t

并行操作中的数据块切分大小

输入

表2 kutacc_af2_attention_inputs_t数据结构表定义

参数名

类型

描述

gate

kutacc_tensor_h

输入中间层gate数据

k

kutacc_tensor_h

输入中间层k数据

v

kutacc_tensor_h

输入中间层v数据

q

kutacc_tensor_h

输入中间层q数据

avg

kutacc_tensor_h

权重avg数据

batch

int64_t

输入批次

seq_len

int64_t

输入数据的序列长度

表3 kutacc_af2_attention_weights_t数据结构表

参数名

类型

描述

nchannels

int64_t

输入权重的总数

nheads

int64_t

输入权重的head数

head_size

int64_t

输入权重每个head中的数据数量

query_w

kutacc_tensor_h

权重query_w数据

key_w

kutacc_tensor_h

权重key_w数据

gating_w

kutacc_tensor_h

权重gating_w数据

gating_b

kutacc_tensor_h

门控偏移量数据

output_w

kutacc_tensor_h

权重output_w数据

output_b

kutacc_tensor_h

输出矩阵乘法偏移量数据

value_w

kutacc_tensor_h

权重value_w数据

整数参数的约束关系 应满足:

nchannels = nheads * head_size;

batch, seq_len, nchannels, nheads, head_size, block_size > 0,

nchannels < INT64_MAX

在调用kpex接口时构造tensor应满足以下形状约束

表4 kpex gating_attention入参形状约束

参数

形状

query_w

[nheads, head_size, nchannels]

key_w

[nheads, head_size, nchannels]

value_w

[nheads, head_size, nchannels]

gating_w

[nchannels, nheads, head_size]

gating_b

[nheads, head_size]

output_w

[nchannels, nheads, head_size]

output_b

[nchannels]

q_data

[batch, seq_len, nchannels]

m_data

[batch, seq_len, nchannels]

bias

[batch, 1, 1, seq_len]

nonbatched_bias

[nheads, seq_len, seq_len] or [0]

示例

C++ interface:
//test_gating_attention.h
#include <ATen/core/Tensor.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/ones.h>
#include <ATen/ops/zeros.h>
#include <ATen/ops/full.h>
#include <ATen/native/cpu/utils.h>
#include <c10/core/ScalarType.h>
#include "common_header.h"

namespace alphafold {
at::Tensor test_global_attention(int64_t batch, int64_t seq_len, int64_t nchannels, int64_t nheads, int64_t head_size);
}

//bind.h
#include "test_gating_attention.h"
#include <torch/extension.h>

namespace alphafold {
inline void bind(pybind11::module &m)
{
    auto submodule = m.def_submodule("alphafold");
    submodule.("test_gating_attention", &test_gating_attention,py::("batch"), py::("seq_len"), py::("nchannels"), py::("nheads"), py::("head_size"), py::("block_size");
}


//test_gating_attention.cpp
#include "kutacc.h"
#include <utils/memory.h>
#include <utils/bf16.h>
#include <utils/TensorWrapper.h>
#include "test_gating_attention.h"
namespace alphafold {
at::Tensor test_gating_attention(int64_t batch, int64_t seq_len, int64_t nchannels, int64_t nheads, int64_t head_size, int64_t block_size) {
    float a = 3.0f;
    float b = 2.5f;
    at::Tensor q_data = at::ones({batch, seq_len, nchannels}, at::TensorOptions().device(kpex::device()).dtype(at::kBFloat16));
    auto q = q_data.new_empty({batch, seq_len, nheads, head_size});
    auto k = q_data.new_empty({batch, seq_len, nheads, head_size});
    auto v = q_data.new_empty({nheads, head_size, batch, seq_len});
    auto gate = q_data.new_empty({batch, seq_len, nheads, head_size});
    auto weighted_avg = q_data.new_empty({batch, seq_len, nheads, head_size});
    at::Tensor input = at::ones({batch * seq_len, nchannels}, q_data.options());
    at::Tensor bias =  at::ones({batch, 1, 1, seq_len}, q_data.options());
    at::Tensor nonbatched_bias = at::zeros({nheads, seq_len, seq_len}, q_data.options());
    at::Tensor query_w = at::full({nheads, head_size, nchannels}, a, q_data.options()).to(at::kBFloat16);
    auto bf16_opt = query_w.options().device(kpex::device()).dtype(at::kBFloat16);
    auto float_opt = query_w.options().device(kpex::device()).dtype(at::kFloat);
    at::Tensor key_w = at::full({nheads, head_size, nchannels}, b, q_data.options()).to(bf16_opt);
    at::Tensor value_w = at::full({nheads, head_size, nchannels}, a, q_data.options()).to(bf16_opt);
    at::Tensor gating_w = at::full({nheads, head_size, nchannels}, a, q_data.options()).to(bf16_opt);
    at::Tensor gating_b = at::full({nheads, head_size}, b, q_data.options()).to(float_opt);
    at::Tensor output_w = at::full({nchannels, nheads, head_size}, b, q_data.options()).to(bf16_opt);
    at::Tensor output_b = at::full({nchannels}, a, q_data.options()).to(float_opt);
    query_w = query_w.to(bf16_opt).contiguous().view({nchannels, nchannels});
    key_w = key_w.to(bf16_opt).contiguous().view({nchannels, nchannels});
    value_w = value_w.to(bf16_opt).contiguous().view({nchannels, nchannels});
    gating_w = gating_w.to(bf16_opt).contiguous().view({nchannels, nchannels});
    output_w = output_w.to(bf16_opt).contiguous().view({nchannels, nchannels});
    auto query_w_res = linear_weight_prepack(query_w);
    auto key_w_res = linear_weight_prepack(key_w);
    auto value_w_res = linear_weight_prepack(value_w);
    auto gating_w_res = linear_weight_prepack(gating_w);
    auto output_w_res = linear_weight_prepack(output_w);
    at::Tensor out = at::empty(q_data.sizes(), q_data.options());
    auto input_tw = convert_to_tensor_wrapper(input);
    auto q_tw = convert_to_tensor_wrapper(q);
    auto k_tw = convert_to_tensor_wrapper(k);
    auto v_tw = convert_to_tensor_wrapper(v);
    auto gate_tw = convert_to_tensor_wrapper(gate);
    auto weighted_avg_tw = convert_to_tensor_wrapper(weighted_avg);
    auto bias_tw = convert_to_tensor_wrapper(bias);
    auto nonbatched_bias_tw = convert_to_tensor_wrapper(nonbatched_bias);
    auto query_w_tw = convert_to_tensor_wrapper(query_w_res);
    auto key_w_tw = convert_to_tensor_wrapper(key_w_res);
    auto value_w_tw = convert_to_tensor_wrapper(value_w_res);
    auto gating_w_tw = convert_to_tensor_wrapper(gating_w);
    auto gating_b_tw = convert_to_tensor_wrapper(gating_w_res);
    auto output_w_tw = convert_to_tensor_wrapper(output_w_res);
    auto output_b_tw = convert_to_tensor_wrapper(output_b);
    auto out_tw = convert_to_tensor_wrapper(out);
    kutacc_af2_attention_weights_t_wrapper *gating_attention_weight_ptr = new kutacc_af2_attention_weights_t_wrapper(query_w_tw, key_w_tw, value_w_tw, gating_w_tw, gating_b_tw,
        output_w_tw, output_b_tw, nchannels, nheads, head_size);
    kutacc_af2_attention_inputs_t_wrapper *gating_attention_q_ptr = new kutacc_af2_attention_inputs_t_wrapper(q_tw, k_tw, v_tw, gate_tw, weighted_avg_tw, batch, seq_len);
    kutacc_af2_gating_attention(input_tw.get_tensor(), gating_attention_q_ptr,bias_tw.get_tensor(), nonbatched_bias_tw.get_tensor(), gating_attention_weight_ptr, out_tw.get_tensor(), block_size);
    delete gating_attention_weight_ptr;
    delete gating_attention_q_ptr;
    return out;
}
}

//test.py
import torch
import kpex._C as kernel
def test_gating_attention(batch, seq_len, nchannels, nheads, head_size, block_size):
    out = kernel.alphafold.test_gating_attention(batch, seq_len, nchannels, nheads, head_size, block_size)
    print(out)
    return out

/**
 * input:2, 2, 8, 4, 2, 2
 * output:
tensor([[[484., 484., 484., 484., 484., 484., 484., 484.],
         [484., 484., 484., 484., 484., 484., 484., 484.]],

        [[484., 484., 484., 484., 484., 484., 484., 484.],
         [484., 484., 484., 484., 484., 484., 484., 484.]]],
       dtype=torch.bfloat16)
 *
 */