鲲鹏社区首页
中文
注册
开发者
我要评分
获取效率
正确性
完整性
易理解
在线提单
论坛求助

kutacc_af2_all_gather

进程间的all_gather操作

接口定义

void kutacc_af2_all_gather(kutacc_tensor_h data, kutacc_tensor_h out);

参数

参数名

类型

描述

输入/输出

data

kutacc_tensor_h

输入张量

输入

out

kutacc_tensor_h

输出张量

输入

示例

C++ interface:
#include <vector>
#include <mpi.h>
#include "kutacc.h"

int main()
{
    int world_size = 2, rank = 0;
    int m = 4, n = 5, len = 1;
    int64_t dim = 3, buffer_size = m * n * len;

    MPI_Init(&argc, &argv);
    MPI_Comm_size(MPI_COMM_WORLD, &world_size);
    MPI_Comm_rank(MPI_COMM_WORLD, &rank);
    kutacc_initialize(world_size, rank, buffer_size);

    std::vector<int64_t> sizes = {m, n, len};
    std::vector<int64_t> strides = {n * len, len, 1};
    std::vector<int64_t> data(m * n * len);

    int val;
    if (rank == 0) {
        val = 0;
    } else if (rank == 1) {
        val = 32;
    }
    for (int i = 0; i < m; ++i) {
        for (int j = 0; j < n; ++j) {
            data[i * n * len + j * len] = val++;
        }
    }
    std::vector<int64_t> out_sizes = {world_size * m, n, len};
    std::vector<int64_t> out_strides = {n * len, len, 1};
    std::vector<int64_t> out_data(world_size * m * n * len);
    
    kutacc::TensorWrapper in(data.data(), sizes, strides, dim, kutacc::kBF16);
    kutacc::TensorWrapper out(out_data.data(), out_sizes, out_strides, dim, kutacc::kBF16);
    kutacc_af2_all_gather(in.get_tensor(), out.get_tensor());

    kutacc_finalize();
    MPI_Finalize();

    return 0;
}

/**
 * rank0:              rank1:
 * 0 1 2 3 4           32 33 34 35 36
 * 5 6 7 8 9           37 38 39 40 41
 * 10 11 12 13 14      42 43 44 45 46
 * 15 16 17 18 19      47 48 49 50 51
 *
 * all_gather()
 *
 * rank0 && rank1:
 * 0 1 2 3 4
 * 5 6 7 8 9
 * 10 11 12 13 14
 * 15 16 17 18 19
 * 32 33 34 35 36
 * 37 38 39 40 41
 * 42 43 44 45 46
 * 47 48 49 50 51
**/