鲲鹏社区首页
中文
注册
开发者
我要评分
获取效率
正确性
完整性
易理解
在线提单
论坛求助

kutacc_af2_triangle_multiplication_last

triangle_multiplication的最后步骤,基于out_linear步骤获得的out与gate相乘生成输出张量。

接口定义

void kutacc_af2_triangle_multiplication_last(kutacc_tensor_h out, kutacc_tensor_h gate, int64_t n_res, int64_t n_res_gather, int64_t c_o);

参数

表1 入参定义

参数名

类型

描述

输入/输出

out

kutacc_tensor_h

输出数据

输出

gate

kutacc_tensor_h

门控张量

输入

n_res

int64_t

残基数

输入

n_res_gather

int64_t

残基数

输入

c_o

int64_t

输出特征维度

输入

triangle_multiplication的整数参数应满足的约束关系:

n_res, n_res_gather, c_o > 0,

n_res * n_res_gather < INT64_MAX,

c_z = c_o, c_i = c_o,

单进程情况下需要注意n_res = n_res_gather,多进程则不需要满足该条件

在构建用例及使用KPEX时应满足以下算子形状约束

表2 kpex triangle_multiplication入参形状约束

tensor

shape

描述

act

[n_res, n_res_gather, c_z]

KPEX层输入

mask

[n_res, n_res_gather]

KPEX层掩码输入

input_ln_w

[c_o]

输入参数act经过layernorm变化过程的权重

input_ln_b

[c_o]

输入参数act经过layernorm变化过程的偏置

left_proj_w

[c_i, c_o]

见8.1.2.2.2.1.1.2 表2kutacc_af2_tm_proj_weights_t数据结构表定义参数proj_w

left_proj_b

[c_i]

见8.1.2.2.2.1.1.2 表2kutacc_af2_tm_proj_weights_t数据结构表定义参数proj_b

right_proj_w

[c_i, c_o]

见8.1.2.2.2.1.1.2 表2kutacc_af2_tm_proj_weights_t数据结构表定义参数proj_w

right_proj_b

[c_i]

见8.1.2.2.2.1.1.2 表2kutacc_af2_tm_proj_weights_t数据结构表定义参数proj_b

left_gate_w

[c_i, c_o]

见8.1.2.2.2.1.1.2 表2kutacc_af2_tm_proj_weights_t数据结构表定义参数gate_w

left_gate_b

[c_i]

见8.1.2.2.2.1.1.2 表2kutacc_af2_tm_proj_weights_t数据结构表定义参数gate_b

right_gate_w

[c_i, c_o]

见8.1.2.2.2.1.1.2 表2kutacc_af2_tm_proj_weights_t数据结构表定义参数gate_w

right_gate_b

[c_i]

见8.1.2.2.2.1.1.2 表2kutacc_af2_tm_proj_weights_t数据结构表定义参数gate_b

gating_w

[c_i, c_o]

见上一节表3 kutacc_af2_tm_linear_weights_t数据结构表定义参数 gating_w

gating_b

[c_i]

见上一节表3 kutacc_af2_tm_linear_weights_t数据结构表定义参数 gating_b

center_ln_w

[c_i]

center_act layernorm归一化过程中的权重

center_ln_b

[c_i]

center_act layernorm归一化过程中的偏置

output_proj_w

[c_o, c_i]

见上一节表3 kutacc_af2_tm_linear_weights_t数据结构表定义参数 output_proj_w

output_proj_b

[c_o]

见上一节表3 kutacc_af2_tm_linear_weights_t数据结构表定义参数 output_proj_b

示例

C++ interface:
// test_triangle_multiplication.h
#define KPEX_TPP_ALPHAFOLD_TEST_TMP_H
#include <ATen/core/Tensor.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/ones.h>
#include <ATen/ops/zeros.h>
#include <ATen/ops/full.h>
#include <ATen/native/cpu/utils.h>
#include <c10/core/ScalarType.h>
namespace alphafold {
at::Tensor test_triangle_multiplication(int64_t n_res, int64_t n_res_gather, int64_t c_o, int64_t c_i);
}
#endif

