Skip to content

Commit 0ba00ab

Browse files
committed
PR comments
1 parent 0538dcc commit 0ba00ab

File tree

5 files changed

+27
-23
lines changed

5 files changed

+27
-23
lines changed

vllm/model_executor/layers/fused_moe/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818

1919
__all__ += [
2020
"fused_moe",
21-
"fused_topk",
2221
"fused_experts",
22+
"fused_topk",
2323
"get_config_file_name",
2424
"grouped_topk",
2525
]

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -442,8 +442,8 @@ def fused_experts(hidden_states: torch.Tensor,
442442
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
443443
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
444444
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
445-
#assert w1.is_contiguous(), "Expert weights1 must be contiguous"
446-
#assert w2.is_contiguous(), "Expert weights2 must be contiguous"
445+
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
446+
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
447447
assert hidden_states.dtype in [
448448
torch.float32, torch.float16, torch.bfloat16
449449
]

vllm/model_executor/layers/fused_moe/fused_moe_awq.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33

44
from vllm import _custom_ops as ops
55
from vllm.logger import init_logger
6-
7-
from .fused_moe import fused_experts, moe_align_block_size
6+
from vllm.model_executor.layers.fused_moe.fused_moe import (
7+
fused_experts, moe_align_block_size)
88

99
logger = init_logger(__name__)
1010

@@ -43,12 +43,11 @@ def fused_experts_awq(
4343
# If large seq_len prefill, dequantize and use the fp16 MoE kernel.
4444
do_naive_dequant = hidden_states.shape[:-1].numel() >= NAIVE_THRESHOLD
4545
if do_naive_dequant:
46-
# TODO: why is this not contiguous already?
47-
# from @dsikka: because of the permutation operation
46+
# NOTE: not contiguous because of the permutation operation
4847
dequant_w1 = ops.awq_dequantize(w1, w1_scales, w1_qzeros, 0, 0,
49-
0).permute(0, 2, 1)
48+
0).permute(0, 2, 1).contiguous()
5049
dequant_w2 = ops.awq_dequantize(w2, w2_scales, w2_qzeros, 0, 0,
51-
0).permute(0, 2, 1)
50+
0).permute(0, 2, 1).contiguous()
5251

5352
return fused_experts(hidden_states, dequant_w1, dequant_w2,
5453
topk_weights, topk_ids)

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -225,9 +225,9 @@ def weight_loader(self, param: torch.nn.Parameter,
225225

226226
# If transposed, weight is saved as [input_dim, output_dim]
227227
# Otherwise, weight is saved as [output_dim, input_dim]
228-
is_transposed = getattr(param, "is_transposed", False)
229-
input_dim = 0 if is_transposed else 1
230-
output_dim = 1 if is_transposed else 0
228+
# Default is not transposed/input dim is dim 1
229+
input_dim = getattr(param, "input_dim", 1)
230+
output_dim = getattr(param, "output_dim", 0)
231231

232232
# Index the loaded weight for tp sharding.
233233
# down_proj: "RowParallel" so tp sharding on input_dim

vllm/model_executor/layers/quantization/awq.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict, List, Optional, Union
1+
from typing import Any, Dict, List, Optional
22

33
import torch
44
from torch.nn.parameter import Parameter
@@ -8,7 +8,7 @@
88
fused_experts_awq)
99
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
1010
from vllm.model_executor.layers.quantization.base_config import (
11-
QuantizationConfig)
11+
QuantizationConfig, QuantizeMethodBase)
1212
from vllm.model_executor.utils import set_weight_attrs
1313

1414

@@ -65,9 +65,8 @@ def from_config(cls, config: Dict[str, Any]) -> "AWQConfig":
6565
zero_point = cls.get_from_keys(config, ["zero_point"])
6666
return cls(weight_bits, group_size, zero_point)
6767

68-
def get_quant_method(
69-
self, layer: torch.nn.Module,
70-
prefix: str) -> Optional[Union["AWQMoEMethod", "AWQLinearMethod"]]:
68+
def get_quant_method(self, layer: torch.nn.Module,
69+
prefix: str) -> Optional["QuantizeMethodBase"]:
7170
if isinstance(layer, LinearBase):
7271
return AWQLinearMethod(self)
7372
elif isinstance(layer, FusedMoE):
@@ -202,7 +201,8 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
202201
w13_qweight, {
203202
"packed_dim": 1,
204203
"pack_factor": self.quant_config.pack_factor,
205-
"is_transposed": True,
204+
"input_dim": 0,
205+
"output_dim": 1,
206206
**extra_weight_attrs
207207
})
208208

@@ -217,7 +217,8 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
217217
w2_qweight, {
218218
"packed_dim": 1,
219219
"pack_factor": self.quant_config.pack_factor,
220-
"is_transposed": True,
220+
"input_dim": 0,
221+
"output_dim": 1,
221222
**extra_weight_attrs
222223
})
223224

@@ -231,7 +232,8 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
231232
requires_grad=False)
232233
layer.register_parameter("w13_scales", w13_scales)
233234
set_weight_attrs(w13_scales, {
234-
"is_transposed": True,
235+
"input_dim": 0,
236+
"output_dim": 1,
235237
**extra_weight_attrs
236238
})
237239

@@ -243,7 +245,8 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
243245
requires_grad=False)
244246
layer.register_parameter("w2_scales", w2_scales)
245247
set_weight_attrs(w2_scales, {
246-
"is_transposed": True,
248+
"input_dim": 0,
249+
"output_dim": 1,
247250
**extra_weight_attrs
248251
})
249252

@@ -260,7 +263,8 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
260263
w13_qzeros, {
261264
"packed_dim": 1,
262265
"pack_factor": self.quant_config.pack_factor,
263-
"is_transposed": True,
266+
"input_dim": 0,
267+
"output_dim": 1,
264268
**extra_weight_attrs
265269
})
266270

@@ -275,7 +279,8 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
275279
w2_qzeros, {
276280
"packed_dim": 1,
277281
"pack_factor": self.quant_config.pack_factor,
278-
"is_transposed": True,
282+
"input_dim": 0,
283+
"output_dim": 1,
279284
**extra_weight_attrs
280285
})
281286

0 commit comments

Comments
 (0)