鲲鹏社区首页
中文
注册
开发者
我要评分
获取效率
正确性
完整性
易理解
在线提单
论坛求助

kutacc_af2_transition

transition算子:输入act张量,使用layernorm将输入归一化,在进行linear特性扩展,使用relu进行非线性激活,最后再进行linear特征压缩,然后返回输出

接口定义

kutacc_export void kutacc_af2_transition(kutacc_af2_trans_act_inputs_t *trans_inputs_ptr, kutacc_af2_trans_weights_t * trans_weights_ptr, kutacc_tensor_hout);

参数

表1 参数定义

参数名

类型

描述

输入/输出

trans_inputs_ptr

kutacc_af2_trans_act_inputs_t *

kutacc_af2_trans_act_inputs_t类型的指针,结构体定义见表2 kutacc_af2_trans_act_inputs_t数据结构表定义

输入

trans_weights_ptr

kutacc_af2_trans_weights_t *

kutacc_af2_trans_weights_t类型的指针,结构体定义见表3 kutacc_af2_trans_weights_t数据结构表定义

输入

out

kutacc_tensor_h

输出数据

输出

表2 kutacc_af2_trans_act_inputs_t数据结构表定义

参数名

类型

描述

输入/输出

batch

int64_t

act的第一维度大小

输入

n_res

int64_t

act的第二维度大小

输入

input_act

kutacc_tensor_h

act经过layernorm后获得的输入矩阵

输入

intermediate_act

kutacc_tensor_h

第一次线性变化生成的中间变量,用于第二次线性变化作为输入

输入

表3 kutacc_af2_trans_weights_t数据结构表定义

参数名

类型

描述

输入/输出

c_o

int64_t

input_ln_w(见表4参数input_ln_w)第一维度大小

输入

c_i

int64_t

linear1_w第一维度大小

输入

linear1_w

kutacc_tensor_h

第一次线性变化权重

输入

linear1_b

kutacc_tensor_h

第一次线性变化偏置

输入

linear2_w

kutacc_tensor_h

第二次线性变化权重

输入

lienar2_b

kutacc_tensor_h

第二次线性变化偏置

输入

transition整数参数应满足的约束关系:

batch, n_res, c_o, c_i > 0,

batch, n_res, c_o, c_i, batch * n_res < INT64_MAX

在KPEX层构造参数时各参数的形状约束表(括号之后为全名):

表4 KPEX transition入参张量形状及描述

参数名

形状

描述

input_ln_w(input_layer_norm_weight)

[c_o]

act layernorm归一化过程中的权重参数

input_ln_b(input_layer_norm_bias)

[c_o]

act layernorm归一化过程中的偏置参数

linear_1_w(transition1_weight)

[c_i, c_o]

见表3参数linear1_w

linear_1_b(transition1_bias)

[c_i]

见表3参数linear1_b

linear_2_w(transition2_weight)

[c_o, c_i]

见表3参数linear2_w

linear_2_b(transition2_bias)

[c_o]

见表3参数linear2_b

act

[batch, n_res, c_o]

输入张量

示例

C++ interface:

// test_transition.h
#ifndef KPEX_TPP_ALPHAFOLD_TEST_TRANSITION_H
#define KPEX_TPP_ALPHAFOLD_TEST_TRANSITION_H
#include <ATen/core/Tensor.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/ones.h>
#include <ATen/ops/zeros.h>
#include <ATen/ops/full.h>
#include <ATen/native/cpu/utils.h>
#include <c10/core/ScalarType.h>
namespace alphafold {
at::Tensor test_transition(int64_t batch, int64_t n_res, int64_t c_o, int64_t c_i);
}
#endif

// test_transition.cpp
#include "kutacc.h"
#include "test_transition.h"
#include "transition.h"
#include <utils/memory.h>
#include "utils/TensorWrapper.h"
#include "utils/layernorm.h"
namespace alphafold {
at::Tensor test_transition(int64_t batch, int64_t n_res, int64_t c_o, int64_t c_i)
{
    float a = 0.2f;
    at::Tensor act = at::full({batch, n_res, c_o}, a, at::TensorOptions().device(kpex::device()).dtype(at::kBFloat16));
    at::Tensor out = act.new_empty(act.sizes());
    float b = 0.1f;
    float c = 0.5f;
    at::Tensor input_ln_w = at::ones({c_o}, at::TensorOptions().device(kpex::device()).dtype(at::kFloat));
    at::Tensor input_ln_b = at::zeros({c_o}, input_ln_w.options());
    at::Tensor linear1_w = at::full({c_i, c_o}, b, act.options());
    at::Tensor linear1_b = at::full({c_i}, c, input_ln_w.options());
    at::Tensor linear1_w_new = linear_weight_prepack(linear1_w);
    at::Tensor linear2_w = at::full({c_o, c_i}, b, act.options());
    at::Tensor linear2_w_new = linear_weight_prepack(linear2_w);
    at::Tensor linear_2_b = at::zeros({c_o}, input_ln_w.options());
    at::Tensor input_act = layernorm(act, input_ln_w, input_ln_b);
    at::Tensor intermediate_act = act.new_empty({batch * n_res, c_i});
    kutacc::TensorWrapper input_act_tw = convert_to_tensor_wrapper(input_act);
    kutacc::TensorWrapper linear1_w_tw = convert_to_tensor_wrapper(linear1_w_new);
    kutacc::TensorWrapper linear1_b_tw = convert_to_tensor_wrapper(linear1_b);
    kutacc::TensorWrapper linear2_w_tw = convert_to_tensor_wrapper(linear2_w_new);
    kutacc::TensorWrapper linear2_b_tw = convert_to_tensor_wrapper(linear2_b);
    kutacc::TensorWrapper intermediate_act_tw = convert_to_tensor_wrapper(intermediate_act);
    kutacc::TensorWrapper out_tw = convert_to_tensor_wrapper(out);
    
    kutacc_af2_trans_weights_t_wrapper *trans_weights_ptr = new kutacc_af2_trans_weights_t_wrapper(linear1_w_tw, linear1_b_tw, linear2_w_tw, linear2_b_tw, c_o, c_i);
    kutacc_af2_trans_act_inputs_t_wrapper *trans_inputs_ptr = new kutacc_af2_trans_act_inputs_t_wrapper(input_act_tw, intermediate_act_tw, batch, n_res);
    kutacc_af2_transition(trans_inputs_ptr, trans_weights_ptr, out_tw.get_tensor());
    delete trans_weights_ptr;
    delete trans_inputs_ptr;
    return out;
}
}

// bind.h
#include <torch/extension.h>
#include "test_transition.h"

namespace alphafold {
inline void bind(pybind11::module &m)
{
    auto submodule = m.def_submodule("alphafold");
    submodule.def("test_transition", &test_transition, py::arg("batch"), py::arg("n_res"), py::arg("c_o"), py::arg("c_i"));
}
}

// test.py
import torch
from torch import nn
import numpy as np
import torch.distributed as dist
import kpex._C as kernel
import kpex
import os

def test_transition(batch, n_res, c_o, c_i):
    out = kernel.alphafold.test_transition(batch, n_res, c_o, c_i)
    return out