Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions csrc/kernels/intranode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -837,24 +837,34 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
out_dtypes[j] = static_cast<dtype_t>(values[j]);

#ifndef DISABLE_SM90_FEATURES

uint32_t warp_mask = 0xffffffff;
bool const is_tail_loop = ((i - lane_id) + 32) > hidden_int4;
if (is_tail_loop) {
static constexpr uint64_t one = 1;
int const num_active_threads = hidden_int4 - (i - lane_id);
warp_mask = static_cast<uint32_t>((one << num_active_threads) - one);
}


// Wait TMA arrival
if (elect_one_sync())
if (elect_one_sync(warp_mask))
tma_store_wait<kNumStages - 1>();
__syncwarp();
__syncwarp(warp_mask);

// Write into TMA buffer
auto tma_stage_idx = (i / 32) % kNumStages;
reinterpret_cast<int4*>(tma_buffer)[tma_stage_idx * 32 + lane_id] = out_int4;

// Issue TMA
tma_store_fence();
__syncwarp();
if (elect_one_sync()) {
__syncwarp(warp_mask);
if (elect_one_sync(warp_mask)) {
auto tma_bytes = min(32, hidden_int4 - i) * static_cast<int>(sizeof(int4));
tma_store_1d(reinterpret_cast<int4*>(tma_buffer) + tma_stage_idx * 32,
recv_int4 + token_idx * hidden_int4 + i, tma_bytes, false);
}
__syncwarp();
__syncwarp(warp_mask);
#else
recv_int4[token_idx * hidden_int4 + i] = out_int4;
#endif
Expand Down
4 changes: 2 additions & 2 deletions csrc/kernels/utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ __forceinline__ __device__ int get_lane_id() {
return lane_id;
}

__device__ __forceinline__ uint32_t elect_one_sync() {
__device__ __forceinline__ uint32_t elect_one_sync(const uint32_t membermask=0xffffffff) {
#ifndef DISABLE_SM90_FEATURES
uint32_t pred = 0;
asm volatile(
Expand All @@ -321,7 +321,7 @@ __device__ __forceinline__ uint32_t elect_one_sync() {
"@%%px mov.s32 %0, 1;\n"
"}\n"
: "+r"(pred)
: "r"(0xffffffff));
: "r"(membermask));
return pred;
#else
return get_lane_id() == 0;
Expand Down
7 changes: 5 additions & 2 deletions tests/test_intranode.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,11 @@ def test_main(args: argparse.Namespace, num_sms: int, local_rank: int, num_ranks
# Random data
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * rank
x_pure_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
x_e4m3 = per_token_cast_to_fp8(x) if deep_ep.Buffer.is_sm90_compiled() else None
x_e4m3 = (x_e4m3[0], x_e4m3[1].T.contiguous().T) if x_e4m3 is not None else None

x_e4m3 = None
if hidden % 128 == 0:
x_e4m3 = per_token_cast_to_fp8(x) if deep_ep.Buffer.is_sm90_compiled() else None
x_e4m3 = (x_e4m3[0], x_e4m3[1].T.contiguous().T) if x_e4m3 is not None else None
scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1
topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=False)[1]
topk_weights = torch.ones((num_tokens, num_topk), dtype=torch.float32, device='cuda') * rank
Expand Down