鲲鹏社区首页
中文
注册
我要评分
文档获取效率
文档正确性
内容完整性
文档易理解
在线提单
论坛求助

make_tensor

创建Tensor对象,包含源数据及内存布局。

接口定义

template<typename dtype, typename Layout>

Tensor<dtype, Layout> make_tensor(dtype *ptr, Layout layout);

模板参数

表1 模板参数定义

参数名

类型

描述

dtype

typename

精度类型

Layout

typename

布局类型

参数

表2 参数定义

参数名

类型

描述

输入/输出

ptr

dtype *

矩阵源数据指针

输入

layout

Layout

矩阵内存布局

输入

返回值

  • 返回Tensor<dtype, Layout>对象

示例

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
#include "stdlib.h"
#include "kupl.h"
using namespace kupl::tensor;

int main()
{
    constexpr int MATRIX_M  = 32;
    constexpr int MATRIX_N  = 16;
    constexpr int MATRIX_K = 512;
    double *data_a = (double *)malloc(sizeof(double) * MATRIX_M * MATRIX_K);
    double *data_b = (double *)malloc(sizeof(double) * MATRIX_K * MATRIX_N);
    double *data_c = (double *)malloc(sizeof(double) * MATRIX_M * MATRIX_N);

    auto shape_a = make_shape(Int<32>{}, Int<512>{});
    auto shape_b = make_shape(Int<512>{}, Int<16>{});
    auto shape_c = make_shape(Int<32>{}, Int<16>{});

    auto stride_a = make_stride(Int<1>{}, Int<32>{});
    auto stride_b = make_stride(Int<16>{}, Int<1>{});
    auto stride_c = make_stride(Int<16>{}, Int<1>{});

    auto layout_a = make_layout(shape_a, stride_a);
    auto layout_b = make_layout(shape_b, stride_b);
    auto layout_c = make_layout(shape_c, stride_c);

    auto mma_atom_shape = make_shape(Int<1>{}, Int<1>{}, Int<1>{});
    auto tiled_mma = make_tiled_mma(Ops<MMA_32x16x512_F64F64F64>{}, mma_atom_shape);
    auto store_atom_shape = make_shape(Int<1>{}, Int<1>{});
    auto tile_store = make_tiled_store(Ops<STORE_32x16_F64>{}, store_atom_shape);

    auto tensor_a = make_tensor(data_a, layout_a);
    auto tensor_b = make_tensor(data_b, layout_b);
    auto tensor_c = make_tensor(data_c, layout_c);

    tensor_tiled_mma(tiled_mma, tensor_c, tensor_a, tensor_b, tensor_c);
    tensor_tiled_store(tile_store, tensor_c);

    free(data_a);
    free(data_b);
    free(data_c);
    return 0;
}

上述示例演示了基于32*16*512_F64F64F64矩阵形状的mma流程,其中通过make_tensor创建矩阵对象tensor,作为后续mma/store接口参数。