kutacc_af2_transpose
进程间的transpose操作
接口定义
void kutacc_af2_transpose(kutacc_tensor_h data, kutacc_tensor_h out);
参数
参数名 |
类型 |
描述 |
输入/输出 |
|---|---|---|---|
data |
kutacc_tensor_h |
输入张量 |
输入 |
out |
kutacc_tensor_h |
输出张量 |
输入 |
示例
C++ interface:
#include <vector>
#include <mpi.h>
#include "kutacc.h"
int main()
{
int world_size = 2, rank = 0;
int m = 5, n = 6, len = 1;
int64_t dim = 3, buffer_size = m * n * len;
int64_t block_n = (n + world_size - 1) / world_size;
MPI_Init(&argc, &argv);
MPI_Comm_size(MPI_COMM_WORLD, &world_size);
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
kutacc_initialize(world_size, rank, buffer_size);
std::vector<int64_t> sizes = {m, n, len};
std::vector<int64_t> strides = {n * len, len, 1};
std::vector<int64_t> data(m * n * len);
int val;
if (rank == 0) {
val = 0;
} else if (rank == 1) {
val = 32;
}
for (int i = 0; i < m; ++i) {
for (int j = 0; j < n; ++j) {
data[i * n * len + j * len] = val++;
}
}
std::vector<int64_t> out_sizes = {world_size * m, block_n, len};
std::vector<int64_t> out_strides = {block_n * len, len, 1};
std::vector<int64_t> out_data(world_size * m * block_n * len);
kutacc::TensorWrapper in(data.data(), sizes, strides, dim, kutacc::kBF16);
kutacc::TensorWrapper out(out_data.data(), out_sizes, out_strides, dim, kutacc::kBF16);
kutacc_af2_transpose(in.get_tensor(), out.get_tensor());
kutacc_finalize();
MPI_Finalize();
return 0;
}
/**
* rank0: rank:1
* 0 1 2 3 4 5 32 33 34 35 36 37
* 6 7 8 9 10 11 38 39 40 41 42 43
* 12 13 14 15 16 17 44 45 46 47 48 49
* 18 19 20 21 22 23 50 51 52 53 54 55
* 24 25 26 27 28 29 56 57 58 59 60 61
*
* transpose()
*
* rank:0 rank:1
* 0 1 2 3 4 5
* 6 7 8 9 10 11
* 12 13 14 15 16 17
* 18 19 20 21 22 23
* 24 25 26 27 28 29
* 32 33 34 35 36 37
* 38 39 40 41 42 43
* 44 45 46 47 48 49
* 50 51 52 53 54 55
* 56 57 58 59 60 61
**/
父主题: 通信算子