From 77891b8ae6d233dd16c501220f49b3dc9fec955c Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Sat, 8 Jun 2024 13:54:05 -0400 Subject: [PATCH] [Misc][Breaking] Change FP8 checkpoint format from act_scale -> input_scale (#5353) --- .../model_executor/layers/quantization/fp8.py | 30 +++++++++---------- vllm/model_executor/models/mixtral.py | 16 +++++----- 2 files changed, 23 insertions(+), 23 deletions(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index de94bad7c38e6..0cf2bd927a800 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -171,10 +171,10 @@ def create_weights( output_partition_sizes=output_partition_sizes, **extra_weight_attrs) - # ACTIVATION SCALE + # INPUT ACTIVATION SCALE if self.quant_config.activation_scheme == "static": self._create_scale_param( - scale_name="act_scale", + scale_name="input_scale", layer=layer, output_partition_sizes=output_partition_sizes, **extra_weight_attrs) @@ -207,7 +207,7 @@ def process_weights_after_loading(self, layer: Module) -> None: layer.weight = Parameter(qweight.t(), requires_grad=False) layer.weight_scale = Parameter(weight_scale, requires_grad=False) layer.logical_widths = None - layer.act_scale = None + layer.input_scale = None return # If checkpoint is fp8, requantize the separately quantized logical @@ -232,18 +232,18 @@ def process_weights_after_loading(self, layer: Module) -> None: weight = layer.weight layer.weight = Parameter(weight.t(), requires_grad=False) - # ACT_SCALE + # INPUT ACTIVATION SCALE # Dynamic: set to None (required input to ops.scaled_fp8_quant). - # Static: set to max of the act_scales (since they are equal). + # Static: set to max of the input_scales (since they are equal). if self.quant_config.activation_scheme == "dynamic": - layer.act_scale = None + layer.input_scale = None elif self.quant_config.activation_scheme == "static": - if not all_close_1d(layer.act_scale): + if not all_close_1d(layer.input_scale): raise ValueError( - "All the act_scales for the logical weights of a layer " - f"must be equal. But got {layer.act_scale}") - layer.act_scale = Parameter(layer.act_scale.max(), - requires_grad=False) + "All the input_scales for the logical weights of a " + f"layer must be equal. But got {layer.input_scale}") + layer.input_scale = Parameter(layer.input_scale.max(), + requires_grad=False) else: raise ValueError( f"Unknown scheme {self.quant_config.activation_scheme}") @@ -254,11 +254,11 @@ def apply(self, bias: Optional[torch.Tensor] = None) -> torch.Tensor: # ops.scaled_fp8_quant supports both dynamic and static quant. - # If dynamic, layer.act_scale is None and x_scale computed from x. - # If static, layer.act_scale is scalar and x_scale set to act_scale. + # If dynamic, layer.input_scale is None and x_scale computed from x. + # If static, layer.input_scale is scalar and x_scale is input_scale. if bias is None and self.cutlass_fp8_supported: - qinput, x_scale = ops.scaled_fp8_quant(x, layer.act_scale) + qinput, x_scale = ops.scaled_fp8_quant(x, layer.input_scale) # Fused GEMM_DQ output = ops.cutlass_scaled_mm_dq( @@ -271,7 +271,7 @@ def apply(self, else: qinput, x_scale = ops.scaled_fp8_quant(x, - layer.act_scale, + layer.input_scale, batch_dim_padding=17) # Fused GEMM_DQ -- note we padded the input above because diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 0f82549780ba4..3faf54d292b99 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -147,7 +147,7 @@ def __init__( "weight_loader": self.weight_loader, }) - # ACT_SCALE (for fp8) + # INPUT_SCALE (for fp8) if quant_config.activation_scheme == "static": if not quant_config.is_checkpoint_fp8_serialized: raise ValueError( @@ -182,11 +182,11 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, param_data[expert_id, :, :] = loaded_weight[:, shard] # Loading scales - if "act_scale" in weight_name or "w2.weight_scale" in weight_name: + if "input_scale" in weight_name or "w2.weight_scale" in weight_name: if param_data[expert_id] != 1 and (param_data[expert_id] - loaded_weight).abs() > 1e-5: raise ValueError( - "act_scales of w1 and w3 of a layer " + "input_scales of w1 and w3 of a layer " f"must be equal. But got {param_data[expert_id]} " f"vs. {loaded_weight}") param_data[expert_id] = loaded_weight @@ -225,9 +225,9 @@ def process_weights_after_loading(self): self.w2_weight = nn.Parameter(w2_weight, requires_grad=False) else: - # If checkpoint is fp8 + static, cleanup act_scales. - # Since state_dict has an act_scale per expert but our kernels - # are passed one act_scale shared across all experts. + # If checkpoint is fp8 + static, cleanup input_scales. + # Since state_dict has an input_scale per expert but our kernels + # are passed one input_scale shared across all experts. if self.quant_config.activation_scheme == "static": if self.a13_scale is None or self.a2_scale is None: raise ValueError( @@ -237,7 +237,7 @@ def process_weights_after_loading(self): if (not all_close_1d(self.a13_scale) or not all_close_1d(self.a2_scale)): print_warning_once( - "Found act_scales that are not equal for " + "Found input_scales that are not equal for " "fp8 MoE layer. Using the maximum across experts " "for each layer. ") @@ -576,7 +576,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # These are the activation scales for the experts # (param_name, weight_name, expert_id) ("a13_scale" if weight_name in ["w1", "w3"] else "a2_scale", - f"experts.{expert_id}.{weight_name}.act_scale", expert_id) + f"experts.{expert_id}.{weight_name}.input_scale", expert_id) for expert_id in range(self.config.num_local_experts) for weight_name in ["w1", "w2", "w3"] ]