Skip to content

Commit 0538dcc

Browse files
committed
update to comply with main
1 parent 7325e78 commit 0538dcc

File tree

4 files changed

+21
-29
lines changed

4 files changed

+21
-29
lines changed

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
2828

2929
@abstractmethod
3030
def apply(self, layer: torch.nn.Module, x: torch.Tensor,
31-
topk_weights: torch.Tensor,
32-
topk_ids: torch.Tensor) -> torch.Tensor:
31+
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
32+
**kwargs) -> torch.Tensor:
3333
raise NotImplementedError
3434

3535

@@ -59,21 +59,18 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
5959
set_weight_attrs(w2_weight, extra_weight_attrs)
6060

6161
def apply(self, layer: torch.nn.Module, x: torch.Tensor,
62-
topk_weights: torch.Tensor,
63-
topk_ids: torch.Tensor) -> torch.Tensor:
62+
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
63+
**kwargs) -> torch.Tensor:
6464

6565
return self.forward(x=x,
6666
layer=layer,
6767
topk_weights=topk_weights,
68-
topk_ids=topk_ids)
68+
topk_ids=topk_ids,
69+
**kwargs)
6970

70-
def forward_cuda(
71-
self,
72-
layer: torch.nn.Module,
73-
x: torch.Tensor,
74-
topk_weights: torch.Tensor,
75-
topk_ids: torch.Tensor,
76-
) -> torch.Tensor:
71+
def forward_cuda(self, layer: torch.nn.Module, x: torch.Tensor,
72+
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
73+
**kwargs) -> torch.Tensor:
7774
return fused_experts(hidden_states=x,
7875
w1=layer.w13_weight,
7976
w2=layer.w2_weight,
@@ -85,17 +82,11 @@ def forward_cpu(self, *args, **kwargs):
8582
raise NotImplementedError(
8683
"The CPU backend currently does not support MoE.")
8784

88-
def forward_tpu(
89-
self,
90-
layer: torch.nn.Module,
91-
x: torch.Tensor,
92-
topk_weights: torch.Tensor,
93-
topk_ids: torch.Tensor,
94-
) -> torch.Tensor:
95-
96-
#assert not use_grouped_topk
97-
#assert num_expert_group is None
98-
#assert topk_group is None
85+
def forward_tpu(self, layer: torch.nn.Module, x: torch.Tensor,
86+
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
87+
use_grouped_topk: bool) -> torch.Tensor:
88+
89+
assert not use_grouped_topk
9990
return fused_experts(hidden_states=x,
10091
w1=layer.w13_weight,
10192
w2=layer.w2_weight,
@@ -294,7 +285,8 @@ def forward(self, hidden_states: torch.Tensor,
294285
layer=self,
295286
x=hidden_states,
296287
topk_weights=topk_weights,
297-
topk_ids=topk_ids)
288+
topk_ids=topk_ids,
289+
use_grouped_topk=self.use_grouped_topk)
298290

299291
# Optionally reduce.
300292
if self.reduce_results and self.tp_size > 1:

vllm/model_executor/layers/quantization/awq.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -280,8 +280,8 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
280280
})
281281

282282
def apply(self, layer: torch.nn.Module, x: torch.Tensor,
283-
topk_weights: torch.Tensor,
284-
topk_ids: torch.Tensor) -> torch.Tensor:
283+
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
284+
**kwargs) -> torch.Tensor:
285285

286286
return fused_experts_awq(hidden_states=x,
287287
w1=layer.w13_qweight,

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -408,8 +408,8 @@ def process_weights_after_loading(self, layer: Module) -> None:
408408
return
409409

410410
def apply(self, layer: torch.nn.Module, x: torch.Tensor,
411-
topk_weights: torch.Tensor,
412-
topk_ids: torch.Tensor) -> torch.Tensor:
411+
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
412+
**kwargs) -> torch.Tensor:
413413

414414
return fused_experts(x,
415415
layer.w13_weight,

vllm/model_executor/models/deepseek_v2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def __init__(
115115
hidden_size=config.hidden_size,
116116
intermediate_size=config.moe_intermediate_size,
117117
reduce_results=False,
118-
renormalize=False,
118+
renormalize=True,
119119
quant_config=quant_config,
120120
use_grouped_topk=True,
121121
num_expert_group=config.n_group,

0 commit comments

Comments
 (0)