-
-
Notifications
You must be signed in to change notification settings - Fork 2
[cuda]Add d-qd mxfp8 support #46
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: yi/nvfp4-moe
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -294,7 +294,8 @@ def _is_fp8_w8a8(self, weight_quant: BaseModel, | |
| is_symmetric_weight = weight_quant.symmetric | ||
| is_static_weight = not weight_quant.dynamic | ||
| is_per_tensor_or_channel_weight = (weight_quant.strategy in [ | ||
| QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL | ||
| QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL, | ||
| QuantizationStrategy.TENSOR_GROUP | ||
| ]) | ||
| if not (is_floating_point and is_symmetric_weight and is_static_weight | ||
| and is_per_tensor_or_channel_weight): | ||
|
|
@@ -310,6 +311,40 @@ def _is_fp8_w8a8(self, weight_quant: BaseModel, | |
| input_quant.strategy == QuantizationStrategy.TENSOR) | ||
| return is_symmetric_activation and is_per_tensor_activation | ||
|
|
||
| def _is_mxfp8_w8a8(self, weight_quant: BaseModel, | ||
| input_quant: BaseModel) -> bool: | ||
| # FIXME: (Yi) enhance check | ||
| # Confirm weights and activations quantized. | ||
| if weight_quant is None or input_quant is None: | ||
| return False | ||
|
|
||
| # Confirm weight scheme is supported. | ||
| is_floating_point = (weight_quant.type == QuantizationType.FLOAT | ||
| and input_quant.type == QuantizationType.FLOAT) | ||
| is_symmetric_weight = weight_quant.symmetric | ||
| is_static_weight = not weight_quant.dynamic | ||
| is_per_tensor_group_weight = (weight_quant.strategy in [ | ||
|
|
||
| QuantizationStrategy.TENSOR_GROUP | ||
| ]) | ||
| if not ( | ||
| is_floating_point | ||
| and is_symmetric_weight | ||
| and is_static_weight | ||
| and is_per_tensor_group_weight | ||
| ): | ||
| return False | ||
|
|
||
| # Dynamic quantization is always supported if weights supported. | ||
| if input_quant.dynamic: | ||
| return True | ||
|
|
||
| # Confirm activation scheme is supported. | ||
| is_symmetric_activation = input_quant.symmetric | ||
| is_per_tensor_activation = ( | ||
| input_quant.strategy == QuantizationStrategy.TENSOR) | ||
| return is_symmetric_activation and is_per_tensor_activation | ||
|
|
||
| def _is_fp8_w8a8_sm90(self, weight_quant: BaseModel, | ||
| input_quant: BaseModel) -> bool: | ||
| return (self._check_scheme_supported(90, error=False, match_exact=True) | ||
|
|
@@ -351,7 +386,7 @@ def _is_wNa16_group_channel(self, weight_quant: BaseModel, | |
| def _get_scheme_from_parts( | ||
| self, weight_quant: BaseModel, | ||
| input_quant: BaseModel) -> "CompressedTensorsScheme": | ||
|
|
||
| # breakpoint() | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| # Detect If Mixed Precision | ||
| if self._is_fp4a16_nvfp4(weight_quant, input_quant): | ||
| return CompressedTensorsW4A16Fp4() | ||
|
|
@@ -385,9 +420,8 @@ def _get_scheme_from_parts( | |
| return CompressedTensorsW4A16Fp4( | ||
| has_input_global_scale=True) | ||
|
|
||
| if self._is_fp8_w8a8(weight_quant, input_quant): | ||
| is_fp8_w8a8_supported = self._check_scheme_supported( | ||
| CompressedTensorsW8A8Fp8.get_min_capability(), error=False) | ||
| if self._is_fp8_w8a8(weight_quant, input_quant=input_quant): | ||
| is_fp8_w8a8_supported = self._check_scheme_supported(CompressedTensorsW8A8Fp8.get_min_capability(), error=False) | ||
| if is_fp8_w8a8_supported: | ||
| return CompressedTensorsW8A8Fp8( | ||
| strategy=weight_quant.strategy, | ||
|
|
@@ -400,6 +434,18 @@ def _get_scheme_from_parts( | |
| strategy=weight_quant.strategy, | ||
| is_static_input_scheme=not input_quant.dynamic) | ||
|
|
||
| if self._is_mxfp8_w8a8(weight_quant, input_quant): | ||
| from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( | ||
| CompressedTensorsW8A8MXFp8, | ||
| ) | ||
|
Comment on lines
+438
to
+440
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| return CompressedTensorsW8A8MXFp8( | ||
| strategy=weight_quant.strategy, | ||
| is_static_input_scheme=( | ||
| input_quant and not input_quant.dynamic | ||
| ), | ||
| ) | ||
|
|
||
| # note: input_quant can be None | ||
| if self._is_fp8_w8a16(weight_quant, input_quant): | ||
| is_static_input_scheme = (input_quant | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,180 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # SPDX-License-Identifier: Apache-2.0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from typing import Callable, Optional | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from compressed_tensors.quantization import QuantizationStrategy | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from torch.nn import Parameter | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| CompressedTensorsScheme, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Fp8LinearOp, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| maybe_create_device_identity, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| normalize_e4m3fn_to_e4m3fnuz, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| requantize_with_max_scale, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from vllm.model_executor.parameter import ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ChannelQuantScaleParameter, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ModelWeightParameter, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| PerTensorScaleParameter, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from vllm.model_executor.parameter import ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| GroupQuantScaleParameter, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ModelWeightParameter, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| PerTensorScaleParameter, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+19
to
+28
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from vllm.platforms import current_platform | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| __all__ = ["CompressedTensorsW8A8MXFp8"] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def get_fp_scale(scale_e8m0): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # https://github.com/pytorch/ao/blob/994a4ba6c869854fcaa6ca7e118fcbd75e6c28cc/torchao/prototype/mx_formats/mx_tensor.py#L337 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from torchao.prototype.mx_formats.constants import ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| E8M0_EXPONENT_BIAS, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| E8M0_EXPONENT_NAN_VAL, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| scale_e8m0 = scale_e8m0.view(torch.uint8) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| s_offset = scale_e8m0.to(torch.int16) - E8M0_EXPONENT_BIAS | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # TODO(later): it would be nice if there was a way to do the 2^x operation | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # in PyTorch without creating a tensor of twos | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| two = torch.full(s_offset.size(), 2.0, device=scale_e8m0.device) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # pow(two, s_offset) can be out of range of floating point formats. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # TODO(later): handle this for float16 if we decide to support float16 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # scales. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| s_fp = torch.pow(two, s_offset) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return s_fp | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def dequant_mx_fp8(weight_fp8, scale_e8m0, block_size): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| scale_float = get_fp_scale(scale_e8m0) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| weight_bf16 = weight_fp8.to(torch.bfloat16) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| weight_original_shape = weight_bf16.shape | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| weight_bf16 = weight_bf16.reshape(-1, block_size) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| scale_float = scale_float.reshape(-1, 1) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dequant_weight = weight_bf16 * scale_float | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dequant_weight = dequant_weight.reshape(weight_original_shape) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return dequant_weight | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def quant_mx_fp8(tensor): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from torchao.prototype.mx_formats.mx_tensor import ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| to_mx, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ScaleCalculationMode, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+36
to
+68
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| scale_e8m0_biased, data_lp = to_mx( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| data_hp=tensor, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| elem_dtype=torch.float8_e4m3fn, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| block_size=32, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| scaling_mode=ScaleCalculationMode.FLOOR, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| pack_fp6=False, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return scale_e8m0_biased, data_lp | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| class CompressedTensorsW8A8MXFp8(CompressedTensorsScheme): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def __init__(self, strategy: str, is_static_input_scheme: bool): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.strategy = strategy | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.out_dtype = torch.get_default_dtype() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.is_static_input_scheme = is_static_input_scheme | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.group_size = 32 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @classmethod | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def get_min_capability(cls) -> int: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # lovelace and up | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # FIXME: (Yi) correct the minimum capability | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return 80 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def process_weights_after_loading(self, layer) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def create_weights(self, layer: torch.nn.Module, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| output_partition_sizes: list[int], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| input_size_per_partition: int, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| params_dtype: torch.dtype, weight_loader: Callable, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| **kwargs): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # maybe_create_device_identity() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| output_size_per_partition = sum(output_partition_sizes) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| layer.logical_widths = output_partition_sizes | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # WEIGHT | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| weight = ModelWeightParameter(data=torch.empty( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| output_size_per_partition, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| input_size_per_partition, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dtype=torch.float8_e4m3fn), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| input_dim=1, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| output_dim=0, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| weight_loader=weight_loader) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| layer.register_parameter("weight", weight) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # WEIGHT SCALE | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # TODO: update create_xxx_parameter functions to return | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # the newly added parameters | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if self.strategy == QuantizationStrategy.TENSOR_GROUP: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Per Group Weight Scale | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| weight_scale = GroupQuantScaleParameter( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| data=torch.empty( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| sum(output_partition_sizes), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| input_size_per_partition // self.group_size, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dtype=torch.uint8, # E8M0 for MXFP8 scale | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| input_dim=1, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| output_dim=0, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| weight_loader=weight_loader, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| raise NotImplementedError(f"Strategy {self.strategy} is not supported for W8A8-MXFp8") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # min requirement for fp8 kernels | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # weight_scale[:] = torch.finfo(torch.float32).min | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # weight_scale.fill_(torch.finfo(torch.float32).min) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+138
to
+139
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| layer.register_parameter("weight_scale", weight_scale) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # INPUT SCALE | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if self.is_static_input_scheme: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| input_scale = PerTensorScaleParameter(data=torch.empty( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| len(output_partition_sizes), dtype=torch.float32), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| weight_loader=weight_loader) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| input_scale[:] = torch.finfo(torch.float32).min | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| layer.register_parameter("input_scale", input_scale) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def apply_weights(self, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| layer: torch.nn.Module, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| x: torch.Tensor, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| bias: Optional[torch.Tensor] = None) -> torch.Tensor: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # dequant weight | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| weight = layer.weight | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| weight_scale = layer.weight_scale | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dequnat_weight = dequant_mx_fp8( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| weight_fp8=weight.data, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| scale_e8m0=weight_scale.data, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| block_size=self.group_size | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dequnat_weight = dequnat_weight.to(x.dtype) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # q-dq input | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| x_scale, x_quant = quant_mx_fp8(x) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dequant_x = dequant_mx_fp8( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| weight_fp8=x_quant, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| scale_e8m0=x_scale, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| block_size=self.group_size | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| x = dequant_x.to(x.dtype) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| out = x @ dequnat_weight.t() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+158
to
+172
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's a recurring typo
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return out.to(x.dtype) + (bias if bias is not None else 0) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return self.fp8_linear.apply(input=x, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| weight=layer.weight, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| weight_scale=layer.weight_scale, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| out_dtype=self.out_dtype, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| input_scale=layer.input_scale, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| bias=bias) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This
FIXMEsuggests the check is incomplete. Could you please add more details on what needs to be enhanced, or address it if possible? LeavingFIXMEs without context can lead to technical debt.