Skip to content

Commit 0bb745f

Browse files
esmeetuzwd003
authored andcommitted
fix up
1 parent 56f7d45 commit 0bb745f

File tree

5 files changed

+12
-222
lines changed

5 files changed

+12
-222
lines changed

csrc/moe_align_block_size_kernels.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ __global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids,
1919
int32_t num_experts,
2020
int32_t block_size,
2121
size_t numel) {
22-
const size_t tokens_per_thread = ((numel + blockDim.x - 1) / blockDim.x);
22+
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
2323
const size_t start_idx = threadIdx.x * tokens_per_thread;
2424
__shared__ int32_t tokens_cnts[NUM_MAX_EXPERTS + 1][NUM_MAX_EXPERTS];
2525
__shared__ int32_t cumsum[NUM_MAX_EXPERTS + 1];

vllm/model_executor/layers/fused_moe.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,6 @@ def fused_moe_kernel(
3030
stride_bn,
3131
stride_cm,
3232
stride_cn,
33-
stride_weight,
34-
stride_token_id,
3533
# Meta-parameters
3634
BLOCK_SIZE_M: tl.constexpr,
3735
BLOCK_SIZE_N: tl.constexpr,
@@ -112,7 +110,7 @@ def fused_moe_kernel(
112110
b_ptrs += BLOCK_SIZE_K * stride_bk
113111

114112
if MUL_ROUTED_WEIGHT:
115-
moe_weight = tl.load(topk_weights_ptr + offs_token * stride_weight,
113+
moe_weight = tl.load(topk_weights_ptr + offs_token,
116114
mask=token_mask,
117115
other=0)
118116
accumulator = accumulator * moe_weight[:, None]
@@ -178,6 +176,10 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
178176
expert_ids: torch.Tensor,
179177
num_tokens_post_padded: torch.Tensor,
180178
mul_routed_weight: bool, top_k: int, config: dict):
179+
180+
assert topk_weights.stride(1) == 1
181+
assert sorted_token_ids.stride(0) == 1
182+
181183
grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[
182184
'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), )
183185

@@ -200,8 +202,6 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
200202
B.stride(1),
201203
C.stride(1),
202204
C.stride(2),
203-
topk_weights.stride(1),
204-
sorted_token_ids.stride(0),
205205
MUL_ROUTED_WEIGHT=mul_routed_weight,
206206
top_k=top_k,
207207
compute_type=tl.bfloat16 if A.dtype == torch.bfloat16 else tl.float16,

vllm/model_executor/models/deepseek.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# Adapted from
33
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
44
# Copyright 2023 The vLLM team.
5-
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
5+
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
66
#
77
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
88
# and OPT implementations in this library. It has been modified from its
@@ -22,12 +22,12 @@
2222
# limitations under the License.
2323
"""Inference-only Deepseek model."""
2424
from typing import Any, Dict, List, Optional, Tuple
25+
from transformers import PretrainedConfig
2526

2627
import torch
2728
import torch.nn.functional as F
2829

2930
from torch import nn
30-
from vllm.transformers_utils.configs.deepseek import DeepseekConfig
3131

3232
from vllm.model_executor.input_metadata import InputMetadata
3333
from vllm.model_executor.layers.activation import SiluAndMul
@@ -91,7 +91,7 @@ class DeepseekMoE(nn.Module):
9191

9292
def __init__(
9393
self,
94-
config: DeepseekConfig,
94+
config: PretrainedConfig,
9595
linear_method: Optional[LinearMethodBase] = None,
9696
):
9797
super().__init__()
@@ -264,7 +264,7 @@ class DeepseekDecoderLayer(nn.Module):
264264

265265
def __init__(
266266
self,
267-
config: DeepseekConfig,
267+
config: PretrainedConfig,
268268
layer_idx: int,
269269
linear_method: Optional[LinearMethodBase] = None,
270270
) -> None:
@@ -331,7 +331,7 @@ class DeepseekModel(nn.Module):
331331

332332
def __init__(
333333
self,
334-
config: DeepseekConfig,
334+
config: PretrainedConfig,
335335
linear_method: Optional[LinearMethodBase] = None,
336336
) -> None:
337337
super().__init__()
@@ -372,7 +372,7 @@ class DeepseekForCausalLM(nn.Module):
372372

373373
def __init__(
374374
self,
375-
config: DeepseekConfig,
375+
config: PretrainedConfig,
376376
linear_method: Optional[LinearMethodBase] = None,
377377
) -> None:
378378
super().__init__()

vllm/transformers_utils/configs/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from vllm.transformers_utils.configs.aquila import AquilaConfig
22
from vllm.transformers_utils.configs.baichuan import BaiChuanConfig
33
from vllm.transformers_utils.configs.chatglm import ChatGLMConfig
4-
from vllm.transformers_utils.configs.deepseek import DeepseekConfig
54
from vllm.transformers_utils.configs.mpt import MPTConfig
65
from vllm.transformers_utils.configs.qwen import QWenConfig
76
# RWConfig is for the original tiiuae/falcon-40b(-instruct) and
@@ -14,7 +13,6 @@
1413
"AquilaConfig",
1514
"BaiChuanConfig",
1615
"ChatGLMConfig",
17-
"DeepseekConfig",
1816
"MPTConfig",
1917
"QWenConfig",
2018
"RWConfig",

vllm/transformers_utils/configs/deepseek.py

Lines changed: 0 additions & 208 deletions
This file was deleted.

0 commit comments

Comments
 (0)