// bind.h
#include <torch/extension.h>
#include "test_triangle_multiplication.h"
namespace alphafold {
inline void bind(pybind11::module &m)
{
    auto submodule = m.def_submodule("alphafold");
    submodule.def("test_triangle_multiplication", &test_triangle_multiplication, 
        py::arg("n_res"), py::arg("n_res_gather"), py::arg("c_o"), py::arg("c_i"));
}
}

// test.py
import torch
from torch import nn
import numpy as np
import torch.distributed as dist
import kpex._C as kernel
import kpex
import os
def test_triangle_multiplication(n_res, n_res_gather, c_o, c_i):
    out = kernel.alphafold.test_triangle_multiplication(n_res, n_res_gather, c_o, c_i)
    return out

// test_triangle_multiplication.cpp
#include "test_triangle_multiplication.h"
#include "triangle_multiplication.h"
#include "kutacc.h"
#include "utils/memory.h"
#include "utils/layernorm.h"
#include "utils/TensorWrapper.h"
namespace alphafold {
at::Tensor test_triangle_multiplication(int64_t n_res, int64_t n_res_gather, int64_t c_o, int64_t c_i)
{
    float a = 0.5f;
    float b = 0.2f;
    float c = 0.4f;
    float d = 1.25f;
    auto bf16_opt = at::TensorOptions().device(kpex::device()).dtype(at::kBFloat16);
    auto float_opt = at::TensorOptions().device(kpex::device()).dtype(at::kFloat);
    at::Tensor input_ln_w = at::full({c_o}, a, float_opt);
    at::Tensor input_ln_b = at::full({c_o}, b, float_opt);
    at::Tensor left_proj_w = linear_weight_prepack(at::ones({c_i, c_o}, bf16_opt));
    at::Tensor left_proj_b = at::zeros({c_i}, float_opt);
    at::Tensor right_proj_w = linear_weight_prepack(at::full({c_i, c_o}, c, bf16_opt));
    at::Tensor right_proj_b = at::ones({c_i}, float_opt);
    at::Tensor left_gate_w = linear_weight_prepack(at::ones({c_i, c_o}, bf16_opt));
    at::Tensor left_gate_b = at::ones({c_i}, float_opt);
    at::Tensor right_gate_w = linear_weight_prepack(at::zeros({c_i, c_o}, bf16_opt));
    at::Tensor right_gate_b = at::zeros({c_i}, float_opt);
    at::Tensor gating_w = linear_weight_prepack(at::full({c_i, c_o}, d, bf16_opt));
    at::Tensor gating_b = at::zeros({c_i}, float_opt);
    at::Tensor center_ln_w = at::ones({c_i}, float_opt);
    at::Tensor center_ln_b = at::zeros({c_i}, float_opt);
    at::Tensor output_proj_w = linear_weight_prepack(at::zeros({c_o, c_i}, bf16_opt));
    at::Tensor output_proj_b =  at::zeros({c_o}, float_opt);
    float e = 1.2f;
    at::Tensor act = at::full({n_res, n_res_gather, c_o}, e, bf16_opt);
    at::Tensor mask = at::ones({n_res, n_res_gather}, bf16_opt);
    at::Tensor out = at::empty(act.sizes(), act.options());
    at::Tensor input_act = layernorm(act, input_ln_w, input_ln_b);
    at::Tensor center_act;
    at::Tensor left_proj_act = input_act.new_empty({c_i, n_res, n_res_gather});
    at::Tensor right_proj_act = input_act.new_empty({c_i, n_res, n_res_gather});
    at::Tensor gate = act.new_empty({n_res, n_res_gather, c_o});
    auto left_proj_act_tw = convert_to_tensor_wrapper(left_proj_act);
    auto right_proj_act_tw = convert_to_tensor_wrapper(right_proj_act);
    auto gate_tw = convert_to_tensor_wrapper(gate);
    auto input_act_tw = convert_to_tensor_wrapper(input_act);
    auto mask_tw = convert_to_tensor_wrapper(mask);
    auto left_proj_w_tw = convert_to_tensor_wrapper(left_proj_w);
    auto left_proj_b_tw = convert_to_tensor_wrapper(left_proj_b);
    auto left_gate_w_tw = convert_to_tensor_wrapper(left_gate_w);
    auto left_gate_b_tw = convert_to_tensor_wrapper(left_gate_b);
    auto right_proj_w_tw = convert_to_tensor_wrapper(right_proj_w);
    auto right_proj_b_tw = convert_to_tensor_wrapper(right_proj_b);
    auto right_gate_w_tw= convert_to_tensor_wrapper(right_gate_w);
    auto right_gate_b_tw= convert_to_tensor_wrapper(right_gate_b);
    at::Tensor gate_left = input_act.new_empty({c_i, n_res, n_res_gather});
    auto gate_left_tw = convert_to_tensor_wrapper(gate_left);
    kutacc_af2_tm_act_inputs_t_wrapper *left_acts_ptr = new kutacc_af2_tm_act_inputs_t_wrapper(left_proj_act_tw, 
        input_act_tw, gate_left_tw, n_res, n_res_gather);
    kutacc_af2_tm_proj_weights_t_wrapper *left_weights_ptr = new kutacc_af2_tm_proj_weights_t_wrapper(left_proj_w_tw,
        left_proj_b_tw, left_gate_w_tw, left_gate_b_tw, c_o, c_i);
    kutacc_af2_triangle_multiplication_calc_proj(left_acts_ptr, mask_tw.get_tensor(), left_weights_ptr, false);
    
    at::Tensor gate_right = input_act.new_empty({c_i, n_res, n_res_gather});
    auto gate_right_tw = convert_to_tensor_wrapper(gate_right);
    kutacc_af2_tm_act_inputs_t_wrapper *right_acts_ptr = new kutacc_af2_tm_act_inputs_t_wrapper(right_proj_act_tw, 
        input_act_tw, gate_right_tw, n_res, n_res_gather);
    kutacc_af2_tm_proj_weights_t_wrapper *right_weights_ptr = new kutacc_af2_tm_proj_weights_t_wrapper(right_proj_w_tw, 
        right_proj_b_tw, right_gate_w_tw, right_gate_b_tw, c_o, c_i);
    kutacc_af2_triangle_multiplication_calc_proj(right_acts_ptr, mask_tw.get_tensor(), right_weights_ptr, false);
    center_act = act.new_empty({left_proj_act.sizes()[0], n_res_gather, n_res_gather});
    auto center_act_tw = convert_to_tensor_wrapper(center_act);
    kutacc_af2_triangle_multiplication_equation(center_act_tw.get_tensor(), left_proj_act_tw.get_tensor(), 
        right_proj_act_tw.get_tensor(), n_res_gather, true);
    center_act = center_act.permute({1, 2, 0}).contiguous();
    center_act = layernorm(center_act, center_ln_w, center_ln_b);
    auto center_act_new_tw = convert_to_tensor_wrapper(center_act);
    auto out_tw = convert_to_tensor_wrapper(out);
    auto gating_w_tw = convert_to_tensor_wrapper(gating_w);
    auto gating_b_tw = convert_to_tensor_wrapper(gating_b);
    auto output_proj_w_tw = convert_to_tensor_wrapper(output_proj_w);
    auto output_proj_b_tw = convert_to_tensor_wrapper(output_proj_b);
    kutacc_af2_tm_linear_weights_t_wrapper *linear_weights_ptr = new kutacc_af2_tm_linear_weights_t_wrapper(gating_w_tw,
        gating_b_tw, output_proj_w_tw, output_proj_b_tw, c_o, c_i);
    kutacc_af2_triangle_multiplication_gate_and_out_linear(gate_tw.get_tensor(), out_tw.get_tensor(), left_acts_ptr, 
        center_act_new_tw.get_tensor(), linear_weights_ptr, true);
    kutacc_af2_triangle_multiplication_last(out_tw.get_tensor(), gate_tw.get_tensor(), n_res, n_res_gather, c_o);
    delete left_acts_ptr;
    delete left_weights_ptr;
    delete right_acts_ptr;
    delete right_weights_ptr;
    delete linear_weights_ptr;
    return out;
}
}