Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion vllm/model_executor/layers/fused_moe/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,11 @@ def use_int4_w4a16(self) -> bool:

@property
def use_mxfp4_w4a4(self) -> bool:
return self.quant_dtype == "mxfp4"
return (self._a1.dtype == "mxfp4" and self._w1.dtype == "mxfp4")

@property
def use_mxfp4_w4a16(self) -> bool:
return (self._a1.dtype is None and self._w1.dtype == "mxfp4")

@property
def use_nvfp4_w4a4(self) -> bool:
Expand Down Expand Up @@ -453,6 +457,22 @@ def int8_w8a8_moe_quant_config(
)


def mxfp4_w4a16_moe_quant_config(
w1_scale: Union[torch.Tensor, "PrecisionConfig"],
w2_scale: Union[torch.Tensor, "PrecisionConfig"],
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None) -> FusedMoEQuantConfig:
"""
Construct a quant config for unquantized activations and mxfp4 weights.
"""
return FusedMoEQuantConfig(
_a1=FusedMoEQuantDesc(),
_a2=FusedMoEQuantDesc(),
_w1=FusedMoEQuantDesc("mxfp4", None, w1_scale, None, None, w1_bias),
_w2=FusedMoEQuantDesc("mxfp4", None, w2_scale, None, None, w2_bias),
)


def mxfp4_w4a4_moe_quant_config(
w1_scale: Union[torch.Tensor, "PrecisionConfig"],
w2_scale: Union[torch.Tensor, "PrecisionConfig"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,31 @@
TopKWeightAndReduceContiguous, TopKWeightAndReduceDelegate)
from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input)
from vllm.utils import round_up


class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
"""
Prepare/Finalize using DeepEP High-Throughput kernels.
"""

@staticmethod
def maybe_roundup_layer_hidden_size(hidden_size: int,
dtype: torch.dtype) -> int:
# Round up hidden size so it is compatible with DeepEP High Throughput
# kernels.
# DeepEP intranode kernels make copies in units of,
# 32(warp-size) int4 elements. Round up hidden size to respect this.
# For example, an input hidden size of 2880 with dtype torch.bfloat16
# will be rounded up to 3072.
hidden_size_bytes = hidden_size * dtype.itemsize
xfer_atom_size = 512 # 32 * 16 (size(int4))
if hidden_size_bytes % xfer_atom_size == 0:
return hidden_size

hidden_size_bytes = round_up(hidden_size_bytes, xfer_atom_size)
return hidden_size_bytes // dtype.itemsize

def __init__(self, buffer: deep_ep.Buffer, num_dispatchers: int,
dp_size: int, rank_expert_offset: int):
super().__init__()
Expand Down
174 changes: 140 additions & 34 deletions vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEQuantConfig)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate)
TopKWeightAndReduceNoOP)
from vllm.triton_utils import tl, triton
from vllm.utils import has_triton_kernels

logger = init_logger(__name__)
Expand All @@ -19,13 +20,55 @@
import triton_kernels.swiglu
from triton_kernels.matmul_ogs import (FnSpecs, FusedActivation,
matmul_ogs)
from triton_kernels.routing import routing
from triton_kernels.routing import (RoutingData, routing,
routing_from_bitmatrix)
from triton_kernels.tensor import Bitmatrix
except (ModuleNotFoundError, AttributeError) as e:
logger.error(
"Failed to import Triton kernels. Please make sure your triton "
"version is compatible. Error: %s", e)


@triton.jit
def pack_bitmatrix(
bitmatrix,
topk_ids,
n_rows, # n_rows in bitmatrix / topk_ids
bm_cols: tl.constexpr, # n int32_t bitpacks in bitmatrix
n_expts_act, # num_topk
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
):
"""
Packs topk_ids into a bitmatrix.
code reference:
https://github.com/triton-lang/triton/blob/dd1bbc52b34d202dfe5ffea1e04fb16166c5c04e/python/triton_kernels/bench/distributed.py#L264
"""
pid_m = tl.program_id(0)
offsets_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offsets_k = tl.arange(0, BLOCK_SIZE_K)
offsets = offsets_m[:, None] * n_expts_act + offsets_k[None, :]
mask = (offsets_m < n_rows)[:, None] & (offsets_k < n_expts_act)[None, :]
indices = tl.load(topk_ids + offsets, mask=mask, other=-1)
div = indices // 32
rem = indices % 32
one = tl.cast(1, tl.uint32)

