鲲鹏社区首页
中文
注册
我要评分
文档获取效率
文档正确性
内容完整性
文档易理解
在线提单
论坛求助

gating_attention

gating_attention函数通过在标准多头自注意力(multi-self-attention)输出上施加可学习的门控(gating),来动态地筛选和控制信息流,从而提升模型对关键残基或序列位置的敏感性与表达能力。

接口定义

void gating_attention(int64_t batch, int64_t seq_len, int64_t nchannels, int64_t nheads, int64_t head_size, int64_t block_size, void* bias, std::vector<int64_t> bias_strides, void* nonbatched_bias, std::vector<int64_t> nonbatched_bias_sizes, std::vector<int64_t> nonbatched_bias_strides, void* input, void* out, std::vector<int64_t> out_strides, void* gate, void* k, void* v, void* q, std::vector<int64_t> q_strides, std::vector<int64_t> gate_strides, std::vector<int64_t> v_strides, std::vector<int64_t> k_strides, void* value_w, void* weighted_avg, std::vector<int64_t> weighted_avg_strides, void* query_w, void* key_w, void* gating_w, void* gating_b, void* output_w, void* output_b);

参数

参数名

类型

描述

输入/输出

batch

int64_t

输入批次

输入

seq_len

int64_t

输入数据的序列长度

输入

nchannels

int64_t

输入权重的总数

输入

nheads

int64_t

输入权重的head数

输入

head_size

int64_t

输入权重每个head中的数据数量

输入

block_size

int64_t

并行操作中的数据块切分大小

输入

bias

void*

神经网络中的attention偏移量数据

输入

bias_strides

std::vector<int64_t>

神经网络中的attention偏移量数据的步长

输入

nonbatched_bias

void*

神经网络中的通用偏移量数据

输入

nonbatched_bias_sizes

std::vector<int64_t>

神经网络中的通用偏移量数据的大小

输入

nonbatched_bias_strides

std::vector<int64_t>

神经网络中的通用偏移量数据的步长

输入

input

void*

输入数据

输入

out

void*

输出数据

输出

out_strides

std::vector<int64_t>

输出数据的步长

输入

gate

void*

输入中间层gate数据

输入

k

void*

输入中间层k数据

输入

v

void*

输入中间层v数据

输入

q

void*

输入中间层q数据

输入

q_strides

std::vector<int64_t>

输入中间层q数据步长

输入

gate_strides

std::vector<int64_t>

输入中间层gate数据步长

输入

v_strides

std::vector<int64_t>

输入中间层v数据步长

输入

k_strides

std::vector<int64_t>

输入中间层k数据步长

输入

value_w

void*

权重value_w数据

输入

weighted_avg

void*

权重weighted_avg数据

输入

weighted_avg_strides

std::vector<int64_t>

权重weighted_avg数据步长

输入

query_w

void*

权重query_w数据

输入

key_w

void*

权重key_w数据

输入

gating_w

void*

权重gating_w数据

输入

gating_b

void*

门控偏移量数据

输入

output_w

void*

权重output_w数据

输入

output_b

void*

输出矩阵乘法偏移量数据

输入

示例

C++ interface:
#include "kutacc.h"
#include "arm_neon.h"
  
    __bf16 A[2][2];
    __bf16 B[2][2];
 
    __bf16 C[2][2];
    __bf16 D[2][2];
 
    __bf16 output[2][2];
    __bf16 output_w[2][2];
    __bf16 output_b[2][2];
 
    B[0][0] = vcvth_bf16_f32((float)1.0);
    B[0][1] = vcvth_bf16_f32((float)2.0);
    B[1][0] = vcvth_bf16_f32((float)5.0);
    B[1][1] = vcvth_bf16_f32((float)6.0);
 
    C[0][0] = vcvth_bf16_f32((float)1.0);
    C[0][1] = vcvth_bf16_f32((float)2.0);
    C[1][0] = vcvth_bf16_f32((float)3.0);
    C[1][1] = vcvth_bf16_f32((float)4.0);
 
    D[0][0] = vcvth_bf16_f32((float)4.0);
    D[0][1] = vcvth_bf16_f32((float)3.0);
    D[1][0] = vcvth_bf16_f32((float)2.0);
    D[1][1] = vcvth_bf16_f32((float)1.0);
 
    __bf16 bias[4];
 
    bias[0] = vcvth_bf16_f32((float)1.0);
    bias[1] = vcvth_bf16_f32((float)2.0);
    bias[2] = vcvth_bf16_f32((float)3.0);
    bias[3] = vcvth_bf16_f32((float)4.0);
 
    kutacc::gating_attention(1, 1, 1, 1, 1, 1,
                             bias, {2,2}, bias, {2,2}, {2,2},
                             B, output, {2,2}, C, D, B, C, {2,2},
                             {2,2}, {2,2}, {2,2}, D,
                             B, {2,2}, C, D, B, C, output_w, output_b);