Skip to content

Commit 9bfbcd6

Browse files
author
Varun Sundar Rabindranath
committed
add warpmask for __syncwarp
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
1 parent 9af0e0d commit 9bfbcd6

File tree

2 files changed

+15
-5
lines changed

2 files changed

+15
-5
lines changed

csrc/kernels/intranode.cu

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,10 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
600600

601601
// TMA stuffs
602602
#ifndef DISABLE_SM90_FEATURES
603+
auto set_n_bits = [](const int x) -> uint32_t {
604+
static constexpr uint64_t one = 1;
605+
return static_cast<uint32_t>((one << x) - one);
606+
};
603607
extern __shared__ __align__(1024) uint8_t smem_buffer[];
604608
auto tma_buffer = smem_buffer + (thread_id / 32) * kNumTMABytesPerWarp;
605609
#endif
@@ -839,24 +843,27 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
839843
out_dtypes[j] = static_cast<dtype_t>(values[j]);
840844

841845
#ifndef DISABLE_SM90_FEATURES
846+
const int num_participating_threads = min(32, hidden_int4 - i);
847+
const unsigned int warp_mask = set_n_bits(num_participating_threads);
848+
842849
// Wait TMA arrival
843850
if (lane_id == 0)
844851
tma_store_wait<kNumStages - 1>();
845-
__syncwarp();
852+
__syncwarp(warp_mask);
846853

847854
// Write into TMA buffer
848855
auto tma_stage_idx = (i / 32) % kNumStages;
849856
reinterpret_cast<int4*>(tma_buffer)[tma_stage_idx * 32 + lane_id] = out_int4;
850857

851858
// Issue TMA
852859
tma_store_fence();
853-
__syncwarp();
860+
__syncwarp(warp_mask);
854861
if (lane_id == 0) {
855862
auto tma_bytes = min(32, hidden_int4 - i) * static_cast<int>(sizeof(int4));
856863
tma_store_1d(reinterpret_cast<int4*>(tma_buffer) + tma_stage_idx * 32,
857864
recv_int4 + token_idx * hidden_int4 + i, tma_bytes, false);
858865
}
859-
__syncwarp();
866+
__syncwarp(warp_mask);
860867
#else
861868
recv_int4[token_idx * hidden_int4 + i] = out_int4;
862869
#endif

tests/test_intranode.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,11 @@ def test_main(args: argparse.Namespace, num_sms: int, local_rank: int, num_ranks
2525
# Random data
2626
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * rank
2727
x_pure_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
28-
x_e4m3 = per_token_cast_to_fp8(x) if deep_ep.Buffer.is_sm90_compiled() else None
29-
x_e4m3 = (x_e4m3[0], x_e4m3[1].T.contiguous().T) if x_e4m3 is not None else None
28+
29+
x_e4m3 = None
30+
if hidden % 128 == 0:
31+
x_e4m3 = per_token_cast_to_fp8(x) if deep_ep.Buffer.is_sm90_compiled() else None
32+
x_e4m3 = (x_e4m3[0], x_e4m3[1].T.contiguous().T) if x_e4m3 is not None else None
3033
scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1
3134
topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=False)[1]
3235
topk_weights = torch.ones((num_tokens, num_topk), dtype=torch.float32, device='cuda') * rank

0 commit comments

Comments
 (0)