鲲鹏社区首页
中文
注册
开发者
我要评分
获取效率
正确性
完整性
易理解
在线提单
论坛求助

kutacc_af2_transpose

进程间的transpose操作

接口定义

void kutacc_af2_transpose(kutacc_tensor_h data, kutacc_tensor_h out);

参数

参数名

类型

描述

输入/输出

data

kutacc_tensor_h

输入张量

输入

out

kutacc_tensor_h

输出张量

输入

示例

C++ interface:
#include <vector>
#include <mpi.h>
#include "kutacc.h"

int main()
{
    int world_size = 2, rank = 0;
    int m = 5, n = 6, len = 1;
    int64_t dim = 3, buffer_size = m * n * len;
    int64_t block_n = (n + world_size - 1) / world_size;

    MPI_Init(&argc, &argv);
    MPI_Comm_size(MPI_COMM_WORLD, &world_size);
    MPI_Comm_rank(MPI_COMM_WORLD, &rank);
    kutacc_initialize(world_size, rank, buffer_size);

    std::vector<int64_t> sizes = {m, n, len};
    std::vector<int64_t> strides = {n * len, len, 1};
    std::vector<int64_t> data(m * n * len);

    int val;
    if (rank == 0) {
        val = 0;
    } else if (rank == 1) {
        val = 32;
    }
    for (int i = 0; i < m; ++i) {
        for (int j = 0; j < n; ++j) {
            data[i * n * len + j * len] = val++;
        }
    }
    std::vector<int64_t> out_sizes = {world_size * m, block_n, len};
    std::vector<int64_t> out_strides = {block_n * len, len, 1};
    std::vector<int64_t> out_data(world_size * m * block_n * len);

    kutacc::TensorWrapper in(data.data(), sizes, strides, dim, kutacc::kBF16);
    kutacc::TensorWrapper out(out_data.data(), out_sizes, out_strides, dim, kutacc::kBF16);

    kutacc_af2_transpose(in.get_tensor(), out.get_tensor());

    kutacc_finalize();
    MPI_Finalize();

    return 0;
}

/**
 * rank0:                 rank:1
 * 0 1 2 3 4 5            32 33 34 35 36 37
 * 6 7 8 9 10 11          38 39 40 41 42 43
 * 12 13 14 15 16 17      44 45 46 47 48 49
 * 18 19 20 21 22 23      50 51 52 53 54 55
 * 24 25 26 27 28 29      56 57 58 59 60 61
 *
 * transpose()
 *
 * rank:0                 rank:1
 * 0 1 2                  3 4 5
 * 6 7 8                  9 10 11
 * 12 13 14               15 16 17
 * 18 19 20               21 22 23
 * 24 25 26               27 28 29
 * 32 33 34               35 36 37
 * 38 39 40               41 42 43
 * 44 45 46               47 48 49
 * 50 51 52               53 54 55
 * 56 57 58               59 60 61
**/