# Iterate through all the relevant bitmatrix columns.
for i in range(bm_cols):
# When BLOCK_SIZE_K=32, offs is just the column index.
offs = tl.arange(0, BLOCK_SIZE_K // 32) + i * (BLOCK_SIZE_K // 32)
# All topks that need to go into this column has the correct bit set.
# Other bits are 0. x is a 2D tensor.
x = tl.where(div[:, :, None] == offs[None, None, :],
(one << rem)[:, :, None], 0)
# Reduce x to get a single int32_t bitpack.
y = tl.reduce_or(x, axis=1)
bitmatrix_ptrs = bitmatrix + offsets_m[:,
None] * bm_cols + offs[None, :]
tl.store(bitmatrix_ptrs, y, mask=offsets_m[:, None] < n_rows)


def triton_kernel_moe_forward(
hidden_states: torch.Tensor,
w1, # Tensor or triton_kernels.Tensor
Expand Down Expand Up @@ -124,48 +167,99 @@ def triton_kernel_fused_experts(
return intermediate_cache3


class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
def make_routing_data(
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
num_local_experts: int,
) -> tuple["RoutingData", torch.Tensor, torch.Tensor]:

topk_ids = topk_ids.to(torch.int16)
topk_weights = topk_weights.to(torch.bfloat16)

n_rows, num_topk = topk_ids.size()

BLOCK_SIZE_M = 512
BLOCK_SIZE_K = 32

bm_cols = triton.cdiv(num_local_experts, BLOCK_SIZE_K) # n_bitpacks
bitmatrix = torch.zeros((n_rows, bm_cols),
dtype=torch.uint32,
device=topk_ids.device)

grid = (triton.cdiv(n_rows, BLOCK_SIZE_M), )
pack_bitmatrix[grid](
bitmatrix,
topk_ids,
n_rows,
bm_cols,
num_topk,
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_K=BLOCK_SIZE_K,
)

bitmatrix_shape = [n_rows, bm_cols * 32]
bitmatrix_shape_max = [n_rows, None]
bitmatrix = Bitmatrix(bitmatrix,
shape=bitmatrix_shape,
shape_max=bitmatrix_shape_max,
scratchpad=None)

# matmul_ogs expects invalid topk_weights to be -1s
topk_weights = torch.where(topk_ids == -1, -1.0, topk_weights)
routing_data, gather_indx, scatter_indx = routing_from_bitmatrix(
bitmatrix, topk_weights, topk_ids, num_local_experts, num_topk)

return routing_data, gather_indx, scatter_indx


class BaseOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):

def __init__(self, quant_config: FusedMoEQuantConfig):
super().__init__(quant_config)

def supports_expert_map(self) -> bool:
return True

def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Weight application and reduction happens in the fused_experts kernel.
return TopKWeightAndReduceNoOP()

def __init__(
def _make_routing_data(
self,
max_num_tokens: int,
num_dispatchers: int,
quant_config: FusedMoEQuantConfig,
):
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
num_local_experts: int,
) -> tuple["RoutingData", torch.Tensor, torch.Tensor]:
return make_routing_data(topk_ids, topk_weights, num_local_experts)


class OAITritonExperts(BaseOAITritonExperts):

def __init__(self, quant_config: FusedMoEQuantConfig):
# TODO (varun) : Enable activation quantization
assert quant_config.use_mxfp4_w4a16, "Supports only mxfp4_w4a16"
super().__init__(quant_config)
self.max_num_tokens = max_num_tokens
self.num_dispatchers = num_dispatchers

@property
def activation_formats(
self
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
return (mk.FusedMoEActivationFormat.BatchedExperts,
mk.FusedMoEActivationFormat.BatchedExperts)
return (mk.FusedMoEActivationFormat.Standard,
mk.FusedMoEActivationFormat.Standard)

def supports_chunking(self) -> bool:
return False

def supports_expert_map(self) -> bool:
return False

def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl.
return TopKWeightAndReduceDelegate()
return True

def workspace_shapes(
self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int,
topk: int, global_num_experts: int, local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata]
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
# workspace are allocated inside the kernel
assert a.dim() == 2
num_dp = self.num_dispatchers
num_experts = local_num_experts
max_num_tokens = self.max_num_tokens
workspace2 = (0, 0, 0)
output = (num_experts, max_num_tokens * num_dp, N)
return (output, workspace2, output, a.dtype)
workspace1 = (M, K)
workspace2 = (0, 0)
output = (M, K)
return (workspace1, workspace2, output, a.dtype)

def apply(
self,
Expand All @@ -185,17 +279,29 @@ def apply(
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
):
return triton_kernel_fused_experts(
output,
if expert_map is not None:
topk_ids = expert_map[topk_ids]

local_num_experts = w1.size(0)
if global_num_experts == -1:
global_num_experts = local_num_experts

routing_data, gather_indx, scatter_indx = self._make_routing_data(
topk_ids, topk_weights, local_num_experts)

experts_output = triton_kernel_fused_experts(
None,
hidden_states,
w1,
w2,
routing_data=None,
gather_indx=None,
scatter_indx=None,
routing_data,
gather_indx,
scatter_indx,
activation=activation,
quant_config=self.quant_config,
apply_router_weight_on_input=False,
global_num_experts=global_num_experts,
expert_map=expert_map,
global_num_experts=local_num_experts,
expert_map=None, # applied already
a1q_scale=a1q_scale)

output.copy_(experts_output, non_blocking=True)
Loading