鲲鹏社区首页
中文
注册
开发者
我要评分
获取效率
正确性
完整性
易理解
在线提单
论坛求助

kutacc_af2_global_attention

在AF2中,global_attention是一种全局的注意力机制,通过对多序列比对过程中列方向的目标残基的信息进行整合,会将来自首个序列的三维结构信息传递给其他序列,从而更好地理解该残基在不同进化背景下所起的作用,以及残基之间的协同进化关系等,进而为蛋白质结构预测提供更丰富的信息。

接口定义

void kutacc_af2_global_attention(kutacc_af2_attention_inputs_t *q_based_ptr, kutacc_tensor_h q_data, kutacc_tensor_h q_mask, kutacc_af2_attention_weights_t *weight_ptr, kutacc_tensor_h out);

参数

表1 入参定义

参数名

类型

描述

输入/输出

q_based_ptr

kutacc_af2_attention_inputs_t *

kutacc_af2_attention_inputs_t类型的指针,数据结构表定义见下方表2 kutacc_af2_attention_inputs_t数据结构表定义

输入

q_data

kutacc_tensor_h

Q矩阵数据

输入

q_mask

kutacc_tensor_h

Q掩码矩阵

输入

weight_ptr

kutacc_af2_attention_weights_t *

kutacc_af2_attention_weights_t类型的指针,数据结构表定义见下方表3 kutacc_af2_attention_weights_t数据结构表

输入

out

kutacc_tensor_h

输出数据

输出

表2 kutacc_af2_attention_inputs_t数据结构表定义

参数名

类型

描述

gate

kutacc_tensor_h

输入中间层门控张量数据

k

kutacc_tensor_h

输入中间层k数据

v

kutacc_tensor_h

输入中间层v数据

q

kutacc_tensor_h

输入中间层q数据

avg

kutacc_tensor_h

权重avg数据

batch

int64_t

输入批次

seq_len

int64_t

输入数据的序列长度

表3 kutacc_af2_attention_weights_t数据结构表

参数名

类型

描述

nchannels

int64_t

输入权重的总数

nheads

int64_t

输入权重的head数

head_size

int64_t

输入权重每个head中的数据数量

query_w

kutacc_tensor_h

权重query_w数据

key_w

kutacc_tensor_h

权重key_w数据

gating_w

kutacc_tensor_h

权重gating_w数据

gating_b

kutacc_tensor_h

门控偏移量数据

output_w

kutacc_tensor_h

权重output_w数据

output_b

kutacc_tensor_h

输出矩阵乘法偏移量数据

value_w

kutacc_tensor_h

权重value_w数据

global_attention整数参数应满足的约束关系:

nchannels = nheads * head_size,

batch, seq_len, nchannels, nheads, head_size > 0

在满足上述条件的同时应满足nchannels=64

在构建用例及使用KPEX时应满足以下算子形状约束

表4 global attention入参形状约束

tensor

shape

query_w

[nheads, head_size, nchannels]

key_w

[head_size, nchannels]

value_w

[head_size, nchannels]

gating_w

[nheads, head_size, nchannles]

gating_b

[nheads, head_size]

output_w

[nchannels, nheads, head_size]

output_b

[nchannels]

q_data

[batch, seq_len, nchannels]

m_data

[batch, seq_len, nchannels]

q_mask

[batch, seq_len, 1]

示例

C++ interface:

//test_global_attention.h
#ifndef KPEX_TPP_ALPHAFOLD_TEST_GLOBAL_ATTENTION_H
#define KPEX_TPP_ALPHAFOLD_TEST_GLOBAL_ATTENTION_H
#include <ATen/core/Tensor.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/ones.h>
#include <ATen/ops/zeros.h>
#include <ATen/ops/full.h>
#include <ATen/native/cpu/utils.h>
#include <c10/core/ScalarType.h>
#include "common_header.h"
namespace alphafold {
at::Tensor test_global_attention(int64_t batch, int64_t seq_len, int64_t nchannels, int64_t nheads, int64_t head_size);
}
#endif

