鲲鹏社区首页
中文
注册
开发者
我要评分
获取效率
正确性
完整性
易理解
在线提单
论坛求助

kutacc_comm_init

通信相关变量的创建和初始化

接口定义

void kutacc_comm_init(int64_t _worldsize, int64_t rank, int64_t buffer_size);

参数

参数名

类型

描述

输入/输出

world_size

int64_t

总进程数

输入

rank

int64_t

当前进程号

输入

buffer_size

int64_t

共享内存空间大小

输入

示例

C++ interface:
#include <vector>
#include <mpi.h>
#include "kutacc.h"

int main()
{
    int world_size = 2, rank = 0;
    int m = 4, n = 5, len = 1;
    int64_t dim = 3, buffer_size = m * n * len;

    MPI_Init(&argc, &argv);
    MPI_Comm_size(MPI_COMM_WORLD, &world_size);
    MPI_Comm_rank(MPI_COMM_WORLD, &rank);
    kutacc_comm_init(world_size, rank, buffer_size);

    std::vector<int64_t> sizes = {m, n, len};
    std::vector<int64_t> strides = {n * len, len, 1};
    std::vector<int64_t> data(m * n * len);

    int val;
    if (rank == 0) {
        val = 0;
    } else if (rank == 1) {
        val = 32;
    }
    for (int i = 0; i < m; ++i) {
        for (int j = 0; j < n; ++j) {
            data[i * n * len + j * len] = val++;
        }
    }
    std::vector<int64_t> out_sizes = {world_size * m, n, len};
    std::vector<int64_t> out_strides = {n * len, len, 1};
    std::vector<int64_t> out_data(world_size * m * n * len);
    
    kutacc::TensorWrapper in(data.data(), sizes, strides, dim, kutacc::kBF16);
    kutacc::TensorWrapper out(out_data.data(), out_sizes, out_strides, dim, kutacc::kBF16);
    kutacc_af2_all_gather(in.get_tensor(), out.get_tensor());

    kutacc_comm_fini();
    MPI_Finalize();

    return 0;
}