gating_attention
gating_attention函数通过在标准多头自注意力(multi-self-attention)输出上施加可学习的门控(gating),来动态地筛选和控制信息流,从而提升模型对关键残基或序列位置的敏感性与表达能力。
接口定义
void gating_attention(int64_t batch, int64_t seq_len, int64_t nchannels, int64_t nheads, int64_t head_size, int64_t block_size, void* bias, std::vector<int64_t> bias_strides, void* nonbatched_bias, std::vector<int64_t> nonbatched_bias_sizes, std::vector<int64_t> nonbatched_bias_strides, void* input, void* out, std::vector<int64_t> out_strides, void* gate, void* k, void* v, void* q, std::vector<int64_t> q_strides, std::vector<int64_t> gate_strides, std::vector<int64_t> v_strides, std::vector<int64_t> k_strides, void* value_w, void* weighted_avg, std::vector<int64_t> weighted_avg_strides, void* query_w, void* key_w, void* gating_w, void* gating_b, void* output_w, void* output_b);
参数
参数名 |
类型 |
描述 |
输入/输出 |
---|---|---|---|
batch |
int64_t |
输入批次 |
输入 |
seq_len |
int64_t |
输入数据的序列长度 |
输入 |
nchannels |
int64_t |
输入权重的总数 |
输入 |
nheads |
int64_t |
输入权重的head数 |
输入 |
head_size |
int64_t |
输入权重每个head中的数据数量 |
输入 |
block_size |
int64_t |
并行操作中的数据块切分大小 |
输入 |
bias |
void* |
神经网络中的attention偏移量数据 |
输入 |
bias_strides |
std::vector<int64_t> |
神经网络中的attention偏移量数据的步长 |
输入 |
nonbatched_bias |
void* |
神经网络中的通用偏移量数据 |
输入 |
nonbatched_bias_sizes |
std::vector<int64_t> |
神经网络中的通用偏移量数据的大小 |
输入 |
nonbatched_bias_strides |
std::vector<int64_t> |
神经网络中的通用偏移量数据的步长 |
输入 |
input |
void* |
输入数据 |
输入 |
out |
void* |
输出数据 |
输出 |
out_strides |
std::vector<int64_t> |
输出数据的步长 |
输入 |
gate |
void* |
输入中间层gate数据 |
输入 |
k |
void* |
输入中间层k数据 |
输入 |
v |
void* |
输入中间层v数据 |
输入 |
q |
void* |
输入中间层q数据 |
输入 |
q_strides |
std::vector<int64_t> |
输入中间层q数据步长 |
输入 |
gate_strides |
std::vector<int64_t> |
输入中间层gate数据步长 |
输入 |
v_strides |
std::vector<int64_t> |
输入中间层v数据步长 |
输入 |
k_strides |
std::vector<int64_t> |
输入中间层k数据步长 |
输入 |
value_w |
void* |
权重value_w数据 |
输入 |
weighted_avg |
void* |
权重weighted_avg数据 |
输入 |
weighted_avg_strides |
std::vector<int64_t> |
权重weighted_avg数据步长 |
输入 |
query_w |
void* |
权重query_w数据 |
输入 |
key_w |
void* |
权重key_w数据 |
输入 |
gating_w |
void* |
权重gating_w数据 |
输入 |
gating_b |
void* |
门控偏移量数据 |
输入 |
output_w |
void* |
权重output_w数据 |
输入 |
output_b |
void* |
输出矩阵乘法偏移量数据 |
输入 |
示例
#include "kutacc.h" #include "arm_neon.h" __bf16 A[2][2]; __bf16 B[2][2]; __bf16 C[2][2]; __bf16 D[2][2]; __bf16 output[2][2]; __bf16 output_w[2][2]; __bf16 output_b[2][2]; B[0][0] = vcvth_bf16_f32((float)1.0); B[0][1] = vcvth_bf16_f32((float)2.0); B[1][0] = vcvth_bf16_f32((float)5.0); B[1][1] = vcvth_bf16_f32((float)6.0); C[0][0] = vcvth_bf16_f32((float)1.0); C[0][1] = vcvth_bf16_f32((float)2.0); C[1][0] = vcvth_bf16_f32((float)3.0); C[1][1] = vcvth_bf16_f32((float)4.0); D[0][0] = vcvth_bf16_f32((float)4.0); D[0][1] = vcvth_bf16_f32((float)3.0); D[1][0] = vcvth_bf16_f32((float)2.0); D[1][1] = vcvth_bf16_f32((float)1.0); __bf16 bias[4]; bias[0] = vcvth_bf16_f32((float)1.0); bias[1] = vcvth_bf16_f32((float)2.0); bias[2] = vcvth_bf16_f32((float)3.0); bias[3] = vcvth_bf16_f32((float)4.0); kutacc::gating_attention(1, 1, 1, 1, 1, 1, bias, {2,2}, bias, {2,2}, {2,2}, B, output, {2,2}, C, D, B, C, {2,2}, {2,2}, {2,2}, {2,2}, D, B, {2,2}, C, D, B, C, output_w, output_b);