@@ -16,7 +16,6 @@ def fused_moe_kernel(
16
16
expert_ids_ptr ,
17
17
num_tokens_post_padded_ptr ,
18
18
# Matrix dimensions
19
- M ,
20
19
N ,
21
20
K ,
22
21
EM ,
@@ -86,10 +85,9 @@ def fused_moe_kernel(
86
85
a_ptrs = a_ptr + (offs_token [:, None ] // top_k * stride_am +
87
86
offs_k [None , :] * stride_ak )
88
87
89
- #
90
- off_experts = tl .load (expert_ids_ptr + pid_m ) * stride_be
91
- b_ptrs = b_ptr + off_experts + (offs_k [:, None ] * stride_bk +
92
- offs_bn [None , :] * stride_bn )
88
+ off_experts = tl .load (expert_ids_ptr + pid_m )
89
+ b_ptrs = b_ptr + off_experts * stride_be + (offs_k [:, None ] * stride_bk +
90
+ offs_bn [None , :] * stride_bn )
93
91
94
92
# -----------------------------------------------------------
95
93
# Iterate to compute a block of the C matrix.
@@ -129,7 +127,7 @@ def fused_moe_kernel(
129
127
tl .store (c_ptrs , accumulator , mask = c_mask )
130
128
131
129
132
- def alig_block_size (
130
+ def moe_align_block_size (
133
131
topk_ids : torch .Tensor , block_size : int ,
134
132
num_experts : int ) -> (torch .Tensor , torch .Tensor , torch .Tensor ):
135
133
"""
@@ -169,11 +167,48 @@ def alig_block_size(
169
167
num_tokens_post_pad = torch .empty ((1 ),
170
168
dtype = torch .int32 ,
171
169
device = topk_ids .device )
172
- ops .moe_alig_block_size (topk_ids , num_experts , block_size , sorted_ids ,
173
- expert_ids , num_tokens_post_pad )
170
+ ops .moe_align_block_size (topk_ids , num_experts , block_size , sorted_ids ,
171
+ expert_ids , num_tokens_post_pad )
174
172
return sorted_ids , expert_ids , num_tokens_post_pad
175
173
176
174
175
+ def invoke_fused_moe_kernel (A : torch .Tensor , B : torch .Tensor , C : torch .Tensor ,
176
+ topk_weights : torch .Tensor , topk_ids : torch .Tensor ,
177
+ sorted_token_ids : torch .Tensor ,
178
+ expert_ids : torch .Tensor ,
179
+ num_tokens_post_padded : torch .Tensor ,
180
+ mul_routed_weight : bool , top_k : int , config : dict ):
181
+ grid = lambda META : (triton .cdiv (sorted_token_ids .shape [0 ], META [
182
+ 'BLOCK_SIZE_M' ]) * triton .cdiv (B .shape [1 ], META ['BLOCK_SIZE_N' ]), )
183
+
184
+ fused_moe_kernel [grid ](
185
+ A ,
186
+ B ,
187
+ C ,
188
+ topk_weights ,
189
+ sorted_token_ids ,
190
+ expert_ids ,
191
+ num_tokens_post_padded ,
192
+ B .shape [1 ],
193
+ B .shape [2 ],
194
+ sorted_token_ids .shape [0 ],
195
+ topk_ids .numel (),
196
+ A .stride (0 ),
197
+ A .stride (1 ),
198
+ B .stride (0 ),
199
+ B .stride (2 ),
200
+ B .stride (1 ),
201
+ C .stride (1 ),
202
+ C .stride (2 ),
203
+ topk_weights .stride (1 ),
204
+ sorted_token_ids .stride (0 ),
205
+ MUL_ROUTED_WEIGHT = mul_routed_weight ,
206
+ top_k = top_k ,
207
+ compute_type = tl .bfloat16 if A .dtype == torch .bfloat16 else tl .float16 ,
208
+ ** config ,
209
+ )
210
+
211
+
177
212
def fused_moe (hidden_states : torch .Tensor ,
178
213
w1 : torch .Tensor ,
179
214
w2 : torch .Tensor ,
@@ -196,11 +231,12 @@ def fused_moe(hidden_states: torch.Tensor,
196
231
"""
197
232
# Check constraints.
198
233
assert hidden_states .shape [1 ] == w1 .shape [2 ], "Incompatible dimensions"
199
- assert hidden_states .is_contiguous (), "Matrix A must be contiguous"
200
- assert w1 .is_contiguous (), "Matrix B must be contiguous"
234
+ assert hidden_states .is_contiguous (), "Hidden_states must be contiguous"
235
+ assert w1 .is_contiguous (), "Expert weights1 must be contiguous"
236
+ assert w2 .is_contiguous (), "Expert weights2 must be contiguous"
201
237
assert hidden_states .dtype in [torch .float16 , torch .bfloat16 ]
202
- M , K = hidden_states .shape
203
- E , N , K = w1 .shape
238
+ M , _ = hidden_states .shape
239
+ E , N , _ = w1 .shape
204
240
205
241
config = {
206
242
'BLOCK_SIZE_M' : 64 ,
@@ -227,73 +263,21 @@ def fused_moe(hidden_states: torch.Tensor,
227
263
device = hidden_states .device ,
228
264
dtype = hidden_states .dtype )
229
265
230
- sorted_token_ids , expert_ids , num_tokens_post_padded = alig_block_size (
266
+ sorted_token_ids , expert_ids , num_tokens_post_padded = moe_align_block_size (
231
267
topk_ids , config ['BLOCK_SIZE_M' ], E )
232
- # 1D launch kernel where each block gets its own program.
233
- grid = lambda META : (triton .cdiv (sorted_token_ids .shape [0 ], META [
234
- 'BLOCK_SIZE_M' ]) * triton .cdiv (N , META ['BLOCK_SIZE_N' ]), )
235
268
236
- fused_moe_kernel [grid ](
237
- hidden_states ,
238
- w1 ,
239
- intermediate_cache1 ,
240
- topk_weights ,
241
- sorted_token_ids ,
242
- expert_ids ,
243
- num_tokens_post_padded ,
244
- M ,
245
- N ,
246
- K ,
247
- sorted_token_ids .shape [0 ],
248
- topk_ids .numel (),
249
- hidden_states .stride (0 ),
250
- hidden_states .stride (1 ),
251
- w1 .stride (0 ),
252
- w1 .stride (2 ),
253
- w1 .stride (1 ),
254
- intermediate_cache1 .stride (1 ),
255
- intermediate_cache1 .stride (2 ),
256
- topk_weights .stride (1 ),
257
- sorted_token_ids .stride (0 ),
258
- MUL_ROUTED_WEIGHT = False ,
259
- top_k = topk_ids .shape [1 ],
260
- compute_type = tl .bfloat16
261
- if hidden_states .dtype == torch .bfloat16 else tl .float16 ,
262
- ** config ,
263
- )
269
+ invoke_fused_moe_kernel (hidden_states , w1 , intermediate_cache1 ,
270
+ topk_weights , topk_ids , sorted_token_ids ,
271
+ expert_ids , num_tokens_post_padded , False ,
272
+ topk_ids .shape [1 ], config )
264
273
265
274
ops .silu_and_mul (intermediate_cache2 , intermediate_cache1 .view (- 1 , N ))
266
275
267
- grid = lambda META : (triton .cdiv (sorted_token_ids .shape [0 ], META [
268
- 'BLOCK_SIZE_M' ]) * triton .cdiv (w2 .shape [1 ], META ['BLOCK_SIZE_N' ]), )
269
- fused_moe_kernel [grid ](
270
- intermediate_cache2 ,
271
- w2 ,
272
- intermediate_cache3 ,
273
- topk_weights ,
274
- sorted_token_ids ,
275
- expert_ids ,
276
- num_tokens_post_padded ,
277
- M ,
278
- w2 .shape [1 ],
279
- w2 .shape [2 ],
280
- sorted_token_ids .shape [0 ],
281
- topk_ids .numel (),
282
- intermediate_cache2 .stride (0 ),
283
- intermediate_cache2 .stride (1 ),
284
- w2 .stride (0 ),
285
- w2 .stride (2 ),
286
- w2 .stride (1 ),
287
- intermediate_cache3 .stride (1 ),
288
- intermediate_cache3 .stride (2 ),
289
- topk_weights .stride (1 ),
290
- sorted_token_ids .stride (0 ),
291
- MUL_ROUTED_WEIGHT = True ,
292
- top_k = 1 , #
293
- compute_type = tl .bfloat16
294
- if hidden_states .dtype == torch .bfloat16 else tl .float16 ,
295
- ** config ,
296
- )
276
+ invoke_fused_moe_kernel (intermediate_cache2 , w2 , intermediate_cache3 ,
277
+ topk_weights , topk_ids , sorted_token_ids ,
278
+ expert_ids , num_tokens_post_padded , True , 1 ,
279
+ config )
280
+
297
281
if inplace :
298
282
return torch .sum (intermediate_cache3 .view (* intermediate_cache3 .shape ),
299
283
dim = 1 ,
0 commit comments