From d3bdfd3ab9bac6bf1a88f717869bf9c06683d4b4 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Tue, 13 Aug 2024 14:57:45 -0400 Subject: [PATCH] [Misc] Update Fused MoE weight loading (#7334) --- vllm/model_executor/layers/fused_moe/layer.py | 316 ++++++++++-------- .../model_executor/layers/quantization/fp8.py | 141 ++++---- vllm/model_executor/models/deepseek_v2.py | 2 +- vllm/model_executor/models/jamba.py | 2 +- vllm/model_executor/models/mixtral.py | 2 +- vllm/model_executor/models/qwen2_moe.py | 2 +- 6 files changed, 264 insertions(+), 201 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index a0dc4c94744a8..4e29ab701b937 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -24,15 +24,9 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, raise NotImplementedError @abstractmethod - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool = True, - use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, - topk_group: Optional[int] = None) -> torch.Tensor: + def apply(self, layer: torch.nn.Module, x: torch.Tensor, + router_logits: torch.Tensor, top_k: int, renormalize: bool, + use_grouped_topk: bool) -> torch.Tensor: raise NotImplementedError @@ -61,66 +55,78 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool = True, - use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, - topk_group: Optional[int] = None, - ) -> torch.Tensor: - return self.forward(x, layer.w13_weight, layer.w2_weight, - router_logits, top_k, renormalize, - use_grouped_topk, num_expert_group, topk_group) - - def forward_cuda( - self, - x: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool, - num_expert_group: Optional[int], - topk_group: Optional[int], - ) -> torch.Tensor: - from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe - return fused_moe(x, - w1, - w2, - router_logits, - top_k, - renormalize=renormalize, - inplace=True, - use_grouped_topk=use_grouped_topk, - num_expert_group=num_expert_group, - topk_group=topk_group) + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None) -> torch.Tensor: + + return self.forward(x=x, + layer=layer, + router_logits=router_logits, + top_k=top_k, + renormalize=renormalize, + use_grouped_topk=use_grouped_topk, + topk_group=topk_group, + num_expert_group=num_expert_group) + + def forward_cuda(self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None) -> torch.Tensor: + + from vllm.model_executor.layers.fused_moe.fused_moe import ( + fused_experts) + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group) + + return fused_experts(hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True) def forward_cpu(self, *args, **kwargs): raise NotImplementedError( "The CPU backend currently does not support MoE.") - def forward_tpu( - self, - x: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool, - num_expert_group: Optional[int], - topk_group: Optional[int], - ) -> torch.Tensor: + def forward_tpu(self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None) -> torch.Tensor: + from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe assert not use_grouped_topk assert num_expert_group is None assert topk_group is None - return fused_moe(x, w1, w2, router_logits, top_k, renormalize) + return fused_moe(hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk=top_k, + gating_output=router_logits, + renormalize=renormalize) class FusedMoE(torch.nn.Module): @@ -195,52 +201,83 @@ def __init__( def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, weight_name: str, - shard_id: int, expert_id: int): - param_data = param.data - - # Input scales can be loaded directly and should be equal. - if "input_scale" in weight_name: - if param_data[expert_id] != 1 and (param_data[expert_id] - - loaded_weight).abs() > 1e-5: - raise ValueError( - "input_scales of w1 and w3 of a layer " - f"must be equal. But got {param_data[expert_id]} " - f"vs. {loaded_weight}") - param_data[expert_id] = loaded_weight - # Weight scales - elif "weight_scale" in weight_name: - # If we are in merged column case (gate_up_proj) - # shard_id 0 == gate_proj / w1 - # shard_id 2 == up_proj / w3 - if shard_id == 0 or shard_id == 2: - # We have to keep the weight scales of w1 and w3 because - # we need to re-quantize w1/w3 weights after weight loading. - idx = 0 if shard_id == 0 else 1 - param_data[expert_id][idx] = loaded_weight - # If we are in the row parallel case (down_proj) - # shard_id 1 == down_proj / w2 - else: - param_data[expert_id] = loaded_weight - # Weights + shard_id: str, expert_id: int) -> None: + if shard_id not in ("w1", "w2", "w3"): + raise ValueError(f"shard_id must be ['w1','w2','w3'] but " + f"got {shard_id}.") + + # Special case for fp8 scales. + if getattr(param, "is_fp8_scale", False): + self._load_fp8_scale(param.data, loaded_weight, weight_name, + shard_id, expert_id) + return + + expert_data = param.data[expert_id] + tp_rank = get_tensor_model_parallel_rank() + + # If transposed, weight is saved as [input_dim, output_dim] + # Otherwise, weight is saved as [output_dim, input_dim] + # Default is not transposed/input dim is dim 1 + input_dim = getattr(param, "input_dim", 1) + output_dim = getattr(param, "output_dim", 0) + + # Index the loaded weight for tp sharding. + # down_proj: "RowParallel" so tp sharding on input_dim + if shard_id == "w2": + shard_dim = input_dim + shard_size = expert_data.shape[shard_dim] + # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim + elif shard_id in ("w1", "w3"): + shard_dim = output_dim + shard_size = expert_data.shape[output_dim] // 2 + offset = shard_size * tp_rank + loaded_weight = loaded_weight.narrow(shard_dim, offset, shard_size) + + # Narrow parameter and load. + # w1, gate_proj: Load into first logical weight of w13. + if shard_id == "w1": + expert_data = expert_data.narrow(shard_dim, 0, shard_size) + expert_data.copy_(loaded_weight) + # w3, up_proj: Load into second logical weight of w13. + elif shard_id == "w3": + expert_data = expert_data.narrow(shard_dim, shard_size, shard_size) + expert_data.copy_(loaded_weight) + # w2, down_proj: Load into only logical weight of w2. + elif shard_id == "w2": + expert_data.copy_(loaded_weight) else: - tp_rank = get_tensor_model_parallel_rank() - shard_size = self.intermediate_size_per_partition - shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) - - # w1, gate_proj case: Load into first shard of w13. - if shard_id == 0: - param_data[expert_id, - 0:shard_size, :] = loaded_weight[shard, :] - # w3, up_proj case: Load into second shard of w13. - elif shard_id == 2: - param_data[expert_id, shard_size:2 * - shard_size, :] = loaded_weight[shard, :] - # w2, down_proj case: Load into only shard of w2. - elif shard_id == 1: - param_data[expert_id, :, :] = loaded_weight[:, shard] - else: - raise ValueError( - f"Shard id must be in [0,1,2] but got {shard_id}") + raise ValueError( + f"Expected shard_id w1,w2 or w3 but got {shard_id}") + + @staticmethod + def select_experts(hidden_states: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + use_grouped_topk: bool, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None): + from vllm.model_executor.layers.fused_moe.fused_moe import ( + fused_topk, grouped_topk) + + # DeekSeekv2 uses grouped_top_k + if use_grouped_topk: + assert topk_group is not None + assert num_expert_group is not None + topk_weights, topk_ids = grouped_topk( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + num_expert_group=num_expert_group, + topk_group=topk_group) + else: + topk_weights, topk_ids = fused_topk(hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize) + + return topk_weights, topk_ids def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): @@ -248,14 +285,14 @@ def forward(self, hidden_states: torch.Tensor, # Matrix multiply. final_hidden_states = self.quant_method.apply( - self, + layer=self, x=hidden_states, router_logits=router_logits, top_k=self.top_k, renormalize=self.renormalize, use_grouped_topk=self.use_grouped_topk, - num_expert_group=self.num_expert_group, - topk_group=self.topk_group) + topk_group=self.topk_group, + num_expert_group=self.num_expert_group) if self.reduce_results and self.tp_size > 1: final_hidden_states = tensor_model_parallel_all_reduce( @@ -267,35 +304,42 @@ def forward(self, hidden_states: torch.Tensor, def make_expert_params_mapping( cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str, ckpt_up_proj_name: str, - num_experts: int) -> List[Tuple[str, str, int, int]]: - - gate_up = [ckpt_gate_proj_name, ckpt_up_proj_name] - gate_down_up = [ - ckpt_gate_proj_name, ckpt_down_proj_name, ckpt_up_proj_name - ] + num_experts: int) -> List[Tuple[str, str, int, str]]: return [ - # These are the weight scales for the experts - # (param_name, weight_name, expert_id, shard_id) - ("experts.w13_scale" - if weight_name in gate_up else "experts.w2_scale", - f"experts.{expert_id}.{weight_name}.weight_scale", expert_id, - shard_id) for expert_id in range(num_experts) - for shard_id, weight_name in enumerate(gate_down_up) - ] + [ - # These are the weights for the experts # (param_name, weight_name, expert_id, shard_id) - ("experts.w13_weight" - if weight_name in gate_up else "experts.w2_weight", - f"experts.{expert_id}.{weight_name}.weight", expert_id, shard_id) - for expert_id in range(num_experts) - for shard_id, weight_name in enumerate(gate_down_up) - ] + [ - # These are the weight scales for the experts - # (param_name, weight_name, expert_id, shard_id) - ("experts.a13_scale" - if weight_name in gate_up else "experts.a2_scale", - f"experts.{expert_id}.{weight_name}.input_scale", expert_id, - shard_id) for expert_id in range(num_experts) - for shard_id, weight_name in enumerate(gate_down_up) + ("experts.w13_" if weight_name + in [ckpt_gate_proj_name, ckpt_up_proj_name] else "experts.w2_", + f"experts.{expert_id}.{weight_name}.", expert_id, shard_id) + for expert_id in range(num_experts) for shard_id, weight_name in [ + ("w1", ckpt_gate_proj_name), + ("w2", ckpt_down_proj_name), + ("w3", ckpt_up_proj_name), + ] ] + + def _load_fp8_scale(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor, weight_name: str, + shard_id: str, expert_id: int) -> None: + param_data = param.data + + # Input scales can be loaded directly and should be equal. + if "input_scale" in weight_name: + if param_data[expert_id] != 1 and (param_data[expert_id] - + loaded_weight).abs() > 1e-5: + raise ValueError( + "input_scales of w1 and w3 of a layer " + f"must be equal. But got {param_data[expert_id]} " + f"vs. {loaded_weight}") + param_data[expert_id] = loaded_weight + # Weight scales + elif "weight_scale" in weight_name: + # If we are in merged column case (gate_up_proj) + if shard_id in ("w1", "w3"): + # We have to keep the weight scales of w1 and w3 because + # we need to re-quantize w1/w3 weights after weight loading. + idx = 0 if shard_id == "w1" else 1 + param_data[expert_id][idx] = loaded_weight + # If we are in the row parallel case (down_proj) + else: + param_data[expert_id] = loaded_weight diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index cdd2413f5b2c4..8b8cf41cdfb3d 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -290,23 +290,29 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int, # WEIGHT_SCALES # Allocate 2 scales for w1 and w3 respectively. # They will be combined to a single scale after weight loading. - w13_scale = torch.nn.Parameter(torch.ones(num_experts, - 2, - dtype=torch.float32), - requires_grad=False) - layer.register_parameter("w13_scale", w13_scale) + w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts, + 2, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_weight_scale", w13_weight_scale) - w2_scale = torch.nn.Parameter(torch.ones(num_experts, - dtype=torch.float32), - requires_grad=False) - layer.register_parameter("w2_scale", w2_scale) + w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_weight_scale", w2_weight_scale) # If loading fp8 checkpoint, pass the weight loaders. # If loading an fp16 checkpoint, do not (we will quantize in # process_weights_after_loading() if self.quant_config.is_checkpoint_fp8_serialized: - set_weight_attrs(w13_scale, extra_weight_attrs) - set_weight_attrs(w2_scale, extra_weight_attrs) + set_weight_attrs(w13_weight_scale, { + "is_fp8_scale": True, + **extra_weight_attrs + }) + set_weight_attrs(w2_weight_scale, { + "is_fp8_scale": True, + **extra_weight_attrs + }) # INPUT_SCALES if self.quant_config.activation_scheme == "static": @@ -315,20 +321,26 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int, "Found static activation scheme for checkpoint that " "was not serialized fp8.") - a13_scale = torch.nn.Parameter(torch.ones(num_experts, - dtype=torch.float32), - requires_grad=False) - layer.register_parameter("a13_scale", a13_scale) - set_weight_attrs(a13_scale, extra_weight_attrs) - - a2_scale = torch.nn.Parameter(torch.ones(num_experts, - dtype=torch.float32), - requires_grad=False) - layer.register_parameter("a2_scale", a2_scale) - set_weight_attrs(a2_scale, extra_weight_attrs) + w13_input_scale = torch.nn.Parameter(torch.ones( + num_experts, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_input_scale", w13_input_scale) + set_weight_attrs(w13_input_scale, { + "is_fp8_scale": True, + **extra_weight_attrs + }) + + w2_input_scale = torch.nn.Parameter(torch.ones( + num_experts, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_input_scale", w2_input_scale) + set_weight_attrs(w2_input_scale, { + "is_fp8_scale": True, + **extra_weight_attrs + }) else: - layer.a13_scale = None - layer.a2_scale = None + layer.w13_input_scale = None + layer.w2_input_scale = None def process_weights_after_loading(self, layer: Module) -> None: @@ -341,16 +353,16 @@ def process_weights_after_loading(self, layer: Module) -> None: # Re-initialize w13_scale because we directly quantize # merged w13 weights and generate a single scaling factor. - layer.w13_scale = torch.nn.Parameter(torch.ones( + layer.w13_weight_scale = torch.nn.Parameter(torch.ones( layer.num_experts, dtype=torch.float32, device=w13_weight.device), - requires_grad=False) + requires_grad=False) for expert in range(layer.num_experts): - w13_weight[expert, :, :], layer.w13_scale[ + w13_weight[expert, :, :], layer.w13_weight_scale[ expert] = ops.scaled_fp8_quant( layer.w13_weight.data[expert, :, :]) - w2_weight[expert, :, :], layer.w2_scale[ + w2_weight[expert, :, :], layer.w2_weight_scale[ expert] = ops.scaled_fp8_quant( layer.w2_weight.data[expert, :, :]) layer.w13_weight = torch.nn.Parameter(w13_weight, @@ -366,40 +378,41 @@ def process_weights_after_loading(self, layer: Module) -> None: # Fp8 moe kernels require a single activation scale. # We take the max of all the scales in case they differ. if self.quant_config.activation_scheme == "static": - if layer.a13_scale is None or layer.a2_scale is None: + if (layer.w13_input_scale is None + or layer.w2_input_scale is None): raise ValueError( "QuantConfig has static quantization, but found " "activation scales are None.") - if (not all_close_1d(layer.a13_scale) - or not all_close_1d(layer.a2_scale)): + if (not all_close_1d(layer.w13_input_scale) + or not all_close_1d(layer.w2_input_scale)): print_warning_once( "Found input_scales that are not equal for " "fp8 MoE layer. Using the maximum across experts " "for each layer. ") - layer.a13_scale = torch.nn.Parameter(layer.a13_scale.max(), - requires_grad=False) - layer.a2_scale = torch.nn.Parameter(layer.a2_scale.max(), - requires_grad=False) + layer.w13_input_scale = torch.nn.Parameter( + layer.w13_input_scale.max(), requires_grad=False) + layer.w2_input_scale = torch.nn.Parameter( + layer.w2_input_scale.max(), requires_grad=False) # Fp8 moe kernel needs single weight scale for w13 per expert. # We take the max then dequant and requant each expert. - assert layer.w13_scale is not None + assert layer.w13_weight_scale is not None shard_size = layer.intermediate_size_per_partition - max_w13_scales = layer.w13_scale.max(dim=1).values + max_w13_scales = layer.w13_weight_scale.max(dim=1).values for expert_id in range(layer.num_experts): start = 0 for shard_id in range(2): dq_weight = per_tensor_dequantize( layer.w13_weight[expert_id][start:start + shard_size, :], - layer.w13_scale[expert_id][shard_id]) + layer.w13_weight_scale[expert_id][shard_id]) layer.w13_weight[expert_id][ start:start + shard_size, :], _ = ops.scaled_fp8_quant( dq_weight, max_w13_scales[expert_id]) start += shard_size - layer.w13_scale = torch.nn.Parameter(max_w13_scales, - requires_grad=False) + layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, + requires_grad=False) return def apply(self, @@ -407,27 +420,33 @@ def apply(self, x: torch.Tensor, router_logits: torch.Tensor, top_k: int, - renormalize: bool = True, - use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, - topk_group: Optional[int] = None) -> torch.Tensor: - - from vllm.model_executor.layers.fused_moe import fused_moe - return fused_moe(x, - layer.w13_weight, - layer.w2_weight, - router_logits, - top_k, - renormalize=renormalize, - inplace=True, - use_fp8=True, - w1_scale=layer.w13_scale, - w2_scale=layer.w2_scale, - a1_scale=layer.a13_scale, - a2_scale=layer.a2_scale, - use_grouped_topk=use_grouped_topk, - num_expert_group=num_expert_group, - topk_group=topk_group) + renormalize: bool, + use_grouped_topk: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None) -> torch.Tensor: + + from vllm.model_executor.layers.fused_moe import fused_experts + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group) + + return fused_experts(x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + use_fp8=True, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale) class Fp8KVCacheMethod(BaseKVCacheMethod): diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 1ac15cefb5e29..c7f3af0ccb266 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -593,7 +593,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = param.weight_loader weight_loader(param, loaded_weight, - weight_name, + name, shard_id=shard_id, expert_id=expert_id) break diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 6296cd502b1e1..dd4d63661a692 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -930,7 +930,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = param.weight_loader weight_loader(param, loaded_weight, - weight_name, + name, shard_id=shard_id, expert_id=expert_id) break diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 34c21350dbc60..587d2f26a2d5e 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -455,7 +455,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = param.weight_loader weight_loader(param, loaded_weight, - weight_name, + name, shard_id=shard_id, expert_id=expert_id) break diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index b85512095622f..e160c9a320820 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -492,7 +492,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = param.weight_loader weight_loader(param, loaded_weight, - weight_name, + name, shard_id=shard_id, expert_id=expert_id) break