kutacc_comm_init
通信相关变量的创建和初始化
接口定义
void kutacc_comm_init(int64_t _worldsize, int64_t rank, int64_t buffer_size);
参数
参数名 |
类型 |
描述 |
输入/输出 |
|---|---|---|---|
world_size |
int64_t |
总进程数 |
输入 |
rank |
int64_t |
当前进程号 |
输入 |
buffer_size |
int64_t |
共享内存空间大小 |
输入 |
示例
C++ interface:
#include <vector>
#include <mpi.h>
#include "kutacc.h"
int main()
{
int world_size = 2, rank = 0;
int m = 4, n = 5, len = 1;
int64_t dim = 3, buffer_size = m * n * len;
MPI_Init(&argc, &argv);
MPI_Comm_size(MPI_COMM_WORLD, &world_size);
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
kutacc_comm_init(world_size, rank, buffer_size);
std::vector<int64_t> sizes = {m, n, len};
std::vector<int64_t> strides = {n * len, len, 1};
std::vector<int64_t> data(m * n * len);
int val;
if (rank == 0) {
val = 0;
} else if (rank == 1) {
val = 32;
}
for (int i = 0; i < m; ++i) {
for (int j = 0; j < n; ++j) {
data[i * n * len + j * len] = val++;
}
}
std::vector<int64_t> out_sizes = {world_size * m, n, len};
std::vector<int64_t> out_strides = {n * len, len, 1};
std::vector<int64_t> out_data(world_size * m * n * len);
kutacc::TensorWrapper in(data.data(), sizes, strides, dim, kutacc::kBF16);
kutacc::TensorWrapper out(out_data.data(), out_sizes, out_strides, dim, kutacc::kBF16);
kutacc_af2_all_gather(in.get_tensor(), out.get_tensor());
kutacc_comm_fini();
MPI_Finalize();
return 0;
}
父主题: 通信算子