Skip to content

feat: integrate deepgemm into EPMoE #5805

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 212 additions & 0 deletions python/sglang/srt/layers/moe/ep_moe/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,3 +707,215 @@ def grouped_gemm_triton(
**config,
)
return c


@triton.jit
def compute_masked_m_triton_kernel(seg_indptr, masked_m, num_experts, N):
expert_id = tl.program_id(0)
start = tl.load(seg_indptr + expert_id)
end = tl.load(seg_indptr + expert_id + 1)
tl.store(masked_m + expert_id, (end - start))


@triton.jit
def deepgemm_compute_src2dst_triton_kernel(
topk_ids,
reorder_ids,
seg_indptr,
src2dst,
max_m,
num_toks,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = dst_id < num_toks
src_id = tl.load(reorder_ids + dst_id, mask=mask)
expert_id = tl.load(topk_ids + src_id, mask=(src_id < num_toks))
expert_dst_start = tl.load(seg_indptr + expert_id)
expert_dst_offset = dst_id - expert_dst_start
dst_id = expert_id * max_m + expert_dst_offset
tl.store(src2dst + src_id, dst_id, mask=mask)


@triton.jit
def fill_gateup_input_triton_kernel(
input_ptr,
scale_ptr,
gateup_input_ptr,
gateup_input_scale_ptr,
src2dst_ptr,
topk_ids_ptr,
start_expert_id,
end_expert_id,
topk,
m_max,
hidden_size,
scale_size,
BLOCK_SIZE: tl.constexpr,
):

src_idx = tl.program_id(0)
src2dst_ptr = src2dst_ptr + src_idx * topk
topk_ids_ptr = topk_ids_ptr + src_idx * topk
src_ptr = input_ptr + src_idx * hidden_size
scale_src_ptr = scale_ptr + src_idx * scale_size

for idx in range(topk):
expert_id = tl.load(topk_ids_ptr + idx)
if expert_id >= start_expert_id and expert_id <= end_expert_id:
dst_idx = tl.load(src2dst_ptr + idx)
dst_idx = dst_idx - start_expert_id * m_max
dst_ptr = gateup_input_ptr + dst_idx * hidden_size
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
offset = start_offset + tl.arange(0, BLOCK_SIZE)
mask = offset < hidden_size
in_data = tl.load(src_ptr + offset, mask=mask)
tl.store(dst_ptr + offset, in_data, mask=mask)
scale_dst_ptr = gateup_input_scale_ptr + dst_idx * scale_size
for start_offset in tl.range(0, scale_size, BLOCK_SIZE):
offset = start_offset + tl.arange(0, BLOCK_SIZE)
mask = offset < scale_size
in_scale = tl.load(scale_src_ptr + offset, mask=mask)
tl.store(scale_dst_ptr + offset, in_scale, mask=mask)


def exp2_upper(num: int) -> int:
for i in range(2, 31):
value = pow(2, i)
if num <= value:
return value
return num


def moe_ep_deepgemm_preproess(
topk_ids: torch.Tensor,
num_experts: int,
hidden_states: torch.Tensor,
top_k: int,
start_expert_id,
end_expert_id,
block_shape,
output_dtype: torch.dtype = torch.float8_e4m3fn,
):
reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32)
masked_m = torch.zeros(num_experts, device=topk_ids.device, dtype=torch.int32)

compute_seg_indptr_triton_kernel[(num_experts,)](
reorder_topk_ids, seg_indptr, topk_ids.numel()
)

grid = lambda meta: (triton.cdiv(topk_ids.numel(), meta["BLOCK_SIZE"]),)
compute_masked_m_triton_kernel[(num_experts,)](
seg_indptr, masked_m, num_experts, reorder_topk_ids.numel()
)

m_max = exp2_upper(hidden_states.size(0))
expected_m = (topk_ids.numel() + num_experts - 1) // num_experts
gateup_input = torch.empty(
(int(end_expert_id - start_expert_id + 1), m_max, hidden_states.size(1)),
device=hidden_states.device,
dtype=output_dtype,
)

deepgemm_compute_src2dst_triton_kernel[grid](
topk_ids,
reorder_ids,
seg_indptr,
src2dst,
m_max,
topk_ids.numel(),
BLOCK_SIZE=256,
)

if block_shape is not None:
assert len(block_shape) == 2
block_n, block_k = block_shape[0], block_shape[1]
if _is_cuda:
hidden_states, scale = sglang_per_token_group_quant_fp8(
hidden_states, block_k
)
else:
hidden_states, scale = per_token_group_quant_fp8(hidden_states, block_k)

gateup_input_scale = torch.empty(
(gateup_input.size(0), gateup_input.size(1), scale.size(1)),
device=hidden_states.device,
dtype=scale.dtype,
)

fill_gateup_input_triton_kernel[(hidden_states.shape[0],)](
hidden_states,
scale,
gateup_input,
gateup_input_scale,
src2dst,
topk_ids,
start_expert_id,
end_expert_id,
top_k,
m_max,
hidden_states.size(1),
scale.size(1),
BLOCK_SIZE=1024,
)

return (
m_max,
masked_m[start_expert_id : (end_expert_id + 1)],
expected_m,
src2dst,
gateup_input,
gateup_input_scale,
)


