Skip to content

Commit f7fb36f

Browse files
author
Varun Sundar Rabindranath
committed
pass membermask to elect_one_sync
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
1 parent 73ec503 commit f7fb36f

File tree

2 files changed

+5
-10
lines changed

2 files changed

+5
-10
lines changed

csrc/kernels/intranode.cu

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -596,10 +596,6 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
596596

597597
// TMA stuffs
598598
#ifndef DISABLE_SM90_FEATURES
599-
auto set_n_bits = [](const int x) -> uint32_t {
600-
static constexpr uint64_t one = 1;
601-
return static_cast<uint32_t>((one << x) - one);
602-
};
603599
extern __shared__ __align__(1024) uint8_t smem_buffer[];
604600
auto tma_buffer = smem_buffer + (thread_id / 32) * kNumTMABytesPerWarp;
605601
#endif
@@ -841,11 +837,10 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
841837
out_dtypes[j] = static_cast<dtype_t>(values[j]);
842838

843839
#ifndef DISABLE_SM90_FEATURES
844-
const int num_participating_threads = min(32, hidden_int4 - i);
845-
const unsigned int warp_mask = set_n_bits(num_participating_threads);
840+
auto const warp_mask = __activemask();
846841

847842
// Wait TMA arrival
848-
if (elect_one_sync())
843+
if (elect_one_sync(warp_mask))
849844
tma_store_wait<kNumStages - 1>();
850845
__syncwarp(warp_mask);
851846

@@ -856,7 +851,7 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
856851
// Issue TMA
857852
tma_store_fence();
858853
__syncwarp(warp_mask);
859-
if (elect_one_sync()) {
854+
if (elect_one_sync(warp_mask)) {
860855
auto tma_bytes = min(32, hidden_int4 - i) * static_cast<int>(sizeof(int4));
861856
tma_store_1d(reinterpret_cast<int4*>(tma_buffer) + tma_stage_idx * 32,
862857
recv_int4 + token_idx * hidden_int4 + i, tma_bytes, false);

csrc/kernels/utils.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ __forceinline__ __device__ int get_lane_id() {
310310
return lane_id;
311311
}
312312

313-
__device__ __forceinline__ uint32_t elect_one_sync() {
313+
__device__ __forceinline__ uint32_t elect_one_sync(const uint32_t membermask=0xffffffff) {
314314
#ifndef DISABLE_SM90_FEATURES
315315
uint32_t pred = 0;
316316
asm volatile(
@@ -321,7 +321,7 @@ __device__ __forceinline__ uint32_t elect_one_sync() {
321321
"@%%px mov.s32 %0, 1;\n"
322322
"}\n"
323323
: "+r"(pred)
324-
: "r"(0xffffffff));
324+
: "r"(membermask));
325325
return pred;
326326
#else
327327
return get_lane_id() == 0;

0 commit comments

Comments
 (0)