鲲鹏社区首页
中文
注册
开发者
我要评分
获取效率
正确性
完整性
易理解
在线提单
论坛求助

kutacc_af2_outer_product_mean_chunk

outer_product_mean中用于分块计算结果的计算函数。

接口定义

kutacc_export void kutacc_af2_outer_product_mean_chunk(kutacc_af2_opm_act_inputs_t *opm_acts_ptr, kutacc_af2_opm_mask_inputs_t *opm_masks_ptr, kutacc_af2_opm_weights_t *opm_weights_ptr,kutacc_tensor_h out, int64_t left_block_size, int64_t right_block_size);

参数

表1 入参定义

参数名

类型

描述

输入/输出

opm_acts_ptr

kutacc_af2_opm_act_inputs_t *

kutacc_af2_opm_act_inputs_t类型的指针,具体数据结构定义见下方表2 kutacc_af2_opm_act_inputs_t 数据结构表定义

输入

opm_masks_ptr

kutacc_af2_opm_mask_inputs_t *

kutacc_af2_opm_mask_inputs_t类型的指针,具体数据结构定义见下方表3

kutacc_af2_opm_mask_inputs_t 数据结构表定义

输入

opm_weights_ptr

kutacc_af2_opm_weights_t *

kutacc_af2_opm_weights_t类型的指针,具体数据结构定义见下方表4

kutacc_af2_opm_weights_t数据结构表定义

输入

out

kutacc_tensor_h

输出数据

输出

left_block_size

int64_t

左分块大小

输入

right_block_size

int64_t

右分块大小

输入

表2 kutacc_af2_opm_act_inputs_t 数据结构表定义

参数名

类型

描述

输入/输出

n_seq

int64_t

序列数量

输入

n_res

int64_t

残基数量

输入

input_act

kutacc_tensor_h

输入激活张量

输入

left_proj

kutacc_tensor_h

左投影

输入

right_proj

kutacc_tensor_h

右投影

输入

left_proj_

kutacc_tensor_h

经过掩码处理后的左投影

输入

right_proj_

kutacc_tensor_h

经过掩码处理后的右投影

输入

表3 kutacc_af2_opm_mask_inputs_t 数据结构表定义

参数名

类型

描述

输入/输出

n_res_gather

int64_t

聚合后的残基数量

输入

mask_bias

int64_t

掩码张量地址偏移量

输入

mask

kutacc_tensor_h

掩码张量

输入

norm

kutacc_tensor_h

归一化因子张量

输入

表4 kutacc_af2_opm_weights_t 数据结构表定义

参数名

类型

描述

输入/输出

c_m

int64_t

输入特征维度

输入

c_i

int64_t

投影后的特征维度

输入

c_z

int64_t

输出特征维度

输入

left_proj_w

kutacc_tensor_h

左投影权重

输入

left_proj_b

kutacc_tensor_h

左投影偏移量

输入

right_proj_w

kutacc_tensor_h

右投影权重

输入

right_proj_b

kutacc_tensor_h

右投影偏移量

输入

outer_w

kutacc_tensor_h

输出权重

输入

outer_b

kutacc_tensor_h

输出偏移量

输入

outer_product_mean整数参数应满足的约束关系:

n_res, n_res_gather, c_i, c_z, n_res, n_res_gather, left_block_size, right_block_size > 0;

n_seq * n_res <INT64_MAX,

left_block_size * right_block_size * c_i * c_i < INT64_MAX,

left_block_size* right_block_size * c_z < INT64_MAX,

单进程时n_res = n_res_gather 多进程下不满足该条件

在构建用例及使用KPEX时应满足以下算子形状约束

表5 KPEX outer_product_mean入参形状约束

tensor/param

shape/value

描述

input_ln_w

[c_m]

通过layernorm生成input_act所需的权重参数

input_ln_b

[c_m]

通过layernorm生成input_act所需的偏置参数

