gating_attention
gating_attention函数通过在标准多头自注意力(multi-self-attention)输出上施加可学习的门控(gating),来动态地筛选和控制信息流,从而提升模型对关键残基或序列位置的敏感性与表达能力。
接口定义
void gating_attention(int64_t batch, int64_t seq_len, int64_t nchannels, int64_t nheads, int64_t head_size, int64_t block_size, void* bias, std::vector<int64_t> bias_strides, void* nonbatched_bias, std::vector<int64_t> nonbatched_bias_sizes, std::vector<int64_t> nonbatched_bias_strides, void* input, void* out, std::vector<int64_t> out_strides, void* gate, void* k, void* v, void* q, std::vector<int64_t> q_strides, std::vector<int64_t> gate_strides, std::vector<int64_t> v_strides, std::vector<int64_t> k_strides, void* value_w, void* weighted_avg, std::vector<int64_t> weighted_avg_strides, void* query_w, void* key_w, void* gating_w, void* gating_b, void* output_w, void* output_b);
参数
参数名  | 
类型  | 
描述  | 
输入/输出  | 
|---|---|---|---|
batch  | 
int64_t  | 
输入批次  | 
输入  | 
seq_len  | 
int64_t  | 
输入数据的序列长度  | 
输入  | 
nchannels  | 
int64_t  | 
输入权重的总数  | 
输入  | 
nheads  | 
int64_t  | 
输入权重的head数  | 
输入  | 
head_size  | 
int64_t  | 
输入权重每个head中的数据数量  | 
输入  | 
block_size  | 
int64_t  | 
并行操作中的数据块切分大小  | 
输入  | 
bias  | 
void*  | 
神经网络中的attention偏移量数据  | 
输入  | 
bias_strides  | 
std::vector<int64_t>  | 
神经网络中的attention偏移量数据的步长  | 
输入  | 
nonbatched_bias  | 
void*  | 
神经网络中的通用偏移量数据  | 
输入  | 
nonbatched_bias_sizes  | 
std::vector<int64_t>  | 
神经网络中的通用偏移量数据的大小  | 
输入  | 
nonbatched_bias_strides  | 
std::vector<int64_t>  | 
神经网络中的通用偏移量数据的步长  | 
输入  | 
input  | 
void*  | 
输入数据  | 
输入  | 
out  | 
void*  | 
输出数据  | 
输出  | 
out_strides  | 
std::vector<int64_t>  | 
输出数据的步长  | 
输入  | 
gate  | 
void*  | 
输入中间层gate数据  | 
输入  | 
k  | 
void*  | 
输入中间层k数据  | 
输入  | 
v  | 
void*  | 
输入中间层v数据  | 
输入  | 
q  | 
void*  | 
输入中间层q数据  | 
输入  | 
q_strides  | 
std::vector<int64_t>  | 
输入中间层q数据步长  | 
输入  | 
gate_strides  | 
std::vector<int64_t>  | 
输入中间层gate数据步长  | 
输入  | 
v_strides  | 
std::vector<int64_t>  | 
输入中间层v数据步长  | 
输入  | 
k_strides  | 
std::vector<int64_t>  | 
输入中间层k数据步长  | 
输入  | 
value_w  | 
void*  | 
权重value_w数据  | 
输入  | 
weighted_avg  | 
void*  | 
权重weighted_avg数据  | 
输入  | 
weighted_avg_strides  | 
std::vector<int64_t>  | 
权重weighted_avg数据步长  | 
输入  | 
query_w  | 
void*  | 
权重query_w数据  | 
输入  | 
key_w  | 
void*  | 
权重key_w数据  | 
输入  | 
gating_w  | 
void*  | 
权重gating_w数据  | 
输入  | 
gating_b  | 
void*  | 
门控偏移量数据  | 
输入  | 
output_w  | 
void*  | 
权重output_w数据  | 
输入  | 
output_b  | 
void*  | 
输出矩阵乘法偏移量数据  | 
输入  | 
示例
#include "kutacc.h"
#include "arm_neon.h"
  
    __bf16 A[2][2];
    __bf16 B[2][2];
 
    __bf16 C[2][2];
    __bf16 D[2][2];
 
    __bf16 output[2][2];
    __bf16 output_w[2][2];
    __bf16 output_b[2][2];
 
    B[0][0] = vcvth_bf16_f32((float)1.0);
    B[0][1] = vcvth_bf16_f32((float)2.0);
    B[1][0] = vcvth_bf16_f32((float)5.0);
    B[1][1] = vcvth_bf16_f32((float)6.0);
 
    C[0][0] = vcvth_bf16_f32((float)1.0);
    C[0][1] = vcvth_bf16_f32((float)2.0);
    C[1][0] = vcvth_bf16_f32((float)3.0);
    C[1][1] = vcvth_bf16_f32((float)4.0);
 
    D[0][0] = vcvth_bf16_f32((float)4.0);
    D[0][1] = vcvth_bf16_f32((float)3.0);
    D[1][0] = vcvth_bf16_f32((float)2.0);
    D[1][1] = vcvth_bf16_f32((float)1.0);
 
    __bf16 bias[4];
 
    bias[0] = vcvth_bf16_f32((float)1.0);
    bias[1] = vcvth_bf16_f32((float)2.0);
    bias[2] = vcvth_bf16_f32((float)3.0);
    bias[3] = vcvth_bf16_f32((float)4.0);
 
    kutacc::gating_attention(1, 1, 1, 1, 1, 1,
                             bias, {2,2}, bias, {2,2}, {2,2},
                             B, output, {2,2}, C, D, B, C, {2,2},
                             {2,2}, {2,2}, {2,2}, D,
                             B, {2,2}, C, D, B, C, output_w, output_b);