Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,8 @@ def _is_fp8_w8a8(self, weight_quant: BaseModel,
is_symmetric_weight = weight_quant.symmetric
is_static_weight = not weight_quant.dynamic
is_per_tensor_or_channel_weight = (weight_quant.strategy in [
QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL
QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL,
QuantizationStrategy.TENSOR_GROUP
])
if not (is_floating_point and is_symmetric_weight and is_static_weight
and is_per_tensor_or_channel_weight):
Expand All @@ -310,6 +311,40 @@ def _is_fp8_w8a8(self, weight_quant: BaseModel,
input_quant.strategy == QuantizationStrategy.TENSOR)
return is_symmetric_activation and is_per_tensor_activation

def _is_mxfp8_w8a8(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
# FIXME: (Yi) enhance check

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This FIXME suggests the check is incomplete. Could you please add more details on what needs to be enhanced, or address it if possible? Leaving FIXMEs without context can lead to technical debt.

# Confirm weights and activations quantized.
if weight_quant is None or input_quant is None:
return False

# Confirm weight scheme is supported.
is_floating_point = (weight_quant.type == QuantizationType.FLOAT
and input_quant.type == QuantizationType.FLOAT)
is_symmetric_weight = weight_quant.symmetric
is_static_weight = not weight_quant.dynamic
is_per_tensor_group_weight = (weight_quant.strategy in [

QuantizationStrategy.TENSOR_GROUP
])
if not (
is_floating_point
and is_symmetric_weight
and is_static_weight
and is_per_tensor_group_weight
):
return False

# Dynamic quantization is always supported if weights supported.
if input_quant.dynamic:
return True

# Confirm activation scheme is supported.
is_symmetric_activation = input_quant.symmetric
is_per_tensor_activation = (
input_quant.strategy == QuantizationStrategy.TENSOR)
return is_symmetric_activation and is_per_tensor_activation

def _is_fp8_w8a8_sm90(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
return (self._check_scheme_supported(90, error=False, match_exact=True)
Expand Down Expand Up @@ -351,7 +386,7 @@ def _is_wNa16_group_channel(self, weight_quant: BaseModel,
def _get_scheme_from_parts(
self, weight_quant: BaseModel,
input_quant: BaseModel) -> "CompressedTensorsScheme":

# breakpoint()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This breakpoint() call appears to be a debugging artifact and must be removed before merging.

# Detect If Mixed Precision
if self._is_fp4a16_nvfp4(weight_quant, input_quant):
return CompressedTensorsW4A16Fp4()
Expand Down Expand Up @@ -385,9 +420,8 @@ def _get_scheme_from_parts(
return CompressedTensorsW4A16Fp4(
has_input_global_scale=True)

if self._is_fp8_w8a8(weight_quant, input_quant):
is_fp8_w8a8_supported = self._check_scheme_supported(
CompressedTensorsW8A8Fp8.get_min_capability(), error=False)
if self._is_fp8_w8a8(weight_quant, input_quant=input_quant):
is_fp8_w8a8_supported = self._check_scheme_supported(CompressedTensorsW8A8Fp8.get_min_capability(), error=False)
if is_fp8_w8a8_supported:
return CompressedTensorsW8A8Fp8(
strategy=weight_quant.strategy,
Expand All @@ -400,6 +434,18 @@ def _get_scheme_from_parts(
strategy=weight_quant.strategy,
is_static_input_scheme=not input_quant.dynamic)

if self._is_mxfp8_w8a8(weight_quant, input_quant):
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsW8A8MXFp8,
)
Comment on lines +438 to +440

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Imports should generally be at the top of the file to improve readability and avoid potential circular import issues. Please move this import to the top-level imports section of the file.


return CompressedTensorsW8A8MXFp8(
strategy=weight_quant.strategy,
is_static_input_scheme=(
input_quant and not input_quant.dynamic
),
)

# note: input_quant can be None
if self._is_fp8_w8a16(weight_quant, input_quant):
is_static_input_scheme = (input_quant
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8
from .compressed_tensors_wNa16 import (WNA16_SUPPORTED_BITS,
CompressedTensorsWNA16)
from .compressed_tensors_w8a8_mxfp8 import CompressedTensorsW8A8MXFp8

from .compressed_tensors_24 import CompressedTensors24 # isort: skip

Expand All @@ -20,5 +21,6 @@
"CompressedTensorsW8A8Int8", "CompressedTensorsW8A8Fp8",
"WNA16_SUPPORTED_BITS", "W4A16SPARSE24_SUPPORTED_BITS",
"CompressedTensors24", "CompressedTensorsW4A16Fp4",
"CompressedTensorsW4A4Fp4"
"CompressedTensorsW4A4Fp4",
"CompressedTensorsW8A8MXFp8"
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from typing import Callable, Optional

import torch
from compressed_tensors.quantization import QuantizationStrategy
from torch.nn import Parameter

from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp,
maybe_create_device_identity,
normalize_e4m3fn_to_e4m3fnuz,
requantize_with_max_scale,
)
from vllm.model_executor.parameter import (
ChannelQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter,
)
from vllm.model_executor.parameter import (
GroupQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter,
)
Comment on lines +19 to +28

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

ModelWeightParameter and PerTensorScaleParameter are imported twice. This can be consolidated into a single import statement. Also, ChannelQuantScaleParameter is imported but not used in this file and can be removed.

Suggested change
from vllm.model_executor.parameter import (
ChannelQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter,
)
from vllm.model_executor.parameter import (
GroupQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter,
)
from vllm.model_executor.parameter import (
GroupQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter,
)

from vllm.platforms import current_platform

__all__ = ["CompressedTensorsW8A8MXFp8"]


def get_fp_scale(scale_e8m0):
# https://github.com/pytorch/ao/blob/994a4ba6c869854fcaa6ca7e118fcbd75e6c28cc/torchao/prototype/mx_formats/mx_tensor.py#L337
from torchao.prototype.mx_formats.constants import (
E8M0_EXPONENT_BIAS,
E8M0_EXPONENT_NAN_VAL,
)

scale_e8m0 = scale_e8m0.view(torch.uint8)
s_offset = scale_e8m0.to(torch.int16) - E8M0_EXPONENT_BIAS
# TODO(later): it would be nice if there was a way to do the 2^x operation
# in PyTorch without creating a tensor of twos
two = torch.full(s_offset.size(), 2.0, device=scale_e8m0.device)
# pow(two, s_offset) can be out of range of floating point formats.
# TODO(later): handle this for float16 if we decide to support float16
# scales.
s_fp = torch.pow(two, s_offset)

return s_fp


def dequant_mx_fp8(weight_fp8, scale_e8m0, block_size):
scale_float = get_fp_scale(scale_e8m0)
weight_bf16 = weight_fp8.to(torch.bfloat16)
weight_original_shape = weight_bf16.shape
weight_bf16 = weight_bf16.reshape(-1, block_size)
scale_float = scale_float.reshape(-1, 1)
dequant_weight = weight_bf16 * scale_float
dequant_weight = dequant_weight.reshape(weight_original_shape)
return dequant_weight

def quant_mx_fp8(tensor):
from torchao.prototype.mx_formats.mx_tensor import (
to_mx,
ScaleCalculationMode,
)
Comment on lines +36 to +68

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Local imports, such as those for torchao on lines 36 and 65, should be moved to the top of the file. This improves code organization, readability, and helps prevent potential issues like circular dependencies.


scale_e8m0_biased, data_lp = to_mx(
data_hp=tensor,
elem_dtype=torch.float8_e4m3fn,
block_size=32,
scaling_mode=ScaleCalculationMode.FLOOR,
pack_fp6=False,
)
return scale_e8m0_biased, data_lp



class CompressedTensorsW8A8MXFp8(CompressedTensorsScheme):

def __init__(self, strategy: str, is_static_input_scheme: bool):
self.strategy = strategy
self.out_dtype = torch.get_default_dtype()
self.is_static_input_scheme = is_static_input_scheme
# self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This line is commented out. If it's no longer needed, please remove it to keep the code clean. If it's part of a future implementation, consider adding a TODO with more context.

self.group_size = 32

@classmethod
def get_min_capability(cls) -> int:
# lovelace and up
# FIXME: (Yi) correct the minimum capability

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This FIXME indicates that the minimum capability might be incorrect. Please verify and update this value to ensure the feature is correctly gated for supported hardware.

return 80

def process_weights_after_loading(self, layer) -> None:
return

def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
# maybe_create_device_identity()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This commented-out function call should be removed if it's not intended to be used.


output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes

# WEIGHT
weight = ModelWeightParameter(data=torch.empty(
output_size_per_partition,
input_size_per_partition,
dtype=torch.float8_e4m3fn),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
layer.register_parameter("weight", weight)

# WEIGHT SCALE
# TODO: update create_xxx_parameter functions to return
# the newly added parameters
if self.strategy == QuantizationStrategy.TENSOR_GROUP:
# Per Group Weight Scale
weight_scale = GroupQuantScaleParameter(
data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition // self.group_size,
dtype=torch.uint8, # E8M0 for MXFP8 scale
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
else:
raise NotImplementedError(f"Strategy {self.strategy} is not supported for W8A8-MXFp8")

# min requirement for fp8 kernels
# weight_scale[:] = torch.finfo(torch.float32).min
# weight_scale.fill_(torch.finfo(torch.float32).min)
Comment on lines +138 to +139

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

These commented-out lines should be removed to improve code clarity.

layer.register_parameter("weight_scale", weight_scale)

# INPUT SCALE
if self.is_static_input_scheme:
input_scale = PerTensorScaleParameter(data=torch.empty(
len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
input_scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("input_scale", input_scale)

def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:

# dequant weight
weight = layer.weight
weight_scale = layer.weight_scale
dequnat_weight = dequant_mx_fp8(
weight_fp8=weight.data,
scale_e8m0=weight_scale.data,
block_size=self.group_size
)
dequnat_weight = dequnat_weight.to(x.dtype)
# q-dq input
x_scale, x_quant = quant_mx_fp8(x)
dequant_x = dequant_mx_fp8(
weight_fp8=x_quant,
scale_e8m0=x_scale,
block_size=self.group_size
)
x = dequant_x.to(x.dtype)
out = x @ dequnat_weight.t()
Comment on lines +158 to +172

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There's a recurring typo dequnat_weight which should be dequant_weight.

Suggested change
dequnat_weight = dequant_mx_fp8(
weight_fp8=weight.data,
scale_e8m0=weight_scale.data,
block_size=self.group_size
)
dequnat_weight = dequnat_weight.to(x.dtype)
# q-dq input
x_scale, x_quant = quant_mx_fp8(x)
dequant_x = dequant_mx_fp8(
weight_fp8=x_quant,
scale_e8m0=x_scale,
block_size=self.group_size
)
x = dequant_x.to(x.dtype)
out = x @ dequnat_weight.t()
dequant_weight = dequant_mx_fp8(
weight_fp8=weight.data,
scale_e8m0=weight_scale.data,
block_size=self.group_size
)
dequant_weight = dequant_weight.to(x.dtype)
# q-dq input
x_scale, x_quant = quant_mx_fp8(x)
dequant_x = dequant_mx_fp8(
weight_fp8=x_quant,
scale_e8m0=x_scale,
block_size=self.group_size
)
x = dequant_x.to(x.dtype)
out = x @ dequant_weight.t()

return out.to(x.dtype) + (bias if bias is not None else 0)

return self.fp8_linear.apply(input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
out_dtype=self.out_dtype,
input_scale=layer.input_scale,
bias=bias)