@triton.jit
def deepgemm_post_reorder_triton_kernel(
down_output_ptr,
output_ptr,
src2dst_ptr,
topk_ids_ptr,
topk_weights_ptr,
start_expert_id,
end_expert_id,
topk,
hidden_size,
max_m,
BLOCK_SIZE: tl.constexpr,
):
InDtype = down_output_ptr.dtype.element_ty

src_idx = tl.program_id(0)
src2dst_ptr = src2dst_ptr + src_idx * topk
topk_ids_ptr = topk_ids_ptr + src_idx * topk
topk_weights_ptr = topk_weights_ptr + src_idx * topk

computed = False
store_ptr = output_ptr + src_idx * hidden_size
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
offset = start_offset + tl.arange(0, BLOCK_SIZE)
mask = offset < hidden_size

sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype)
for idx in range(topk):
expert_id = tl.load(topk_ids_ptr + idx)
if expert_id >= start_expert_id and expert_id <= end_expert_id:
computed = True
dst_idx = tl.load(src2dst_ptr + idx)
dst_idx = dst_idx - start_expert_id * max_m
weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype)
load_ptr = down_output_ptr + dst_idx * hidden_size
in_data = tl.load(load_ptr + offset, mask=mask)
sum_vec += in_data * weigh_scale
tl.store(store_ptr + offset, sum_vec, mask=mask)

if computed == False:
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
offset = start_offset + tl.arange(0, BLOCK_SIZE)
mask = offset < hidden_size
tl.store(
store_ptr + offset, tl.zeros([BLOCK_SIZE], dtype=InDtype), mask=mask
)
134 changes: 133 additions & 1 deletion python/sglang/srt/layers/moe/ep_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.moe.ep_moe.kernels import (
deepgemm_post_reorder_triton_kernel,
gelu_and_mul_triton_kernel,
grouped_gemm_triton,
moe_ep_deepgemm_preproess,
post_reorder_triton_kernel,
pre_reorder_triton_kernel,
run_moe_ep_preproess,
Expand All @@ -35,10 +37,17 @@
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.utils import DeepEPMode, is_hip, set_weight_attrs
from sglang.srt.utils import (
DeepEPMode,
get_bool_env_var,
is_cuda,
is_hip,
set_weight_attrs,
)

_is_hip = is_hip()

Expand Down Expand Up @@ -198,7 +207,130 @@ def __init__(

self.grouped_gemm_runner = None

self.w13_weight_fp8 = (
self.w13_weight,
(
self.w13_weight_scale_inv
if self.use_block_quant
else self.w13_weight_scale
),
)
self.w2_weight_fp8 = (
self.w2_weight,
self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
)

def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
if use_deep_gemm and _ENABLE_JIT_DEEPGEMM:
return self.forward_deepgemm(hidden_states, router_logits)
else:
return self.forward_normal(hidden_states, router_logits)

def forward_deepgemm(
self, hidden_states: torch.Tensor, router_logits: torch.Tensor
):
assert self.quant_method is not None

topk_weights, topk_ids = select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
use_grouped_topk=self.use_grouped_topk,
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
correction_bias=self.correction_bias,
custom_routing_function=self.custom_routing_function,
routed_scaling_factor=self.routed_scaling_factor,
)

# PreReorder
m_max, masked_m, expected_m, src2dst, gateup_input, gateup_input_scale = (
moe_ep_deepgemm_preproess(
topk_ids,
self.num_experts,
hidden_states,
self.top_k,
self.start_expert_id,
self.end_expert_id,
self.block_shape,
)
)
gateup_input_fp8 = (
gateup_input,
get_col_major_tma_aligned_tensor(gateup_input_scale),
)

# GroupGemm-0
num_groups, m, k = gateup_input_fp8[0].size()
n = self.w13_weight.size(1)
gateup_output = torch.empty(
(num_groups, m, n), device=hidden_states.device, dtype=torch.bfloat16
)

m_grouped_gemm_fp8_fp8_bf16_nt_masked(
gateup_input_fp8, self.w13_weight_fp8, gateup_output, masked_m, expected_m
)

# Act
down_input = torch.empty(
(
gateup_output.shape[0],
gateup_output.shape[1],
gateup_output.shape[2] // 2,
),
device=gateup_output.device,
dtype=self.fp8_dtype,
)
scale_block_size = 128
down_input_scale = torch.empty(
(
gateup_output.shape[0],
gateup_output.shape[1],
gateup_output.shape[2] // 2 // scale_block_size,
),
device=gateup_output.device,
dtype=torch.float32,
)
silu_and_mul_masked_post_quant_fwd(
gateup_output,
down_input,
down_input_scale,
scale_block_size,
masked_m,
)

# GroupGemm-1
n = self.w2_weight.size(1)
down_input_fp8 = (
down_input,
get_col_major_tma_aligned_tensor(down_input_scale),
)
down_output = torch.empty(
(num_groups, m, n), device=down_input.device, dtype=torch.bfloat16
)
m_grouped_gemm_fp8_fp8_bf16_nt_masked(
down_input_fp8, self.w2_weight_fp8, down_output, masked_m, expected_m
)

# PostReorder
output = torch.empty_like(hidden_states)
deepgemm_post_reorder_triton_kernel[(hidden_states.size(0),)](
down_output,
output,
src2dst,
topk_ids,
topk_weights,
self.start_expert_id,
self.end_expert_id,
self.top_k,
hidden_states.size(1),
m_max,
BLOCK_SIZE=512,
)
return output

def forward_normal(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
assert self.quant_method is not None

if self.grouped_gemm_runner is None:
Expand Down