1
1
# SPDX-License-Identifier: Apache-2.0
2
2
""" CUTLASS based Fused MoE kernels."""
3
- import os
4
3
from typing import Optional
5
4
6
5
import torch
@@ -271,8 +270,6 @@ def cutlass_moe_fp8(
271
270
272
271
FLOAT4_E2M1_MAX = scalar_types .float4_e2m1f .max ()
273
272
FLOAT8_E4M3_MAX = torch .finfo (torch .float8_e4m3fn ).max
274
- MAX_TOKENS_PER_EXPERT = int (
275
- os .environ .get ('VLLM_MODELOPT_MAX_TOKENS_PER_EXPERT' , '65536' ))
276
273
277
274
278
275
def cutlass_moe_fp4 (a : torch .Tensor , a1_gscale : torch .Tensor ,
@@ -330,10 +327,7 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
330
327
assert a .dtype in [torch .half , torch .bfloat16 ], "Invalid input dtype"
331
328
assert (topk_weights .shape [0 ] == m and topk_ids .shape [0 ]
332
329
== m ), ("topk must be provided for each row of a" )
333
- assert (m <= MAX_TOKENS_PER_EXPERT ), (
334
- f"m must be less than MAX_TOKENS_PER_EXPERT({ MAX_TOKENS_PER_EXPERT } )"
335
- f" for cutlass_moe_fp4, observed m = { m } . Use"
336
- f" VLLM_MODELOPT_MAX_TOKENS_PER_EXPERT to set this value." )
330
+
337
331
out_dtype = a .dtype
338
332
num_topk = topk_ids .shape [1 ]
339
333
@@ -362,8 +356,7 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
362
356
expert_offsets ,
363
357
blockscale_offsets ,
364
358
num_topk ,
365
- expert_map = a_map ,
366
- MAX_TOKENS_PER_EXPERT = MAX_TOKENS_PER_EXPERT )
359
+ expert_map = a_map )
367
360
368
361
c1 = ops .cutlass_fp4_moe_mm (rep_a_fp4 , w1_fp4 , rep_a_blockscale ,
369
362
w1_blockscale , w1_alphas , problem_sizes1 ,
@@ -378,12 +371,7 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
378
371
torch .ops ._C .silu_and_mul (intermediate , c1 )
379
372
380
373
int_fp4 , int_blockscale = ops .scaled_fp4_experts_quant (
381
- intermediate ,
382
- a2_gscale ,
383
- expert_offsets ,
384
- blockscale_offsets ,
385
- num_topk ,
386
- MAX_TOKENS_PER_EXPERT = MAX_TOKENS_PER_EXPERT )
374
+ intermediate , a2_gscale , expert_offsets , blockscale_offsets , num_topk )
387
375
388
376
c2 = ops .cutlass_fp4_moe_mm (int_fp4 , w2_fp4 , int_blockscale , w2_blockscale ,
389
377
w2_alphas , problem_sizes2 , expert_offsets [:- 1 ],
0 commit comments