Skip to content

Commit 0d4e26b

Browse files
authored
Add GeGLU support to trtllm-gen NVFP4 Fused MoE Kernel (#1525)
<!-- .github/pull_request_template.md --> ## πŸ“Œ Description Added GeGLU support to trtllm-gen NVFP4 Fused MoE kernels. Also added TopK routing. ## πŸ” Related Issues <!-- Link any related issues here --> ## πŸš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### βœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## πŸ§ͺ Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> ## Collaborator @azhurkevich
1 parent 669ff33 commit 0d4e26b

23 files changed

+497
-229
lines changed

β€Žbenchmarks/README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ python3 flashinfer_benchmark.py --routine mm_fp4 --m 8192 --n 4096 --k 16384 --o
4545
# MOE FP4 Block Scale (DeepSeekV3 routing)
4646
python3 flashinfer_benchmark.py --routine trtllm_fp4_block_scale_moe --num_tokens 1024 --hidden_size 1024 --intermediate_size 1024 --num_experts 128 --top_k 8 --n_group 8 --topk_group 4 --routed_scaling_factor 2.5 --use_routing_bias --routing_method deepseek_v3 --use_shuffled_weight --verbose --generate_repro_command
4747

48+
# MOE FP4 Block Scale (topk routing, GeGlu gated act)
49+
python3 flashinfer_benchmark.py --routine trtllm_fp4_block_scale_moe --num_tokens 1024 --hidden_size 1024 --intermediate_size 1024 --num_experts 128 --top_k 8 --routing_method topk --use_shuffled_weight --gated_act geglu --verbose --generate_repro_command
50+
4851
# MOE FP8 Block Scale with DeepSeekV3 routing
4952
python3 flashinfer_benchmark.py --routine trtllm_fp8_block_scale_moe --num_tokens 1024 --hidden_size 1024 --intermediate_size 1024 --num_experts 128 --top_k 8 --n_group 8 --topk_group 4 --routed_scaling_factor 2.5 --use_routing_bias --routing_method deepseek_v3 --use_shuffled_weight --verbose --generate_repro_command
5053

@@ -148,6 +151,7 @@ The output CSV will contain detailed metrics including:
148151
| `--tp_rank` | Tensor-parallel rank |
149152
| `--ep_size` | Expert-parallel world size |
150153
| `--ep_rank` | Expert-parallel rank |
154+
| `--gated_act` | Gated activation function: `swiglu` (default) or `geglu` |
151155

152156
### MOE Routing Method Compatibility
153157

β€Žbenchmarks/bench_trtllm_gen_fused_moe_autotuner.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import torch
44
import numpy as np
55
from flashinfer import (
6+
RoutingMethodType,
7+
GatedActType,
68
fp4_quantize,
79
mxfp8_quantize,
810
next_positive_power_of_2,
@@ -156,9 +158,10 @@ def bench_trtllm_gen_fused_moe_autotuner(
156158
num_experts,
157159
None, # routed_scaling_factor
158160
tile_tokens_dim,
159-
1,
161+
RoutingMethodType.Renormalize.value[0],
160162
True,
161163
enable_pdl,
164+
GatedActType.SwiGlu.value, # gated_act_type
162165
None,
163166
num_tokens if tune_max_num_tokens is None else tune_max_num_tokens,
164167
)

β€Žbenchmarks/routines/flashinfer_benchmark_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
"use_routing_scales_on_input",
6060
"input_dtype",
6161
"weight_dtype",
62+
"gated_act",
6263
# CUTLASS fused MoE specific
6364
"cutlass_variant",
6465
"quantized_input",

β€Žbenchmarks/routines/moe.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,10 @@ def parse_moe_args(line, parser):
132132
"deepseek_v3",
133133
"llama4",
134134
"renormalize_naive",
135+
"topk",
135136
],
136137
help=(
137-
"Routing method: renormalize | deepseek_v3 | llama4 | renormalize_naive."
138+
"Routing method: renormalize | deepseek_v3 | llama4 | renormalize_naive | topk."
138139
),
139140
)
140141
parser.add_argument(
@@ -177,6 +178,14 @@ def parse_moe_args(line, parser):
177178
default="bfloat16",
178179
help="Data type of the weights (before quantization).",
179180
)
181+
parser.add_argument(
182+
"--gated_act",
183+
type=str,
184+
required=False,
185+
default="swiglu",
186+
choices=["swiglu", "geglu"],
187+
help="Type of gated activation function: swiglu | geglu.",
188+
)
180189