left_proj_w

[c_i, c_m]

见表4参数 left_proj_w

left_proj_b

[c_i]

见表4参数 left_proj_b

right_proj_w

[c_i, c_m]

见表4参数 right_proj_w

right_proj_b

[c_i]

见表4参数 right_proj_b

output_w

[c_z, c_i, c_i]

见表4参数 outer_w

output_b

[c_z]

见表4参数 outer_b

act

[n_seq, n_res, c_m]

KPEX输入,经过layernorm生成input_act

mask

[n_seq, n_res_gather]

见表3参数 mask

left_block_size

greater than 0 or None

见表1参数 left_block_size

right_block_size

greater than 0 or None

见表2参数 right_block_size

示例

C++ interface:

// test_outer_product_mean.h
#ifndef KPEX_TPP_ALPHAFOLD_TEST_OPM_H
#define KPEX_TPP_ALPHAFOLD_TEST_OPM_H
#include <ATen/core/Tensor.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/ones.h>
#include <ATen/ops/zeros.h>
#include <ATen/ops/full.h>
#include <ATen/native/cpu/utils.h>
#include <c10/core/ScalarType.h>
namespace alphafold {
at::Tensor test_outer_product_mean(int64_t c_i, int64_t c_m, int64_t c_z, int64_t n_seq, int64_t n_res, int64_t n_res_gather);
}
#endif

// bind.h
#include <torch/extension.h>
#include "test_outer_product_mean.h"
namespace alphafold {
inline void bind(pybind11::module &m)
{
    autosubmodule = m.def_submodule("alphafold");
    submodule.def("test_outer_product_mean", &test_outer_product_mean, py::arg("c_i"), py::arg("c_m"), py::arg("c_z"), py::arg("n_seq"), py::arg("n_res"), py::arg("n_res_gather"));
}
}

// test.py
import copy
import time
import types
import torch
from torch import nn
import numpy as np
import torch.distributed as dist
import kpex._C as kernel
import kpex
import os

def test_triangle_multiplication(n_res, n_res_gather, c_o, c_i):
    out = kernel.alphafold.test_triangle_multiplication(n_res, n_res_gather, c_o, c_i)
    return out

