Skip to content

Commit

Permalink
Deepseek V3 fix
Browse files Browse the repository at this point in the history
  • Loading branch information
root committed Jan 24, 2025
1 parent c5a9406 commit 682cb2b
Showing 1 changed file with 19 additions and 20 deletions.
39 changes: 19 additions & 20 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,26 +33,26 @@ def fused_moe_kernel(
expert_ids_ptr,
num_tokens_post_padded_ptr,
# Matrix dimensions
N,
K,
EM,
num_valid_tokens,
N: tl.int64,
K: tl.int64,
EM: tl.int64,
num_valid_tokens: tl.int64,
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_am,
stride_ak,
stride_be,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_asm,
stride_ask,
stride_bse,
stride_bsk,
stride_bsn,
stride_am: tl.int64,
stride_ak: tl.int64,
stride_be: tl.int64,
stride_bk: tl.int64,
stride_bn: tl.int64,
stride_cm: tl.int64,
stride_cn: tl.int64,
stride_asm: tl.int64,
stride_ask: tl.int64,
stride_bse: tl.int64,
stride_bsk: tl.int64,
stride_bsn: tl.int64,
# Block size for block-wise quantization
group_n: tl.constexpr,
group_k: tl.constexpr,
Expand Down Expand Up @@ -114,18 +114,17 @@ def fused_moe_kernel(
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
return
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(
tl.int64)
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
token_mask = offs_token < num_valid_tokens

offs_bn = (pid_n * BLOCK_SIZE_N +
tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
offs_k[None, :] * stride_ak)

off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
off_experts = tl.load(expert_ids_ptr + pid_m)
b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk +
offs_bn[None, :] * stride_bn)
if use_int8_w8a16:
Expand Down

0 comments on commit 682cb2b

Please sign in to comment.