kutacc_af2_all_gather
进程间的all_gather操作
接口定义
void kutacc_af2_all_gather(kutacc_tensor_h data, kutacc_tensor_h out);
参数
参数名 |
类型 |
描述 |
输入/输出 |
|---|---|---|---|
data |
kutacc_tensor_h |
输入张量 |
输入 |
out |
kutacc_tensor_h |
输出张量 |
输入 |
示例
C++ interface:
#include <vector>
#include <mpi.h>
#include "kutacc.h"
int main()
{
int world_size = 2, rank = 0;
int m = 4, n = 5, len = 1;
int64_t dim = 3, buffer_size = m * n * len;
MPI_Init(&argc, &argv);
MPI_Comm_size(MPI_COMM_WORLD, &world_size);
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
kutacc_initialize(world_size, rank, buffer_size);
std::vector<int64_t> sizes = {m, n, len};
std::vector<int64_t> strides = {n * len, len, 1};
std::vector<int64_t> data(m * n * len);
int val;
if (rank == 0) {
val = 0;
} else if (rank == 1) {
val = 32;
}
for (int i = 0; i < m; ++i) {
for (int j = 0; j < n; ++j) {
data[i * n * len + j * len] = val++;
}
}
std::vector<int64_t> out_sizes = {world_size * m, n, len};
std::vector<int64_t> out_strides = {n * len, len, 1};
std::vector<int64_t> out_data(world_size * m * n * len);
kutacc::TensorWrapper in(data.data(), sizes, strides, dim, kutacc::kBF16);
kutacc::TensorWrapper out(out_data.data(), out_sizes, out_strides, dim, kutacc::kBF16);
kutacc_af2_all_gather(in.get_tensor(), out.get_tensor());
kutacc_finalize();
MPI_Finalize();
return 0;
}
/**
* rank0: rank1:
* 0 1 2 3 4 32 33 34 35 36
* 5 6 7 8 9 37 38 39 40 41
* 10 11 12 13 14 42 43 44 45 46
* 15 16 17 18 19 47 48 49 50 51
*
* all_gather()
*
* rank0 && rank1:
* 0 1 2 3 4
* 5 6 7 8 9
* 10 11 12 13 14
* 15 16 17 18 19
* 32 33 34 35 36
* 37 38 39 40 41
* 42 43 44 45 46
* 47 48 49 50 51
**/
父主题: 通信算子