kutacc_af2_transition
transition算子:输入act张量,使用layernorm将输入归一化,在进行linear特性扩展,使用relu进行非线性激活,最后再进行linear特征压缩,然后返回输出
接口定义
kutacc_export void kutacc_af2_transition(kutacc_af2_trans_act_inputs_t *trans_inputs_ptr, kutacc_af2_trans_weights_t * trans_weights_ptr, kutacc_tensor_hout);
参数
表1 参数定义
参数名 |
类型 |
描述 |
输入/输出 |
|---|---|---|---|
trans_inputs_ptr |
kutacc_af2_trans_act_inputs_t * |
kutacc_af2_trans_act_inputs_t类型的指针,结构体定义见表2 kutacc_af2_trans_act_inputs_t数据结构表定义 |
输入 |
trans_weights_ptr |
kutacc_af2_trans_weights_t * |
kutacc_af2_trans_weights_t类型的指针,结构体定义见表3 kutacc_af2_trans_weights_t数据结构表定义 |
输入 |
out |
kutacc_tensor_h |
输出数据 |
输出 |
表2 kutacc_af2_trans_act_inputs_t数据结构表定义
参数名 |
类型 |
描述 |
输入/输出 |
|---|---|---|---|
batch |
int64_t |
act的第一维度大小 |
输入 |
n_res |
int64_t |
act的第二维度大小 |
输入 |
input_act |
kutacc_tensor_h |
act经过layernorm后获得的输入矩阵 |
输入 |
intermediate_act |
kutacc_tensor_h |
第一次线性变化生成的中间变量,用于第二次线性变化作为输入 |
输入 |
表3 kutacc_af2_trans_weights_t数据结构表定义
参数名 |
类型 |
描述 |
输入/输出 |
|---|---|---|---|
c_o |
int64_t |
input_ln_w(见表4参数input_ln_w)第一维度大小 |
输入 |
c_i |
int64_t |
linear1_w第一维度大小 |
输入 |
linear1_w |
kutacc_tensor_h |
第一次线性变化权重 |
输入 |
linear1_b |
kutacc_tensor_h |
第一次线性变化偏置 |
输入 |
linear2_w |
kutacc_tensor_h |
第二次线性变化权重 |
输入 |
lienar2_b |
kutacc_tensor_h |
第二次线性变化偏置 |
输入 |
transition整数参数应满足的约束关系:
batch, n_res, c_o, c_i > 0,
batch, n_res, c_o, c_i, batch * n_res < INT64_MAX
在KPEX层构造参数时各参数的形状约束表(括号之后为全名):
表4 KPEX transition入参张量形状及描述
参数名 |
形状 |
描述 |
|---|---|---|
input_ln_w(input_layer_norm_weight) |
[c_o] |
act layernorm归一化过程中的权重参数 |
input_ln_b(input_layer_norm_bias) |
[c_o] |
act layernorm归一化过程中的偏置参数 |
linear_1_w(transition1_weight) |
[c_i, c_o] |
见表3参数linear1_w |
linear_1_b(transition1_bias) |
[c_i] |
见表3参数linear1_b |
linear_2_w(transition2_weight) |
[c_o, c_i] |
见表3参数linear2_w |
linear_2_b(transition2_bias) |
[c_o] |
见表3参数linear2_b |
act |
[batch, n_res, c_o] |
输入张量 |
示例
C++ interface:
// test_transition.h
#ifndef KPEX_TPP_ALPHAFOLD_TEST_TRANSITION_H
#define KPEX_TPP_ALPHAFOLD_TEST_TRANSITION_H
#include <ATen/core/Tensor.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/ones.h>
#include <ATen/ops/zeros.h>
#include <ATen/ops/full.h>
#include <ATen/native/cpu/utils.h>
#include <c10/core/ScalarType.h>
namespace alphafold {
at::Tensor test_transition(int64_t batch, int64_t n_res, int64_t c_o, int64_t c_i);
}
#endif
// test_transition.cpp
#include "kutacc.h"
#include "test_transition.h"
#include "transition.h"
#include <utils/memory.h>
#include "utils/TensorWrapper.h"
#include "utils/layernorm.h"
namespace alphafold {
at::Tensor test_transition(int64_t batch, int64_t n_res, int64_t c_o, int64_t c_i)
{
float a = 0.2f;
at::Tensor act = at::full({batch, n_res, c_o}, a, at::TensorOptions().device(kpex::device()).dtype(at::kBFloat16));
at::Tensor out = act.new_empty(act.sizes());
float b = 0.1f;
float c = 0.5f;
at::Tensor input_ln_w = at::ones({c_o}, at::TensorOptions().device(kpex::device()).dtype(at::kFloat));
at::Tensor input_ln_b = at::zeros({c_o}, input_ln_w.options());
at::Tensor linear1_w = at::full({c_i, c_o}, b, act.options());
at::Tensor linear1_b = at::full({c_i}, c, input_ln_w.options());
at::Tensor linear1_w_new = linear_weight_prepack(linear1_w);
at::Tensor linear2_w = at::full({c_o, c_i}, b, act.options());
at::Tensor linear2_w_new = linear_weight_prepack(linear2_w);
at::Tensor linear_2_b = at::zeros({c_o}, input_ln_w.options());
at::Tensor input_act = layernorm(act, input_ln_w, input_ln_b);
at::Tensor intermediate_act = act.new_empty({batch * n_res, c_i});
kutacc::TensorWrapper input_act_tw = convert_to_tensor_wrapper(input_act);
kutacc::TensorWrapper linear1_w_tw = convert_to_tensor_wrapper(linear1_w_new);
kutacc::TensorWrapper linear1_b_tw = convert_to_tensor_wrapper(linear1_b);
kutacc::TensorWrapper linear2_w_tw = convert_to_tensor_wrapper(linear2_w_new);
kutacc::TensorWrapper linear2_b_tw = convert_to_tensor_wrapper(linear2_b);
kutacc::TensorWrapper intermediate_act_tw = convert_to_tensor_wrapper(intermediate_act);
kutacc::TensorWrapper out_tw = convert_to_tensor_wrapper(out);
kutacc_af2_trans_weights_t_wrapper *trans_weights_ptr = new kutacc_af2_trans_weights_t_wrapper(linear1_w_tw, linear1_b_tw, linear2_w_tw, linear2_b_tw, c_o, c_i);
kutacc_af2_trans_act_inputs_t_wrapper *trans_inputs_ptr = new kutacc_af2_trans_act_inputs_t_wrapper(input_act_tw, intermediate_act_tw, batch, n_res);
kutacc_af2_transition(trans_inputs_ptr, trans_weights_ptr, out_tw.get_tensor());
delete trans_weights_ptr;
delete trans_inputs_ptr;
return out;
}
}
// bind.h
#include <torch/extension.h>
#include "test_transition.h"
namespace alphafold {
inline void bind(pybind11::module &m)
{
auto submodule = m.def_submodule("alphafold");
submodule.def("test_transition", &test_transition, py::arg("batch"), py::arg("n_res"), py::arg("c_o"), py::arg("c_i"));
}
}
// test.py
import torch
from torch import nn
import numpy as np
import torch.distributed as dist
import kpex._C as kernel
import kpex
import os
def test_transition(batch, n_res, c_o, c_i):
out = kernel.alphafold.test_transition(batch, n_res, c_o, c_i)
return out