181190
# CUTLASS fused MoE specific
182191
parser.add_argument(
@@ -225,13 +234,22 @@ def parse_moe_args(line, parser):
225234
args = parser.parse_args(line)
226235

227236
# Normalize routing method (map string to internal int expected by kernels)
228-
name_to_type = {
237+
routing_method_name_to_type = {
229238
"renormalize": 1,
230239
"deepseek_v3": 2,
231240
"llama4": 3,
232241
"renormalize_naive": 4,
242+
"topk": 5,
233243
}
234-
args.routing_method_type = name_to_type[args.routing_method]
244+
args.routing_method_type = routing_method_name_to_type[args.routing_method]
245+
246+
# Normalize gated act type (map string to internal int expected by kernels)
247+
gated_act_name_to_type = {
248+
"swiglu": 0,
249+
"geglu": 1,
250+
}
251+
args.gated_act_type = gated_act_name_to_type[args.gated_act]
252+
235253
if args.verbose >= 1:
236254
print(f"[INFO] {args = }")
237255
return args
@@ -451,8 +469,7 @@ def get_effective_bytes(dtype: torch.dtype, fmt: Optional[str]) -> float:
451469
if active_experts is not None:
452470
num_active_experts = active_experts
453471
else:
454-
# CUTLASS MoE does not support active_experts, so we return -1
455-
return -1
472+
num_active_experts = min(num_experts, top_k * num_tokens)
456473
weight_bytes = num_active_experts * weight_bytes_per_expert
457474

458475
# Output memory (typically full precision)
@@ -539,6 +556,7 @@ def testTrtllmFp4BlockScaleMoe(args):
539556
use_shuffled_weight = args.use_shuffled_weight
540557
weight_layout = args.weight_layout
541558
is_cuda_graph_compatible = not args.no_cuda_graph
559+
gated_act_type = args.gated_act_type
542560

543561
if args.verbose >= 1:
544562
print(
@@ -669,6 +687,7 @@ def run_fp4_moe():
669687
routed_scaling_factor=routed_scaling_factor,
670688
tile_tokens_dim=tile_tokens_dim,
671689
routing_method_type=routing_method_type,
690+
gated_act_type=gated_act_type,
672691
do_finalize=True,
673692
)
674693

@@ -745,6 +764,7 @@ def run_fp4_moe():
745764
cur_res["use_routing_scales_on_input"] = args.use_routing_scales_on_input
746765
cur_res["input_dtype"] = input_dtype
747766
cur_res["weight_dtype"] = weight_dtype
767+
cur_res["gated_act"] = args.gated_act
748768
res.append(cur_res)
749769

750770
return res

β€Žcsrc/trtllm_batched_gemm_runner.cu

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,10 @@ TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner(
103103
tileSize == mOptions.tileSize &&
104104
options.mUseShuffledMatrixA == mOptions.useShuffledMatrixA &&
105105
options.mLayoutA == mOptions.weightLayout) {
106-
// FIXME: Disable split-k for now.
107-
if (options.mClusterDimZ != 1) {
106+
// FIXME: Disable split-k for swiglu for now.
107+
if (static_cast<batchedGemm::gemmGatedAct::ActType>(mOptions.actType) ==
108+
batchedGemm::gemmGatedAct::ActType::SwiGlu &&
109+
options.mClusterDimZ != 1) {
108110
continue;
109111
}
110112

@@ -213,8 +215,8 @@ void TrtllmGenBatchedGemmRunner::run(
213215
gemmData.mInputBuffers.mPtrPerTokenSfB =
214216
mOptions.transposeMmaOutput ? perTokensSfA : perTokensSfB;
215217
gemmData.mInputBuffers.mPtrBias = ptrBias;
216-
gemmData.mInputBuffers.mPtrSwiGluAlpha = ptrAlpha;
217-
gemmData.mInputBuffers.mPtrSwiGluBeta = ptrBeta;
218+
gemmData.mInputBuffers.mPtrGatedActAlpha = ptrAlpha;
219+
gemmData.mInputBuffers.mPtrGatedActBeta = ptrBeta;
218220
gemmData.mInputBuffers.mPtrClampLimit = ptrClampLimit;
219221

220222
gemmData.mInputBuffers.mPtrRouteMap = routeMap;

β€Žcsrc/trtllm_fused_moe_kernel_launcher.cu

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
namespace flashinfer {
4040

4141
namespace btg = batchedGemm::trtllm::gen;
42+
using tensorrt_llm::kernels::trtllmgen_moe::MoE::GatedActType;
4243
using tensorrt_llm::kernels::trtllmgen_moe::Routing::RoutingMethodType;
4344

4445
at::Tensor trtllm_fp8_per_tensor_scale_moe_launcher(
@@ -732,10 +733,11 @@ std::vector<at::Tensor> trtllm_fp4_block_scale_moe_launcher(
732733
} else if (static_cast<RoutingMethodType>(routing_method_type) ==
733734
RoutingMethodType::Renormalize ||
734735
static_cast<RoutingMethodType>(routing_method_type) ==
735-
RoutingMethodType::RenormalizeNaive) {
736+
RoutingMethodType::RenormalizeNaive ||
737+
static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::TopK) {
736738
TORCH_CHECK(
737739
top_k <= 8 && top_k > 0,
738-
"Current routing kernel (no groups, renormalize) only supports top_k<=8 && top_k>0.");
740+
"Current routing kernel (no groups, renormalize/topk) only supports top_k<=8 && top_k>0.");
739741
} else if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::Llama4) {
740742
TORCH_CHECK(top_k == 1, "Current routing kernel (no groups, Llama4) only supports top_k=1.");
741743
}
@@ -1058,8 +1060,8 @@ std::vector<at::Tensor> trtllm_fp4_block_scale_moe(
10581060
std::optional<int64_t> n_group, std::optional<int64_t> topk_group, int64_t intermediate_size,
10591061
int64_t local_expert_offset, int64_t local_num_experts,
10601062
std::optional<double> routed_scaling_factor, int64_t tile_tokens_dim,
1061-
int64_t routing_method_type, bool do_finalize, bool enable_pdl, at::Tensor& output,
1062-
int64_t config_index) {
1063+
int64_t routing_method_type, bool do_finalize, bool enable_pdl, int64_t gated_act_type,
1064+
at::Tensor& output, int64_t config_index) {
10631065
using RunnerType = tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner;
10641066

10651067
int const num_tokens = hidden_states.sizes()[0];
@@ -1110,7 +1112,7 @@ std::vector<at::Tensor> trtllm_fp4_block_scale_moe(
11101112
// Properly initialize the runner using make_unique like in the original code
11111113
auto mRunner = std::make_unique<RunnerType>(
11121114
mDtypeAct, mDtypeWeights, mUseDeepSeekFp8, (int32_t)tile_tokens_dim,
1113-
tensorrt_llm::kernels::ActType::SwiGlu, /*useShuffledMatrixA*/ true);
1115+
static_cast<GatedActType>(gated_act_type), /*useShuffledMatrixA*/ true);
11141116

11151117
if (config_index == -1) {
11161118
config_index = mRunner->getDefaultValidConfigIndex(top_k, hidden_size, intermediate_size,
@@ -1131,25 +1133,27 @@ int64_t trtllm_get_default_moe_configs(int64_t const tile_tokens_dim, int64_t co
11311133
int64_t const dtype_weights_, bool const useDeepSeekFp8,
11321134
int64_t const top_k, int64_t const hidden_size,
11331135
int64_t const intermediate_size,
1134-
int64_t const num_local_experts, int64_t const num_tokens) {
1136+
int64_t const num_local_experts,
1137+
int64_t const gated_act_type, int64_t const num_tokens) {
11351138
auto dtype_act = static_cast<btg::Dtype>(dtype_act_);
11361139
auto dtype_weights = static_cast<btg::Dtype>(dtype_weights_);
11371140
tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner moe_runner(
11381141
dtype_act, dtype_weights, useDeepSeekFp8, (int32_t)tile_tokens_dim,
1139-
tensorrt_llm::kernels::ActType::SwiGlu, /*useShuffledMatrixA*/ true);
1142+
static_cast<GatedActType>(gated_act_type), /*useShuffledMatrixA*/ true);
11401143
return moe_runner.getDefaultValidConfigIndex(top_k, hidden_size, intermediate_size,
11411144
num_local_experts, num_tokens);
11421145
}
11431146

11441147
std::vector<int64_t> trtllm_get_valid_moe_configs(
11451148
int64_t const tile_tokens_dim, int64_t const dtype_act_, int64_t const dtype_weights_,
11461149
bool const useDeepSeekFp8, int64_t const top_k, int64_t const hidden_size,
1147-
int64_t const intermediate_size, int64_t const num_local_experts, int64_t const num_tokens) {
1150+
int64_t const intermediate_size, int64_t const num_local_experts, int64_t const gated_act_type,
1151+
int64_t const num_tokens) {
11481152
auto dtype_act = static_cast<btg::Dtype>(dtype_act_);
11491153
auto dtype_weights = static_cast<btg::Dtype>(dtype_weights_);
11501154
tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner moe_runner(
11511155
dtype_act, dtype_weights, useDeepSeekFp8, (int32_t)tile_tokens_dim,
1152-
tensorrt_llm::kernels::ActType::SwiGlu, /*useShuffledMatrixA*/ true);
1156+
static_cast<GatedActType>(gated_act_type), /*useShuffledMatrixA*/ true);
11531157
return moe_runner.getValidConfigIndices(top_k, hidden_size, intermediate_size, num_local_experts,
11541158
num_tokens);
11551159
}

β€Žcsrc/trtllm_fused_moe_routing_renormalize.cu

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ __forceinline__ __device__ void routingTopKExperts(
3232
cg::thread_block_tile<WarpSize> const& warp, DataType (&score)[VecSize],
3333
int32_t (&idx)[VecSize], DataType (&warpTopKScore)[MaxNumTopExperts],
3434
int32_t (&warpTopKExpertIdx)[MaxNumTopExperts], int32_t const laneIdx, int32_t const numExperts,
35-
int32_t topK, InputType const* ptrScores, bool const normTopkProb) {
35+
int32_t topK, InputType const* ptrScores, bool const normTopkProb,
36+
bool const applySoftmaxAfterTopK) {
3637
DataType minScore = DataType{-INFINITY};
3738

3839
for (int i = 0; i < VecSize; i++) {
@@ -59,11 +60,14 @@ __forceinline__ __device__ void routingTopKExperts(
5960
warpTopKScore[laneIdx] = warpTopKScore[laneIdx] / sum;
6061
}
6162
} else {
62-
auto softmaxScore =
63-
calcSoftmax(warp, laneIdx < topK ? warpTopKScore[laneIdx] : minScore, laneIdx, topK);
64-
if (laneIdx < topK) {
65-
warpTopKScore[laneIdx] = softmaxScore;
63+
if (applySoftmaxAfterTopK) {
64+
auto softmaxScore =
65+
calcSoftmax(warp, laneIdx < topK ? warpTopKScore[laneIdx] : minScore, laneIdx, topK);
66+
if (laneIdx < topK) {
67+
warpTopKScore[laneIdx] = softmaxScore;
68+
}
6669
}
70+
// If applySoftmaxAfterTopK is false, we keep the raw TopK values without softmax
6771
}
6872
}
6973

@@ -113,7 +117,8 @@ __global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) __launch_bounds__(Nu
113117
if (validToken) {
114118
routingTopKExperts<BaseType, InputT, VecSize, KernelParams::DoSoftmaxBeforeTopK>(
115119
warp, score, idx, warpTopKScore, warpTopKExpertIdx, laneIdx, params.mNumExperts,
116-
params.mTopK, params.mPtrScores + scoreOffset, params.mNormTopkProb);
120+
params.mTopK, params.mPtrScores + scoreOffset, params.mNormTopkProb,
121+
params.mApplySoftmaxAfterTopK);
117122

118123
if (laneIdx < params.mTopK) {
119124
smemPackedScoreIdx[warpIdx * params.mTopK + laneIdx] =
@@ -205,7 +210,8 @@ __global__ void __launch_bounds__(NumThreadsHist)
205210

206211
routingTopKExperts<BaseType, InputT, VecSize, KernelParams::DoSoftmaxBeforeTopK>(
207212
warp, allScores, allExpertIdx, warpTopKScore, warpTopKExpertIdx, laneIdx,
208-
params.mNumExperts, params.mTopK, params.mPtrScores + scoreOffset, params.mNormTopkProb);
213+
params.mNumExperts, params.mTopK, params.mPtrScores + scoreOffset, params.mNormTopkProb,
214+
params.mApplySoftmaxAfterTopK);
209215

210216
if (laneIdx < params.mTopK) {
211217
PackedScoreIdx<OutputT> packedScore{static_cast<OutputT>(warpTopKScore[laneIdx]),

β€Žcsrc/trtllm_fused_moe_runner.cu

Lines changed: 35 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,9 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3
122122
routingData.mLocalExpertsStrideLog2 = 0;
123123
routingData.mNumLocalExperts = localNumExperts;
124124
moe::dev::routing::routingLlama4::run(routingData, stream);
125-
} else if (routingMethodType == RoutingMethodType::Renormalize /* default */
126-
|| routingMethodType == RoutingMethodType::RenormalizeNaive /* Softmax -> TopK */) {
125+
} else if (routingMethodType == RoutingMethodType::Renormalize /* default */
126+
|| routingMethodType == RoutingMethodType::RenormalizeNaive /* Softmax -> TopK */
127+
|| routingMethodType == RoutingMethodType::TopK /* TopK only (no softmax) */) {
127128
moe::dev::routing::routingRenormalize::Data routingData;
128129

129130
//
@@ -135,6 +136,7 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3
135136
routingData.mUsePdl = true;
136137
routingData.mDoSoftmaxBeforeTopK = routingMethodType == RoutingMethodType::RenormalizeNaive;
137138
routingData.mNormTopkProb = routingMethodType == RoutingMethodType::RenormalizeNaive;
139+
routingData.mApplySoftmaxAfterTopK = routingMethodType == RoutingMethodType::Renormalize;
138140

139141
routingData.mPtrScores = routingLogits;
140142

@@ -178,33 +180,41 @@ namespace PermuteGemm1 {
178180

179181
tensorrt_llm::kernels::TrtllmGenBatchedGemmRunnerOptions getOptions(
180182
btg::Dtype dtypeAct, btg::Dtype dtypeWeights, int32_t tileTokensDim, bool useDeepSeekFp8,
181-
ActType actType, bool useShuffledMatrixA, batchedGemm::gemm::MatrixLayout weightLayout) {
182-
tensorrt_llm::kernels::TrtllmGenBatchedGemmRunnerOptions options = {
183-
// Swap A and B dtypes because transposeMmaOutput is hardcoded to true
184-
.dtypeA = dtypeWeights,
185-
.dtypeB = dtypeAct,
186-
.dtypeC = dtypeAct,
187-
.actType = actType,
188-
.deepSeekFp8 = useDeepSeekFp8,
189-
.fusedAct = !useDeepSeekFp8,
190-
.routeAct = true,
191-
.staticBatch = false,
192-
.transposeMmaOutput = true,
193-
.tileSize = tileTokensDim,
194-
.epilogueTileM = useDeepSeekFp8 ? 64 : 128,
195-
.useShuffledMatrixA = useShuffledMatrixA,
196-
.weightLayout = weightLayout};
197-
return options;
183+
MoE::GatedActType gatedActType, bool useShuffledMatrixA,
184+
batchedGemm::gemm::MatrixLayout weightLayout) {
185+
if (gatedActType == MoE::GatedActType::SwiGlu || gatedActType == MoE::GatedActType::GeGlu) {
186+
ActType actType =
187+
(gatedActType == MoE::GatedActType::SwiGlu) ? ActType::SwiGlu : ActType::GeGlu;
188+
tensorrt_llm::kernels::TrtllmGenBatchedGemmRunnerOptions options = {
189+
// Swap A and B dtypes because transposeMmaOutput is hardcoded to true
190+
.dtypeA = dtypeWeights,
191+
.dtypeB = dtypeAct,
192+
.dtypeC = dtypeAct,
193+
.actType = actType,
194+
.deepSeekFp8 = useDeepSeekFp8,
195+
.fusedAct = !useDeepSeekFp8,
196+
.routeAct = true,
197+
.staticBatch = false,
198+
.transposeMmaOutput = true,
199+
.tileSize = tileTokensDim,
200+
.epilogueTileM = useDeepSeekFp8 ? 64 : 128,
201+
.useShuffledMatrixA = useShuffledMatrixA,
202+
.weightLayout = weightLayout};
203+
return options;
204+
} else {
205+
TORCH_CHECK(false, "Unimplemented gated act type %s of enum %d",
206+
MoE::serializeGatedActType(gatedActType).c_str(), (int)gatedActType);
207+
}
198208
}
199209

200210
Runner::Runner(btg::Dtype dtypeAct, btg::Dtype dtypeWeights, bool useDeepSeekFp8, int tileTokensDim,
201-
ActType actType, bool useShuffledMatrixA,
211+
MoE::GatedActType gatedActType, bool useShuffledMatrixA,
202212
batchedGemm::gemm::MatrixLayout weightLayout)
203213
: mDtypeAct(dtypeAct),
204214
mDtypeWeights(dtypeWeights),
205215
mTileTokensDim(tileTokensDim),
206216
mRunner(tensorrt_llm::kernels::TrtllmGenBatchedGemmRunner(
207-
getOptions(mDtypeAct, mDtypeWeights, mTileTokensDim, useDeepSeekFp8, actType,
217+
getOptions(mDtypeAct, mDtypeWeights, mTileTokensDim, useDeepSeekFp8, gatedActType,
208218
useShuffledMatrixA, weightLayout))) {}
209219

210220
void Runner::run(void* hiddenState, void* hiddenStateScale, void* weights, void* weightsScale,
@@ -352,10 +362,10 @@ std::vector<int64_t> Runner::getPassingConfigIndices() const {
352362

353363
namespace MoE {
354364
Runner::Runner(btg::Dtype dtypeAct, btg::Dtype dtypeWeights, bool useDeepSeekFp8,
355-
int32_t tileTokensDim, ActType actType, bool useShuffledMatrixA,
365+
int32_t tileTokensDim, GatedActType gatedActType, bool useShuffledMatrixA,
356366
batchedGemm::gemm::MatrixLayout weightLayout)
357367
: mPermuteGemm1(PermuteGemm1::Runner(dtypeAct, dtypeWeights, useDeepSeekFp8, tileTokensDim,
358-
actType, useShuffledMatrixA, weightLayout)),
368+
gatedActType, useShuffledMatrixA, weightLayout)),
359369
mGemm2(Gemm2::Runner(dtypeAct, dtypeWeights, btg::Dtype::Bfloat16, useDeepSeekFp8,
360370
tileTokensDim, useShuffledMatrixA, weightLayout)) {
361371
auto const& gemm1PassingIndices = mPermuteGemm1.getPassingConfigIndices();
@@ -375,8 +385,8 @@ Runner::Runner(btg::Dtype dtypeAct, btg::Dtype dtypeWeights, bool useDeepSeekFp8
375385

376386
Runner::Runner(btg::Dtype dtypeElt, bool useDeepSeekFp8, int32_t tileTokensDim,
377387
bool useShuffledMatrixA, batchedGemm::gemm::MatrixLayout weightLayout)
378-
: Runner(dtypeElt, dtypeElt, useDeepSeekFp8, tileTokensDim, ActType::SwiGlu, useShuffledMatrixA,
379-
weightLayout) {}
388+
: Runner(dtypeElt, dtypeElt, useDeepSeekFp8, tileTokensDim, GatedActType::SwiGlu,
389+
useShuffledMatrixA, weightLayout) {}
380390

381391
void Runner::setOpsData(MoERunnerArgs const& args, MoEWorkspace const& workspace,
382392
moe::dev::convertsf::Data& convertSfData,

β€Žflashinfer/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
from .fp8_quantization import mxfp8_dequantize_host, mxfp8_quantize
6565
from .fused_moe import (
6666
RoutingMethodType,
67+
GatedActType,
6768
cutlass_fused_moe,
6869
reorder_rows_for_gated_act_gemm,
6970
trtllm_fp4_block_scale_moe,

0 commit comments

Comments
Β (0)