8181 __VA_ARGS__ \
8282 }
8383
84- #define DISPATCH_NUM_FRAGS_X (num_frags_x, NUM_FRAGS_X, ...) \
85- if (num_frags_x == 1 ) { \
86- constexpr size_t NUM_FRAGS_X = 1 ; \
87- __VA_ARGS__ \
88- } else if (num_frags_x == 2 ) { \
89- constexpr size_t NUM_FRAGS_X = 2 ; \
90- __VA_ARGS__ \
91- } else { \
92- std::cerr << " Unsupported num_frags_x: " << num_frags_x << std::endl; \
84+ #define DISPATCH_NUM_FRAGS_X (num_frags_x, NUM_FRAGS_X, ...) \
85+ if (num_frags_x == 1 ) { \
86+ constexpr size_t NUM_FRAGS_X = 1 ; \
87+ __VA_ARGS__ \
88+ } else if (num_frags_x == 2 ) { \
89+ constexpr size_t NUM_FRAGS_X = 2 ; \
90+ __VA_ARGS__ \
91+ } else { \
92+ std::ostringstream err_msg; \
93+ err_msg << " Unsupported num_frags_x: " << num_frags_x; \
94+ throw std::invalid_argument (err_msg.str ()); \
9395 }
9496
95- #define DISPATCH_NUM_FRAGS_Z (max_frags_z, NUM_FRAGS_Z, ...) \
96- if (max_frags_z == 4 ) { \
97- constexpr size_t NUM_FRAGS_Z = 4 ; \
98- __VA_ARGS__ \
99- } else if (max_frags_z == 2 ) { \
100- constexpr size_t NUM_FRAGS_Z = 2 ; \
101- __VA_ARGS__ \
102- } else { \
103- std::cerr << " Unsupported max_frags_z: " << max_frags_z << std::endl; \
97+ #define DISPATCH_NUM_FRAGS_Z (max_frags_z, NUM_FRAGS_Z, ...) \
98+ if (max_frags_z >= 4 ) { \
99+ constexpr size_t NUM_FRAGS_Z = 4 ; \
100+ __VA_ARGS__ \
101+ } else if (max_frags_z >= 2 ) { \
102+ constexpr size_t NUM_FRAGS_Z = 2 ; \
103+ __VA_ARGS__ \
104+ } else if (max_frags_z >= 1 ) { \
105+ constexpr size_t NUM_FRAGS_Z = 1 ; \
106+ __VA_ARGS__ \
107+ } else { \
108+ std::ostringstream err_msg; \
109+ err_msg << " Unsupported max_frags_z: " << max_frags_z; \
110+ throw std::invalid_argument (err_msg.str ()); \
104111 }
105112
106- #define DISPATCH_GQA_GROUP_SIZE (group_size, GROUP_SIZE, ...) \
107- if (group_size == 1 ) { \
108- constexpr size_t GROUP_SIZE = 1 ; \
109- __VA_ARGS__ \
110- } else if (group_size == 4 ) { \
111- constexpr size_t GROUP_SIZE = 4 ; \
112- __VA_ARGS__ \
113- } else if (group_size == 8 ) { \
114- constexpr size_t GROUP_SIZE = 8 ; \
115- __VA_ARGS__ \
116- } else { \
117- std::cerr << " Unsupported group_size: " << group_size << std::endl; \
113+ #define DISPATCH_GQA_GROUP_SIZE (group_size, GROUP_SIZE, ...) \
114+ if (group_size == 1 ) { \
115+ constexpr size_t GROUP_SIZE = 1 ; \
116+ __VA_ARGS__ \
117+ } else if (group_size == 4 ) { \
118+ constexpr size_t GROUP_SIZE = 4 ; \
119+ __VA_ARGS__ \
120+ } else if (group_size == 8 ) { \
121+ constexpr size_t GROUP_SIZE = 8 ; \
122+ __VA_ARGS__ \
123+ } else { \
124+ std::ostringstream err_msg; \
125+ err_msg << " Unsupported group_size: " << group_size; \
126+ throw std::invalid_argument (err_msg.str ()); \
118127 }
119128
120129#define DISPATCH_CAUSAL (causal, CAUSAL, ...) \
169178 } \
170179 }
171180
172- #define DISPATCH_HEAD_DIM_PREFILL (head_dim, HEAD_DIM, ...) \
173- switch (head_dim) { \
174- case 64 : { \
175- constexpr size_t HEAD_DIM = 64 ; \
176- __VA_ARGS__ \
177- break ; \
178- } \
179- case 128 : { \
180- constexpr size_t HEAD_DIM = 128 ; \
181- __VA_ARGS__ \
182- break ; \
183- } \
184- default : { \
185- std::ostringstream err_msg; \
186- err_msg << " Unsupported head_dim: " << head_dim; \
187- throw std::invalid_argument (err_msg.str ()); \
188- } \
189- }
190-
191181#define DISPATCH_ROTARY_MODE (rotary_mode, ROTARY_MODE, ...) \
192182 switch (rotary_mode) { \
193183 case RotaryMode::kNone : { \
@@ -222,7 +212,7 @@ __forceinline__ __device__ __host__ T1 ceil_div(const T1 x, const T2 y) {
222212
223213template <typename IdType>
224214std::tuple<IdType, IdType, std::vector<IdType>, std::vector<IdType>> split_qo_indptr (
225- IdType* qo_indptr, uint32_t batch_size, uint32_t gqa_group_size,
215+ IdType* qo_indptr, uint32_t batch_size, uint32_t gqa_group_size, uint32_t head_dim,
226216 cudaStream_t stream = nullptr ) {
227217 constexpr uint32_t num_warps = 4 ;
228218 std::vector<IdType> qo_indptr_h (batch_size + 1 ), request_indices, tile_indices;
@@ -235,7 +225,7 @@ std::tuple<IdType, IdType, std::vector<IdType>, std::vector<IdType>> split_qo_in
235225
236226 const uint32_t total_q_len = qo_indptr_h[batch_size];
237227 const bool avg_len_greater_than_64 = total_q_len * gqa_group_size > 64 * batch_size;
238- const uint32_t num_frags_x = avg_len_greater_than_64 ? 2 : 1 ;
228+ const uint32_t num_frags_x = (head_dim < 256 && avg_len_greater_than_64) ? 2 : 1 ;
239229 const uint32_t num_rows_per_cta = num_frags_x * num_warps * 16 ;
240230 uint32_t num_qo_tiles = 0 ;
241231
0 commit comments