Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Fix parameter names and process_after_weight_loading for W4A16 MoE Group Act Order #11528

Merged
merged 8 commits into from
Jan 23, 2025
Prev Previous commit
Next Next commit
Don't partition w2 when we use group quantization
Signed-off-by: ElizaWszola <eliza@neuralmagic.com>
  • Loading branch information
ElizaWszola authored and dsikka committed Jan 17, 2025
commit ddfac980600c287ec7b6501a7f0132944e902616
33 changes: 20 additions & 13 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,8 @@ def __init__(
hidden_size=hidden_size,
intermediate_size=self.intermediate_size_per_partition,
params_dtype=params_dtype,
weight_loader=self.weight_loader)
weight_loader=self.weight_loader,
intermediate_full=intermediate_size)

def _load_per_tensor_weight_scale(self, shard_id: str,
param: torch.nn.Parameter,
Expand All @@ -312,19 +313,19 @@ def _load_per_tensor_weight_scale(self, shard_id: str,
elif shard_id == "w2":
param_data[expert_id] = loaded_weight

def _load_model_weight_or_group_weight_scale(self, shard_dim: int,
expert_data: torch.Tensor,
shard_id: str,
loaded_weight: torch.Tensor,
tp_rank: int):
def _load_model_weight_or_group_weight_scale(
self, shard_dim: int, expert_data: torch.Tensor, shard_id: str,
loaded_weight: torch.Tensor, tp_rank: int, load_full_w2: bool):
# Load grouped weight scales for group quantization
# or model weights
# In act_order scenario, we need to load full w2 scales
if shard_id == "w2":
self._load_w2(shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
tp_rank=tp_rank,
load_full=load_full_w2)
elif shard_id in ("w1", "w3"):
self._load_w13(shard_id=shard_id,
shard_dim=shard_dim,
Expand Down Expand Up @@ -365,14 +366,17 @@ def _load_w13(self, expert_data: torch.Tensor, shard_dim: int,
expert_data.copy_(loaded_weight)

def _load_w2(self, expert_data: torch.Tensor, shard_dim: int,
shard_id: str, loaded_weight: torch.Tensor, tp_rank: int):
shard_id: str, loaded_weight: torch.Tensor, tp_rank: int,
load_full: bool):

# Index the loaded weight for tp sharding.
# down_proj: "RowParallel" so tp sharding on input_dim
# Narrow parameter and load.
shard_size = expert_data.shape[shard_dim]
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
shard_size)
if not load_full:
loaded_weight = loaded_weight.narrow(shard_dim,
shard_size * tp_rank,
shard_size)
# w2, down_proj: Load into only logical weight of w2.
expert_data.copy_(loaded_weight)

Expand All @@ -391,7 +395,8 @@ def _load_g_idx(self, shard_id: str, expert_data: torch.Tensor,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
tp_rank=tp_rank,
load_full=False)
else:
assert shard_id in ("w1", "w3")
expert_data.copy_(loaded_weight)
Expand Down Expand Up @@ -480,7 +485,8 @@ def weight_loader(self, param: torch.nn.Parameter,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
tp_rank=tp_rank,
load_full_w2=True)
elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
self._load_per_tensor_weight_scale(shard_id=shard_id,
param=param,
Expand All @@ -506,7 +512,8 @@ def weight_loader(self, param: torch.nn.Parameter,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
tp_rank=tp_rank,
load_full_w2=False)
return

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,8 @@ def __init__(

def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size: int,
params_dtype: torch.dtype, **extra_weight_attrs):
intermediate_full: int, params_dtype: torch.dtype,
**extra_weight_attrs):

# Will transpose the loaded weight along the
# intermediate and hidden dim sizes. Will
Expand All @@ -296,11 +297,15 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
layer.register_parameter("w2_weight_packed", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)

self.is_k_full = (intermediate_full == intermediate_size)
scales_size = (intermediate_full if self.actorder
and self.group_size != -1 else intermediate_size)

if self.strategy == "channel":
num_groups_w2 = num_groups_w13 = 1
self.group_size = -1
else:
num_groups_w2 = intermediate_size // self.group_size
num_groups_w2 = scales_size // self.group_size
num_groups_w13 = hidden_size // self.group_size

w13_scale = torch.nn.Parameter(torch.ones(num_experts,
Expand Down Expand Up @@ -547,4 +552,5 @@ def apply(
sort_indices1=layer.w13_g_idx_sort_indices,
sort_indices2=layer.w2_g_idx_sort_indices,
num_bits=self.num_bits,
is_k_full=self.is_k_full,
)