kutacc_af2_triangle_multiplication_last
triangle_multiplication的最后步骤,基于out_linear步骤获得的out与gate相乘生成输出张量。
接口定义
void kutacc_af2_triangle_multiplication_last(kutacc_tensor_h out, kutacc_tensor_h gate, int64_t n_res, int64_t n_res_gather, int64_t c_o);
参数
表1 入参定义
参数名 |
类型 |
描述 |
输入/输出 |
|---|---|---|---|
out |
kutacc_tensor_h |
输出数据 |
输出 |
gate |
kutacc_tensor_h |
门控张量 |
输入 |
n_res |
int64_t |
残基数 |
输入 |
n_res_gather |
int64_t |
残基数 |
输入 |
c_o |
int64_t |
输出特征维度 |
输入 |
triangle_multiplication的整数参数应满足的约束关系:
n_res, n_res_gather, c_o > 0,
n_res * n_res_gather < INT64_MAX,
c_z = c_o, c_i = c_o,
单进程情况下需要注意n_res = n_res_gather,多进程则不需要满足该条件
在构建用例及使用KPEX时应满足以下算子形状约束
表2 kpex triangle_multiplication入参形状约束
tensor |
shape |
描述 |
|---|---|---|
act |
[n_res, n_res_gather, c_z] |
KPEX层输入 |
mask |
[n_res, n_res_gather] |
KPEX层掩码输入 |
input_ln_w |
[c_o] |
输入参数act经过layernorm变化过程的权重 |
input_ln_b |
[c_o] |
输入参数act经过layernorm变化过程的偏置 |
left_proj_w |
[c_i, c_o] |
见8.1.2.2.2.1.1.2 表2kutacc_af2_tm_proj_weights_t数据结构表定义参数proj_w |
left_proj_b |
[c_i] |
见8.1.2.2.2.1.1.2 表2kutacc_af2_tm_proj_weights_t数据结构表定义参数proj_b |
right_proj_w |
[c_i, c_o] |
见8.1.2.2.2.1.1.2 表2kutacc_af2_tm_proj_weights_t数据结构表定义参数proj_w |
right_proj_b |
[c_i] |
见8.1.2.2.2.1.1.2 表2kutacc_af2_tm_proj_weights_t数据结构表定义参数proj_b |
left_gate_w |
[c_i, c_o] |
见8.1.2.2.2.1.1.2 表2kutacc_af2_tm_proj_weights_t数据结构表定义参数gate_w |
left_gate_b |
[c_i] |
见8.1.2.2.2.1.1.2 表2kutacc_af2_tm_proj_weights_t数据结构表定义参数gate_b |
right_gate_w |
[c_i, c_o] |
见8.1.2.2.2.1.1.2 表2kutacc_af2_tm_proj_weights_t数据结构表定义参数gate_w |
right_gate_b |
[c_i] |
见8.1.2.2.2.1.1.2 表2kutacc_af2_tm_proj_weights_t数据结构表定义参数gate_b |
gating_w |
[c_i, c_o] |
见上一节表3 kutacc_af2_tm_linear_weights_t数据结构表定义参数 gating_w |
gating_b |
[c_i] |
见上一节表3 kutacc_af2_tm_linear_weights_t数据结构表定义参数 gating_b |
center_ln_w |
[c_i] |
center_act layernorm归一化过程中的权重 |
center_ln_b |
[c_i] |
center_act layernorm归一化过程中的偏置 |
output_proj_w |
[c_o, c_i] |
见上一节表3 kutacc_af2_tm_linear_weights_t数据结构表定义参数 output_proj_w |
output_proj_b |
[c_o] |
见上一节表3 kutacc_af2_tm_linear_weights_t数据结构表定义参数 output_proj_b |
示例
// test_triangle_multiplication.h
#define KPEX_TPP_ALPHAFOLD_TEST_TMP_H
#include <ATen/core/Tensor.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/ones.h>
#include <ATen/ops/zeros.h>
#include <ATen/ops/full.h>
#include <ATen/native/cpu/utils.h>
#include <c10/core/ScalarType.h>
namespace alphafold {
at::Tensor test_triangle_multiplication(int64_t n_res, int64_t n_res_gather, int64_t c_o, int64_t c_i);
}
#endif
// bind.h
#include <torch/extension.h>
#include "test_triangle_multiplication.h"
namespace alphafold {
inline void bind(pybind11::module &m)
{
auto submodule = m.def_submodule("alphafold");
submodule.def("test_triangle_multiplication", &test_triangle_multiplication,
py::arg("n_res"), py::arg("n_res_gather"), py::arg("c_o"), py::arg("c_i"));
}
}
// test.py
import torch
from torch import nn
import numpy as np
import torch.distributed as dist
import kpex._C as kernel
import kpex
import os
def test_triangle_multiplication(n_res, n_res_gather, c_o, c_i):
out = kernel.alphafold.test_triangle_multiplication(n_res, n_res_gather, c_o, c_i)
return out
// test_triangle_multiplication.cpp
#include "test_triangle_multiplication.h"
#include "triangle_multiplication.h"
#include "kutacc.h"
#include "utils/memory.h"
#include "utils/layernorm.h"
#include "utils/TensorWrapper.h"
namespace alphafold {
at::Tensor test_triangle_multiplication(int64_t n_res, int64_t n_res_gather, int64_t c_o, int64_t c_i)
{
float a = 0.5f;
float b = 0.2f;
float c = 0.4f;
float d = 1.25f;
auto bf16_opt = at::TensorOptions().device(kpex::device()).dtype(at::kBFloat16);
auto float_opt = at::TensorOptions().device(kpex::device()).dtype(at::kFloat);
at::Tensor input_ln_w = at::full({c_o}, a, float_opt);
at::Tensor input_ln_b = at::full({c_o}, b, float_opt);
at::Tensor left_proj_w = linear_weight_prepack(at::ones({c_i, c_o}, bf16_opt));
at::Tensor left_proj_b = at::zeros({c_i}, float_opt);
at::Tensor right_proj_w = linear_weight_prepack(at::full({c_i, c_o}, c, bf16_opt));
at::Tensor right_proj_b = at::ones({c_i}, float_opt);
at::Tensor left_gate_w = linear_weight_prepack(at::ones({c_i, c_o}, bf16_opt));
at::Tensor left_gate_b = at::ones({c_i}, float_opt);
at::Tensor right_gate_w = linear_weight_prepack(at::zeros({c_i, c_o}, bf16_opt));
at::Tensor right_gate_b = at::zeros({c_i}, float_opt);
at::Tensor gating_w = linear_weight_prepack(at::full({c_i, c_o}, d, bf16_opt));
at::Tensor gating_b = at::zeros({c_i}, float_opt);
at::Tensor center_ln_w = at::ones({c_i}, float_opt);
at::Tensor center_ln_b = at::zeros({c_i}, float_opt);
at::Tensor output_proj_w = linear_weight_prepack(at::zeros({c_o, c_i}, bf16_opt));
at::Tensor output_proj_b = at::zeros({c_o}, float_opt);
float e = 1.2f;
at::Tensor act = at::full({n_res, n_res_gather, c_o}, e, bf16_opt);
at::Tensor mask = at::ones({n_res, n_res_gather}, bf16_opt);
at::Tensor out = at::empty(act.sizes(), act.options());
at::Tensor input_act = layernorm(act, input_ln_w, input_ln_b);
at::Tensor center_act;
at::Tensor left_proj_act = input_act.new_empty({c_i, n_res, n_res_gather});
at::Tensor right_proj_act = input_act.new_empty({c_i, n_res, n_res_gather});
at::Tensor gate = act.new_empty({n_res, n_res_gather, c_o});
auto left_proj_act_tw = convert_to_tensor_wrapper(left_proj_act);
auto right_proj_act_tw = convert_to_tensor_wrapper(right_proj_act);
auto gate_tw = convert_to_tensor_wrapper(gate);
auto input_act_tw = convert_to_tensor_wrapper(input_act);
auto mask_tw = convert_to_tensor_wrapper(mask);
auto left_proj_w_tw = convert_to_tensor_wrapper(left_proj_w);
auto left_proj_b_tw = convert_to_tensor_wrapper(left_proj_b);
auto left_gate_w_tw = convert_to_tensor_wrapper(left_gate_w);
auto left_gate_b_tw = convert_to_tensor_wrapper(left_gate_b);
auto right_proj_w_tw = convert_to_tensor_wrapper(right_proj_w);
auto right_proj_b_tw = convert_to_tensor_wrapper(right_proj_b);
auto right_gate_w_tw= convert_to_tensor_wrapper(right_gate_w);
auto right_gate_b_tw= convert_to_tensor_wrapper(right_gate_b);
at::Tensor gate_left = input_act.new_empty({c_i, n_res, n_res_gather});
auto gate_left_tw = convert_to_tensor_wrapper(gate_left);
kutacc_af2_tm_act_inputs_t_wrapper *left_acts_ptr = new kutacc_af2_tm_act_inputs_t_wrapper(left_proj_act_tw,
input_act_tw, gate_left_tw, n_res, n_res_gather);
kutacc_af2_tm_proj_weights_t_wrapper *left_weights_ptr = new kutacc_af2_tm_proj_weights_t_wrapper(left_proj_w_tw,
left_proj_b_tw, left_gate_w_tw, left_gate_b_tw, c_o, c_i);
kutacc_af2_triangle_multiplication_calc_proj(left_acts_ptr, mask_tw.get_tensor(), left_weights_ptr, false);
at::Tensor gate_right = input_act.new_empty({c_i, n_res, n_res_gather});
auto gate_right_tw = convert_to_tensor_wrapper(gate_right);
kutacc_af2_tm_act_inputs_t_wrapper *right_acts_ptr = new kutacc_af2_tm_act_inputs_t_wrapper(right_proj_act_tw,
input_act_tw, gate_right_tw, n_res, n_res_gather);
kutacc_af2_tm_proj_weights_t_wrapper *right_weights_ptr = new kutacc_af2_tm_proj_weights_t_wrapper(right_proj_w_tw,
right_proj_b_tw, right_gate_w_tw, right_gate_b_tw, c_o, c_i);
kutacc_af2_triangle_multiplication_calc_proj(right_acts_ptr, mask_tw.get_tensor(), right_weights_ptr, false);
center_act = act.new_empty({left_proj_act.sizes()[0], n_res_gather, n_res_gather});
auto center_act_tw = convert_to_tensor_wrapper(center_act);
kutacc_af2_triangle_multiplication_equation(center_act_tw.get_tensor(), left_proj_act_tw.get_tensor(),
right_proj_act_tw.get_tensor(), n_res_gather, true);
center_act = center_act.permute({1, 2, 0}).contiguous();
center_act = layernorm(center_act, center_ln_w, center_ln_b);
auto center_act_new_tw = convert_to_tensor_wrapper(center_act);
auto out_tw = convert_to_tensor_wrapper(out);
auto gating_w_tw = convert_to_tensor_wrapper(gating_w);
auto gating_b_tw = convert_to_tensor_wrapper(gating_b);
auto output_proj_w_tw = convert_to_tensor_wrapper(output_proj_w);
auto output_proj_b_tw = convert_to_tensor_wrapper(output_proj_b);
kutacc_af2_tm_linear_weights_t_wrapper *linear_weights_ptr = new kutacc_af2_tm_linear_weights_t_wrapper(gating_w_tw,
gating_b_tw, output_proj_w_tw, output_proj_b_tw, c_o, c_i);
kutacc_af2_triangle_multiplication_gate_and_out_linear(gate_tw.get_tensor(), out_tw.get_tensor(), left_acts_ptr,
center_act_new_tw.get_tensor(), linear_weights_ptr, true);
kutacc_af2_triangle_multiplication_last(out_tw.get_tensor(), gate_tw.get_tensor(), n_res, n_res_gather, c_o);
delete left_acts_ptr;
delete left_weights_ptr;
delete right_acts_ptr;
delete right_weights_ptr;
delete linear_weights_ptr;
return out;
}
}