// test_outer_product_mean.cpp
#include "test_outer_product_mean.h"
#include "kutacc.h"
#include "outer_product_mean.h"
#include "utils/memory.h"
#include "utils/layernorm.h"
#include <utils/TensorWrapper.h>
namespace alphafold {
at::Tensor test_outer_product_mean(int64_t c_i, int64_t c_m, int64_t c_z, int64_t n_seq, int64_t n_res, int64_t n_res_gather)
{
    float a = 0.2f;
    float b = 0.5f;
    float c = 1.5f;
    float d = 2.0f;
    at::Tensor act = at::full({n_seq, n_res, c_m}, d, at::TensorOptions().device(kpex::device()).dtype(at::kBFloat16));
    at::Tensor mask = at::ones({n_res_gather, n_seq}, at::TensorOptions().device(kpex::device()).dtype(at::kBFloat16));
    at::Tensor left_proj = act.new_empty({c_i, n_res, n_seq});
    at::Tensor right_proj = act.new_empty({c_i, n_res, n_seq});
    at::Tensor left_proj_ = act.new_empty({n_res, c_i, n_seq});
    at::Tensor right_proj_ = act.new_empty({n_res, c_i, n_seq});
    at::Tensor norm = mask.new_empty({n_res, n_res_gather});
    int64_t mask_bias = 0;
    at::Tensor input_ln_w = at::full({c_m}, a, at::TensorOptions().device(kpex::device()).dtype(at::kFloat));
    at::Tensor input_ln_b = at::full({c_m}, b, at::TensorOptions().device(kpex::device()).dtype(at::kFloat));
    at::Tensor left_proj_w = linear_weight_prepack(at::full({c_i, c_m}, c, at::TensorOptions().device(kpex::device()).dtype(at::kBFloat16)));
    at::Tensor left_proj_b = at::zeros({c_i}, at::TensorOptions().device(kpex::device()).dtype(at::kFloat));
    at::Tensor right_proj_w = linear_weight_prepack(at::ones({c_i, c_m}, at::TensorOptions().device(kpex::device()).dtype(at::kBFloat16)));
    at::Tensor right_proj_b = at::zeros({c_i}, at::TensorOptions().device(kpex::device()).dtype(at::kFloat));
    at::Tensor output_w = linear_weight_prepack(at::ones({c_z, c_i * c_i}, at::TensorOptions().device(kpex::device()).dtype(at::kBFloat16)));
    at::Tensor output_b = at::zeros({c_z}, at::TensorOptions().device(kpex::device()).dtype(at::kFloat));
    at::Tensor input_act = layernorm(act.transpose(0, 1), input_ln_w, input_ln_b);
    at::Tensor out = act.new_empty({n_res, n_res_gather, c_z});
    kutacc::TensorWrapper input_act_tw = convert_to_tensor_wrapper(input_act);
    kutacc::TensorWrapper mask_tw = convert_to_tensor_wrapper(mask);
    kutacc::TensorWrapper left_proj_w_tw = convert_to_tensor_wrapper(left_proj_w);
    kutacc::TensorWrapper left_proj_b_tw = convert_to_tensor_wrapper(left_proj_b);
    kutacc::TensorWrapper right_proj_w_tw = convert_to_tensor_wrapper(right_proj_w);
    kutacc::TensorWrapper right_proj_b_tw = convert_to_tensor_wrapper(right_proj_b);
    kutacc::TensorWrapper left_proj_tw = convert_to_tensor_wrapper(left_proj);
    kutacc::TensorWrapper right_proj_tw = convert_to_tensor_wrapper(right_proj);
    kutacc::TensorWrapper left_proj_tw_ = convert_to_tensor_wrapper(left_proj_);
    kutacc::TensorWrapper right_proj_tw_ = convert_to_tensor_wrapper(right_proj_);
    kutacc::TensorWrapper norm_tw = convert_to_tensor_wrapper(norm);
    kutacc::TensorWrapper output_w_tw = convert_to_tensor_wrapper(output_w);
    kutacc::TensorWrapper output_b_tw = convert_to_tensor_wrapper(output_b);
    kutacc::TensorWrapper out_tw = convert_to_tensor_wrapper(out);
    int64_t left_block_size = 1024;
    int64_t right_block_size = 1024;
    kutacc_af2_opm_weights_t_wrapper *opm_weights_ptr = new kutacc_af2_opm_weights_t_wrapper(left_proj_w_tw, left_proj_b_tw, right_proj_w_tw, right_proj_b_tw,
        output_w_tw, output_b_tw, c_m, c_i, c_z);
    kutacc_af2_opm_act_inputs_t_wrapper *opm_inputs_ptr = new kutacc_af2_opm_act_inputs_t_wrapper(input_act_tw, left_proj_tw, right_proj_tw, left_proj_tw_,
        right_proj_tw_, n_seq, n_res);
    kutacc_af2_opm_mask_inputs_t_wrapper *opm_mask_ptr = new kutacc_af2_opm_mask_inputs_t_wrapper(mask_tw, norm_tw, n_res_gather, mask_bias);
    kutacc_af2_outer_product_mean_calc_left_and_right_mul(opm_inputs_ptr, opm_mask_ptr, opm_weights_ptr);
    kutacc_af2_outer_product_mean_chunk(opm_inputs_ptr, opm_mask_ptr, opm_weights_ptr, out_tw.get_tensor(), left_block_size, right_block_size);
    delete opm_inputs_ptr;
    delete opm_weights_ptr;
    delete opm_mask_ptr;
    return out;
}
}