@@ -596,10 +596,6 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
596596
597597 // TMA stuffs
598598#ifndef DISABLE_SM90_FEATURES
599- auto set_n_bits = [](const int x) -> uint32_t {
600- static constexpr uint64_t one = 1 ;
601- return static_cast <uint32_t >((one << x) - one);
602- };
603599 extern __shared__ __align__ (1024 ) uint8_t smem_buffer[];
604600 auto tma_buffer = smem_buffer + (thread_id / 32 ) * kNumTMABytesPerWarp ;
605601#endif
@@ -841,11 +837,10 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
841837 out_dtypes[j] = static_cast <dtype_t >(values[j]);
842838
843839#ifndef DISABLE_SM90_FEATURES
844- const int num_participating_threads = min (32 , hidden_int4 - i);
845- const unsigned int warp_mask = set_n_bits (num_participating_threads);
840+ auto const warp_mask = __activemask ();
846841
847842 // Wait TMA arrival
848- if (elect_one_sync ())
843+ if (elect_one_sync (warp_mask ))
849844 tma_store_wait<kNumStages - 1 >();
850845 __syncwarp (warp_mask);
851846
@@ -856,7 +851,7 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
856851 // Issue TMA
857852 tma_store_fence ();
858853 __syncwarp (warp_mask);
859- if (elect_one_sync ()) {
854+ if (elect_one_sync (warp_mask )) {
860855 auto tma_bytes = min (32 , hidden_int4 - i) * static_cast <int >(sizeof (int4 ));
861856 tma_store_1d (reinterpret_cast <int4 *>(tma_buffer) + tma_stage_idx * 32 ,
862857 recv_int4 + token_idx * hidden_int4 + i, tma_bytes, false );
0 commit comments