//bind.h
#include <torch/extension.h>
#include "test_global_attention.h"
namespace alphafold {
inline void bind(pybind11::module &m)
{
    auto submodule = m.def_submodule("alphafold");
    submodule.def("test_global_attention", &test_global_attention, py::arg("batch"), py::arg("seq_len"), py::arg("nchannels"), py::arg("nheads"), py::arg("head_size"));
}

//test_global_attention.cpp
#include "kutacc.h"
#include <utils/memory.h>
#include <utils/bf16.h>
#include <utils/TensorWrapper.h>
#include "test_global_attention.h"
namespace alphafold {
at::Tensor test_global_attention(int64_t batch, int64_t seq_len, int64_t nchannels, int64_t nheads, int64_t head_size) {
    float a = 0.75f;
    float b = 0.25f;
    float c = 0.5f;
    at::Tensor q_data = at::ones({batch, seq_len, nchannels}, at::TensorOptions().device(kpex::device()).dtype(at::kBFloat16));
    at::Tensor out = at::empty(q_data.sizes(), q_data.options());
    at::Tensor q_mask = at::ones({batch, seq_len, 1}, q_data.options());
    q_mask = q_mask.contiguous();
    at::Tensor q_avg = q_data.new_empty({batch, nchannels});
    at::Tensor q = q_data.new_empty({batch, nheads, head_size});
    at::Tensor k = q_data.new_empty({batch, seq_len, head_size});
    at::Tensor v = q_data.new_empty({head_size, batch, seq_len});
    at::Tensor gate = q_data.new_empty({batch, seq_len, nheads, head_size});
    at::Tensor query_w = at::full({nheads, head_size, nchannels}, a, q_data.options()).to(at::kBFloat16);
    auto bf16_opt = query_w.options().device(kpex::device()).dtype(at::kBFloat16);
    auto float_opt = query_w.options().device(kpex::device()).dtype(at::kFloat);
    at::Tensor key_w = at::full({head_size, nchannels}, b, at::TensorOptions().device(kpex::device()).dtype(at::kBFloat16));
    at::Tensor value_w = at::full({head_size, nchannels}, a, at::TensorOptions().device(kpex::device()).dtype(at::kBFloat16));
    at::Tensor gating_w = at::full({nheads, head_size, nchannels}, a, at::TensorOptions().device(kpex::device()).dtype(at::kBFloat16));
    at::Tensor gating_b = at::full({nheads, head_size}, c, at::TensorOptions().device(kpex::device()).dtype(at::kFloat));
    at::Tensor output_w = at::full({nchannels, nheads, head_size}, b, at::TensorOptions().device(kpex::device()).dtype(at::kBFloat16));
    at::Tensor output_b = at::full({nchannels}, c, at::TensorOptions().device(kpex::device()).dtype(at::kFloat));
    query_w = query_w.to(bf16_opt).contiguous().view({nchannels, nchannels});
    key_w = key_w.to(bf16_opt).contiguous().view({head_size, nchannels});
    value_w = value_w.to(bf16_opt).contiguous().view({head_size, nchannels});
    gating_w = gating_w.to(bf16_opt).contiguous().view({nchannels, nchannels});
    output_w = output_w.to(bf16_opt).contiguous().view({nchannels, nchannels});
    auto query_w_res = linear_weight_prepack(query_w);
    auto key_w_res = linear_weight_prepack(key_w);
    auto value_w_res = linear_weight_prepack(value_w);
    auto gating_w_res = linear_weight_prepack(gating_w);
    auto output_w_res = linear_weight_prepack(output_w);
    kutacc::TensorWrapper q_avg_tw = convert_to_tensor_wrapper(q_avg);
    kutacc::TensorWrapper q_tw = convert_to_tensor_wrapper(q);
    kutacc::TensorWrapper k_tw = convert_to_tensor_wrapper(k);
    kutacc::TensorWrapper v_tw = convert_to_tensor_wrapper(v);
    kutacc::TensorWrapper gate_tw = convert_to_tensor_wrapper(gate);
    kutacc::TensorWrapper q_data_tw = convert_to_tensor_wrapper(q_data);
    kutacc::TensorWrapper q_mask_tw = convert_to_tensor_wrapper(q_mask);
    kutacc::TensorWrapper out_tw = convert_to_tensor_wrapper(out);
    kutacc::TensorWrapper query_w_tw = convert_to_tensor_wrapper(query_w_res);
    kutacc::TensorWrapper key_w_tw = convert_to_tensor_wrapper(key_w_res);
    kutacc::TensorWrapper value_w_tw = convert_to_tensor_wrapper(value_w_res);
    kutacc::TensorWrapper gating_w_tw = convert_to_tensor_wrapper(gating_w_res);
    kutacc::TensorWrapper gating_b_tw = convert_to_tensor_wrapper(gating_b);
    kutacc::TensorWrapper output_w_tw = convert_to_tensor_wrapper(output_w_res);
    kutacc::TensorWrapper output_b_tw = convert_to_tensor_wrapper(output_b);
    kutacc_af2_attention_weights_t_wrapper *global_attention_weight_ptr = new kutacc_af2_attention_weights_t_wrapper(query_w_tw, key_w_tw, value_w_tw, gating_w_tw, gating_b_tw,
        output_w_tw, output_b_tw, nchannels, nheads, head_size);
    kutacc_af2_attention_inputs_t_wrapper *global_attention_q_ptr = new kutacc_af2_attention_inputs_t_wrapper(q_tw, k_tw, v_tw, gate_tw, q_avg_tw, batch, seq_len);
    kutacc_af2_global_attention(global_attention_q_ptr, q_data_tw.get_tensor(), q_mask_tw.get_tensor(), global_attention_weight_ptr, out_tw.get_tensor());
    delete global_attention_weight_ptr;
    delete global_attention_q_ptr;
    return out;
}
}

//test.py
import torch
import kpex._C as kernel
def test_global_attention(batch, seq_len, nchannels, nheads, head_size):
    out = kernel.alphafold.test_global_attention(batch, seq_len, nchannels, nheads, head_size)
    return out

//output:
tensor([[[768., 768., 768., 768., 768., 768., 768., 768., 768., 768., 768.,
          768., 768., 768., 768., 768., 768., 768., 768., 768., 768., 768.,
          768., 768., 768., 768., 768., 768., 768., 768., 768., 768., 768.,
          768., 768., 768., 768., 768., 768., 768., 768., 768., 768., 768.,
          768., 768., 768., 768., 768., 768., 768., 768., 768., 768., 768.,
          768., 768., 768., 768., 768., 768., 768., 768., 768.],
         [768., 768., 768., 768., 768., 768., 768., 768., 768., 768., 768.,
          768., 768., 768., 768., 768., 768., 768., 768., 768., 768., 768.,
          768., 768., 768., 768., 768., 768., 768., 768., 768., 768., 768.,
          768., 768., 768., 768., 768., 768., 768., 768., 768., 768., 768.,
          768., 768., 768., 768., 768., 768., 768., 768., 768., 768., 768.,
          768., 768., 768., 768., 768., 768., 768., 768., 768.]],

        [[768., 768., 768., 768., 768., 768., 768., 768., 768., 768., 768.,
          768., 768., 768., 768., 768., 768., 768., 768., 768., 768., 768.,
          768., 768., 768., 768., 768., 768., 768., 768., 768., 768., 768.,
          768., 768., 768., 768., 768., 768., 768., 768., 768., 768., 768.,
          768., 768., 768., 768., 768., 768., 768., 768., 768., 768., 768.,
          768., 768., 768., 768., 768., 768., 768., 768., 768.],
         [768., 768., 768., 768., 768., 768., 768., 768., 768., 768., 768.,
          768., 768., 768., 768., 768., 768., 768., 768., 768., 768., 768.,
          768., 768., 768., 768., 768., 768., 768., 768., 768., 768., 768.,
          768., 768., 768., 768., 768., 768., 768., 768., 768., 768., 768.,
          768., 768., 768., 768., 768., 768., 768., 768., 768., 768., 768.,
          768., 768., 768., 768., 768., 768., 768., 768., 768.]]],
       dtype=torch.bfloat16)