kutacc_af2_outer_product_mean_chunk
outer_product_mean中用于分块计算结果的计算函数。
接口定义
kutacc_export void kutacc_af2_outer_product_mean_chunk(kutacc_af2_opm_act_inputs_t *opm_acts_ptr, kutacc_af2_opm_mask_inputs_t *opm_masks_ptr, kutacc_af2_opm_weights_t *opm_weights_ptr,kutacc_tensor_h out, int64_t left_block_size, int64_t right_block_size);
参数
表1 入参定义
参数名 |
类型 |
描述 |
输入/输出 |
|---|---|---|---|
opm_acts_ptr |
kutacc_af2_opm_act_inputs_t * |
kutacc_af2_opm_act_inputs_t类型的指针,具体数据结构定义见下方表2 kutacc_af2_opm_act_inputs_t 数据结构表定义 |
输入 |
opm_masks_ptr |
kutacc_af2_opm_mask_inputs_t * |
kutacc_af2_opm_mask_inputs_t类型的指针,具体数据结构定义见下方表3 kutacc_af2_opm_mask_inputs_t 数据结构表定义 |
输入 |
opm_weights_ptr |
kutacc_af2_opm_weights_t * |
kutacc_af2_opm_weights_t类型的指针,具体数据结构定义见下方表4 kutacc_af2_opm_weights_t数据结构表定义 |
输入 |
out |
kutacc_tensor_h |
输出数据 |
输出 |
left_block_size |
int64_t |
左分块大小 |
输入 |
right_block_size |
int64_t |
右分块大小 |
输入 |
表2 kutacc_af2_opm_act_inputs_t 数据结构表定义
参数名 |
类型 |
描述 |
输入/输出 |
|---|---|---|---|
n_seq |
int64_t |
序列数量 |
输入 |
n_res |
int64_t |
残基数量 |
输入 |
input_act |
kutacc_tensor_h |
输入激活张量 |
输入 |
left_proj |
kutacc_tensor_h |
左投影 |
输入 |
right_proj |
kutacc_tensor_h |
右投影 |
输入 |
left_proj_ |
kutacc_tensor_h |
经过掩码处理后的左投影 |
输入 |
right_proj_ |
kutacc_tensor_h |
经过掩码处理后的右投影 |
输入 |
表3 kutacc_af2_opm_mask_inputs_t 数据结构表定义
参数名 |
类型 |
描述 |
输入/输出 |
|---|---|---|---|
n_res_gather |
int64_t |
聚合后的残基数量 |
输入 |
mask_bias |
int64_t |
掩码张量地址偏移量 |
输入 |
mask |
kutacc_tensor_h |
掩码张量 |
输入 |
norm |
kutacc_tensor_h |
归一化因子张量 |
输入 |
表4 kutacc_af2_opm_weights_t 数据结构表定义
参数名 |
类型 |
描述 |
输入/输出 |
|---|---|---|---|
c_m |
int64_t |
输入特征维度 |
输入 |
c_i |
int64_t |
投影后的特征维度 |
输入 |
c_z |
int64_t |
输出特征维度 |
输入 |
left_proj_w |
kutacc_tensor_h |
左投影权重 |
输入 |
left_proj_b |
kutacc_tensor_h |
左投影偏移量 |
输入 |
right_proj_w |
kutacc_tensor_h |
右投影权重 |
输入 |
right_proj_b |
kutacc_tensor_h |
右投影偏移量 |
输入 |
outer_w |
kutacc_tensor_h |
输出权重 |
输入 |
outer_b |
kutacc_tensor_h |
输出偏移量 |
输入 |
outer_product_mean整数参数应满足的约束关系:
n_res, n_res_gather, c_i, c_z, n_res, n_res_gather, left_block_size, right_block_size > 0;
n_seq * n_res <INT64_MAX,
left_block_size * right_block_size * c_i * c_i < INT64_MAX,
left_block_size* right_block_size * c_z < INT64_MAX,
单进程时n_res = n_res_gather 多进程下不满足该条件
在构建用例及使用KPEX时应满足以下算子形状约束
表5 KPEX outer_product_mean入参形状约束
tensor/param |
shape/value |
描述 |
|---|---|---|
input_ln_w |
[c_m] |
通过layernorm生成input_act所需的权重参数 |
input_ln_b |
[c_m] |
通过layernorm生成input_act所需的偏置参数 |
left_proj_w |
[c_i, c_m] |
见表4参数 left_proj_w |
left_proj_b |
[c_i] |
见表4参数 left_proj_b |
right_proj_w |
[c_i, c_m] |
见表4参数 right_proj_w |
right_proj_b |
[c_i] |
见表4参数 right_proj_b |
output_w |
[c_z, c_i, c_i] |
见表4参数 outer_w |
output_b |
[c_z] |
见表4参数 outer_b |
act |
[n_seq, n_res, c_m] |
KPEX输入,经过layernorm生成input_act |
mask |
[n_seq, n_res_gather] |
见表3参数 mask |
left_block_size |
greater than 0 or None |
见表1参数 left_block_size |
right_block_size |
greater than 0 or None |
见表2参数 right_block_size |
示例
C++ interface:
// test_outer_product_mean.h
#ifndef KPEX_TPP_ALPHAFOLD_TEST_OPM_H
#define KPEX_TPP_ALPHAFOLD_TEST_OPM_H
#include <ATen/core/Tensor.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/ones.h>
#include <ATen/ops/zeros.h>
#include <ATen/ops/full.h>
#include <ATen/native/cpu/utils.h>
#include <c10/core/ScalarType.h>
namespace alphafold {
at::Tensor test_outer_product_mean(int64_t c_i, int64_t c_m, int64_t c_z, int64_t n_seq, int64_t n_res, int64_t n_res_gather);
}
#endif
// bind.h
#include <torch/extension.h>
#include "test_outer_product_mean.h"
namespace alphafold {
inline void bind(pybind11::module &m)
{
autosubmodule = m.def_submodule("alphafold");
submodule.def("test_outer_product_mean", &test_outer_product_mean, py::arg("c_i"), py::arg("c_m"), py::arg("c_z"), py::arg("n_seq"), py::arg("n_res"), py::arg("n_res_gather"));
}
}
// test.py
import copy
import time
import types
import torch
from torch import nn
import numpy as np
import torch.distributed as dist
import kpex._C as kernel
import kpex
import os
def test_triangle_multiplication(n_res, n_res_gather, c_o, c_i):
out = kernel.alphafold.test_triangle_multiplication(n_res, n_res_gather, c_o, c_i)
return out
// test_outer_product_mean.cpp
#include "test_outer_product_mean.h"
#include "kutacc.h"
#include "outer_product_mean.h"
#include "utils/memory.h"
#include "utils/layernorm.h"
#include <utils/TensorWrapper.h>
namespace alphafold {
at::Tensor test_outer_product_mean(int64_t c_i, int64_t c_m, int64_t c_z, int64_t n_seq, int64_t n_res, int64_t n_res_gather)
{
float a = 0.2f;
float b = 0.5f;
float c = 1.5f;
float d = 2.0f;
at::Tensor act = at::full({n_seq, n_res, c_m}, d, at::TensorOptions().device(kpex::device()).dtype(at::kBFloat16));
at::Tensor mask = at::ones({n_res_gather, n_seq}, at::TensorOptions().device(kpex::device()).dtype(at::kBFloat16));
at::Tensor left_proj = act.new_empty({c_i, n_res, n_seq});
at::Tensor right_proj = act.new_empty({c_i, n_res, n_seq});
at::Tensor left_proj_ = act.new_empty({n_res, c_i, n_seq});
at::Tensor right_proj_ = act.new_empty({n_res, c_i, n_seq});
at::Tensor norm = mask.new_empty({n_res, n_res_gather});
int64_t mask_bias = 0;
at::Tensor input_ln_w = at::full({c_m}, a, at::TensorOptions().device(kpex::device()).dtype(at::kFloat));
at::Tensor input_ln_b = at::full({c_m}, b, at::TensorOptions().device(kpex::device()).dtype(at::kFloat));
at::Tensor left_proj_w = linear_weight_prepack(at::full({c_i, c_m}, c, at::TensorOptions().device(kpex::device()).dtype(at::kBFloat16)));
at::Tensor left_proj_b = at::zeros({c_i}, at::TensorOptions().device(kpex::device()).dtype(at::kFloat));
at::Tensor right_proj_w = linear_weight_prepack(at::ones({c_i, c_m}, at::TensorOptions().device(kpex::device()).dtype(at::kBFloat16)));
at::Tensor right_proj_b = at::zeros({c_i}, at::TensorOptions().device(kpex::device()).dtype(at::kFloat));
at::Tensor output_w = linear_weight_prepack(at::ones({c_z, c_i * c_i}, at::TensorOptions().device(kpex::device()).dtype(at::kBFloat16)));
at::Tensor output_b = at::zeros({c_z}, at::TensorOptions().device(kpex::device()).dtype(at::kFloat));
at::Tensor input_act = layernorm(act.transpose(0, 1), input_ln_w, input_ln_b);
at::Tensor out = act.new_empty({n_res, n_res_gather, c_z});
kutacc::TensorWrapper input_act_tw = convert_to_tensor_wrapper(input_act);
kutacc::TensorWrapper mask_tw = convert_to_tensor_wrapper(mask);
kutacc::TensorWrapper left_proj_w_tw = convert_to_tensor_wrapper(left_proj_w);
kutacc::TensorWrapper left_proj_b_tw = convert_to_tensor_wrapper(left_proj_b);
kutacc::TensorWrapper right_proj_w_tw = convert_to_tensor_wrapper(right_proj_w);
kutacc::TensorWrapper right_proj_b_tw = convert_to_tensor_wrapper(right_proj_b);
kutacc::TensorWrapper left_proj_tw = convert_to_tensor_wrapper(left_proj);
kutacc::TensorWrapper right_proj_tw = convert_to_tensor_wrapper(right_proj);
kutacc::TensorWrapper left_proj_tw_ = convert_to_tensor_wrapper(left_proj_);
kutacc::TensorWrapper right_proj_tw_ = convert_to_tensor_wrapper(right_proj_);
kutacc::TensorWrapper norm_tw = convert_to_tensor_wrapper(norm);
kutacc::TensorWrapper output_w_tw = convert_to_tensor_wrapper(output_w);
kutacc::TensorWrapper output_b_tw = convert_to_tensor_wrapper(output_b);
kutacc::TensorWrapper out_tw = convert_to_tensor_wrapper(out);
int64_t left_block_size = 1024;
int64_t right_block_size = 1024;
kutacc_af2_opm_weights_t_wrapper *opm_weights_ptr = new kutacc_af2_opm_weights_t_wrapper(left_proj_w_tw, left_proj_b_tw, right_proj_w_tw, right_proj_b_tw,
output_w_tw, output_b_tw, c_m, c_i, c_z);
kutacc_af2_opm_act_inputs_t_wrapper *opm_inputs_ptr = new kutacc_af2_opm_act_inputs_t_wrapper(input_act_tw, left_proj_tw, right_proj_tw, left_proj_tw_,
right_proj_tw_, n_seq, n_res);
kutacc_af2_opm_mask_inputs_t_wrapper *opm_mask_ptr = new kutacc_af2_opm_mask_inputs_t_wrapper(mask_tw, norm_tw, n_res_gather, mask_bias);
kutacc_af2_outer_product_mean_calc_left_and_right_mul(opm_inputs_ptr, opm_mask_ptr, opm_weights_ptr);
kutacc_af2_outer_product_mean_chunk(opm_inputs_ptr, opm_mask_ptr, opm_weights_ptr, out_tw.get_tensor(), left_block_size, right_block_size);
delete opm_inputs_ptr;
delete opm_weights_ptr;
delete opm_mask_ptr;
return out;
}
}