AllGather

该样例工程位于examples\allgather文件夹下。

在该样例中,实现了一个通信量较小(单PE通信量小于2MB)情况下,有着更低时延的AllGather纯通信算子。 各PE首先将存在本端input地址下的数据push到本PE的对称内存上;确认远端PE的任务完成后,从远端PE的对称内存拉取对应PE的数据,从而整体完成AllGather的操作。这个样例展示了多种SHMEM API的用法,包括shmem_mte_put_mem_nbi、shmemx_signal_op以及shmem_mte_get_mem_nbi等,用于p2p的通信以及同步任务。

核函数实现

#include "kernel_operator.h"
#include "acl/acl.h"
#include "shmem_api.h"
using namespace AscendC;

constexpr int64_t SYNC_FLAG_INTERVAL = 16;
constexpr int64_t UB_DMA_MAX_SIZE = 190 * 1024;
constexpr int64_t GVA_BUFF_MAX_SIZE = 100 * 1024 * 1024;

template<typename T>
SHMEM_DEVICE void all_gather_small_data(uint64_t fftsAddr, __gm__ T* input, __gm__ T* output, __gm__ T* gva, int elements, int magic)
{
#ifdef __DAV_C220_VEC__
    const int64_t aivNum = GetBlockNum() * 2;
    const int64_t aivIndex = GetBlockIdx();

    const int64_t data_offset = aivNum * SYNC_FLAG_INTERVAL;
    const int64_t flag_offset = aivIndex * SYNC_FLAG_INTERVAL;

    int64_t my_rank = shmem_my_pe();
    int64_t pe_size = shmem_n_pes();

    __gm__ T *input_gm = (__gm__ T *)input;
    __gm__ T *output_gm = (__gm__ T *)output;

    __gm__ T *gva_data_gm = (__gm__ T*)((__gm__ int32_t*)gva + data_offset);
    __gm__ int32_t *gva_sync_gm = (__gm__ int32_t *)gva;
    
    __ubuf__ T* tmp_buff = (__ubuf__ T*)(64);

    // data move parameters
    const uint32_t ub_size = UB_DMA_MAX_SIZE;
    uint32_t input_offset, output_offset, gva_offset, num_per_core;

    // [AllGather Step 1] local input gm -> symmetric mem.
    num_per_core = elements / aivNum;
    input_offset = aivIndex * num_per_core;
    gva_offset = aivIndex * num_per_core;
    if (aivIndex == aivNum - 1) {
        num_per_core = elements - num_per_core * aivIndex;
    }
    shmem_mte_put_mem_nbi(gva_data_gm + gva_offset, input_gm + input_offset, tmp_buff, ub_size, num_per_core, my_rank, EVENT_ID0);

    const int64_t core_per_rank = aivNum / pe_size;
    const int64_t core_rank_idx = aivIndex % core_per_rank;
    const int64_t x = aivIndex / core_per_rank;

    // Sync Ensure Corresponding Tasks Done.
    shmem_quiet();
    shmemx_barrier_all_vec();

    shmemx_signal_op(gva_sync_gm + flag_offset, magic, SHMEM_SIGNAL_SET, my_rank);
    shmem_signal_wait_until((__gm__ int32_t *)shmem_ptr(gva_sync_gm, x) + flag_offset, SHMEM_CMP_EQ, magic);

    // [AllGather Step 2] symmetric mem -> local output.
    num_per_core = elements / core_per_rank;
    output_offset = x * elements + core_rank_idx * num_per_core;
    gva_offset = core_rank_idx * num_per_core;
    if (core_rank_idx == core_per_rank - 1) {
        num_per_core = elements - num_per_core * core_rank_idx;
    }
    shmem_mte_get_mem_nbi(output_gm + output_offset, gva_data_gm + gva_offset, tmp_buff, ub_size, num_per_core, x, EVENT_ID0);
#endif
}