Skip to content

Commit fa914be

Browse files
dsikkaElizaWszolamgoin
authored andcommitted
[BugFix] Fix parameter names and process_after_weight_loading for W4A16 MoE Group Act Order (vllm-project#11528)
Signed-off-by: ElizaWszola <eliza@neuralmagic.com> Co-authored-by: ElizaWszola <eliza@neuralmagic.com> Co-authored-by: Michael Goin <michael@neuralmagic.com>
1 parent 00dbfa7 commit fa914be

File tree

8 files changed

+243
-148
lines changed

8 files changed

+243
-148
lines changed

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 60 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
3838

3939
@abstractmethod
4040
def create_weights(self, layer: torch.nn.Module, num_experts: int,
41-
hidden_size: int, intermediate_size: int,
41+
hidden_size: int, intermediate_size_per_partition: int,
4242
params_dtype: torch.dtype, **extra_weight_attrs):
4343
raise NotImplementedError
4444

@@ -65,22 +65,24 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
6565
"""MoE method without quantization."""
6666

6767
def create_weights(self, layer: torch.nn.Module, num_experts: int,
68-
hidden_size: int, intermediate_size: int,
68+
hidden_size: int, intermediate_size_per_partition: int,
6969
params_dtype: torch.dtype, **extra_weight_attrs):
7070
# Fused gate_up_proj (column parallel)
71-
w13_weight = torch.nn.Parameter(torch.empty(num_experts,
72-
2 * intermediate_size,
73-
hidden_size,
74-
dtype=params_dtype),
71+
w13_weight = torch.nn.Parameter(torch.empty(
72+
num_experts,
73+
2 * intermediate_size_per_partition,
74+
hidden_size,
75+
dtype=params_dtype),
7576
requires_grad=False)
7677
layer.register_parameter("w13_weight", w13_weight)
7778
set_weight_attrs(w13_weight, extra_weight_attrs)
7879

7980
# down_proj (row parallel)
80-
w2_weight = torch.nn.Parameter(torch.empty(num_experts,
81-
hidden_size,
82-
intermediate_size,
83-
dtype=params_dtype),
81+
w2_weight = torch.nn.Parameter(torch.empty(
82+
num_experts,
83+
hidden_size,
84+
intermediate_size_per_partition,
85+
dtype=params_dtype),
8486
requires_grad=False)
8587
layer.register_parameter("w2_weight", w2_weight)
8688
set_weight_attrs(w2_weight, extra_weight_attrs)
@@ -289,13 +291,20 @@ def __init__(
289291
self.quant_method = quant_config.get_quant_method(self, prefix)
290292
assert self.quant_method is not None
291293

292-
self.quant_method.create_weights(
293-
layer=self,
294-
num_experts=num_experts,
295-
hidden_size=hidden_size,
296-
intermediate_size=self.intermediate_size_per_partition,
297-
params_dtype=params_dtype,
298-
weight_loader=self.weight_loader)
294+
moe_quant_params = {
295+
"num_experts": num_experts,
296+
"hidden_size": hidden_size,
297+
"intermediate_size_per_partition":
298+
self.intermediate_size_per_partition,
299+
"params_dtype": params_dtype,
300+
"weight_loader": self.weight_loader,
301+
}
302+
# need full intermediate size pre-sharding for WNA16 act order
303+
if (self.quant_method.__class__.__name__ ==
304+
"CompressedTensorsWNA16MoEMethod"):
305+
moe_quant_params["intermediate_size_full"] = intermediate_size
306+
307+
self.quant_method.create_weights(layer=self, **moe_quant_params)
299308

300309
def _load_per_tensor_weight_scale(self, shard_id: str,
301310
param: torch.nn.Parameter,
@@ -312,19 +321,30 @@ def _load_per_tensor_weight_scale(self, shard_id: str,
312321
elif shard_id == "w2":
313322
param_data[expert_id] = loaded_weight
314323

315-
def _load_model_weight_or_group_weight_scale(self, shard_dim: int,
324+
def _load_model_weight_or_group_weight_scale(self,
325+
shard_dim: int,
316326
expert_data: torch.Tensor,
317327
shard_id: str,
318328
loaded_weight: torch.Tensor,
319-
tp_rank: int):
320-
# Load grouped weight scales for group quantization
321-
# or model weights
329+
tp_rank: int,
330+
load_full_w2: bool = False):
331+
"""
332+
Load grouped weight scales for group quantization or model weights
333+
:param shard_dim: dimension to shard
334+
:param expert_data: parameter for a particular expert
335+
:param shard_id: either w1, w2, or w3
336+
:param loaded_weight: checkpoint weight to load into the param
337+
:param tp_rank: tensor parallel rank
338+
:param load_full_w2: whether or not the w2 loaded should be sharded.
339+
"""
322340
if shard_id == "w2":
323-
self._load_w2(shard_id=shard_id,
324-
shard_dim=shard_dim,
341+
# In the case where we have actorder/g_idx, we do not partition the
342+
# w2 scales, as indicated by `load_full` argument, for all tp cases
343+
self._load_w2(shard_dim=shard_dim,
325344
loaded_weight=loaded_weight,
326345
expert_data=expert_data,
327-
tp_rank=tp_rank)
346+
tp_rank=tp_rank,
347+
load_full=load_full_w2)
328348
elif shard_id in ("w1", "w3"):
329349
self._load_w13(shard_id=shard_id,
330350
shard_dim=shard_dim,
@@ -364,15 +384,21 @@ def _load_w13(self, expert_data: torch.Tensor, shard_dim: int,
364384
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
365385
expert_data.copy_(loaded_weight)
366386

367-
def _load_w2(self, expert_data: torch.Tensor, shard_dim: int,
368-
shard_id: str, loaded_weight: torch.Tensor, tp_rank: int):
387+
def _load_w2(self,
388+
expert_data: torch.Tensor,
389+
shard_dim: int,
390+
loaded_weight: torch.Tensor,
391+
tp_rank: int,
392+
load_full: bool = False):
369393

370394
# Index the loaded weight for tp sharding.
371395
# down_proj: "RowParallel" so tp sharding on input_dim
372396
# Narrow parameter and load.
373397
shard_size = expert_data.shape[shard_dim]
374-
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
375-
shard_size)
398+
if not load_full:
399+
loaded_weight = loaded_weight.narrow(shard_dim,
400+
shard_size * tp_rank,
401+
shard_size)
376402
# w2, down_proj: Load into only logical weight of w2.
377403
expert_data.copy_(loaded_weight)
378404

@@ -387,8 +413,7 @@ def _load_g_idx(self, shard_id: str, expert_data: torch.Tensor,
387413
shard_dim: int, loaded_weight: torch.Tensor, tp_rank: int):
388414

389415
if shard_id == "w2":
390-
self._load_w2(shard_id=shard_id,
391-
shard_dim=shard_dim,
416+
self._load_w2(shard_dim=shard_dim,
392417
loaded_weight=loaded_weight,
393418
expert_data=expert_data,
394419
tp_rank=tp_rank)
@@ -416,19 +441,19 @@ def weight_loader(self, param: torch.nn.Parameter,
416441
]
417442
# Fetch the dim to shard the parameter/loaded weight
418443
# based on the shard id. This will be whatever
419-
# dimension intermediate_size is used.
444+
# dimension intermediate_size_per_partition is used.
420445
SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}
421446

422447
expert_data = param.data[expert_id]
423448
tp_rank = get_tensor_model_parallel_rank()
424449

425450
# is_transposed: if the dim to shard the weight
426451
# should be flipped. Required by GPTQ, compressed-tensors
427-
# should be whatever dimension intermediate_size is
452+
# should be whatever dimension intermediate_size_per_partition is
428453
is_transposed = getattr(param, "is_transposed", False)
429454
shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
430455
if is_transposed:
431-
shard_dim = ~shard_dim
456+
shard_dim = int(not shard_dim)
432457

433458
# Case input scale: input_scale loading is only supported for fp8
434459
if "input_scale" in weight_name:
@@ -480,7 +505,8 @@ def weight_loader(self, param: torch.nn.Parameter,
480505
shard_dim=shard_dim,
481506
loaded_weight=loaded_weight,
482507
expert_data=expert_data,
483-
tp_rank=tp_rank)
508+
tp_rank=tp_rank,
509+
load_full_w2=getattr(param, "load_full_w2", False))
484510
elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
485511
self._load_per_tensor_weight_scale(shard_id=shard_id,
486512
param=param,

vllm/model_executor/layers/quantization/awq_marlin.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ def __init__(self, quant_config: AWQMarlinConfig):
303303
self.quant_config = quant_config
304304

305305
def create_weights(self, layer: torch.nn.Module, num_experts: int,
306-
hidden_size: int, intermediate_size: int,
306+
hidden_size: int, intermediate_size_per_partition: int,
307307
params_dtype: torch.dtype, **extra_weight_attrs):
308308
extra_weight_attrs.update({
309309
"is_transposed":
@@ -312,17 +312,18 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
312312
FusedMoeWeightScaleSupported.GROUP.value,
313313
})
314314

315-
w13_qweight = Parameter(torch.empty(num_experts,
316-
hidden_size,
317-
2 * intermediate_size //
318-
self.quant_config.pack_factor,
319-
dtype=torch.int32),
320-
requires_grad=False)
315+
w13_qweight = Parameter(
316+
torch.empty(num_experts,
317+
hidden_size,
318+
2 * intermediate_size_per_partition //
319+
self.quant_config.pack_factor,
320+
dtype=torch.int32),
321+
requires_grad=False)
321322
layer.register_parameter("w13_qweight", w13_qweight)
322323
set_weight_attrs(w13_qweight, extra_weight_attrs)
323324

324325
w2_qweight = Parameter(torch.empty(num_experts,
325-
intermediate_size,
326+
intermediate_size_per_partition,
326327
hidden_size //
327328
self.quant_config.pack_factor,
328329
dtype=torch.int32),
@@ -331,13 +332,14 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
331332
set_weight_attrs(w2_qweight, extra_weight_attrs)
332333

333334
num_groups_w13 = hidden_size // self.quant_config.group_size
334-
num_groups_w2 = intermediate_size // self.quant_config.group_size
335+
num_groups_w2 = (intermediate_size_per_partition //
336+
self.quant_config.group_size)
335337

336338
# WEIGHT_SCALES
337339
# Allocate 2 scales for w1 and w3 respectively.
338340
w13_scales = Parameter(torch.empty(num_experts,
339341
num_groups_w13,
340-
intermediate_size * 2,
342+
intermediate_size_per_partition * 2,
341343
dtype=params_dtype),
342344
requires_grad=False)
343345
layer.register_parameter("w13_scales", w13_scales)
@@ -353,12 +355,13 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
353355

354356
# WEIGHT_ZERO_POINT
355357
# Allocate 2 zero points for w1 and w3 respectively.
356-
w13_qzeros = Parameter(torch.empty(num_experts,
357-
num_groups_w13,
358-
2 * intermediate_size //
359-
self.quant_config.pack_factor,
360-
dtype=torch.int32),
361-
requires_grad=False)
358+
w13_qzeros = Parameter(
359+
torch.empty(num_experts,
360+
num_groups_w13,
361+
2 * intermediate_size_per_partition //
362+
self.quant_config.pack_factor,
363+
dtype=torch.int32),
364+
requires_grad=False)
362365
layer.register_parameter("w13_qzeros", w13_qzeros)
363366
set_weight_attrs(w13_qzeros, extra_weight_attrs)
364367

0 commit comments

Comments
 (0)